158 lines
4.9 KiB
Rust
158 lines
4.9 KiB
Rust
|
//! Utility functions to quickly encode and decode `&[u8]` to and from framed messages.
|
||
|
//!
|
||
|
//! Framed messages consist of the following items:
|
||
|
//!
|
||
|
//! ```txt
|
||
|
//! | len: u32 of data.len() | data: binary data |
|
||
|
//! ```
|
||
|
//!
|
||
|
//! The data stored after the length consists of the following items:
|
||
|
//!
|
||
|
//! ```txt
|
||
|
//! | checksum: [u8; 32] sha256 hash of `raw_data` | raw_data: &[u8] |
|
||
|
//! ```
|
||
|
|
||
|
|
||
|
use sha2::{Digest, Sha256};
|
||
|
|
||
|
#[derive(Debug, Clone, thiserror::Error)]
|
||
|
pub enum DecodeError {
|
||
|
/// There were not enough bytes to determine the length of the data slice.
|
||
|
#[error("Invalid length: {0}")]
|
||
|
InvalidLength(std::array::TryFromSliceError),
|
||
|
|
||
|
/// There were not enough bytes to read a checksum of the data slice.
|
||
|
#[error("Invalid checksum: {0} bytes")]
|
||
|
InvalidChecksum(std::array::TryFromSliceError),
|
||
|
|
||
|
/// There were not enough bytes to read the rest of the data.
|
||
|
#[error("Incorrect length of internal data: {0}, expected at least: {1}")]
|
||
|
IncorrectLength(usize, u32),
|
||
|
|
||
|
/// The provided checksum of the data did not match the locally-generated checksum.
|
||
|
#[error("Checksum did not match: Their {0} != Our {1}")]
|
||
|
BadChecksum(String, String),
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Clone, thiserror::Error)]
|
||
|
pub enum EncodeError {
|
||
|
/// The given input was larger than could be encoded by this protocol.
|
||
|
#[error("Input too large to encode: {0}")]
|
||
|
InputTooLarge(usize),
|
||
|
}
|
||
|
|
||
|
const LEN_SIZE: usize = std::mem::size_of::<u32>();
|
||
|
|
||
|
fn hash(data: &[u8]) -> Vec<u8> {
|
||
|
let mut hashobj = Sha256::new();
|
||
|
hashobj.update(data);
|
||
|
hashobj.finalize().to_vec()
|
||
|
}
|
||
|
|
||
|
/// Encode a given `&[u8]` to a framed message.
|
||
|
///
|
||
|
/// # Errors
|
||
|
/// An error may be returned if the given `data` is more than [`u32::MAX`] bytes. This is a
|
||
|
/// constraint on a protocol level.
|
||
|
pub fn try_encode(data: &[u8]) -> Result<Vec<u8>, EncodeError> {
|
||
|
let hash = hash(data);
|
||
|
let content = hash.iter().chain(data.iter()).copied().collect::<Vec<_>>();
|
||
|
let mut result = (u32::try_from(content.len())
|
||
|
.map_err(|_| EncodeError::InputTooLarge(content.len()))?)
|
||
|
.to_be_bytes()
|
||
|
.to_vec();
|
||
|
result.extend(content);
|
||
|
Ok(result)
|
||
|
}
|
||
|
|
||
|
/// Decode a framed message into a `Vec<u8>`.
|
||
|
///
|
||
|
/// # Errors
|
||
|
/// An error may be returned if:
|
||
|
/// * The given `data` does not contain enough data to parse a length,
|
||
|
/// * The given `data` does not contain the given length's worth of data,
|
||
|
/// * The given `data` has a checksum that does not match what we build locally.
|
||
|
pub fn try_decode(data: &[u8]) -> Result<Vec<u8>, DecodeError> {
|
||
|
// check length and advance data pointer beyond length check
|
||
|
let len_bytes: [u8; LEN_SIZE] = data[..LEN_SIZE]
|
||
|
.try_into()
|
||
|
.map_err(DecodeError::InvalidLength)?;
|
||
|
let len = u32::from_be_bytes(len_bytes);
|
||
|
if len as usize + LEN_SIZE > data.len() {
|
||
|
return Err(DecodeError::IncorrectLength(data.len() - LEN_SIZE, len));
|
||
|
}
|
||
|
let data = &data[LEN_SIZE..];
|
||
|
|
||
|
let checksum: &[u8; 32] = &data[..32]
|
||
|
.try_into()
|
||
|
.map_err(DecodeError::InvalidChecksum)?;
|
||
|
|
||
|
let content = &data[32..];
|
||
|
let our_checksum = hash(content);
|
||
|
if our_checksum != checksum {
|
||
|
return Err(DecodeError::BadChecksum(
|
||
|
hex::encode(checksum),
|
||
|
hex::encode(our_checksum),
|
||
|
));
|
||
|
}
|
||
|
|
||
|
Ok(content.to_vec())
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use super::{try_encode, try_decode, DecodeError};
|
||
|
|
||
|
#[test]
|
||
|
fn stable_interface() {
|
||
|
let data = (0..255).collect::<Vec<u8>>();
|
||
|
insta::assert_debug_snapshot!(try_encode(&data[..]));
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn equivalency() -> Result<(), DecodeError> {
|
||
|
let data = (0..255).collect::<Vec<u8>>();
|
||
|
assert_eq!(try_decode(&try_encode(&data[..]).unwrap())?, data);
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn allows_extra_data() -> Result<(), DecodeError> {
|
||
|
let data = (0..255).collect::<Vec<u8>>();
|
||
|
let mut encoded = try_encode(&data[..]).unwrap();
|
||
|
// Throw on some extra data
|
||
|
encoded.extend(0..255);
|
||
|
assert_eq!(try_decode(&try_encode(&data[..]).unwrap())?, data);
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn error_on_smaller_data() {
|
||
|
// Data sliced by 1 byte
|
||
|
let data = (0..255).collect::<Vec<u8>>();
|
||
|
let encoded = try_encode(&data[..]).unwrap();
|
||
|
let error = try_decode(&encoded[..data.len() - 1]);
|
||
|
assert!(error.is_err());
|
||
|
|
||
|
// Data includes length and checksum
|
||
|
let error = try_decode(&encoded[..super::LEN_SIZE + 256 / 8]);
|
||
|
assert!(error.is_err());
|
||
|
|
||
|
// Data only includes length
|
||
|
let data = 12u32.to_be_bytes();
|
||
|
let error = try_decode(&data[..]);
|
||
|
assert!(error.is_err());
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn error_on_invalid_checksum() {
|
||
|
let data = (0..255).collect::<Vec<u8>>();
|
||
|
let mut encoded = try_encode(&data[..]).unwrap();
|
||
|
assert_ne!(encoded[super::LEN_SIZE + 1], 0);
|
||
|
encoded[super::LEN_SIZE + 1] = 0;
|
||
|
|
||
|
let error = try_decode(&data[..]);
|
||
|
assert!(error.is_err());
|
||
|
}
|
||
|
}
|