From 8560baaca2b9463827353bbcea8b8778fba64b28 Mon Sep 17 00:00:00 2001 From: Riccardo Casatta Date: Tue, 25 Jul 2023 14:23:18 +0200 Subject: [PATCH 1/3] Make fields of RawNetworkMessage non public provide accessor method and new for downstream libs. This is done in order to more easily change the struct without impacting downstream and also in order to add another field while preserving struct invariant in future commit. --- bitcoin/src/network/message.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/bitcoin/src/network/message.rs b/bitcoin/src/network/message.rs index c083945d..3a1717bf 100644 --- a/bitcoin/src/network/message.rs +++ b/bitcoin/src/network/message.rs @@ -149,10 +149,8 @@ crate::error::impl_std_error!(CommandStringError); /// A Network message #[derive(Clone, Debug, PartialEq, Eq)] pub struct RawNetworkMessage { - /// Magic bytes to identify the network these messages are meant for - pub magic: Magic, - /// The actual message data - pub payload: NetworkMessage, + magic: Magic, + payload: NetworkMessage, } /// A Network message payload. Proper documentation is available on at @@ -302,6 +300,22 @@ impl NetworkMessage { } impl RawNetworkMessage { + + /// Creates a [RawNetworkMessage] + pub fn new(magic: Magic, payload: NetworkMessage) -> Self { + Self { magic, payload } + } + + /// The actual message data + pub fn payload(&self) -> &NetworkMessage { + &self.payload + } + + /// Magic bytes to identify the network these messages are meant for + pub fn magic(&self) -> &Magic { + &self.magic + } + /// Return the message command as a static string reference. /// /// This returns `"unknown"` for [NetworkMessage::Unknown], From bc66ed82b2bc7363836d9bb255dd45f23b5380d1 Mon Sep 17 00:00:00 2001 From: Riccardo Casatta Date: Tue, 25 Jul 2023 14:43:47 +0200 Subject: [PATCH 2/3] Impl Encodable for NetworkMessage Using it in RawNetworkMessage encoding --- bitcoin/src/network/message.rs | 83 ++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/bitcoin/src/network/message.rs b/bitcoin/src/network/message.rs index 3a1717bf..4c426e4d 100644 --- a/bitcoin/src/network/message.rs +++ b/bitcoin/src/network/message.rs @@ -342,51 +342,56 @@ impl<'a> Encodable for HeaderSerializationWrapper<'a> { } } -impl Encodable for RawNetworkMessage { - fn consensus_encode(&self, w: &mut W) -> Result { - let mut len = 0; - len += self.magic.consensus_encode(w)?; - len += self.command().consensus_encode(w)?; - len += CheckedData(match self.payload { - NetworkMessage::Version(ref dat) => serialize(dat), - NetworkMessage::Addr(ref dat) => serialize(dat), - NetworkMessage::Inv(ref dat) => serialize(dat), - NetworkMessage::GetData(ref dat) => serialize(dat), - NetworkMessage::NotFound(ref dat) => serialize(dat), - NetworkMessage::GetBlocks(ref dat) => serialize(dat), - NetworkMessage::GetHeaders(ref dat) => serialize(dat), - NetworkMessage::Tx(ref dat) => serialize(dat), - NetworkMessage::Block(ref dat) => serialize(dat), - NetworkMessage::Headers(ref dat) => serialize(&HeaderSerializationWrapper(dat)), - NetworkMessage::Ping(ref dat) => serialize(dat), - NetworkMessage::Pong(ref dat) => serialize(dat), - NetworkMessage::MerkleBlock(ref dat) => serialize(dat), - NetworkMessage::FilterLoad(ref dat) => serialize(dat), - NetworkMessage::FilterAdd(ref dat) => serialize(dat), - NetworkMessage::GetCFilters(ref dat) => serialize(dat), - NetworkMessage::CFilter(ref dat) => serialize(dat), - NetworkMessage::GetCFHeaders(ref dat) => serialize(dat), - NetworkMessage::CFHeaders(ref dat) => serialize(dat), - NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat), - NetworkMessage::CFCheckpt(ref dat) => serialize(dat), - NetworkMessage::SendCmpct(ref dat) => serialize(dat), - NetworkMessage::CmpctBlock(ref dat) => serialize(dat), - NetworkMessage::GetBlockTxn(ref dat) => serialize(dat), - NetworkMessage::BlockTxn(ref dat) => serialize(dat), - NetworkMessage::Alert(ref dat) => serialize(dat), - NetworkMessage::Reject(ref dat) => serialize(dat), - NetworkMessage::FeeFilter(ref data) => serialize(data), - NetworkMessage::AddrV2(ref dat) => serialize(dat), +impl Encodable for NetworkMessage { + fn consensus_encode(&self, writer: &mut W) -> Result { + match self { + NetworkMessage::Version(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Addr(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Inv(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetData(ref dat) => dat.consensus_encode(writer), + NetworkMessage::NotFound(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetBlocks(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Block(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Headers(ref dat) => HeaderSerializationWrapper(dat).consensus_encode(writer), + NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer), + NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer), + NetworkMessage::FilterLoad(ref dat) => dat.consensus_encode(writer), + NetworkMessage::FilterAdd(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetCFilters(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CFilter(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetCFHeaders(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CFHeaders(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetCFCheckpt(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CFCheckpt(ref dat) => dat.consensus_encode(writer), + NetworkMessage::SendCmpct(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CmpctBlock(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetBlockTxn(ref dat) => dat.consensus_encode(writer), + NetworkMessage::BlockTxn(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Alert(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Reject(ref dat) => dat.consensus_encode(writer), + NetworkMessage::FeeFilter(ref dat) => dat.consensus_encode(writer), + NetworkMessage::AddrV2(ref dat) => dat.consensus_encode(writer), NetworkMessage::Verack | NetworkMessage::SendHeaders | NetworkMessage::MemPool | NetworkMessage::GetAddr | NetworkMessage::WtxidRelay | NetworkMessage::FilterClear - | NetworkMessage::SendAddrV2 => vec![], - NetworkMessage::Unknown { payload: ref data, .. } => serialize(data), - }) - .consensus_encode(w)?; + | NetworkMessage::SendAddrV2 => Ok(0), + NetworkMessage::Unknown { payload: ref data, .. } => data.consensus_encode(writer), + } + } +} + +impl Encodable for RawNetworkMessage { + fn consensus_encode(&self, w: &mut W) -> Result { + let mut len = 0; + len += self.magic.consensus_encode(w)?; + len += self.command().consensus_encode(w)?; + len += CheckedData(serialize(self.payload())).consensus_encode(w)?; Ok(len) } } From 5c8933001c4e3dff412ec7dddda7b357e14db745 Mon Sep 17 00:00:00 2001 From: Riccardo Casatta Date: Tue, 25 Jul 2023 16:24:44 +0200 Subject: [PATCH 3/3] Avoid serialize inner data in RawNetworkMessage RawNetworkMessage keep the payload_len and its checksum in the struct, thus is not needed to serialize the inner network message pub in fields of both RawNetworkMessage and CheckedData are removed so that invariant are preserved. --- bitcoin/examples/handshake.rs | 22 +++++++------- bitcoin/src/consensus/encode.rs | 46 ++++++++++++++++++++-------- bitcoin/src/network/message.rs | 53 ++++++++++++++++++--------------- 3 files changed, 73 insertions(+), 48 deletions(-) diff --git a/bitcoin/examples/handshake.rs b/bitcoin/examples/handshake.rs index 8852b042..b3f98d18 100644 --- a/bitcoin/examples/handshake.rs +++ b/bitcoin/examples/handshake.rs @@ -28,10 +28,8 @@ fn main() { let version_message = build_version_message(address); - let first_message = message::RawNetworkMessage { - magic: constants::Network::Bitcoin.magic(), - payload: version_message, - }; + let first_message = + message::RawNetworkMessage::new(constants::Network::Bitcoin.magic(), version_message); if let Ok(mut stream) = TcpStream::connect(address) { // Send the message @@ -44,24 +42,24 @@ fn main() { loop { // Loop an retrieve new messages let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap(); - match reply.payload { + match reply.payload() { message::NetworkMessage::Version(_) => { - println!("Received version message: {:?}", reply.payload); + println!("Received version message: {:?}", reply.payload()); - let second_message = message::RawNetworkMessage { - magic: constants::Network::Bitcoin.magic(), - payload: message::NetworkMessage::Verack, - }; + let second_message = message::RawNetworkMessage::new( + constants::Network::Bitcoin.magic(), + message::NetworkMessage::Verack, + ); let _ = stream.write_all(encode::serialize(&second_message).as_slice()); println!("Sent verack message"); } message::NetworkMessage::Verack => { - println!("Received verack message: {:?}", reply.payload); + println!("Received verack message: {:?}", reply.payload()); break; } _ => { - println!("Received unknown message: {:?}", reply.payload); + println!("Received unknown message: {:?}", reply.payload()); break; } } diff --git a/bitcoin/src/consensus/encode.rs b/bitcoin/src/consensus/encode.rs index 1ff023cc..7f5ab21e 100644 --- a/bitcoin/src/consensus/encode.rs +++ b/bitcoin/src/consensus/encode.rs @@ -15,7 +15,7 @@ //! typically big-endian decimals, etc.) //! -use core::convert::From; +use core::convert::{From, TryFrom}; use core::{fmt, mem, u32}; use hashes::{sha256, sha256d, Hash}; @@ -325,9 +325,29 @@ pub trait Decodable: Sized { #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)] pub struct VarInt(pub u64); -/// Data which must be preceded by a 4-byte checksum. +/// Data and a 4-byte checksum. #[derive(PartialEq, Eq, Clone, Debug)] -pub struct CheckedData(pub Vec); +pub struct CheckedData { + data: Vec, + checksum: [u8; 4], +} + +impl CheckedData { + /// Creates a new `CheckedData` computing the checksum of given data. + pub fn new(data: Vec) -> Self { + let checksum = sha2_checksum(&data); + Self { data, checksum } + } + + /// Returns a reference to the raw data without the checksum. + pub fn data(&self) -> &[u8] { &self.data } + + /// Returns the raw data without the checksum. + pub fn into_data(self) -> Vec { self.data } + + /// Returns the checksum of the data. + pub fn checksum(&self) -> [u8; 4] { self.checksum } +} // Primitive types macro_rules! impl_int_encodable { @@ -686,10 +706,12 @@ fn sha2_checksum(data: &[u8]) -> [u8; 4] { impl Encodable for CheckedData { #[inline] fn consensus_encode(&self, w: &mut W) -> Result { - (self.0.len() as u32).consensus_encode(w)?; - sha2_checksum(&self.0).consensus_encode(w)?; - w.emit_slice(&self.0)?; - Ok(8 + self.0.len()) + u32::try_from(self.data.len()) + .expect("network message use u32 as length") + .consensus_encode(w)?; + self.checksum().consensus_encode(w)?; + w.emit_slice(&self.data)?; + Ok(8 + self.data.len()) } } @@ -700,12 +722,12 @@ impl Decodable for CheckedData { let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?; let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE }; - let ret = read_bytes_from_finite_reader(r, opts)?; - let expected_checksum = sha2_checksum(&ret); + let data = read_bytes_from_finite_reader(r, opts)?; + let expected_checksum = sha2_checksum(&data); if expected_checksum != checksum { Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum }) } else { - Ok(CheckedData(ret)) + Ok(CheckedData { data, checksum }) } } } @@ -978,7 +1000,7 @@ mod tests { #[test] fn serialize_checkeddata_test() { - let cd = CheckedData(vec![1u8, 2, 3, 4, 5]); + let cd = CheckedData::new(vec![1u8, 2, 3, 4, 5]); assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]); } @@ -1137,7 +1159,7 @@ mod tests { fn deserialize_checkeddata_test() { let cd: Result = deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]); - assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5]))); + assert_eq!(cd.ok(), Some(CheckedData::new(vec![1u8, 2, 3, 4, 5]))); } #[test] diff --git a/bitcoin/src/network/message.rs b/bitcoin/src/network/message.rs index 4c426e4d..a98119ff 100644 --- a/bitcoin/src/network/message.rs +++ b/bitcoin/src/network/message.rs @@ -9,11 +9,11 @@ use core::convert::TryFrom; use core::{fmt, iter}; +use hashes::{sha256d, Hash}; use io::Read as _; use crate::blockdata::{block, transaction}; -use crate::consensus::encode::{CheckedData, Decodable, Encodable, VarInt}; -use crate::consensus::{encode, serialize}; +use crate::consensus::encode::{self, CheckedData, Decodable, Encodable, VarInt}; use crate::io; use crate::merkle_tree::MerkleBlock; use crate::network::address::{AddrV2Message, Address}; @@ -151,6 +151,8 @@ crate::error::impl_std_error!(CommandStringError); pub struct RawNetworkMessage { magic: Magic, payload: NetworkMessage, + payload_len: u32, + checksum: [u8; 4], } /// A Network message payload. Proper documentation is available on at @@ -300,21 +302,21 @@ impl NetworkMessage { } impl RawNetworkMessage { - /// Creates a [RawNetworkMessage] pub fn new(magic: Magic, payload: NetworkMessage) -> Self { - Self { magic, payload } + let mut engine = sha256d::Hash::engine(); + let payload_len = payload.consensus_encode(&mut engine).expect("engine doesn't error"); + let payload_len = u32::try_from(payload_len).expect("network message use u32 as length"); + let checksum = sha256d::Hash::from_engine(engine); + let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]]; + Self { magic, payload, payload_len, checksum } } /// The actual message data - pub fn payload(&self) -> &NetworkMessage { - &self.payload - } + pub fn payload(&self) -> &NetworkMessage { &self.payload } /// Magic bytes to identify the network these messages are meant for - pub fn magic(&self) -> &Magic { - &self.magic - } + pub fn magic(&self) -> &Magic { &self.magic } /// Return the message command as a static string reference. /// @@ -354,7 +356,8 @@ impl Encodable for NetworkMessage { NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer), NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer), NetworkMessage::Block(ref dat) => dat.consensus_encode(writer), - NetworkMessage::Headers(ref dat) => HeaderSerializationWrapper(dat).consensus_encode(writer), + NetworkMessage::Headers(ref dat) => + HeaderSerializationWrapper(dat).consensus_encode(writer), NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer), NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer), NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer), @@ -391,7 +394,9 @@ impl Encodable for RawNetworkMessage { let mut len = 0; len += self.magic.consensus_encode(w)?; len += self.command().consensus_encode(w)?; - len += CheckedData(serialize(self.payload())).consensus_encode(w)?; + len += self.payload_len.consensus_encode(w)?; + len += self.checksum.consensus_encode(w)?; + len += self.payload().consensus_encode(w)?; Ok(len) } } @@ -430,7 +435,10 @@ impl Decodable for RawNetworkMessage { ) -> Result { let magic = Decodable::consensus_decode_from_finite_reader(r)?; let cmd = CommandString::consensus_decode_from_finite_reader(r)?; - let raw_payload = CheckedData::consensus_decode_from_finite_reader(r)?.0; + let checked_data = CheckedData::consensus_decode_from_finite_reader(r)?; + let checksum = checked_data.checksum(); + let raw_payload = checked_data.into_data(); + let payload_len = raw_payload.len() as u32; let mut mem_d = io::Cursor::new(raw_payload); let payload = match &cmd.0[..] { @@ -517,7 +525,7 @@ impl Decodable for RawNetworkMessage { "sendaddrv2" => NetworkMessage::SendAddrV2, _ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() }, }; - Ok(RawNetworkMessage { magic, payload }) + Ok(RawNetworkMessage { magic, payload, payload_len, checksum }) } #[inline] @@ -661,8 +669,7 @@ mod test { ]; for msg in msgs { - let raw_msg = - RawNetworkMessage { magic: Magic::from_bytes([57, 0, 0, 0]), payload: msg }; + let raw_msg = RawNetworkMessage::new(Magic::from_bytes([57, 0, 0, 0]), msg); assert_eq!(deserialize::(&serialize(&raw_msg)).unwrap(), raw_msg); } } @@ -695,7 +702,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_verack_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Verack }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Verack)), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61, 0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); @@ -704,7 +711,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_ping_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Ping(100) }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Ping(100))), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d, @@ -714,7 +721,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_mempool_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::MemPool }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::MemPool)), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70, 0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); @@ -723,7 +730,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_getaddr_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::GetAddr }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr)), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61, 0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); @@ -737,10 +744,8 @@ mod test { 0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2 ]); - let preimage = RawNetworkMessage { - magic: Magic::from(Network::Bitcoin), - payload: NetworkMessage::GetAddr, - }; + let preimage = + RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr); assert!(msg.is_ok()); let msg: RawNetworkMessage = msg.unwrap(); assert_eq!(preimage.magic, msg.magic);