diff --git a/keyfork-frame/src/asyncext.rs b/keyfork-frame/src/asyncext.rs index 9800818..365b0cf 100644 --- a/keyfork-frame/src/asyncext.rs +++ b/keyfork-frame/src/asyncext.rs @@ -15,7 +15,8 @@ pub async fn try_decode_from( ) -> Result, DecodeError> { let len = readable.read_u32().await?; - let mut data = Vec::with_capacity(len as usize); + // Note: Pre-filling the vec is *required* as read_exact uses len, not capacity. + let mut data = vec![0u8; len as usize]; readable.read_exact(&mut data[..]).await?; let content = verify_checksum(&data[..])?; diff --git a/keyfork-frame/src/lib.rs b/keyfork-frame/src/lib.rs index 476db18..0f4d58a 100644 --- a/keyfork-frame/src/lib.rs +++ b/keyfork-frame/src/lib.rs @@ -12,6 +12,8 @@ //! | checksum: [u8; 32] sha256 hash of `raw_data` | raw_data: &[u8] | //! ``` +use std::io::{Read, Write}; + #[cfg(feature = "async")] pub mod asyncext; @@ -51,8 +53,6 @@ pub enum EncodeError { Io(#[from] std::io::Error), } -const LEN_SIZE: usize = std::mem::size_of::(); - pub(crate) fn hash(data: &[u8]) -> Vec { let mut hashobj = Sha256::new(); hashobj.update(data); @@ -65,14 +65,19 @@ pub(crate) fn hash(data: &[u8]) -> Vec { /// An error may be returned if the given `data` is more than [`u32::MAX`] bytes. This is a /// constraint on a protocol level. pub fn try_encode(data: &[u8]) -> Result, EncodeError> { + let mut output = vec![]; + try_encode_to(data, &mut output)?; + Ok(output) +} + +pub fn try_encode_to(data: &[u8], writable: &mut impl Write) -> Result<(), EncodeError> { let hash = hash(data); - let content = hash.iter().chain(data.iter()).copied().collect::>(); - let mut result = (u32::try_from(content.len()) - .map_err(|_| EncodeError::InputTooLarge(content.len()))?) - .to_be_bytes() - .to_vec(); - result.extend(content); - Ok(result) + let len = hash.len() + data.len(); + let len = u32::try_from(len).map_err(|_| EncodeError::InputTooLarge(len))?; + writable.write_all(&len.to_be_bytes())?; + writable.write_all(&hash)?; + writable.write_all(data)?; + Ok(()) } pub(crate) fn verify_checksum(data: &[u8]) -> Result<&[u8], DecodeError> { @@ -99,18 +104,16 @@ pub(crate) fn verify_checksum(data: &[u8]) -> Result<&[u8], DecodeError> { /// * The given `data` does not contain the given length's worth of data, /// * The given `data` has a checksum that does not match what we build locally. pub fn try_decode(data: &[u8]) -> Result, DecodeError> { - // check length and advance data pointer beyond length check - let len_bytes: [u8; LEN_SIZE] = data[..LEN_SIZE] - .try_into() - .map_err(DecodeError::InvalidLength)?; - let len = u32::from_be_bytes(len_bytes); - if len as usize + LEN_SIZE > data.len() { - return Err(DecodeError::IncorrectLength(data.len() - LEN_SIZE, len)); - } - let data = &data[LEN_SIZE..]; - - let content = verify_checksum(data)?; + try_decode_from(&mut &data[..]) +} +pub fn try_decode_from(readable: &mut impl Read) -> Result, DecodeError> { + let mut bytes = 0u32.to_be_bytes(); + readable.read_exact(&mut bytes)?; + let len = u32::from_be_bytes(bytes); + let mut data = vec![0u8; len as usize]; + readable.read_exact(&mut data)?; + let content = verify_checksum(&data)?; Ok(content.to_vec()) }