diff --git a/bitcoin/src/blockdata/witness.rs b/bitcoin/src/blockdata/witness.rs index 40311121..fd47f9dd 100644 --- a/bitcoin/src/blockdata/witness.rs +++ b/bitcoin/src/blockdata/witness.rs @@ -5,6 +5,9 @@ //! This module contains the [`Witness`] struct and related methods to operate on it //! +use core::convert::TryInto; +use core::ops::Index; + use secp256k1::ecdsa; use crate::consensus::encode::{Error, MAX_VEC_SIZE}; @@ -14,6 +17,8 @@ use crate::io::{self, Read, Write}; use crate::prelude::*; use crate::VarInt; +const U32_SIZE: usize = core::mem::size_of::(); + /// The Witness is the data used to unlock bitcoins since the [segwit upgrade](https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki) /// /// Can be logically seen as an array of byte-arrays `Vec>` and indeed you can convert from @@ -33,38 +38,42 @@ pub struct Witness { /// like [`Witness::push`] doesn't have case requiring to shift the entire array witness_elements: usize, - /// If `witness_elements > 0` it's a valid index pointing to the last witness element in `content` - /// (Including the varint specifying the length of the element) - last: usize, - - /// If `witness_elements > 1` it's a valid index pointing to the second-to-last witness element in `content` - /// (Including the varint specifying the length of the element) - second_to_last: usize, + /// This is the valid index pointing to the beginning of the index area. This area is 4 * stack_size bytes + /// at the end of the content vector which stores the indices of each item. + indices_start: usize, } /// Support structure to allow efficient and convenient iteration over the Witness elements pub struct Iter<'a> { - inner: core::slice::Iter<'a, u8>, - remaining: usize, + inner: &'a [u8], + indices_start: usize, + current_index: usize, } impl Decodable for Witness { fn consensus_decode(r: &mut R) -> Result { let witness_elements = VarInt::consensus_decode(r)?.0 as usize; + // Minimum size of witness element is 1 byte, so if the count is + // greater than MAX_VEC_SIZE we must return an error. + if witness_elements > MAX_VEC_SIZE { + return Err(self::Error::OversizedVectorAllocation { + requested: witness_elements, + max: MAX_VEC_SIZE, + }); + } if witness_elements == 0 { Ok(Witness::default()) } else { - let mut cursor = 0usize; - let mut last = 0usize; - let mut second_to_last = 0usize; + // Leave space at the head for element positions. + // We will rotate them to the end of the Vec later. + let witness_index_space = witness_elements * U32_SIZE; + let mut cursor = witness_index_space; // this number should be determined as high enough to cover most witness, and low enough // to avoid wasting space without reallocating - let mut content = vec![0u8; 128]; + let mut content = vec![0u8; cursor + 128]; - for _ in 0..witness_elements { - second_to_last = last; - last = cursor; + for i in 0..witness_elements { let element_size_varint = VarInt::consensus_decode(r)?; let element_size_varint_len = element_size_varint.len(); let element_size = element_size_varint.0 as usize; @@ -80,13 +89,17 @@ impl Decodable for Witness { max: MAX_VEC_SIZE, })?; - if required_len > MAX_VEC_SIZE { + if required_len > MAX_VEC_SIZE + witness_index_space { return Err(self::Error::OversizedVectorAllocation { requested: required_len, max: MAX_VEC_SIZE, }); } + // We will do content.rotate_left(witness_index_space) later. + // Encode the position's value AFTER we rotate left. + encode_cursor(&mut content, 0, i, cursor - witness_index_space); + resize_if_needed(&mut content, required_len); element_size_varint .consensus_encode(&mut &mut content[cursor..cursor + element_size_varint_len])?; @@ -95,16 +108,37 @@ impl Decodable for Witness { cursor += element_size; } content.truncate(cursor); + // Index space is now at the end of the Vec + content.rotate_left(witness_index_space); Ok(Witness { content, witness_elements, - last, - second_to_last, + indices_start: cursor - witness_index_space, }) } } } + +/// Safety Requirements: value must always fit within u32 +#[inline] +fn encode_cursor(bytes: &mut [u8], start_of_indices: usize, index: usize, value: usize) { + let start = start_of_indices + index * U32_SIZE; + let end = start + U32_SIZE; + bytes[start..end].copy_from_slice(&(value as u32).to_ne_bytes()[..]); +} + +#[inline] +fn decode_cursor(bytes: &[u8], start_of_indices: usize, index: usize) -> Option { + let start = start_of_indices + index * U32_SIZE; + let end = start + U32_SIZE; + if end > bytes.len() { + None + } else { + Some(u32::from_ne_bytes(bytes[start..end].try_into().expect("is u32 size")) as usize) + } +} + fn resize_if_needed(vec: &mut Vec, required_len: usize) { if required_len >= vec.len() { let mut new_len = vec.len().max(1); @@ -119,8 +153,11 @@ impl Encodable for Witness { fn consensus_encode(&self, w: &mut W) -> Result { let len = VarInt(self.witness_elements as u64); len.consensus_encode(w)?; - w.emit_slice(&self.content[..])?; - Ok(self.content.len() + len.len()) + let content_with_indices_len = self.content.len(); + let indices_size = self.witness_elements * U32_SIZE; + let content_len = content_with_indices_len - indices_size; + w.emit_slice(&self.content[..content_len])?; + Ok(content_len + len.len()) } } @@ -133,18 +170,17 @@ impl Witness { /// Creates [`Witness`] object from an array of byte-arrays pub fn from_vec(vec: Vec>) -> Self { let witness_elements = vec.len(); + let index_size = witness_elements * U32_SIZE; let content_size: usize = vec .iter() .map(|el| el.len() + VarInt(el.len() as u64).len()) .sum(); - let mut content = vec![0u8; content_size]; + let mut content = vec![0u8; content_size + index_size]; let mut cursor = 0usize; - let mut last = 0; - let mut second_to_last = 0; - for el in vec { - second_to_last = last; - last = cursor; + for (i, el) in vec.into_iter().enumerate() { + encode_cursor(&mut content, content_size, i, cursor); + let el_len_varint = VarInt(el.len() as u64); el_len_varint .consensus_encode(&mut &mut content[cursor..cursor + el_len_varint.len()]) @@ -157,8 +193,7 @@ impl Witness { Witness { witness_elements, content, - last, - second_to_last, + indices_start: content_size, } } @@ -174,7 +209,11 @@ impl Witness { /// Returns a struct implementing [`Iterator`] pub fn iter(&self) -> Iter { - Iter { inner: self.content.iter(), remaining: self.witness_elements } + Iter { + inner: self.content.as_slice(), + indices_start: self.indices_start, + current_index: 0, + } } /// Returns the number of elements this witness holds @@ -194,25 +233,29 @@ impl Witness { pub fn clear(&mut self) { self.content.clear(); self.witness_elements = 0; - self.last = 0; - self.second_to_last = 0; + self.indices_start = 0; } /// Push a new element on the witness, requires an allocation pub fn push>(&mut self, new_element: T) { let new_element = new_element.as_ref(); self.witness_elements += 1; - self.second_to_last = self.last; - self.last = self.content.len(); + let previous_content_end = self.indices_start; let element_len_varint = VarInt(new_element.len() as u64); let current_content_len = self.content.len(); + let new_item_total_len = element_len_varint.len() + new_element.len(); self.content - .resize(current_content_len + element_len_varint.len() + new_element.len(), 0); - let end_varint = current_content_len + element_len_varint.len(); + .resize(current_content_len + new_item_total_len + U32_SIZE, 0); + + self.content[previous_content_end..].rotate_right(new_item_total_len); + self.indices_start += new_item_total_len; + encode_cursor(&mut self.content, self.indices_start, self.witness_elements - 1, previous_content_end); + + let end_varint = previous_content_end + element_len_varint.len(); element_len_varint - .consensus_encode(&mut &mut self.content[current_content_len..end_varint]) + .consensus_encode(&mut &mut self.content[previous_content_end..end_varint]) .expect("writers on vec don't error, space granted through previous resize"); - self.content[end_varint..].copy_from_slice(new_element); + self.content[end_varint..end_varint + new_element.len()].copy_from_slice(new_element); } /// Pushes a DER-encoded ECDSA signature with a signature hash type as a new element on the @@ -237,7 +280,7 @@ impl Witness { if self.witness_elements == 0 { None } else { - self.element_at(self.last) + self.nth(self.witness_elements - 1) } } @@ -246,29 +289,42 @@ impl Witness { if self.witness_elements <= 1 { None } else { - self.element_at(self.second_to_last) + self.nth(self.witness_elements - 2) } } + + /// Return the nth element in the witness, if any + pub fn nth(&self, index: usize) -> Option<&[u8]> { + let pos = decode_cursor(&self.content, self.indices_start, index)?; + self.element_at(pos) + } +} + +impl Index for Witness { + type Output = [u8]; + + fn index(&self, index: usize) -> &Self::Output { + self.nth(index).expect("Out of Bounds") + } } impl<'a> Iterator for Iter<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { - let varint = VarInt::consensus_decode(&mut self.inner.as_slice()).ok()?; - self.inner.nth(varint.len() - 1)?; // VarInt::len returns at least 1 - let len = varint.0 as usize; - let slice = &self.inner.as_slice()[..len]; - if len > 0 { - // we don't need to advance if the element is empty - self.inner.nth(len - 1)?; - } - self.remaining -= 1; + let index = decode_cursor(self.inner, self.indices_start, self.current_index)?; + let varint = VarInt::consensus_decode(&mut &self.inner[index..]).ok()?; + let start = index + varint.len(); + let end = start + varint.0 as usize; + let slice = &self.inner[start..end]; + self.current_index += 1; Some(slice) } fn size_hint(&self) -> (usize, Option) { - (self.remaining, Some(self.remaining)) + let total_count = (self.inner.len() - self.indices_start) / U32_SIZE; + let remaining = total_count - self.current_index; + (remaining, Some(remaining)) } } @@ -373,31 +429,67 @@ mod test { use crate::Transaction; use crate::secp256k1::ecdsa; + fn append_u32_vec(mut v: Vec, n: &[u32]) -> Vec { + for &num in n { + v.extend_from_slice(&num.to_ne_bytes()[..]); + } + v + } + #[test] fn test_push() { let mut witness = Witness::default(); assert_eq!(witness.last(), None); assert_eq!(witness.second_to_last(), None); + assert_eq!(witness.nth(0), None); + assert_eq!(witness.nth(1), None); + assert_eq!(witness.nth(2), None); + assert_eq!(witness.nth(3), None); witness.push(&vec![0u8]); let expected = Witness { witness_elements: 1, - content: vec![1u8, 0], - last: 0, - second_to_last: 0, + content: append_u32_vec(vec![1u8, 0], &[0]), + indices_start: 2, }; assert_eq!(witness, expected); assert_eq!(witness.last(), Some(&[0u8][..])); assert_eq!(witness.second_to_last(), None); + assert_eq!(witness.nth(0), Some(&[0u8][..])); + assert_eq!(witness.nth(1), None); + assert_eq!(witness.nth(2), None); + assert_eq!(witness.nth(3), None); + assert_eq!(&witness[0], &[0u8][..]); witness.push(&vec![2u8, 3u8]); let expected = Witness { witness_elements: 2, - content: vec![1u8, 0, 2, 2, 3], - last: 2, - second_to_last: 0, + content: append_u32_vec(vec![1u8, 0, 2, 2, 3], &[0, 2]), + indices_start: 5, }; assert_eq!(witness, expected); assert_eq!(witness.last(), Some(&[2u8, 3u8][..])); assert_eq!(witness.second_to_last(), Some(&[0u8][..])); + assert_eq!(witness.nth(0), Some(&[0u8][..])); + assert_eq!(witness.nth(1), Some(&[2u8, 3u8][..])); + assert_eq!(witness.nth(2), None); + assert_eq!(witness.nth(3), None); + assert_eq!(&witness[0], &[0u8][..]); + assert_eq!(&witness[1], &[2u8, 3u8][..]); + witness.push(&vec![4u8, 5u8]); + let expected = Witness { + witness_elements: 3, + content: append_u32_vec(vec![1u8, 0, 2, 2, 3, 2, 4, 5], &[0, 2, 5]), + indices_start: 8, + }; + assert_eq!(witness, expected); + assert_eq!(witness.last(), Some(&[4u8, 5u8][..])); + assert_eq!(witness.second_to_last(), Some(&[2u8, 3u8][..])); + assert_eq!(witness.nth(0), Some(&[0u8][..])); + assert_eq!(witness.nth(1), Some(&[2u8, 3u8][..])); + assert_eq!(witness.nth(2), Some(&[4u8, 5u8][..])); + assert_eq!(witness.nth(3), None); + assert_eq!(&witness[0], &[0u8][..]); + assert_eq!(&witness[1], &[2u8, 3u8][..]); + assert_eq!(&witness[2], &[4u8, 5u8][..]); } @@ -438,16 +530,20 @@ mod test { let witness_vec = vec![w0.clone(), w1.clone()]; let witness_serialized: Vec = serialize(&witness_vec); let witness = Witness { - content: witness_serialized[1..].to_vec(), + content: append_u32_vec(witness_serialized[1..].to_vec(), &[0, 34]), witness_elements: 2, - last: 34, - second_to_last: 0, + indices_start: 38, }; for (i, el) in witness.iter().enumerate() { assert_eq!(witness_vec[i], el); } assert_eq!(witness.last(), Some(&w1[..])); assert_eq!(witness.second_to_last(), Some(&w0[..])); + assert_eq!(witness.nth(0), Some(&w0[..])); + assert_eq!(witness.nth(1), Some(&w1[..])); + assert_eq!(witness.nth(2), None); + assert_eq!(&witness[0], &w0[..]); + assert_eq!(&witness[1], &w1[..]); let w_into = Witness::from_vec(witness_vec); assert_eq!(w_into, witness); @@ -467,6 +563,11 @@ mod test { } assert_eq!(expected_wit[1], tx.input[0].witness.last().unwrap().to_hex()); assert_eq!(expected_wit[0], tx.input[0].witness.second_to_last().unwrap().to_hex()); + assert_eq!(expected_wit[0], tx.input[0].witness.nth(0).unwrap().to_hex()); + assert_eq!(expected_wit[1], tx.input[0].witness.nth(1).unwrap().to_hex()); + assert_eq!(None, tx.input[0].witness.nth(2)); + assert_eq!(expected_wit[0], tx.input[0].witness[0].to_hex()); + assert_eq!(expected_wit[1], tx.input[0].witness[1].to_hex()); let tx_bytes_back = serialize(&tx); assert_eq!(tx_bytes_back, tx_bytes);