keyfork-frame: add asyncext to allow AsyncRead/AsyncWrite

This commit is contained in:
Ryan Heywood 2023-08-25 01:32:06 -05:00
parent 76c9214d73
commit fa8e6d726d
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
3 changed files with 80 additions and 15 deletions

View File

@ -5,10 +5,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = ["async"]
async = ["dep:tokio"]
[dependencies] [dependencies]
hex = "0.4.3" hex = "0.4.3"
sha2 = "0.10.7" sha2 = "0.10.7"
thiserror = "1.0.47" thiserror = "1.0.47"
tokio = { version = "1.32.0", optional = true, features = ["io-util"] }
[dev-dependencies] [dev-dependencies]
insta = "1.31.0" insta = "1.31.0"

View File

@ -0,0 +1,45 @@
use std::marker::Unpin;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::{hash, verify_checksum, DecodeError, EncodeError};
/// Decode a framed message into a `Vec<u8>`.
///
/// # Errors
/// An error may be returned if:
/// * The given `data` does not contain enough data to parse a length,
/// * The given `data` does not contain the given length's worth of data,
/// * The given `data` has a checksum that does not match what we build locally.
pub async fn try_decode_from(
readable: &mut (impl AsyncRead + Unpin),
) -> Result<Vec<u8>, DecodeError> {
let len = readable.read_u32().await?;
let mut data = Vec::with_capacity(len as usize);
readable.read_exact(&mut data[..]).await?;
let content = verify_checksum(&data[..])?;
// Note: Optimizing this isn't *too* practical, we could probably pop the first 32 bytes off
// the front of the Vec, but it might not even be worth it as opposed to one reallocation.
Ok(content.to_vec())
}
/// Encode a &[u8] into a framed message
///
/// # Errors
/// An error may be returned if:
/// * The given `data` is more than [`u32::MAX`] bytes. This is a constraint on a protocol level.
/// * The resulting data was unable to be written to the given `writable`.
pub async fn try_encode_to(
data: &[u8],
writable: &mut (impl AsyncWrite + Unpin),
) -> Result<(), EncodeError> {
let hash = hash(data);
let len = u32::try_from(hash.len() + data.len())
.map_err(|_| EncodeError::InputTooLarge(hash.len() + data.len()))?;
writable.write_u32(len).await?;
writable.write_all(&hash[..]).await?;
writable.write_all(data).await?;
Ok(())
}

View File

@ -12,10 +12,12 @@
//! | checksum: [u8; 32] sha256 hash of `raw_data` | raw_data: &[u8] | //! | checksum: [u8; 32] sha256 hash of `raw_data` | raw_data: &[u8] |
//! ``` //! ```
#[cfg(feature = "async")]
pub mod asyncext;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
#[derive(Debug, Clone, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum DecodeError { pub enum DecodeError {
/// There were not enough bytes to determine the length of the data slice. /// There were not enough bytes to determine the length of the data slice.
#[error("Invalid length: {0}")] #[error("Invalid length: {0}")]
@ -32,18 +34,26 @@ pub enum DecodeError {
/// The provided checksum of the data did not match the locally-generated checksum. /// The provided checksum of the data did not match the locally-generated checksum.
#[error("Checksum did not match: Their {0} != Our {1}")] #[error("Checksum did not match: Their {0} != Our {1}")]
BadChecksum(String, String), BadChecksum(String, String),
/// Data could not be read from the input source.
#[error("Data could not be read from the input source: {0}")]
Io(#[from] std::io::Error),
} }
#[derive(Debug, Clone, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum EncodeError { pub enum EncodeError {
/// The given input was larger than could be encoded by this protocol. /// The given input was larger than could be encoded by this protocol.
#[error("Input too large to encode: {0}")] #[error("Input too large to encode: {0}")]
InputTooLarge(usize), InputTooLarge(usize),
/// Data could not be written to the output sink.
#[error("Data could not be written to the output sink: {0}")]
Io(#[from] std::io::Error),
} }
const LEN_SIZE: usize = std::mem::size_of::<u32>(); const LEN_SIZE: usize = std::mem::size_of::<u32>();
fn hash(data: &[u8]) -> Vec<u8> { pub(crate) fn hash(data: &[u8]) -> Vec<u8> {
let mut hashobj = Sha256::new(); let mut hashobj = Sha256::new();
hashobj.update(data); hashobj.update(data);
hashobj.finalize().to_vec() hashobj.finalize().to_vec()
@ -65,6 +75,22 @@ pub fn try_encode(data: &[u8]) -> Result<Vec<u8>, EncodeError> {
Ok(result) Ok(result)
} }
pub(crate) fn verify_checksum(data: &[u8]) -> Result<&[u8], DecodeError> {
let checksum: &[u8; 32] = &data[..32]
.try_into()
.map_err(DecodeError::InvalidChecksum)?;
let content = &data[32..];
let our_checksum = hash(content);
if our_checksum != checksum {
return Err(DecodeError::BadChecksum(
hex::encode(checksum),
hex::encode(our_checksum),
));
}
Ok(content)
}
/// Decode a framed message into a `Vec<u8>`. /// Decode a framed message into a `Vec<u8>`.
/// ///
/// # Errors /// # Errors
@ -83,18 +109,7 @@ pub fn try_decode(data: &[u8]) -> Result<Vec<u8>, DecodeError> {
} }
let data = &data[LEN_SIZE..]; let data = &data[LEN_SIZE..];
let checksum: &[u8; 32] = &data[..32] let content = verify_checksum(data)?;
.try_into()
.map_err(DecodeError::InvalidChecksum)?;
let content = &data[32..];
let our_checksum = hash(content);
if our_checksum != checksum {
return Err(DecodeError::BadChecksum(
hex::encode(checksum),
hex::encode(our_checksum),
));
}
Ok(content.to_vec()) Ok(content.to_vec())
} }