bzip2/
write.rs

1//! Writer-based compression/decompression streams
2
3use std::io;
4use std::io::prelude::*;
5
6use crate::{Action, Compress, Compression, Decompress, Status};
7
8/// A compression stream which will have uncompressed data written to it and
9/// will write compressed data to an output stream.
10pub struct BzEncoder<W: Write> {
11    data: Compress,
12    obj: Option<W>,
13    buf: Vec<u8>,
14    done: bool,
15    panicked: bool,
16}
17
18/// A compression stream which will have compressed data written to it and
19/// will write uncompressed data to an output stream.
20pub struct BzDecoder<W: Write> {
21    data: Decompress,
22    obj: Option<W>,
23    buf: Vec<u8>,
24    done: bool,
25    panicked: bool,
26}
27
28impl<W: Write> BzEncoder<W> {
29    /// Create a new compression stream which will compress at the given level
30    /// to write compress output to the give output stream.
31    pub fn new(obj: W, level: Compression) -> BzEncoder<W> {
32        BzEncoder {
33            data: Compress::new(level, 30),
34            obj: Some(obj),
35            buf: Vec::with_capacity(32 * 1024),
36            done: false,
37            panicked: false,
38        }
39    }
40
41    fn dump(&mut self) -> io::Result<()> {
42        while !self.buf.is_empty() {
43            self.panicked = true;
44            let r = self.obj.as_mut().unwrap().write(&self.buf);
45            self.panicked = false;
46
47            match r {
48                Ok(n) => self.buf.drain(..n),
49                Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
50                Err(err) => return Err(err),
51            };
52        }
53        Ok(())
54    }
55
56    /// Acquires a reference to the underlying writer.
57    pub fn get_ref(&self) -> &W {
58        self.obj.as_ref().unwrap()
59    }
60
61    /// Acquires a mutable reference to the underlying writer.
62    ///
63    /// Note that mutating the output/input state of the stream may corrupt this
64    /// object, so care must be taken when using this method.
65    pub fn get_mut(&mut self) -> &mut W {
66        self.obj.as_mut().unwrap()
67    }
68
69    /// Attempt to finish this output stream, writing out final chunks of data.
70    ///
71    /// Note that this function can only be used once data has finished being
72    /// written to the output stream. After this function is called then further
73    /// calls to [`write`] may result in a panic.
74    ///
75    /// # Panics
76    ///
77    /// Attempts to write data to this stream may result in a panic after this
78    /// function is called.
79    ///
80    /// [`write`]: Self::write
81    pub fn try_finish(&mut self) -> io::Result<()> {
82        while !self.done {
83            self.dump()?;
84            let res = self.data.compress_vec(&[], &mut self.buf, Action::Finish);
85            if res == Ok(Status::StreamEnd) {
86                self.done = true;
87                break;
88            }
89        }
90        self.dump()
91    }
92
93    /// Consumes this encoder, flushing the output stream.
94    ///
95    /// This will flush the underlying data stream and then return the contained
96    /// writer if the flush succeeded.
97    ///
98    /// Note that this function may not be suitable to call in a situation where
99    /// the underlying stream is an asynchronous I/O stream. To finish a stream
100    /// the [`try_finish`] (or `shutdown`) method should be used instead. To
101    /// re-acquire ownership of a stream it is safe to call this method after
102    /// [`try_finish`] or `shutdown` has returned `Ok`.
103    ///
104    /// [`try_finish`]: Self::try_finish
105    pub fn finish(mut self) -> io::Result<W> {
106        self.try_finish()?;
107        Ok(self.obj.take().unwrap())
108    }
109
110    /// Returns the number of bytes produced by the compressor
111    ///
112    /// Note that, due to buffering, this only bears any relation to
113    /// [`total_in`] after a call to [`flush`].  At that point,
114    /// `total_out() / total_in()` is the compression ratio.
115    ///
116    /// [`flush`]: Self::flush
117    /// [`total_in`]: Self::total_in
118    pub fn total_out(&self) -> u64 {
119        self.data.total_out()
120    }
121
122    /// Returns the number of bytes consumed by the compressor
123    /// (e.g. the number of bytes written to this stream.)
124    pub fn total_in(&self) -> u64 {
125        self.data.total_in()
126    }
127}
128
129impl<W: Write> Write for BzEncoder<W> {
130    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
131        loop {
132            self.dump()?;
133
134            let total_in = self.total_in();
135            self.data
136                .compress_vec(data, &mut self.buf, Action::Run)
137                .unwrap();
138            let written = (self.total_in() - total_in) as usize;
139
140            if written > 0 || data.is_empty() {
141                return Ok(written);
142            }
143        }
144    }
145
146    fn flush(&mut self) -> io::Result<()> {
147        loop {
148            self.dump()?;
149            let before = self.total_out();
150            self.data
151                .compress_vec(&[], &mut self.buf, Action::Flush)
152                .unwrap();
153
154            if before == self.total_out() {
155                break;
156            }
157        }
158        self.obj.as_mut().unwrap().flush()
159    }
160}
161
162impl<W: Write> BzDecoder<W> {
163    /// Create a new decoding stream which will decompress all data written
164    /// to it into `obj`.
165    pub fn new(obj: W) -> BzDecoder<W> {
166        BzDecoder {
167            data: Decompress::new(false),
168            obj: Some(obj),
169            buf: Vec::with_capacity(32 * 1024),
170            done: false,
171            panicked: false,
172        }
173    }
174
175    /// Acquires a reference to the underlying writer.
176    pub fn get_ref(&self) -> &W {
177        self.obj.as_ref().unwrap()
178    }
179
180    /// Acquires a mutable reference to the underlying writer.
181    ///
182    /// Note that mutating the output/input state of the stream may corrupt this
183    /// object, so care must be taken when using this method.
184    pub fn get_mut(&mut self) -> &mut W {
185        self.obj.as_mut().unwrap()
186    }
187
188    fn dump(&mut self) -> io::Result<()> {
189        while !self.buf.is_empty() {
190            self.panicked = true;
191            let r = self.obj.as_mut().unwrap().write(&self.buf);
192            self.panicked = false;
193
194            match r {
195                Ok(n) => self.buf.drain(..n),
196                Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
197                Err(err) => return Err(err),
198            };
199        }
200        Ok(())
201    }
202
203    /// Attempt to finish this output stream, writing out final chunks of data.
204    ///
205    /// Note that this function can only be used once data has finished being
206    /// written to the output stream. After this function is called then further
207    /// calls to [`write`] may result in a panic.
208    ///
209    /// # Panics
210    ///
211    /// Attempts to write data to this stream may result in a panic after this
212    /// function is called.
213    ///
214    /// [`write`]: Self::write
215    pub fn try_finish(&mut self) -> io::Result<()> {
216        while !self.done {
217            // The write is effectively a `self.flush()`, but we want to know how many
218            // bytes were written. exit if no input was read and no output was written
219            if self.write(&[])? == 0 {
220                // finishing the output stream is effectively EOF of the input
221                let msg = "Input EOF reached before logical stream end";
222                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, msg));
223            }
224        }
225        self.dump()
226    }
227
228    /// Unwrap the underlying writer, finishing the compression stream.
229    ///
230    /// Note that this function may not be suitable to call in a situation where
231    /// the underlying stream is an asynchronous I/O stream. To finish a stream
232    /// the [`try_finish`] (or `shutdown`) method should be used instead. To
233    /// re-acquire ownership of a stream it is safe to call this method after
234    /// [`try_finish`] or `shutdown` has returned `Ok`.
235    ///
236    /// [`try_finish`]: Self::try_finish
237    pub fn finish(&mut self) -> io::Result<W> {
238        self.try_finish()?;
239        Ok(self.obj.take().unwrap())
240    }
241
242    /// Returns the number of bytes produced by the decompressor
243    ///
244    /// Note that, due to buffering, this only bears any relation to
245    /// [`total_in`] after a call to [`flush`].  At that point,
246    /// `total_in() / total_out()` is the compression ratio.
247    ///
248    /// [`flush`]: Self::flush
249    /// [`total_in`]: Self::total_in
250    pub fn total_out(&self) -> u64 {
251        self.data.total_out()
252    }
253
254    /// Returns the number of bytes consumed by the decompressor
255    /// (e.g. the number of bytes written to this stream.)
256    pub fn total_in(&self) -> u64 {
257        self.data.total_in()
258    }
259}
260
261impl<W: Write> Write for BzDecoder<W> {
262    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
263        if self.done {
264            return Ok(0);
265        }
266        loop {
267            self.dump()?;
268
269            let before = self.total_in();
270            let res = self.data.decompress_vec(data, &mut self.buf);
271            let written = (self.total_in() - before) as usize;
272
273            // make sure that a subsequent call exits early when there is nothing useful left to do
274            self.done |= matches!(res, Err(_) | Ok(Status::StreamEnd));
275
276            if let Err(e) = res {
277                return Err(io::Error::new(io::ErrorKind::InvalidInput, e));
278            }
279
280            if written > 0 || data.is_empty() || self.done {
281                return Ok(written);
282            }
283        }
284    }
285
286    fn flush(&mut self) -> io::Result<()> {
287        self.dump()?;
288        self.obj.as_mut().unwrap().flush()
289    }
290}
291
292impl<W: Write> Drop for BzDecoder<W> {
293    fn drop(&mut self) {
294        if self.obj.is_some() {
295            let _ = self.try_finish();
296        }
297    }
298}
299
300impl<W: Write> Drop for BzEncoder<W> {
301    fn drop(&mut self) {
302        if self.obj.is_some() && !self.panicked {
303            let _ = self.try_finish();
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::{BzDecoder, BzEncoder};
311    use crate::Compression;
312    use partial_io::quickcheck_types::{GenInterrupted, PartialWithErrors};
313    use partial_io::PartialWrite;
314    use std::io::prelude::*;
315
316    #[test]
317    fn smoke() {
318        let d = BzDecoder::new(Vec::new());
319        let mut c = BzEncoder::new(d, Compression::default());
320        c.write_all(b"12834").unwrap();
321        let s = "12345".repeat(100000);
322        c.write_all(s.as_bytes()).unwrap();
323        let data = c.finish().unwrap().finish().unwrap();
324        assert_eq!(&data[0..5], b"12834");
325        assert_eq!(data.len(), 500005);
326        assert!(format!("12834{}", s).as_bytes() == &*data);
327    }
328
329    #[test]
330    fn roundtrip_empty() {
331        // this encodes and then decodes an empty input file
332        let d = BzDecoder::new(Vec::new());
333        let mut c = BzEncoder::new(d, Compression::default());
334        let _ = c.write(b"").unwrap();
335        let data = c.finish().unwrap().finish().unwrap();
336        assert_eq!(&data[..], b"");
337    }
338
339    #[test]
340    fn finish_empty_explicit() {
341        // The empty sequence is not a valid .bzip2 file!
342        // A valid file at least includes the magic bytes, the checksum, etc.
343        //
344        // This used to loop infinitely, see
345        //
346        // - https://github.com/trifectatechfoundation/bzip2-rs/issues/96
347        // - https://github.com/trifectatechfoundation/bzip2-rs/pull/97
348        let mut d = BzDecoder::new(Vec::new());
349        d.write(b"").unwrap();
350        let e = d.finish().unwrap_err();
351        assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
352    }
353
354    #[test]
355    fn finish_empty_drop() {
356        // the drop implementation used to loop infinitely for empty input
357        //
358        // see https://github.com/trifectatechfoundation/bzip2-rs/pull/118
359        let d = BzDecoder::new(Vec::new());
360        drop(d);
361    }
362
363    #[test]
364    fn write_invalid() {
365        // see https://github.com/trifectatechfoundation/bzip2-rs/issues/98
366        let mut d = BzDecoder::new(Vec::new());
367        let e = d.write(b"BZh\xfb").unwrap_err();
368        assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput);
369    }
370
371    #[test]
372    fn qc() {
373        ::quickcheck::quickcheck(test as fn(_) -> _);
374
375        fn test(v: Vec<u8>) -> bool {
376            let w = BzDecoder::new(Vec::new());
377            let mut w = BzEncoder::new(w, Compression::default());
378            w.write_all(&v).unwrap();
379            v == w.finish().unwrap().finish().unwrap()
380        }
381    }
382
383    #[test]
384    fn qc_partial() {
385        quickcheck::quickcheck(test as fn(_, _, _) -> _);
386
387        fn test(
388            v: Vec<u8>,
389            encode_ops: PartialWithErrors<GenInterrupted>,
390            decode_ops: PartialWithErrors<GenInterrupted>,
391        ) -> bool {
392            let w = BzDecoder::new(PartialWrite::new(Vec::new(), decode_ops));
393            let mut w = BzEncoder::new(PartialWrite::new(w, encode_ops), Compression::default());
394            w.write_all(&v).unwrap();
395            v == w
396                .finish()
397                .unwrap()
398                .into_inner()
399                .finish()
400                .unwrap()
401                .into_inner()
402        }
403    }
404
405    #[test]
406    fn terminate_on_drop() {
407        // Test that dropping the BzEncoder flushes bytes to the output, so that
408        // we get a valid, decompressable datastream
409        //
410        // see https://github.com/trifectatechfoundation/bzip2-rs/pull/121
411        let s = "12345".repeat(100);
412
413        let mut compressed = Vec::new();
414        {
415            let mut c: Box<dyn std::io::Write> =
416                Box::new(BzEncoder::new(&mut compressed, Compression::default()));
417            c.write_all(b"12834").unwrap();
418            c.write_all(s.as_bytes()).unwrap();
419            c.flush().unwrap();
420        }
421        assert!(!compressed.is_empty());
422
423        let uncompressed = {
424            let mut d = BzDecoder::new(Vec::new());
425            d.write_all(&compressed).unwrap();
426            d.finish().unwrap()
427        };
428        assert_eq!(&uncompressed[0..5], b"12834");
429        assert_eq!(uncompressed.len(), s.len() + "12834".len());
430        assert!(format!("12834{}", s).as_bytes() == &*uncompressed);
431    }
432}