diff --git a/bitcoin/examples/handshake.rs b/bitcoin/examples/handshake.rs index 8852b042..b3f98d18 100644 --- a/bitcoin/examples/handshake.rs +++ b/bitcoin/examples/handshake.rs @@ -28,10 +28,8 @@ fn main() { let version_message = build_version_message(address); - let first_message = message::RawNetworkMessage { - magic: constants::Network::Bitcoin.magic(), - payload: version_message, - }; + let first_message = + message::RawNetworkMessage::new(constants::Network::Bitcoin.magic(), version_message); if let Ok(mut stream) = TcpStream::connect(address) { // Send the message @@ -44,24 +42,24 @@ fn main() { loop { // Loop an retrieve new messages let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap(); - match reply.payload { + match reply.payload() { message::NetworkMessage::Version(_) => { - println!("Received version message: {:?}", reply.payload); + println!("Received version message: {:?}", reply.payload()); - let second_message = message::RawNetworkMessage { - magic: constants::Network::Bitcoin.magic(), - payload: message::NetworkMessage::Verack, - }; + let second_message = message::RawNetworkMessage::new( + constants::Network::Bitcoin.magic(), + message::NetworkMessage::Verack, + ); let _ = stream.write_all(encode::serialize(&second_message).as_slice()); println!("Sent verack message"); } message::NetworkMessage::Verack => { - println!("Received verack message: {:?}", reply.payload); + println!("Received verack message: {:?}", reply.payload()); break; } _ => { - println!("Received unknown message: {:?}", reply.payload); + println!("Received unknown message: {:?}", reply.payload()); break; } } diff --git a/bitcoin/src/consensus/encode.rs b/bitcoin/src/consensus/encode.rs index 1ff023cc..7f5ab21e 100644 --- a/bitcoin/src/consensus/encode.rs +++ b/bitcoin/src/consensus/encode.rs @@ -15,7 +15,7 @@ //! typically big-endian decimals, etc.) //! -use core::convert::From; +use core::convert::{From, TryFrom}; use core::{fmt, mem, u32}; use hashes::{sha256, sha256d, Hash}; @@ -325,9 +325,29 @@ pub trait Decodable: Sized { #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)] pub struct VarInt(pub u64); -/// Data which must be preceded by a 4-byte checksum. +/// Data and a 4-byte checksum. #[derive(PartialEq, Eq, Clone, Debug)] -pub struct CheckedData(pub Vec); +pub struct CheckedData { + data: Vec, + checksum: [u8; 4], +} + +impl CheckedData { + /// Creates a new `CheckedData` computing the checksum of given data. + pub fn new(data: Vec) -> Self { + let checksum = sha2_checksum(&data); + Self { data, checksum } + } + + /// Returns a reference to the raw data without the checksum. + pub fn data(&self) -> &[u8] { &self.data } + + /// Returns the raw data without the checksum. + pub fn into_data(self) -> Vec { self.data } + + /// Returns the checksum of the data. + pub fn checksum(&self) -> [u8; 4] { self.checksum } +} // Primitive types macro_rules! impl_int_encodable { @@ -686,10 +706,12 @@ fn sha2_checksum(data: &[u8]) -> [u8; 4] { impl Encodable for CheckedData { #[inline] fn consensus_encode(&self, w: &mut W) -> Result { - (self.0.len() as u32).consensus_encode(w)?; - sha2_checksum(&self.0).consensus_encode(w)?; - w.emit_slice(&self.0)?; - Ok(8 + self.0.len()) + u32::try_from(self.data.len()) + .expect("network message use u32 as length") + .consensus_encode(w)?; + self.checksum().consensus_encode(w)?; + w.emit_slice(&self.data)?; + Ok(8 + self.data.len()) } } @@ -700,12 +722,12 @@ impl Decodable for CheckedData { let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?; let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE }; - let ret = read_bytes_from_finite_reader(r, opts)?; - let expected_checksum = sha2_checksum(&ret); + let data = read_bytes_from_finite_reader(r, opts)?; + let expected_checksum = sha2_checksum(&data); if expected_checksum != checksum { Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum }) } else { - Ok(CheckedData(ret)) + Ok(CheckedData { data, checksum }) } } } @@ -978,7 +1000,7 @@ mod tests { #[test] fn serialize_checkeddata_test() { - let cd = CheckedData(vec![1u8, 2, 3, 4, 5]); + let cd = CheckedData::new(vec![1u8, 2, 3, 4, 5]); assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]); } @@ -1137,7 +1159,7 @@ mod tests { fn deserialize_checkeddata_test() { let cd: Result = deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]); - assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5]))); + assert_eq!(cd.ok(), Some(CheckedData::new(vec![1u8, 2, 3, 4, 5]))); } #[test] diff --git a/bitcoin/src/network/message.rs b/bitcoin/src/network/message.rs index c083945d..a98119ff 100644 --- a/bitcoin/src/network/message.rs +++ b/bitcoin/src/network/message.rs @@ -9,11 +9,11 @@ use core::convert::TryFrom; use core::{fmt, iter}; +use hashes::{sha256d, Hash}; use io::Read as _; use crate::blockdata::{block, transaction}; -use crate::consensus::encode::{CheckedData, Decodable, Encodable, VarInt}; -use crate::consensus::{encode, serialize}; +use crate::consensus::encode::{self, CheckedData, Decodable, Encodable, VarInt}; use crate::io; use crate::merkle_tree::MerkleBlock; use crate::network::address::{AddrV2Message, Address}; @@ -149,10 +149,10 @@ crate::error::impl_std_error!(CommandStringError); /// A Network message #[derive(Clone, Debug, PartialEq, Eq)] pub struct RawNetworkMessage { - /// Magic bytes to identify the network these messages are meant for - pub magic: Magic, - /// The actual message data - pub payload: NetworkMessage, + magic: Magic, + payload: NetworkMessage, + payload_len: u32, + checksum: [u8; 4], } /// A Network message payload. Proper documentation is available on at @@ -302,6 +302,22 @@ impl NetworkMessage { } impl RawNetworkMessage { + /// Creates a [RawNetworkMessage] + pub fn new(magic: Magic, payload: NetworkMessage) -> Self { + let mut engine = sha256d::Hash::engine(); + let payload_len = payload.consensus_encode(&mut engine).expect("engine doesn't error"); + let payload_len = u32::try_from(payload_len).expect("network message use u32 as length"); + let checksum = sha256d::Hash::from_engine(engine); + let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]]; + Self { magic, payload, payload_len, checksum } + } + + /// The actual message data + pub fn payload(&self) -> &NetworkMessage { &self.payload } + + /// Magic bytes to identify the network these messages are meant for + pub fn magic(&self) -> &Magic { &self.magic } + /// Return the message command as a static string reference. /// /// This returns `"unknown"` for [NetworkMessage::Unknown], @@ -328,51 +344,59 @@ impl<'a> Encodable for HeaderSerializationWrapper<'a> { } } -impl Encodable for RawNetworkMessage { - fn consensus_encode(&self, w: &mut W) -> Result { - let mut len = 0; - len += self.magic.consensus_encode(w)?; - len += self.command().consensus_encode(w)?; - len += CheckedData(match self.payload { - NetworkMessage::Version(ref dat) => serialize(dat), - NetworkMessage::Addr(ref dat) => serialize(dat), - NetworkMessage::Inv(ref dat) => serialize(dat), - NetworkMessage::GetData(ref dat) => serialize(dat), - NetworkMessage::NotFound(ref dat) => serialize(dat), - NetworkMessage::GetBlocks(ref dat) => serialize(dat), - NetworkMessage::GetHeaders(ref dat) => serialize(dat), - NetworkMessage::Tx(ref dat) => serialize(dat), - NetworkMessage::Block(ref dat) => serialize(dat), - NetworkMessage::Headers(ref dat) => serialize(&HeaderSerializationWrapper(dat)), - NetworkMessage::Ping(ref dat) => serialize(dat), - NetworkMessage::Pong(ref dat) => serialize(dat), - NetworkMessage::MerkleBlock(ref dat) => serialize(dat), - NetworkMessage::FilterLoad(ref dat) => serialize(dat), - NetworkMessage::FilterAdd(ref dat) => serialize(dat), - NetworkMessage::GetCFilters(ref dat) => serialize(dat), - NetworkMessage::CFilter(ref dat) => serialize(dat), - NetworkMessage::GetCFHeaders(ref dat) => serialize(dat), - NetworkMessage::CFHeaders(ref dat) => serialize(dat), - NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat), - NetworkMessage::CFCheckpt(ref dat) => serialize(dat), - NetworkMessage::SendCmpct(ref dat) => serialize(dat), - NetworkMessage::CmpctBlock(ref dat) => serialize(dat), - NetworkMessage::GetBlockTxn(ref dat) => serialize(dat), - NetworkMessage::BlockTxn(ref dat) => serialize(dat), - NetworkMessage::Alert(ref dat) => serialize(dat), - NetworkMessage::Reject(ref dat) => serialize(dat), - NetworkMessage::FeeFilter(ref data) => serialize(data), - NetworkMessage::AddrV2(ref dat) => serialize(dat), +impl Encodable for NetworkMessage { + fn consensus_encode(&self, writer: &mut W) -> Result { + match self { + NetworkMessage::Version(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Addr(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Inv(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetData(ref dat) => dat.consensus_encode(writer), + NetworkMessage::NotFound(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetBlocks(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Block(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Headers(ref dat) => + HeaderSerializationWrapper(dat).consensus_encode(writer), + NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer), + NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer), + NetworkMessage::FilterLoad(ref dat) => dat.consensus_encode(writer), + NetworkMessage::FilterAdd(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetCFilters(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CFilter(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetCFHeaders(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CFHeaders(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetCFCheckpt(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CFCheckpt(ref dat) => dat.consensus_encode(writer), + NetworkMessage::SendCmpct(ref dat) => dat.consensus_encode(writer), + NetworkMessage::CmpctBlock(ref dat) => dat.consensus_encode(writer), + NetworkMessage::GetBlockTxn(ref dat) => dat.consensus_encode(writer), + NetworkMessage::BlockTxn(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Alert(ref dat) => dat.consensus_encode(writer), + NetworkMessage::Reject(ref dat) => dat.consensus_encode(writer), + NetworkMessage::FeeFilter(ref dat) => dat.consensus_encode(writer), + NetworkMessage::AddrV2(ref dat) => dat.consensus_encode(writer), NetworkMessage::Verack | NetworkMessage::SendHeaders | NetworkMessage::MemPool | NetworkMessage::GetAddr | NetworkMessage::WtxidRelay | NetworkMessage::FilterClear - | NetworkMessage::SendAddrV2 => vec![], - NetworkMessage::Unknown { payload: ref data, .. } => serialize(data), - }) - .consensus_encode(w)?; + | NetworkMessage::SendAddrV2 => Ok(0), + NetworkMessage::Unknown { payload: ref data, .. } => data.consensus_encode(writer), + } + } +} + +impl Encodable for RawNetworkMessage { + fn consensus_encode(&self, w: &mut W) -> Result { + let mut len = 0; + len += self.magic.consensus_encode(w)?; + len += self.command().consensus_encode(w)?; + len += self.payload_len.consensus_encode(w)?; + len += self.checksum.consensus_encode(w)?; + len += self.payload().consensus_encode(w)?; Ok(len) } } @@ -411,7 +435,10 @@ impl Decodable for RawNetworkMessage { ) -> Result { let magic = Decodable::consensus_decode_from_finite_reader(r)?; let cmd = CommandString::consensus_decode_from_finite_reader(r)?; - let raw_payload = CheckedData::consensus_decode_from_finite_reader(r)?.0; + let checked_data = CheckedData::consensus_decode_from_finite_reader(r)?; + let checksum = checked_data.checksum(); + let raw_payload = checked_data.into_data(); + let payload_len = raw_payload.len() as u32; let mut mem_d = io::Cursor::new(raw_payload); let payload = match &cmd.0[..] { @@ -498,7 +525,7 @@ impl Decodable for RawNetworkMessage { "sendaddrv2" => NetworkMessage::SendAddrV2, _ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() }, }; - Ok(RawNetworkMessage { magic, payload }) + Ok(RawNetworkMessage { magic, payload, payload_len, checksum }) } #[inline] @@ -642,8 +669,7 @@ mod test { ]; for msg in msgs { - let raw_msg = - RawNetworkMessage { magic: Magic::from_bytes([57, 0, 0, 0]), payload: msg }; + let raw_msg = RawNetworkMessage::new(Magic::from_bytes([57, 0, 0, 0]), msg); assert_eq!(deserialize::(&serialize(&raw_msg)).unwrap(), raw_msg); } } @@ -676,7 +702,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_verack_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Verack }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Verack)), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61, 0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); @@ -685,7 +711,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_ping_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Ping(100) }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Ping(100))), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d, @@ -695,7 +721,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_mempool_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::MemPool }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::MemPool)), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70, 0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); @@ -704,7 +730,7 @@ mod test { #[test] #[rustfmt::skip] fn serialize_getaddr_test() { - assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::GetAddr }), + assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr)), vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61, 0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); @@ -718,10 +744,8 @@ mod test { 0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2 ]); - let preimage = RawNetworkMessage { - magic: Magic::from(Network::Bitcoin), - payload: NetworkMessage::GetAddr, - }; + let preimage = + RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr); assert!(msg.is_ok()); let msg: RawNetworkMessage = msg.unwrap(); assert_eq!(preimage.magic, msg.magic);