diff --git a/src/ecdsa/serialized_signature.rs b/src/ecdsa/serialized_signature.rs index 4fd0449..e89faf3 100644 --- a/src/ecdsa/serialized_signature.rs +++ b/src/ecdsa/serialized_signature.rs @@ -5,6 +5,8 @@ //! unable to run on platforms without allocator. We implement a special type to encapsulate //! serialized signatures and since it's a bit more complicated it has its own module. +pub use into_iter::IntoIter; + use core::{fmt, ops}; use crate::Error; use super::Signature; @@ -62,6 +64,16 @@ impl ops::Deref for SerializedSignature { impl Eq for SerializedSignature {} +impl IntoIterator for SerializedSignature { + type IntoIter = IntoIter; + type Item = u8; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + impl<'a> IntoIterator for &'a SerializedSignature { type IntoIter = core::slice::Iter<'a, u8>; type Item = &'a u8; @@ -107,3 +119,131 @@ impl SerializedSignature { /// Check if the space is zero. pub fn is_empty(&self) -> bool { self.len() == 0 } } + +/// Separate mod to prevent outside code accidentally breaking invariants. +mod into_iter { + use super::*; + + /// Owned iterator over the bytes of [`SerializedSignature`] + /// + /// Created by [`IntoIterator::into_iter`] method. + // allowed because of https://github.com/rust-lang/rust/issues/98348 + #[allow(missing_copy_implementations)] + #[derive(Debug, Clone)] + pub struct IntoIter { + signature: SerializedSignature, + // invariant: pos <= signature.len() + pos: usize, + } + + impl IntoIter { + #[inline] + pub(crate) fn new(signature: SerializedSignature) -> Self { + IntoIter { + signature, + // for all unsigned n: 0 <= n + pos: 0, + } + } + + /// Returns the remaining bytes as a slice. + /// + /// This method is analogous to [`core::slice::Iter::as_slice`]. + #[inline] + pub fn as_slice(&self) -> &[u8] { + &self.signature[self.pos..] + } + } + + impl Iterator for IntoIter { + type Item = u8; + + #[inline] + fn next(&mut self) -> Option { + let byte = *self.signature.get(self.pos)?; + // can't overflow or break invariant because if pos is too large we return early + self.pos += 1; + Some(byte) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // can't underlflow thanks to the invariant + let len = self.signature.len() - self.pos; + (len, Some(len)) + } + + // override for speed + #[inline] + fn nth(&mut self, n: usize) -> Option { + if n >= self.len() { + // upholds invariant becasue the values will be equal + self.pos = self.signature.len(); + None + } else { + // if n < signtature.len() - self.pos then n + self.pos < signature.len() which neither + // overflows nor breaks the invariant + self.pos += n; + self.next() + } + } + } + + impl ExactSizeIterator for IntoIter {} + + impl core::iter::FusedIterator for IntoIter {} + + impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.pos == self.signature.len() { + return None; + } + + // if len is 0 then pos is also 0 thanks to the invariant so we would return before we + // reach this + let new_len = self.signature.len() - 1; + let byte = self.signature[new_len]; + self.signature.set_len(new_len); + Some(byte) + } + } +} + +#[cfg(test)] +mod tests { + use super::SerializedSignature; + + #[test] + fn iterator_ops_are_homomorphic() { + let mut fake_signature_data = [0; 72]; + // fill it with numbers 0 - 71 + for (i, byte) in fake_signature_data.iter_mut().enumerate() { + // up to 72 + *byte = i as u8; + } + + let fake_signature = SerializedSignature { data: fake_signature_data, len: 72 }; + + let mut iter1 = fake_signature.into_iter(); + let mut iter2 = fake_signature.iter(); + + // while let so we can compare size_hint and as_slice + while let (Some(a), Some(b)) = (iter1.next(), iter2.next()) { + assert_eq!(a, *b); + assert_eq!(iter1.size_hint(), iter2.size_hint()); + assert_eq!(iter1.as_slice(), iter2.as_slice()); + } + + let mut iter1 = fake_signature.into_iter(); + let mut iter2 = fake_signature.iter(); + + // manual next_back instead of rev() so that we can check as_slice() + // if next_back is implemented correctly then rev() is also correct - provided by `core` + while let (Some(a), Some(b)) = (iter1.next_back(), iter2.next_back()) { + assert_eq!(a, *b); + assert_eq!(iter1.size_hint(), iter2.size_hint()); + assert_eq!(iter1.as_slice(), iter2.as_slice()); + } + } +}