diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index 0f332fbab..3902c7c0f 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -31,6 +31,14 @@ impl Nonce { pub const fn new(nonce: [u8; 12]) -> Self { Nonce(nonce) } } +// Const validation trait for compile time check with max of 3. +trait UpTo3 {} + +impl UpTo3<0> for () {} +impl UpTo3<1> for () {} +impl UpTo3<2> for () {} +impl UpTo3<3> for () {} + /// A SIMD-friendly structure which holds 25% of the cipher state. /// /// The cipher's quarter round function is the bulk of its work @@ -52,12 +60,10 @@ impl Nonce { /// * For-each loops are easy for the compiler to recognize as vectorizable. /// * The type is a based on an array instead of tuple since the heterogeneous /// nature of tuples can confuse the compiler into thinking it is not vectorizable. -/// * Memory alignment lines up with SIMD size. /// /// In the future, a "blacklist" for the alignment option might be useful to /// disable it on architectures which definitely do not support SIMD in order to avoid /// needless memory inefficientcies. -#[repr(align(16))] #[derive(Clone, Copy, PartialEq)] struct U32x4([u32; 4]); @@ -81,21 +87,29 @@ impl U32x4 { } #[inline(always)] - fn rotate_elements_left(self) -> Self { - let mut result = [0u32; 4]; - (0..4).for_each(|i| { - result[i] = self.0[(i + N as usize) % 4]; - }); - U32x4(result) + fn rotate_elements_left(self) -> Self + where + (): UpTo3, + { + match N { + 1 => U32x4([self.0[1], self.0[2], self.0[3], self.0[0]]), + 2 => U32x4([self.0[2], self.0[3], self.0[0], self.0[1]]), + 3 => U32x4([self.0[3], self.0[0], self.0[1], self.0[2]]), + _ => self, // Rotate by 0 is a no-op. + } } #[inline(always)] - fn rotate_elements_right(self) -> Self { - let mut result = [0u32; 4]; - (0..4).for_each(|i| { - result[i] = self.0[(i + 4 - N as usize) % 4]; - }); - U32x4(result) + fn rotate_elements_right(self) -> Self + where + (): UpTo3, + { + match N { + 1 => U32x4([self.0[3], self.0[0], self.0[1], self.0[2]]), + 2 => U32x4([self.0[2], self.0[3], self.0[0], self.0[1]]), + 3 => U32x4([self.0[1], self.0[2], self.0[3], self.0[0]]), + _ => self, // Rotate by 0 is a no-op. + } } #[inline(always)] @@ -163,7 +177,7 @@ impl State { /// Four quarter rounds performed on the entire state of the cipher in a vectorized SIMD friendly fashion. #[inline(always)] - fn quarter_round(a: U32x4, b: U32x4, c: U32x4, d: U32x4) -> (U32x4, U32x4, U32x4, U32x4) { + fn quarter_round(a: U32x4, b: U32x4, c: U32x4, d: U32x4) -> [U32x4; 4] { let a = a.wrapping_add(b); let d = d.bitxor(a).rotate_left(16); @@ -176,7 +190,7 @@ impl State { let c = c.wrapping_add(d); let b = b.bitxor(c).rotate_left(7); - (a, b, c, d) + [a, b, c, d] } /// Perform a round on "columns" and then "diagonals" of the state. @@ -193,13 +207,13 @@ impl State { let [mut a, mut b, mut c, mut d] = state; // Column round. - (a, b, c, d) = Self::quarter_round(a, b, c, d); + [a, b, c, d] = Self::quarter_round(a, b, c, d); // Diagonal round (with rotations). b = b.rotate_elements_left::<1>(); c = c.rotate_elements_left::<2>(); d = d.rotate_elements_left::<3>(); - (a, b, c, d) = Self::quarter_round(a, b, c, d); + [a, b, c, d] = Self::quarter_round(a, b, c, d); // Rotate the words back into their normal positions. b = b.rotate_elements_right::<1>(); c = c.rotate_elements_right::<2>(); @@ -209,6 +223,7 @@ impl State { } /// Transform the state by performing the ChaCha block function. + #[inline(always)] fn chacha_block(&mut self) { let mut working_state = self.matrix; @@ -223,6 +238,7 @@ impl State { } /// Expose the 512-bit state as a byte stream. + #[inline(always)] fn keystream(&self) -> [u8; 64] { let mut keystream = [0u8; 64]; for i in 0..4 { @@ -260,38 +276,63 @@ impl ChaCha20 { ChaCha20 { key, nonce, block_count: block, seek_offset_bytes: 0 } } - /// Apply the keystream to a buffer. + /// Get the keystream for a specific block. + #[inline(always)] + fn keystream_at_block(&self, block: u32) -> [u8; 64] { + let mut state = State::new(self.key, self.nonce, block); + state.chacha_block(); + state.keystream() + } + + /// Apply the keystream to a buffer updating the cipher block state as necessary. pub fn apply_keystream(&mut self, buffer: &mut [u8]) { - let num_full_blocks = buffer.len() / CHACHA_BLOCKSIZE; - for block in 0..num_full_blocks { - let keystream = - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes); - for (buffer_byte, keystream_byte) in buffer - [block * CHACHA_BLOCKSIZE..(block + 1) * CHACHA_BLOCKSIZE] - .iter_mut() - .zip(keystream.iter()) + // If we have an initial offset, handle the first partial block to get back to alignment. + let remaining_buffer = if self.seek_offset_bytes != 0 { + let bytes_until_aligned = 64 - self.seek_offset_bytes; + let bytes_to_process = buffer.len().min(bytes_until_aligned); + + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in + buffer[..bytes_to_process].iter_mut().zip(&keystream[self.seek_offset_bytes..]) { - *buffer_byte ^= *keystream_byte + *buffer_byte ^= *keystream_byte; + } + + if bytes_to_process < bytes_until_aligned { + self.seek_offset_bytes += bytes_to_process; + return; + } + + self.block_count += 1; + self.seek_offset_bytes = 0; + &mut buffer[bytes_to_process..] + } else { + buffer + }; + + // Process full blocks. + let mut chunks = remaining_buffer.chunks_exact_mut(CHACHA_BLOCKSIZE); + for chunk in &mut chunks { + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in chunk.iter_mut().zip(keystream.iter()) { + *buffer_byte ^= *keystream_byte; } self.block_count += 1; } - if buffer.len() % 64 > 0 { - let keystream = - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes); - for (buffer_byte, keystream_byte) in - buffer[num_full_blocks * CHACHA_BLOCKSIZE..].iter_mut().zip(keystream.iter()) - { - *buffer_byte ^= *keystream_byte + + // Handle any remaining bytes as partial block. + let remainder = chunks.into_remainder(); + if !remainder.is_empty() { + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in remainder.iter_mut().zip(keystream.iter()) { + *buffer_byte ^= *keystream_byte; } - self.block_count += 1; + self.seek_offset_bytes = remainder.len(); } } - /// Get the keystream block at a specified block. - pub fn get_keystream(&mut self, block: u32) -> [u8; 64] { - self.block(block); - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes) - } + /// Get the keystream for specified block. + pub fn get_keystream(&self, block: u32) -> [u8; 64] { self.keystream_at_block(block) } /// Update the index of the keystream to the given byte. pub fn seek(&mut self, seek: u32) { @@ -306,23 +347,6 @@ impl ChaCha20 { } } -fn keystream_at_slice(key: Key, nonce: Nonce, count: u32, seek: usize) -> [u8; 64] { - let mut keystream: [u8; 128] = [0; 128]; - let (first_half, second_half) = keystream.split_at_mut(64); - - let mut state = State::new(key, nonce, count); - state.chacha_block(); - first_half.copy_from_slice(&state.keystream()); - - let mut state = State::new(key, nonce, count + 1); - state.chacha_block(); - second_half.copy_from_slice(&state.keystream()); - - let seeked_keystream: [u8; 64] = - keystream[seek..seek + 64].try_into().expect("slicing produces 64-byte slice"); - seeked_keystream -} - #[cfg(test)] #[cfg(feature = "alloc")] mod tests { @@ -446,4 +470,35 @@ mod tests { let binding = *b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; assert_eq!(binding, to); } + + #[test] + fn multiple_partial_applies() { + let key = + Key(Vec::from_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") + .unwrap() + .try_into() + .unwrap()); + let nonce = Nonce(Vec::from_hex("000000000000004a00000000").unwrap().try_into().unwrap()); + + // Create two instances, one for a full single pass and one for chunked partial calls. + let mut chacha_full = ChaCha20::new(key, nonce, 0); + let mut chacha_chunked = ChaCha20::new(key, nonce, 0); + + // Test data that crosses block boundaries. + let mut full_buffer = [0u8; 100]; + let mut chunked_buffer = [0u8; 100]; + for (i, byte) in full_buffer.iter_mut().enumerate() { + *byte = i as u8; + } + chunked_buffer.copy_from_slice(&full_buffer); + + // Apply keystream to full buffer. + chacha_full.apply_keystream(&mut full_buffer); + // Apply keystream in multiple calls to chunked buffer. + chacha_chunked.apply_keystream(&mut chunked_buffer[..30]); // Partial block + chacha_chunked.apply_keystream(&mut chunked_buffer[30..82]); // Cross block boundary + chacha_chunked.apply_keystream(&mut chunked_buffer[82..]); // End with partial block + + assert_eq!(full_buffer, chunked_buffer); + } } diff --git a/chacha20_poly1305/src/lib.rs b/chacha20_poly1305/src/lib.rs index 043f6877b..045a5b374 100644 --- a/chacha20_poly1305/src/lib.rs +++ b/chacha20_poly1305/src/lib.rs @@ -122,7 +122,7 @@ impl ChaCha20Poly1305 { tag: [u8; 16], aad: Option<&[u8]>, ) -> Result<(), Error> { - let mut chacha = ChaCha20::new_from_block(self.key, self.nonce, 0); + let chacha = ChaCha20::new_from_block(self.key, self.nonce, 0); let keystream = chacha.get_keystream(0); let mut poly = Poly1305::new(keystream[..32].try_into().expect("slicing produces 32-byte slice"));