From f29da57ef62480c777eecc75fcb04b5f6e321296 Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Tue, 16 Jan 2024 10:19:04 +1100 Subject: [PATCH] io: Add functions to read to the end of a reader The `std::io::Read` trait includes `read_to_end` but that method provides a denial of service attack vector since an unbounded reader will exhaust all system memory. Add a method to our `Read` trait called `read_to_limit` that does the same as `std::io::Read::read_to_end` but with memory exhaustion protection. Add a `read_to_end` method on our `Take` trait and call through to it from the new method on our `Read` trait called `read_to_limit`. --- io/src/lib.rs | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/io/src/lib.rs b/io/src/lib.rs index 6d0c7ba0..a338b52e 100644 --- a/io/src/lib.rs +++ b/io/src/lib.rs @@ -18,6 +18,8 @@ extern crate alloc; mod error; mod macros; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; use core::cmp; use core::convert::TryInto; @@ -45,6 +47,19 @@ pub trait Read { #[inline] fn take(&mut self, limit: u64) -> Take { Take { reader: self, remaining: limit } } + + /// Attempts to read up to limit bytes from the reader, allocating space in `buf` as needed. + /// + /// `limit` is used to prevent a denial of service attack vector since an unbounded reader will + /// exhaust all memory. + /// + /// Similar to `std::io::Read::read_to_end` but with the DOS protection. + #[doc(alias = "read_to_end")] + #[cfg(any(feature = "alloc", feature = "std"))] + #[inline] + fn read_to_limit(&mut self, buf: &mut Vec, limit: u64) -> Result { + self.take(limit).read_to_end(buf) + } } /// A trait describing an input stream that uses an internal buffer when reading. @@ -65,6 +80,27 @@ pub struct Take<'a, R: Read + ?Sized> { remaining: u64, } +impl<'a, R: Read + ?Sized> Take<'a, R> { + #[cfg(any(feature = "alloc", feature = "std"))] + #[inline] + fn read_to_end(&mut self, buf: &mut Vec) -> Result { + let mut read: usize = 0; + let mut chunk = [0u8; 64]; + loop { + match self.read(&mut chunk) { + Ok(0) => break, + Ok(n) => { + buf.extend_from_slice(&chunk[0..n]); + read += n; + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} + Err(e) => return Err(e), + }; + } + Ok(read) + } +} + impl<'a, R: Read + ?Sized> Read for Take<'a, R> { #[inline] fn read(&mut self, buf: &mut [u8]) -> Result { @@ -293,4 +329,29 @@ mod tests { assert_eq!(fill.len(), 0); assert_eq!(fill, &[]); } + + #[test] + #[cfg(any(feature = "alloc", feature = "std"))] + fn read_to_limit_greater_than_total_length() { + let s = "16-byte-string!!".to_string(); + let mut reader = Cursor::new(&s); + let mut buf = vec![]; + + // 32 is greater than the reader length. + let read = reader.read_to_limit(&mut buf, 32).expect("failed to read to limit"); + assert_eq!(read, s.len()); + assert_eq!(&buf, s.as_bytes()) + } + + #[test] + #[cfg(any(feature = "alloc", feature = "std"))] + fn read_to_limit_less_than_total_length() { + let s = "16-byte-string!!".to_string(); + let mut reader = Cursor::new(&s); + let mut buf = vec![]; + + let read = reader.read_to_limit(&mut buf, 2).expect("failed to read to limit"); + assert_eq!(read, 2); + assert_eq!(&buf, "16".as_bytes()) + } }