Merge rust-bitcoin/rust-bitcoin#3275: Enforce that `Hash::Bytes` is an array

be13397570 Make hmac & hkdf more robust against buggy `Hash` (Martin Habovstiak)
94c0614bda Enforce that `Hash::Bytes` is an array (Martin Habovstiak)

Pull request description:

  This makes sure `Hash::Bytes` is an array. We've discussed this somewhere but I don't remember where.

  I'm not sure if the second commit is actually valuable but hopefully shouldn't make things worse.

ACKs for top commit:
  apoelstra:
    ACK be13397570 successfully ran local tests; yep, this looks like an improvement. Agreed that the second commit has questionable value but doe not make things worse
  tcharding:
    ACK be13397570

Tree-SHA512: 0fed982084f0f98927c2b4a275cec81cb4bbc0efbf01551a0a4a8b6b39a4504830243ee8d55a5c0418d81b5d4babc7b22332dbacc0609ced8fada84d2961ae71
This commit is contained in:
merge-script 2024-09-02 00:42:39 +00:00
commit 9233eb2fa3
No known key found for this signature in database
GPG Key ID: C588D63CE41B97C1
5 changed files with 33 additions and 17 deletions

View File

@ -11,7 +11,7 @@ use alloc::vec;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::fmt; use core::fmt;
use crate::{GeneralHash, HashEngine, Hmac, HmacEngine}; use crate::{GeneralHash, HashEngine, Hmac, HmacEngine, IsByteArray};
/// Output keying material max length multiple. /// Output keying material max length multiple.
const MAX_OUTPUT_BLOCKS: usize = 255; const MAX_OUTPUT_BLOCKS: usize = 255;
@ -54,14 +54,14 @@ where
/// but the info must be independent from the ikm for security. /// but the info must be independent from the ikm for security.
pub fn expand(&self, info: &[u8], okm: &mut [u8]) -> Result<(), MaxLengthError> { pub fn expand(&self, info: &[u8], okm: &mut [u8]) -> Result<(), MaxLengthError> {
// Length of output keying material in bytes must be less than 255 * hash length. // Length of output keying material in bytes must be less than 255 * hash length.
if okm.len() > (MAX_OUTPUT_BLOCKS * T::LEN) { if okm.len() > (MAX_OUTPUT_BLOCKS * T::Bytes::LEN) {
return Err(MaxLengthError { max: MAX_OUTPUT_BLOCKS * T::LEN }); return Err(MaxLengthError { max: MAX_OUTPUT_BLOCKS * T::Bytes::LEN });
} }
// Counter starts at "1" based on RFC5869 spec and is committed to in the hash. // Counter starts at "1" based on RFC5869 spec and is committed to in the hash.
let mut counter = 1u8; let mut counter = 1u8;
// Ceiling calculation for the total number of blocks (iterations) required for the expand. // Ceiling calculation for the total number of blocks (iterations) required for the expand.
let total_blocks = (okm.len() + T::LEN - 1) / T::LEN; let total_blocks = (okm.len() + T::Bytes::LEN - 1) / T::Bytes::LEN;
while counter <= total_blocks as u8 { while counter <= total_blocks as u8 {
let mut hmac_engine: HmacEngine<T> = HmacEngine::new(self.prk.as_ref()); let mut hmac_engine: HmacEngine<T> = HmacEngine::new(self.prk.as_ref());
@ -69,18 +69,18 @@ where
// First block does not have a previous block, // First block does not have a previous block,
// all other blocks include last block in the HMAC input. // all other blocks include last block in the HMAC input.
if counter != 1u8 { if counter != 1u8 {
let previous_start_index = (counter as usize - 2) * T::LEN; let previous_start_index = (counter as usize - 2) * T::Bytes::LEN;
let previous_end_index = (counter as usize - 1) * T::LEN; let previous_end_index = (counter as usize - 1) * T::Bytes::LEN;
hmac_engine.input(&okm[previous_start_index..previous_end_index]); hmac_engine.input(&okm[previous_start_index..previous_end_index]);
} }
hmac_engine.input(info); hmac_engine.input(info);
hmac_engine.input(&[counter]); hmac_engine.input(&[counter]);
let t = Hmac::from_engine(hmac_engine); let t = Hmac::from_engine(hmac_engine);
let start_index = (counter as usize - 1) * T::LEN; let start_index = (counter as usize - 1) * T::Bytes::LEN;
// Last block might not take full hash length. // Last block might not take full hash length.
let end_index = let end_index =
if counter == (total_blocks as u8) { okm.len() } else { counter as usize * T::LEN }; if counter == (total_blocks as u8) { okm.len() } else { counter as usize * T::Bytes::LEN };
okm[start_index..end_index].copy_from_slice(&t.as_ref()[0..(end_index - start_index)]); okm[start_index..end_index].copy_from_slice(&t.as_ref()[0..(end_index - start_index)]);

View File

@ -72,10 +72,11 @@ impl<T: GeneralHash> HmacEngine<T> {
if key.len() > T::Engine::BLOCK_SIZE { if key.len() > T::Engine::BLOCK_SIZE {
let hash = <T as GeneralHash>::hash(key); let hash = <T as GeneralHash>::hash(key);
for (b_i, b_h) in ipad.iter_mut().zip(hash.as_ref()) { let hash = hash.as_byte_array().as_ref();
for (b_i, b_h) in ipad.iter_mut().zip(hash) {
*b_i ^= *b_h; *b_i ^= *b_h;
} }
for (b_o, b_h) in opad.iter_mut().zip(hash.as_ref()) { for (b_o, b_h) in opad.iter_mut().zip(hash) {
*b_o ^= *b_h; *b_o ^= *b_h;
} }
} else { } else {
@ -119,7 +120,8 @@ impl<T: GeneralHash> fmt::LowerHex for Hmac<T> {
} }
impl<T: GeneralHash> convert::AsRef<[u8]> for Hmac<T> { impl<T: GeneralHash> convert::AsRef<[u8]> for Hmac<T> {
fn as_ref(&self) -> &[u8] { self.0.as_ref() } // Calling as_byte_array is more reliable
fn as_ref(&self) -> &[u8] { self.0.as_byte_array().as_ref() }
} }
impl<T: GeneralHash> GeneralHash for Hmac<T> { impl<T: GeneralHash> GeneralHash for Hmac<T> {
@ -127,7 +129,7 @@ impl<T: GeneralHash> GeneralHash for Hmac<T> {
fn from_engine(mut e: HmacEngine<T>) -> Hmac<T> { fn from_engine(mut e: HmacEngine<T>) -> Hmac<T> {
let ihash = T::from_engine(e.iengine); let ihash = T::from_engine(e.iengine);
e.oengine.input(ihash.as_ref()); e.oengine.input(ihash.as_byte_array().as_ref());
let ohash = T::from_engine(e.oengine); let ohash = T::from_engine(e.oengine);
Hmac(ohash) Hmac(ohash)
} }
@ -135,7 +137,6 @@ impl<T: GeneralHash> GeneralHash for Hmac<T> {
impl<T: GeneralHash> Hash for Hmac<T> { impl<T: GeneralHash> Hash for Hmac<T> {
type Bytes = T::Bytes; type Bytes = T::Bytes;
const LEN: usize = T::LEN;
fn from_slice(sl: &[u8]) -> Result<Hmac<T>, FromSliceError> { T::from_slice(sl).map(Hmac) } fn from_slice(sl: &[u8]) -> Result<Hmac<T>, FromSliceError> { T::from_slice(sl).map(Hmac) }

View File

@ -108,7 +108,6 @@ macro_rules! hash_trait_impls {
impl<$($gen: $gent),*> crate::Hash for Hash<$($gen),*> { impl<$($gen: $gent),*> crate::Hash for Hash<$($gen),*> {
type Bytes = [u8; $bits / 8]; type Bytes = [u8; $bits / 8];
const LEN: usize = $bits / 8;
const DISPLAY_BACKWARD: bool = $reverse; const DISPLAY_BACKWARD: bool = $reverse;
fn from_slice(sl: &[u8]) -> $crate::_export::_core::result::Result<Hash<$($gen),*>, $crate::FromSliceError> { fn from_slice(sl: &[u8]) -> $crate::_export::_core::result::Result<Hash<$($gen),*>, $crate::FromSliceError> {

View File

@ -274,10 +274,10 @@ pub trait Hash:
+ convert::AsRef<[u8]> + convert::AsRef<[u8]>
{ {
/// The byte array that represents the hash internally. /// The byte array that represents the hash internally.
type Bytes: hex::FromHex + Copy; type Bytes: hex::FromHex + Copy + IsByteArray /* <LEN={Self::LEN}> is still unsupported by Rust */;
/// Length of the hash, in bytes. /// Length of the hash, in bytes.
const LEN: usize; const LEN: usize = Self::Bytes::LEN;
/// Copies a byte slice into a hash object. /// Copies a byte slice into a hash object.
fn from_slice(sl: &[u8]) -> Result<Self, FromSliceError>; fn from_slice(sl: &[u8]) -> Result<Self, FromSliceError>;
@ -297,6 +297,23 @@ pub trait Hash:
fn from_byte_array(bytes: Self::Bytes) -> Self; fn from_byte_array(bytes: Self::Bytes) -> Self;
} }
/// Ensures that a type is an array.
pub trait IsByteArray: AsRef<[u8]> + sealed::IsByteArray {
/// The length of the array.
const LEN: usize;
}
impl<const N: usize> IsByteArray for [u8; N] {
const LEN: usize = N;
}
mod sealed {
#[doc(hidden)]
pub trait IsByteArray { }
impl<const N: usize> IsByteArray for [u8; N] { }
}
/// Attempted to create a hash from an invalid length slice. /// Attempted to create a hash from an invalid length slice.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct FromSliceError { pub struct FromSliceError {

View File

@ -231,7 +231,6 @@ macro_rules! hash_newtype {
impl $crate::Hash for $newtype { impl $crate::Hash for $newtype {
type Bytes = <$hash as $crate::Hash>::Bytes; type Bytes = <$hash as $crate::Hash>::Bytes;
const LEN: usize = <$hash as $crate::Hash>::LEN;
const DISPLAY_BACKWARD: bool = $crate::hash_newtype_get_direction!($hash, $(#[$($type_attrs)*])*); const DISPLAY_BACKWARD: bool = $crate::hash_newtype_get_direction!($hash, $(#[$($type_attrs)*])*);
#[inline] #[inline]