//! Utility functions to quickly encode and decode `&[u8]` to and from framed messages. //! //! Framed messages consist of the following items: //! //! ```txt //! | len: u32 of data.len() | data: binary data | //! ``` //! //! The data stored after the length consists of the following items: //! //! ```txt //! | checksum: [u8; 32] sha256 hash of `raw_data` | raw_data: &[u8] | //! ``` use sha2::{Digest, Sha256}; #[derive(Debug, Clone, thiserror::Error)] pub enum DecodeError { /// There were not enough bytes to determine the length of the data slice. #[error("Invalid length: {0}")] InvalidLength(std::array::TryFromSliceError), /// There were not enough bytes to read a checksum of the data slice. #[error("Invalid checksum: {0} bytes")] InvalidChecksum(std::array::TryFromSliceError), /// There were not enough bytes to read the rest of the data. #[error("Incorrect length of internal data: {0}, expected at least: {1}")] IncorrectLength(usize, u32), /// The provided checksum of the data did not match the locally-generated checksum. #[error("Checksum did not match: Their {0} != Our {1}")] BadChecksum(String, String), } #[derive(Debug, Clone, thiserror::Error)] pub enum EncodeError { /// The given input was larger than could be encoded by this protocol. #[error("Input too large to encode: {0}")] InputTooLarge(usize), } const LEN_SIZE: usize = std::mem::size_of::(); fn hash(data: &[u8]) -> Vec { let mut hashobj = Sha256::new(); hashobj.update(data); hashobj.finalize().to_vec() } /// Encode a given `&[u8]` to a framed message. /// /// # Errors /// An error may be returned if the given `data` is more than [`u32::MAX`] bytes. This is a /// constraint on a protocol level. pub fn try_encode(data: &[u8]) -> Result, EncodeError> { let hash = hash(data); let content = hash.iter().chain(data.iter()).copied().collect::>(); let mut result = (u32::try_from(content.len()) .map_err(|_| EncodeError::InputTooLarge(content.len()))?) .to_be_bytes() .to_vec(); result.extend(content); Ok(result) } /// Decode a framed message into a `Vec`. /// /// # Errors /// An error may be returned if: /// * The given `data` does not contain enough data to parse a length, /// * The given `data` does not contain the given length's worth of data, /// * The given `data` has a checksum that does not match what we build locally. pub fn try_decode(data: &[u8]) -> Result, DecodeError> { // check length and advance data pointer beyond length check let len_bytes: [u8; LEN_SIZE] = data[..LEN_SIZE] .try_into() .map_err(DecodeError::InvalidLength)?; let len = u32::from_be_bytes(len_bytes); if len as usize + LEN_SIZE > data.len() { return Err(DecodeError::IncorrectLength(data.len() - LEN_SIZE, len)); } let data = &data[LEN_SIZE..]; let checksum: &[u8; 32] = &data[..32] .try_into() .map_err(DecodeError::InvalidChecksum)?; let content = &data[32..]; let our_checksum = hash(content); if our_checksum != checksum { return Err(DecodeError::BadChecksum( hex::encode(checksum), hex::encode(our_checksum), )); } Ok(content.to_vec()) } #[cfg(test)] mod tests { use super::{try_encode, try_decode, DecodeError}; #[test] fn stable_interface() { let data = (0..255).collect::>(); insta::assert_debug_snapshot!(try_encode(&data[..])); } #[test] fn equivalency() -> Result<(), DecodeError> { let data = (0..255).collect::>(); assert_eq!(try_decode(&try_encode(&data[..]).unwrap())?, data); Ok(()) } #[test] fn allows_extra_data() -> Result<(), DecodeError> { let data = (0..255).collect::>(); let mut encoded = try_encode(&data[..]).unwrap(); // Throw on some extra data encoded.extend(0..255); assert_eq!(try_decode(&try_encode(&data[..]).unwrap())?, data); Ok(()) } #[test] fn error_on_smaller_data() { // Data sliced by 1 byte let data = (0..255).collect::>(); let encoded = try_encode(&data[..]).unwrap(); let error = try_decode(&encoded[..data.len() - 1]); assert!(error.is_err()); // Data includes length and checksum let error = try_decode(&encoded[..super::LEN_SIZE + 256 / 8]); assert!(error.is_err()); // Data only includes length let data = 12u32.to_be_bytes(); let error = try_decode(&data[..]); assert!(error.is_err()); } #[test] fn error_on_invalid_checksum() { let data = (0..255).collect::>(); let mut encoded = try_encode(&data[..]).unwrap(); assert_ne!(encoded[super::LEN_SIZE + 1], 0); encoded[super::LEN_SIZE + 1] = 0; let error = try_decode(&data[..]); assert!(error.is_err()); } }