Avoid serialize inner data in RawNetworkMessage

RawNetworkMessage keep the payload_len and its checksum in the struct, thus
is not needed to serialize the inner network message

pub in fields of both RawNetworkMessage and CheckedData are removed so that
invariant are preserved.
This commit is contained in:
Riccardo Casatta 2023-07-25 16:24:44 +02:00
parent bc66ed82b2
commit 5c8933001c
No known key found for this signature in database
GPG Key ID: FD986A969E450397
3 changed files with 73 additions and 48 deletions

View File

@ -28,10 +28,8 @@ fn main() {
let version_message = build_version_message(address);
let first_message = message::RawNetworkMessage {
magic: constants::Network::Bitcoin.magic(),
payload: version_message,
};
let first_message =
message::RawNetworkMessage::new(constants::Network::Bitcoin.magic(), version_message);
if let Ok(mut stream) = TcpStream::connect(address) {
// Send the message
@ -44,24 +42,24 @@ fn main() {
loop {
// Loop an retrieve new messages
let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap();
match reply.payload {
match reply.payload() {
message::NetworkMessage::Version(_) => {
println!("Received version message: {:?}", reply.payload);
println!("Received version message: {:?}", reply.payload());
let second_message = message::RawNetworkMessage {
magic: constants::Network::Bitcoin.magic(),
payload: message::NetworkMessage::Verack,
};
let second_message = message::RawNetworkMessage::new(
constants::Network::Bitcoin.magic(),
message::NetworkMessage::Verack,
);
let _ = stream.write_all(encode::serialize(&second_message).as_slice());
println!("Sent verack message");
}
message::NetworkMessage::Verack => {
println!("Received verack message: {:?}", reply.payload);
println!("Received verack message: {:?}", reply.payload());
break;
}
_ => {
println!("Received unknown message: {:?}", reply.payload);
println!("Received unknown message: {:?}", reply.payload());
break;
}
}

View File

@ -15,7 +15,7 @@
//! typically big-endian decimals, etc.)
//!
use core::convert::From;
use core::convert::{From, TryFrom};
use core::{fmt, mem, u32};
use hashes::{sha256, sha256d, Hash};
@ -325,9 +325,29 @@ pub trait Decodable: Sized {
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
pub struct VarInt(pub u64);
/// Data which must be preceded by a 4-byte checksum.
/// Data and a 4-byte checksum.
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct CheckedData(pub Vec<u8>);
pub struct CheckedData {
data: Vec<u8>,
checksum: [u8; 4],
}
impl CheckedData {
/// Creates a new `CheckedData` computing the checksum of given data.
pub fn new(data: Vec<u8>) -> Self {
let checksum = sha2_checksum(&data);
Self { data, checksum }
}
/// Returns a reference to the raw data without the checksum.
pub fn data(&self) -> &[u8] { &self.data }
/// Returns the raw data without the checksum.
pub fn into_data(self) -> Vec<u8> { self.data }
/// Returns the checksum of the data.
pub fn checksum(&self) -> [u8; 4] { self.checksum }
}
// Primitive types
macro_rules! impl_int_encodable {
@ -686,10 +706,12 @@ fn sha2_checksum(data: &[u8]) -> [u8; 4] {
impl Encodable for CheckedData {
#[inline]
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
(self.0.len() as u32).consensus_encode(w)?;
sha2_checksum(&self.0).consensus_encode(w)?;
w.emit_slice(&self.0)?;
Ok(8 + self.0.len())
u32::try_from(self.data.len())
.expect("network message use u32 as length")
.consensus_encode(w)?;
self.checksum().consensus_encode(w)?;
w.emit_slice(&self.data)?;
Ok(8 + self.data.len())
}
}
@ -700,12 +722,12 @@ impl Decodable for CheckedData {
let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?;
let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE };
let ret = read_bytes_from_finite_reader(r, opts)?;
let expected_checksum = sha2_checksum(&ret);
let data = read_bytes_from_finite_reader(r, opts)?;
let expected_checksum = sha2_checksum(&data);
if expected_checksum != checksum {
Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum })
} else {
Ok(CheckedData(ret))
Ok(CheckedData { data, checksum })
}
}
}
@ -978,7 +1000,7 @@ mod tests {
#[test]
fn serialize_checkeddata_test() {
let cd = CheckedData(vec![1u8, 2, 3, 4, 5]);
let cd = CheckedData::new(vec![1u8, 2, 3, 4, 5]);
assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
}
@ -1137,7 +1159,7 @@ mod tests {
fn deserialize_checkeddata_test() {
let cd: Result<CheckedData, _> =
deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5])));
assert_eq!(cd.ok(), Some(CheckedData::new(vec![1u8, 2, 3, 4, 5])));
}
#[test]

View File

@ -9,11 +9,11 @@
use core::convert::TryFrom;
use core::{fmt, iter};
use hashes::{sha256d, Hash};
use io::Read as _;
use crate::blockdata::{block, transaction};
use crate::consensus::encode::{CheckedData, Decodable, Encodable, VarInt};
use crate::consensus::{encode, serialize};
use crate::consensus::encode::{self, CheckedData, Decodable, Encodable, VarInt};
use crate::io;
use crate::merkle_tree::MerkleBlock;
use crate::network::address::{AddrV2Message, Address};
@ -151,6 +151,8 @@ crate::error::impl_std_error!(CommandStringError);
pub struct RawNetworkMessage {
magic: Magic,
payload: NetworkMessage,
payload_len: u32,
checksum: [u8; 4],
}
/// A Network message payload. Proper documentation is available on at
@ -300,21 +302,21 @@ impl NetworkMessage {
}
impl RawNetworkMessage {
/// Creates a [RawNetworkMessage]
pub fn new(magic: Magic, payload: NetworkMessage) -> Self {
Self { magic, payload }
let mut engine = sha256d::Hash::engine();
let payload_len = payload.consensus_encode(&mut engine).expect("engine doesn't error");
let payload_len = u32::try_from(payload_len).expect("network message use u32 as length");
let checksum = sha256d::Hash::from_engine(engine);
let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]];
Self { magic, payload, payload_len, checksum }
}
/// The actual message data
pub fn payload(&self) -> &NetworkMessage {
&self.payload
}
pub fn payload(&self) -> &NetworkMessage { &self.payload }
/// Magic bytes to identify the network these messages are meant for
pub fn magic(&self) -> &Magic {
&self.magic
}
pub fn magic(&self) -> &Magic { &self.magic }
/// Return the message command as a static string reference.
///
@ -354,7 +356,8 @@ impl Encodable for NetworkMessage {
NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Block(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Headers(ref dat) => HeaderSerializationWrapper(dat).consensus_encode(writer),
NetworkMessage::Headers(ref dat) =>
HeaderSerializationWrapper(dat).consensus_encode(writer),
NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer),
NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer),
@ -391,7 +394,9 @@ impl Encodable for RawNetworkMessage {
let mut len = 0;
len += self.magic.consensus_encode(w)?;
len += self.command().consensus_encode(w)?;
len += CheckedData(serialize(self.payload())).consensus_encode(w)?;
len += self.payload_len.consensus_encode(w)?;
len += self.checksum.consensus_encode(w)?;
len += self.payload().consensus_encode(w)?;
Ok(len)
}
}
@ -430,7 +435,10 @@ impl Decodable for RawNetworkMessage {
) -> Result<Self, encode::Error> {
let magic = Decodable::consensus_decode_from_finite_reader(r)?;
let cmd = CommandString::consensus_decode_from_finite_reader(r)?;
let raw_payload = CheckedData::consensus_decode_from_finite_reader(r)?.0;
let checked_data = CheckedData::consensus_decode_from_finite_reader(r)?;
let checksum = checked_data.checksum();
let raw_payload = checked_data.into_data();
let payload_len = raw_payload.len() as u32;
let mut mem_d = io::Cursor::new(raw_payload);
let payload = match &cmd.0[..] {
@ -517,7 +525,7 @@ impl Decodable for RawNetworkMessage {
"sendaddrv2" => NetworkMessage::SendAddrV2,
_ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() },
};
Ok(RawNetworkMessage { magic, payload })
Ok(RawNetworkMessage { magic, payload, payload_len, checksum })
}
#[inline]
@ -661,8 +669,7 @@ mod test {
];
for msg in msgs {
let raw_msg =
RawNetworkMessage { magic: Magic::from_bytes([57, 0, 0, 0]), payload: msg };
let raw_msg = RawNetworkMessage::new(Magic::from_bytes([57, 0, 0, 0]), msg);
assert_eq!(deserialize::<RawNetworkMessage>(&serialize(&raw_msg)).unwrap(), raw_msg);
}
}
@ -695,7 +702,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_verack_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Verack }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Verack)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61,
0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -704,7 +711,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_ping_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Ping(100) }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Ping(100))),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d,
@ -714,7 +721,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_mempool_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::MemPool }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::MemPool)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70,
0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -723,7 +730,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_getaddr_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::GetAddr }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61,
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -737,10 +744,8 @@ mod test {
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2
]);
let preimage = RawNetworkMessage {
magic: Magic::from(Network::Bitcoin),
payload: NetworkMessage::GetAddr,
};
let preimage =
RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr);
assert!(msg.is_ok());
let msg: RawNetworkMessage = msg.unwrap();
assert_eq!(preimage.magic, msg.magic);