From 36d45bf360b4fff069203a86e3429831babe5e5d Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 19 Feb 2025 10:45:42 -0800 Subject: [PATCH 1/6] chacha20_poly1305: remove mod operator * Swaps out the mod operator for a switch statement for a 5% performance boost. --- chacha20_poly1305/src/chacha20.rs | 40 +++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index bf0013192..cde412d0a 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -31,6 +31,14 @@ impl Nonce { pub const fn new(nonce: [u8; 12]) -> Self { Nonce(nonce) } } +// Const validation trait for compile time check with max of 3. +trait UpTo3 {} + +impl UpTo3<0> for () {} +impl UpTo3<1> for () {} +impl UpTo3<2> for () {} +impl UpTo3<3> for () {} + /// A SIMD-friendly structure which holds 25% of the cipher state. /// /// The cipher's quarter round function is the bulk of its work @@ -81,21 +89,29 @@ impl U32x4 { } #[inline(always)] - fn rotate_elements_left(self) -> Self { - let mut result = [0u32; 4]; - (0..4).for_each(|i| { - result[i] = self.0[(i + N as usize) % 4]; - }); - U32x4(result) + fn rotate_elements_left(self) -> Self + where + (): UpTo3, + { + match N { + 1 => U32x4([self.0[1], self.0[2], self.0[3], self.0[0]]), + 2 => U32x4([self.0[2], self.0[3], self.0[0], self.0[1]]), + 3 => U32x4([self.0[3], self.0[0], self.0[1], self.0[2]]), + _ => self, // Rotate by 0 is a no-op. + } } #[inline(always)] - fn rotate_elements_right(self) -> Self { - let mut result = [0u32; 4]; - (0..4).for_each(|i| { - result[i] = self.0[(i + 4 - N as usize) % 4]; - }); - U32x4(result) + fn rotate_elements_right(self) -> Self + where + (): UpTo3, + { + match N { + 1 => U32x4([self.0[3], self.0[0], self.0[1], self.0[2]]), + 2 => U32x4([self.0[2], self.0[3], self.0[0], self.0[1]]), + 3 => U32x4([self.0[1], self.0[2], self.0[3], self.0[0]]), + _ => self, // Rotate by 0 is a no-op. + } } #[inline(always)] From dadd1d72248015450fe6f898cd999ac54ff34aec Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 19 Feb 2025 10:48:17 -0800 Subject: [PATCH 2/6] chacha20_poly1305: remove alignment * Benchmarks showed that on recent versions of the rust compiler, alignment settings could hurt and never helped. --- chacha20_poly1305/src/chacha20.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index cde412d0a..48eb310bf 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -60,12 +60,10 @@ impl UpTo3<3> for () {} /// * For-each loops are easy for the compiler to recognize as vectorizable. /// * The type is a based on an array instead of tuple since the heterogeneous /// nature of tuples can confuse the compiler into thinking it is not vectorizable. -/// * Memory alignment lines up with SIMD size. /// /// In the future, a "blacklist" for the alignment option might be useful to /// disable it on architectures which definitely do not support SIMD in order to avoid /// needless memory inefficientcies. -#[repr(align(16))] #[derive(Clone, Copy, PartialEq)] struct U32x4([u32; 4]); From 33dc1b95fa282ec1fe9706525f1a38ee611b7915 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 19 Feb 2025 10:55:25 -0800 Subject: [PATCH 3/6] chacha20_poly1305: swap tuple for array * While perhaps a small performance gain, < 1%, this conforms to the style used in the rest of the module. --- chacha20_poly1305/src/chacha20.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index 48eb310bf..c611ec64e 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -177,7 +177,7 @@ impl State { /// Four quarter rounds performed on the entire state of the cipher in a vectorized SIMD friendly fashion. #[inline(always)] - fn quarter_round(a: U32x4, b: U32x4, c: U32x4, d: U32x4) -> (U32x4, U32x4, U32x4, U32x4) { + fn quarter_round(a: U32x4, b: U32x4, c: U32x4, d: U32x4) -> [U32x4; 4] { let a = a.wrapping_add(b); let d = d.bitxor(a).rotate_left(16); @@ -190,7 +190,7 @@ impl State { let c = c.wrapping_add(d); let b = b.bitxor(c).rotate_left(7); - (a, b, c, d) + [a, b, c, d] } /// Perform a round on "columns" and then "diagonals" of the state. @@ -207,13 +207,13 @@ impl State { let [mut a, mut b, mut c, mut d] = state; // Column round. - (a, b, c, d) = Self::quarter_round(a, b, c, d); + [a, b, c, d] = Self::quarter_round(a, b, c, d); // Diagonal round (with rotations). b = b.rotate_elements_left::<1>(); c = c.rotate_elements_left::<2>(); d = d.rotate_elements_left::<3>(); - (a, b, c, d) = Self::quarter_round(a, b, c, d); + [a, b, c, d] = Self::quarter_round(a, b, c, d); // Rotate the words back into their normal positions. b = b.rotate_elements_right::<1>(); c = c.rotate_elements_right::<2>(); From 415945cd2b3f4bd40f6c15e359ea241ed7b4f2d4 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 19 Feb 2025 11:38:19 -0800 Subject: [PATCH 4/6] chacha20_poly1305: avoid duplicate block work * The keystream function is creating two state block on every call, but just to handle a corner case. Break up the function into separate methods so that the corner case is handled by itself, avoiding unnecessary work most of the time. * Handle offset state internally. While not strictly necessary due to the cipher's use in BIP324, it makes the library much easier to work with (the bug above would probably have been avoided) if the cipher handles the offset state. --- chacha20_poly1305/src/chacha20.rs | 113 ++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 36 deletions(-) diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index c611ec64e..34f041304 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -274,37 +274,64 @@ impl ChaCha20 { ChaCha20 { key, nonce, block_count: block, seek_offset_bytes: 0 } } - /// Apply the keystream to a buffer. + /// Get the keystream for a specific block. + fn keystream_at_block(&self, block: u32) -> [u8; 64] { + let mut state = State::new(self.key, self.nonce, block); + state.chacha_block(); + state.keystream() + } + + /// Apply the keystream to a buffer updating the cipher block state as necessary. pub fn apply_keystream(&mut self, buffer: &mut [u8]) { - let num_full_blocks = buffer.len() / CHACHA_BLOCKSIZE; - for block in 0..num_full_blocks { - let keystream = - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes); - for (buffer_byte, keystream_byte) in buffer - [block * CHACHA_BLOCKSIZE..(block + 1) * CHACHA_BLOCKSIZE] - .iter_mut() - .zip(keystream.iter()) + // If we have an initial offset, handle the first partial block to get back to alignment. + let remaining_buffer = if self.seek_offset_bytes != 0 { + let bytes_until_aligned = 64 - self.seek_offset_bytes; + let bytes_to_process = buffer.len().min(bytes_until_aligned); + + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in + buffer[..bytes_to_process].iter_mut().zip(&keystream[self.seek_offset_bytes..]) { - *buffer_byte ^= *keystream_byte + *buffer_byte ^= *keystream_byte; + } + + if bytes_to_process < bytes_until_aligned { + self.seek_offset_bytes += bytes_to_process; + return; + } + + self.block_count += 1; + self.seek_offset_bytes = 0; + &mut buffer[bytes_to_process..] + } else { + buffer + }; + + // Process full blocks. + let mut chunks = remaining_buffer.chunks_exact_mut(CHACHA_BLOCKSIZE); + for chunk in &mut chunks { + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in chunk.iter_mut().zip(keystream.iter()) { + *buffer_byte ^= *keystream_byte; } self.block_count += 1; } - if buffer.len() % 64 > 0 { - let keystream = - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes); - for (buffer_byte, keystream_byte) in - buffer[num_full_blocks * CHACHA_BLOCKSIZE..].iter_mut().zip(keystream.iter()) - { - *buffer_byte ^= *keystream_byte + + // Handle any remaining bytes as partial block. + let remainder = chunks.into_remainder(); + if !remainder.is_empty() { + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in remainder.iter_mut().zip(keystream.iter()) { + *buffer_byte ^= *keystream_byte; } - self.block_count += 1; + self.seek_offset_bytes = remainder.len(); } } /// Get the keystream block at a specified block. pub fn get_keystream(&mut self, block: u32) -> [u8; 64] { self.block(block); - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes) + self.keystream_at_block(self.block_count) } /// Update the index of the keystream to the given byte. @@ -320,23 +347,6 @@ impl ChaCha20 { } } -fn keystream_at_slice(key: Key, nonce: Nonce, count: u32, seek: usize) -> [u8; 64] { - let mut keystream: [u8; 128] = [0; 128]; - let (first_half, second_half) = keystream.split_at_mut(64); - - let mut state = State::new(key, nonce, count); - state.chacha_block(); - first_half.copy_from_slice(&state.keystream()); - - let mut state = State::new(key, nonce, count + 1); - state.chacha_block(); - second_half.copy_from_slice(&state.keystream()); - - let seeked_keystream: [u8; 64] = - keystream[seek..seek + 64].try_into().expect("slicing produces 64-byte slice"); - seeked_keystream -} - #[cfg(test)] #[cfg(feature = "alloc")] mod tests { @@ -460,4 +470,35 @@ mod tests { let binding = *b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; assert_eq!(binding, to); } + + #[test] + fn multiple_partial_applies() { + let key = + Key(Vec::from_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") + .unwrap() + .try_into() + .unwrap()); + let nonce = Nonce(Vec::from_hex("000000000000004a00000000").unwrap().try_into().unwrap()); + + // Create two instances, one for a full single pass and one for chunked partial calls. + let mut chacha_full = ChaCha20::new(key, nonce, 0); + let mut chacha_chunked = ChaCha20::new(key, nonce, 0); + + // Test data that crosses block boundaries. + let mut full_buffer = [0u8; 100]; + let mut chunked_buffer = [0u8; 100]; + for (i, byte) in full_buffer.iter_mut().enumerate() { + *byte = i as u8; + } + chunked_buffer.copy_from_slice(&full_buffer); + + // Apply keystream to full buffer. + chacha_full.apply_keystream(&mut full_buffer); + // Apply keystream in multiple calls to chunked buffer. + chacha_chunked.apply_keystream(&mut chunked_buffer[..30]); // Partial block + chacha_chunked.apply_keystream(&mut chunked_buffer[30..82]); // Cross block boundary + chacha_chunked.apply_keystream(&mut chunked_buffer[82..]); // End with partial block + + assert_eq!(full_buffer, chunked_buffer); + } } From 30920c4d849cf25ae3847b51380b90aa331f903f Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 19 Feb 2025 12:12:28 -0800 Subject: [PATCH 5/6] chacha20_poly1305: drop mutable requirement * The get_keystream method exposes the keystream for a block for special case scenarios. Generally the cipher state should only be updated with teh apply_keystream method. --- chacha20_poly1305/src/chacha20.rs | 7 ++----- chacha20_poly1305/src/lib.rs | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index 34f041304..a93e0f1e6 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -328,11 +328,8 @@ impl ChaCha20 { } } - /// Get the keystream block at a specified block. - pub fn get_keystream(&mut self, block: u32) -> [u8; 64] { - self.block(block); - self.keystream_at_block(self.block_count) - } + /// Get the keystream for specified block. + pub fn get_keystream(&self, block: u32) -> [u8; 64] { self.keystream_at_block(block) } /// Update the index of the keystream to the given byte. pub fn seek(&mut self, seek: u32) { diff --git a/chacha20_poly1305/src/lib.rs b/chacha20_poly1305/src/lib.rs index 42b2df3af..1627b511e 100644 --- a/chacha20_poly1305/src/lib.rs +++ b/chacha20_poly1305/src/lib.rs @@ -118,7 +118,7 @@ impl ChaCha20Poly1305 { tag: [u8; 16], aad: Option<&[u8]>, ) -> Result<(), Error> { - let mut chacha = ChaCha20::new_from_block(self.key, self.nonce, 0); + let chacha = ChaCha20::new_from_block(self.key, self.nonce, 0); let keystream = chacha.get_keystream(0); let mut poly = Poly1305::new(keystream[..32].try_into().expect("slicing produces 32-byte slice")); From 1ca55ac77db698f3816d8b7ed4051ddb5a579a29 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Wed, 19 Feb 2025 13:04:20 -0800 Subject: [PATCH 6/6] chacha20_poly1305: inline simd functions * Inline all the block operations to give the compiler the best chance to optimize the SIMD instruction usage. This squeezed out another percent or two on the benchmarks when comparing target-cpu=native builds to standard. --- chacha20_poly1305/src/chacha20.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index a93e0f1e6..b9080b40e 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -223,6 +223,7 @@ impl State { } /// Transform the state by performing the ChaCha block function. + #[inline(always)] fn chacha_block(&mut self) { let mut working_state = self.matrix; @@ -237,6 +238,7 @@ impl State { } /// Expose the 512-bit state as a byte stream. + #[inline(always)] fn keystream(&self) -> [u8; 64] { let mut keystream = [0u8; 64]; for i in 0..4 { @@ -275,6 +277,7 @@ impl ChaCha20 { } /// Get the keystream for a specific block. + #[inline(always)] fn keystream_at_block(&self, block: u32) -> [u8; 64] { let mut state = State::new(self.key, self.nonce, block); state.chacha_block();