io: Add functions to read to the end of a reader

The `std::io::Read` trait includes `read_to_end` but that method
provides a denial of service attack vector since an unbounded reader
will exhaust all system memory.

Add a method to our `Read` trait called `read_to_limit` that does the
same as `std::io::Read::read_to_end` but with memory exhaustion
protection.

Add a `read_to_end` method on our `Take` trait and call through to it
from the new method on our `Read` trait called `read_to_limit`.
This commit is contained in:
Tobin C. Harding 2024-01-16 10:19:04 +11:00
parent 2073a40c50
commit f29da57ef6
No known key found for this signature in database
GPG Key ID: 40BF9E4C269D6607
1 changed files with 61 additions and 0 deletions

View File

@ -18,6 +18,8 @@ extern crate alloc;
mod error; mod error;
mod macros; mod macros;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use core::cmp; use core::cmp;
use core::convert::TryInto; use core::convert::TryInto;
@ -45,6 +47,19 @@ pub trait Read {
#[inline] #[inline]
fn take(&mut self, limit: u64) -> Take<Self> { Take { reader: self, remaining: limit } } fn take(&mut self, limit: u64) -> Take<Self> { Take { reader: self, remaining: limit } }
/// Attempts to read up to limit bytes from the reader, allocating space in `buf` as needed.
///
/// `limit` is used to prevent a denial of service attack vector since an unbounded reader will
/// exhaust all memory.
///
/// Similar to `std::io::Read::read_to_end` but with the DOS protection.
#[doc(alias = "read_to_end")]
#[cfg(any(feature = "alloc", feature = "std"))]
#[inline]
fn read_to_limit(&mut self, buf: &mut Vec<u8>, limit: u64) -> Result<usize> {
self.take(limit).read_to_end(buf)
}
} }
/// A trait describing an input stream that uses an internal buffer when reading. /// A trait describing an input stream that uses an internal buffer when reading.
@ -65,6 +80,27 @@ pub struct Take<'a, R: Read + ?Sized> {
remaining: u64, remaining: u64,
} }
impl<'a, R: Read + ?Sized> Take<'a, R> {
#[cfg(any(feature = "alloc", feature = "std"))]
#[inline]
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> Result<usize> {
let mut read: usize = 0;
let mut chunk = [0u8; 64];
loop {
match self.read(&mut chunk) {
Ok(0) => break,
Ok(n) => {
buf.extend_from_slice(&chunk[0..n]);
read += n;
}
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => return Err(e),
};
}
Ok(read)
}
}
impl<'a, R: Read + ?Sized> Read for Take<'a, R> { impl<'a, R: Read + ?Sized> Read for Take<'a, R> {
#[inline] #[inline]
fn read(&mut self, buf: &mut [u8]) -> Result<usize> { fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
@ -293,4 +329,29 @@ mod tests {
assert_eq!(fill.len(), 0); assert_eq!(fill.len(), 0);
assert_eq!(fill, &[]); assert_eq!(fill, &[]);
} }
#[test]
#[cfg(any(feature = "alloc", feature = "std"))]
fn read_to_limit_greater_than_total_length() {
let s = "16-byte-string!!".to_string();
let mut reader = Cursor::new(&s);
let mut buf = vec![];
// 32 is greater than the reader length.
let read = reader.read_to_limit(&mut buf, 32).expect("failed to read to limit");
assert_eq!(read, s.len());
assert_eq!(&buf, s.as_bytes())
}
#[test]
#[cfg(any(feature = "alloc", feature = "std"))]
fn read_to_limit_less_than_total_length() {
let s = "16-byte-string!!".to_string();
let mut reader = Cursor::new(&s);
let mut buf = vec![];
let read = reader.read_to_limit(&mut buf, 2).expect("failed to read to limit");
assert_eq!(read, 2);
assert_eq!(&buf, "16".as_bytes())
}
} }