diff --git a/chacha20_poly1305/src/chacha20.rs b/chacha20_poly1305/src/chacha20.rs index c611ec64e..34f041304 100644 --- a/chacha20_poly1305/src/chacha20.rs +++ b/chacha20_poly1305/src/chacha20.rs @@ -274,37 +274,64 @@ impl ChaCha20 { ChaCha20 { key, nonce, block_count: block, seek_offset_bytes: 0 } } - /// Apply the keystream to a buffer. + /// Get the keystream for a specific block. + fn keystream_at_block(&self, block: u32) -> [u8; 64] { + let mut state = State::new(self.key, self.nonce, block); + state.chacha_block(); + state.keystream() + } + + /// Apply the keystream to a buffer updating the cipher block state as necessary. pub fn apply_keystream(&mut self, buffer: &mut [u8]) { - let num_full_blocks = buffer.len() / CHACHA_BLOCKSIZE; - for block in 0..num_full_blocks { - let keystream = - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes); - for (buffer_byte, keystream_byte) in buffer - [block * CHACHA_BLOCKSIZE..(block + 1) * CHACHA_BLOCKSIZE] - .iter_mut() - .zip(keystream.iter()) + // If we have an initial offset, handle the first partial block to get back to alignment. + let remaining_buffer = if self.seek_offset_bytes != 0 { + let bytes_until_aligned = 64 - self.seek_offset_bytes; + let bytes_to_process = buffer.len().min(bytes_until_aligned); + + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in + buffer[..bytes_to_process].iter_mut().zip(&keystream[self.seek_offset_bytes..]) { - *buffer_byte ^= *keystream_byte + *buffer_byte ^= *keystream_byte; + } + + if bytes_to_process < bytes_until_aligned { + self.seek_offset_bytes += bytes_to_process; + return; + } + + self.block_count += 1; + self.seek_offset_bytes = 0; + &mut buffer[bytes_to_process..] + } else { + buffer + }; + + // Process full blocks. + let mut chunks = remaining_buffer.chunks_exact_mut(CHACHA_BLOCKSIZE); + for chunk in &mut chunks { + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in chunk.iter_mut().zip(keystream.iter()) { + *buffer_byte ^= *keystream_byte; } self.block_count += 1; } - if buffer.len() % 64 > 0 { - let keystream = - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes); - for (buffer_byte, keystream_byte) in - buffer[num_full_blocks * CHACHA_BLOCKSIZE..].iter_mut().zip(keystream.iter()) - { - *buffer_byte ^= *keystream_byte + + // Handle any remaining bytes as partial block. + let remainder = chunks.into_remainder(); + if !remainder.is_empty() { + let keystream = self.keystream_at_block(self.block_count); + for (buffer_byte, keystream_byte) in remainder.iter_mut().zip(keystream.iter()) { + *buffer_byte ^= *keystream_byte; } - self.block_count += 1; + self.seek_offset_bytes = remainder.len(); } } /// Get the keystream block at a specified block. pub fn get_keystream(&mut self, block: u32) -> [u8; 64] { self.block(block); - keystream_at_slice(self.key, self.nonce, self.block_count, self.seek_offset_bytes) + self.keystream_at_block(self.block_count) } /// Update the index of the keystream to the given byte. @@ -320,23 +347,6 @@ impl ChaCha20 { } } -fn keystream_at_slice(key: Key, nonce: Nonce, count: u32, seek: usize) -> [u8; 64] { - let mut keystream: [u8; 128] = [0; 128]; - let (first_half, second_half) = keystream.split_at_mut(64); - - let mut state = State::new(key, nonce, count); - state.chacha_block(); - first_half.copy_from_slice(&state.keystream()); - - let mut state = State::new(key, nonce, count + 1); - state.chacha_block(); - second_half.copy_from_slice(&state.keystream()); - - let seeked_keystream: [u8; 64] = - keystream[seek..seek + 64].try_into().expect("slicing produces 64-byte slice"); - seeked_keystream -} - #[cfg(test)] #[cfg(feature = "alloc")] mod tests { @@ -460,4 +470,35 @@ mod tests { let binding = *b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; assert_eq!(binding, to); } + + #[test] + fn multiple_partial_applies() { + let key = + Key(Vec::from_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") + .unwrap() + .try_into() + .unwrap()); + let nonce = Nonce(Vec::from_hex("000000000000004a00000000").unwrap().try_into().unwrap()); + + // Create two instances, one for a full single pass and one for chunked partial calls. + let mut chacha_full = ChaCha20::new(key, nonce, 0); + let mut chacha_chunked = ChaCha20::new(key, nonce, 0); + + // Test data that crosses block boundaries. + let mut full_buffer = [0u8; 100]; + let mut chunked_buffer = [0u8; 100]; + for (i, byte) in full_buffer.iter_mut().enumerate() { + *byte = i as u8; + } + chunked_buffer.copy_from_slice(&full_buffer); + + // Apply keystream to full buffer. + chacha_full.apply_keystream(&mut full_buffer); + // Apply keystream in multiple calls to chunked buffer. + chacha_chunked.apply_keystream(&mut chunked_buffer[..30]); // Partial block + chacha_chunked.apply_keystream(&mut chunked_buffer[30..82]); // Cross block boundary + chacha_chunked.apply_keystream(&mut chunked_buffer[82..]); // End with partial block + + assert_eq!(full_buffer, chunked_buffer); + } }