keyfork/crates/util/keyfork-frame/src/lib.rs

223 lines
7.2 KiB
Rust

//! Utility functions to quickly encode and decode `&[u8]` to and from framed messages.
//!
//! Framed messages consist of the following items:
//!
//! ```txt
//! | len: u32 of data.len() | data: binary data |
//! ```
//!
//! The data stored after the length consists of the following items:
//!
//! ```txt
//! | checksum: [u8; 32] sha256 hash of `raw_data` | raw_data: &[u8] |
//! ```
use std::io::{Read, Write};
#[cfg(feature = "async")]
pub mod asyncext;
use sha2::{Digest, Sha256};
/// An error encountered while decoding a frame.
#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
/// There were not enough bytes to determine the length of the data slice.
#[error("Invalid length: {0}")]
InvalidLength(std::array::TryFromSliceError),
/// There were not enough bytes to read a checksum of the data slice.
#[error("Invalid checksum: {0} bytes")]
InvalidChecksum(std::array::TryFromSliceError),
/// There were not enough bytes to read the rest of the data.
#[error("Incorrect length of internal data: {0}, expected at least: {1}")]
IncorrectLength(usize, u32),
/// The provided checksum of the data did not match the locally-generated checksum.
#[error("Checksum did not match: Their {0:X?} != Our {1:X?}")]
BadChecksum(Vec<u8>, Vec<u8>),
/// Data could not be read from the input source.
#[error("Data could not be read from the input source: {0}")]
Io(#[from] std::io::Error),
}
/// An error encountered while encoding a frame.
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
/// The given input was larger than could be encoded by this protocol.
#[error("Input too large to encode: {0}")]
InputTooLarge(usize),
/// Data could not be written to the output sink.
#[error("Data could not be written to the output sink: {0}")]
Io(#[from] std::io::Error),
}
pub(crate) fn hash(data: &[u8]) -> Vec<u8> {
let mut hashobj = Sha256::new();
hashobj.update(data);
hashobj.finalize().to_vec()
}
/// Encode a given `&[u8]` to a framed message.
///
/// # Errors
/// An error may be returned if the given `data` is more than [`u32::MAX`] bytes. This is a
/// constraint on a protocol level.
///
/// # Examples
/// ```rust
/// let data = keyfork_frame::try_encode(b"hello world!".as_slice()).unwrap();
/// ```
pub fn try_encode(data: &[u8]) -> Result<Vec<u8>, EncodeError> {
let mut output = vec![];
try_encode_to(data, &mut output)?;
Ok(output)
}
/// Encode data to a type implementing [`Write`].
///
/// # Errors
/// An error may be returned if the givenu `data` is more than [`u32::MAX`] bytes, or if the writer
/// is unable to write data.
///
/// # Examples
/// ```rust
/// let mut output = vec![];
/// keyfork_frame::try_encode_to(b"hello world!".as_slice(), &mut output).unwrap();
/// ```
pub fn try_encode_to(data: &[u8], writable: &mut impl Write) -> Result<(), EncodeError> {
let hash = hash(data);
let len = hash.len() + data.len();
let len = u32::try_from(len).map_err(|_| EncodeError::InputTooLarge(len))?;
writable.write_all(&len.to_be_bytes())?;
writable.write_all(&hash)?;
writable.write_all(data)?;
Ok(())
}
pub(crate) fn verify_checksum(data: &[u8]) -> Result<&[u8], DecodeError> {
let checksum: &[u8; 32] = &data[..32]
.try_into()
.map_err(DecodeError::InvalidChecksum)?;
let content = &data[32..];
let our_checksum = hash(content);
if our_checksum != checksum {
return Err(DecodeError::BadChecksum(checksum.to_vec(), our_checksum));
}
Ok(content)
}
/// Decode a framed message into a `Vec<u8>`.
///
/// # Errors
/// An error may be returned if:
/// * The given `data` does not contain enough data to parse a length,
/// * The given `data` does not contain the given length's worth of data,
/// * The given `data` has a checksum that does not match what we build locally.
///
/// # Examples
/// ```rust
/// let input = b"hello world!";
/// let encoded = keyfork_frame::try_encode(input.as_slice()).unwrap();
/// let decoded = keyfork_frame::try_decode(&encoded).unwrap();
/// assert_eq!(input.as_slice(), decoded.as_slice());
/// ```
pub fn try_decode(data: &[u8]) -> Result<Vec<u8>, DecodeError> {
try_decode_from(&mut &data[..])
}
/// Read and decode a framed message into a `Vec<u8>`.
///
/// Note that unlike [`try_encode_to`], this method does not allow writing to an object
/// implementing Write. This is because the data must be stored entirely in memory to allow
/// verifying the data. The data is then returned using the same in-memory representation as is
/// used in memory, and a caller may then choose to use `writable.write_all()`.
///
/// # Errors
/// An error may be returned if:
/// * The given `data` does not contain enough data to parse a length,
/// * The given `data` does not contain the given length's worth of data,
/// * The given `data` has a checksum that does not match what we build locally.
/// * The source for the data returned an error.
///
/// # Examples
/// ```rust
/// let input = b"hello world!";
/// let mut encoded = vec![];
/// keyfork_frame::try_encode_to(input.as_slice(), &mut encoded).unwrap();
/// let decoded = keyfork_frame::try_decode_from(&mut &encoded[..]).unwrap();
/// assert_eq!(input.as_slice(), decoded.as_slice());
/// ```
pub fn try_decode_from(readable: &mut impl Read) -> Result<Vec<u8>, DecodeError> {
let mut bytes = 0u32.to_be_bytes();
readable.read_exact(&mut bytes)?;
let len = u32::from_be_bytes(bytes);
let mut data = vec![0u8; len as usize];
readable.read_exact(&mut data)?;
let content = verify_checksum(&data)?;
Ok(content.to_vec())
}
#[cfg(test)]
mod tests {
use super::{try_decode, try_encode, DecodeError};
const LEN_SIZE: usize = (u32::BITS / 8) as usize;
#[test]
fn stable_interface() {
let data = (0..255).collect::<Vec<u8>>();
insta::assert_debug_snapshot!(try_encode(&data[..]));
}
#[test]
fn equivalency() -> Result<(), DecodeError> {
let data = (0..255).collect::<Vec<u8>>();
assert_eq!(try_decode(&try_encode(&data[..]).unwrap())?, data);
Ok(())
}
#[test]
fn allows_extra_data() -> Result<(), DecodeError> {
let data = (0..255).collect::<Vec<u8>>();
let mut encoded = try_encode(&data[..]).unwrap();
// Throw on some extra data
encoded.extend(0..255);
assert_eq!(try_decode(&try_encode(&data[..]).unwrap())?, data);
Ok(())
}
#[test]
fn error_on_smaller_data() {
// Data sliced by 1 byte
let data = (0..255).collect::<Vec<u8>>();
let encoded = try_encode(&data[..]).unwrap();
let error = try_decode(&encoded[..data.len() - 1]);
assert!(error.is_err());
// Data includes length and checksum
let error = try_decode(&encoded[..LEN_SIZE + 256 / 8]);
assert!(error.is_err());
// Data only includes length
let data = 12u32.to_be_bytes();
let error = try_decode(&data[..]);
assert!(error.is_err());
}
#[test]
fn error_on_invalid_checksum() {
let data = (0..255).collect::<Vec<u8>>();
let mut encoded = try_encode(&data[..]).unwrap();
assert_ne!(encoded[LEN_SIZE + 1], 0);
encoded[LEN_SIZE + 1] = 0;
let error = try_decode(&data[..]);
assert!(error.is_err());
}
}