1use std::io;
4use std::io::prelude::*;
5
6use crate::{Action, Compress, Compression, Decompress, Status};
7
8pub struct BzEncoder<W: Write> {
11 data: Compress,
12 obj: Option<W>,
13 buf: Vec<u8>,
14 done: bool,
15 panicked: bool,
16}
17
18pub struct BzDecoder<W: Write> {
21 data: Decompress,
22 obj: Option<W>,
23 buf: Vec<u8>,
24 done: bool,
25 panicked: bool,
26}
27
28impl<W: Write> BzEncoder<W> {
29 pub fn new(obj: W, level: Compression) -> BzEncoder<W> {
32 BzEncoder {
33 data: Compress::new(level, 30),
34 obj: Some(obj),
35 buf: Vec::with_capacity(32 * 1024),
36 done: false,
37 panicked: false,
38 }
39 }
40
41 fn dump(&mut self) -> io::Result<()> {
42 while !self.buf.is_empty() {
43 self.panicked = true;
44 let r = self.obj.as_mut().unwrap().write(&self.buf);
45 self.panicked = false;
46
47 match r {
48 Ok(n) => self.buf.drain(..n),
49 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
50 Err(err) => return Err(err),
51 };
52 }
53 Ok(())
54 }
55
56 pub fn get_ref(&self) -> &W {
58 self.obj.as_ref().unwrap()
59 }
60
61 pub fn get_mut(&mut self) -> &mut W {
66 self.obj.as_mut().unwrap()
67 }
68
69 pub fn try_finish(&mut self) -> io::Result<()> {
82 while !self.done {
83 self.dump()?;
84 let res = self.data.compress_vec(&[], &mut self.buf, Action::Finish);
85 if res == Ok(Status::StreamEnd) {
86 self.done = true;
87 break;
88 }
89 }
90 self.dump()
91 }
92
93 pub fn finish(mut self) -> io::Result<W> {
106 self.try_finish()?;
107 Ok(self.obj.take().unwrap())
108 }
109
110 pub fn total_out(&self) -> u64 {
119 self.data.total_out()
120 }
121
122 pub fn total_in(&self) -> u64 {
125 self.data.total_in()
126 }
127}
128
129impl<W: Write> Write for BzEncoder<W> {
130 fn write(&mut self, data: &[u8]) -> io::Result<usize> {
131 loop {
132 self.dump()?;
133
134 let total_in = self.total_in();
135 self.data
136 .compress_vec(data, &mut self.buf, Action::Run)
137 .unwrap();
138 let written = (self.total_in() - total_in) as usize;
139
140 if written > 0 || data.is_empty() {
141 return Ok(written);
142 }
143 }
144 }
145
146 fn flush(&mut self) -> io::Result<()> {
147 loop {
148 self.dump()?;
149 let before = self.total_out();
150 self.data
151 .compress_vec(&[], &mut self.buf, Action::Flush)
152 .unwrap();
153
154 if before == self.total_out() {
155 break;
156 }
157 }
158 self.obj.as_mut().unwrap().flush()
159 }
160}
161
162impl<W: Write> BzDecoder<W> {
163 pub fn new(obj: W) -> BzDecoder<W> {
166 BzDecoder {
167 data: Decompress::new(false),
168 obj: Some(obj),
169 buf: Vec::with_capacity(32 * 1024),
170 done: false,
171 panicked: false,
172 }
173 }
174
175 pub fn get_ref(&self) -> &W {
177 self.obj.as_ref().unwrap()
178 }
179
180 pub fn get_mut(&mut self) -> &mut W {
185 self.obj.as_mut().unwrap()
186 }
187
188 fn dump(&mut self) -> io::Result<()> {
189 while !self.buf.is_empty() {
190 self.panicked = true;
191 let r = self.obj.as_mut().unwrap().write(&self.buf);
192 self.panicked = false;
193
194 match r {
195 Ok(n) => self.buf.drain(..n),
196 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
197 Err(err) => return Err(err),
198 };
199 }
200 Ok(())
201 }
202
203 pub fn try_finish(&mut self) -> io::Result<()> {
216 while !self.done {
217 if self.write(&[])? == 0 {
220 let msg = "Input EOF reached before logical stream end";
222 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, msg));
223 }
224 }
225 self.dump()
226 }
227
228 pub fn finish(&mut self) -> io::Result<W> {
238 self.try_finish()?;
239 Ok(self.obj.take().unwrap())
240 }
241
242 pub fn total_out(&self) -> u64 {
251 self.data.total_out()
252 }
253
254 pub fn total_in(&self) -> u64 {
257 self.data.total_in()
258 }
259}
260
261impl<W: Write> Write for BzDecoder<W> {
262 fn write(&mut self, data: &[u8]) -> io::Result<usize> {
263 if self.done {
264 return Ok(0);
265 }
266 loop {
267 self.dump()?;
268
269 let before = self.total_in();
270 let res = self.data.decompress_vec(data, &mut self.buf);
271 let written = (self.total_in() - before) as usize;
272
273 self.done |= matches!(res, Err(_) | Ok(Status::StreamEnd));
275
276 if let Err(e) = res {
277 return Err(io::Error::new(io::ErrorKind::InvalidInput, e));
278 }
279
280 if written > 0 || data.is_empty() || self.done {
281 return Ok(written);
282 }
283 }
284 }
285
286 fn flush(&mut self) -> io::Result<()> {
287 self.dump()?;
288 self.obj.as_mut().unwrap().flush()
289 }
290}
291
292impl<W: Write> Drop for BzDecoder<W> {
293 fn drop(&mut self) {
294 if self.obj.is_some() {
295 let _ = self.try_finish();
296 }
297 }
298}
299
300impl<W: Write> Drop for BzEncoder<W> {
301 fn drop(&mut self) {
302 if self.obj.is_some() && !self.panicked {
303 let _ = self.try_finish();
304 }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::{BzDecoder, BzEncoder};
311 use crate::Compression;
312 use partial_io::quickcheck_types::{GenInterrupted, PartialWithErrors};
313 use partial_io::PartialWrite;
314 use std::io::prelude::*;
315
316 #[test]
317 fn smoke() {
318 let d = BzDecoder::new(Vec::new());
319 let mut c = BzEncoder::new(d, Compression::default());
320 c.write_all(b"12834").unwrap();
321 let s = "12345".repeat(100000);
322 c.write_all(s.as_bytes()).unwrap();
323 let data = c.finish().unwrap().finish().unwrap();
324 assert_eq!(&data[0..5], b"12834");
325 assert_eq!(data.len(), 500005);
326 assert!(format!("12834{}", s).as_bytes() == &*data);
327 }
328
329 #[test]
330 fn roundtrip_empty() {
331 let d = BzDecoder::new(Vec::new());
333 let mut c = BzEncoder::new(d, Compression::default());
334 let _ = c.write(b"").unwrap();
335 let data = c.finish().unwrap().finish().unwrap();
336 assert_eq!(&data[..], b"");
337 }
338
339 #[test]
340 fn finish_empty_explicit() {
341 let mut d = BzDecoder::new(Vec::new());
349 d.write(b"").unwrap();
350 let e = d.finish().unwrap_err();
351 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
352 }
353
354 #[test]
355 fn finish_empty_drop() {
356 let d = BzDecoder::new(Vec::new());
360 drop(d);
361 }
362
363 #[test]
364 fn write_invalid() {
365 let mut d = BzDecoder::new(Vec::new());
367 let e = d.write(b"BZh\xfb").unwrap_err();
368 assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput);
369 }
370
371 #[test]
372 fn qc() {
373 ::quickcheck::quickcheck(test as fn(_) -> _);
374
375 fn test(v: Vec<u8>) -> bool {
376 let w = BzDecoder::new(Vec::new());
377 let mut w = BzEncoder::new(w, Compression::default());
378 w.write_all(&v).unwrap();
379 v == w.finish().unwrap().finish().unwrap()
380 }
381 }
382
383 #[test]
384 fn qc_partial() {
385 quickcheck::quickcheck(test as fn(_, _, _) -> _);
386
387 fn test(
388 v: Vec<u8>,
389 encode_ops: PartialWithErrors<GenInterrupted>,
390 decode_ops: PartialWithErrors<GenInterrupted>,
391 ) -> bool {
392 let w = BzDecoder::new(PartialWrite::new(Vec::new(), decode_ops));
393 let mut w = BzEncoder::new(PartialWrite::new(w, encode_ops), Compression::default());
394 w.write_all(&v).unwrap();
395 v == w
396 .finish()
397 .unwrap()
398 .into_inner()
399 .finish()
400 .unwrap()
401 .into_inner()
402 }
403 }
404
405 #[test]
406 fn terminate_on_drop() {
407 let s = "12345".repeat(100);
412
413 let mut compressed = Vec::new();
414 {
415 let mut c: Box<dyn std::io::Write> =
416 Box::new(BzEncoder::new(&mut compressed, Compression::default()));
417 c.write_all(b"12834").unwrap();
418 c.write_all(s.as_bytes()).unwrap();
419 c.flush().unwrap();
420 }
421 assert!(!compressed.is_empty());
422
423 let uncompressed = {
424 let mut d = BzDecoder::new(Vec::new());
425 d.write_all(&compressed).unwrap();
426 d.finish().unwrap()
427 };
428 assert_eq!(&uncompressed[0..5], b"12834");
429 assert_eq!(uncompressed.len(), s.len() + "12834".len());
430 assert!(format!("12834{}", s).as_bytes() == &*uncompressed);
431 }
432}