Merge rust-bitcoin/rust-bitcoin#1954: Avoid vector allocation in `RawNetworkMessage` encoding

5c8933001c Avoid serialize inner data in RawNetworkMessage (Riccardo Casatta)
bc66ed82b2 Impl Encodable for NetworkMessage (Riccardo Casatta)
8560baaca2 Make fields of RawNetworkMessage non public (Riccardo Casatta)

Pull request description:

  This PR removes the need to serialize the inner NetworkMessage in the RawNetworkMessage encoding, thus saving memory and reducing allocations.

  To achieve this payload_len and checksum are kept in the RawNetworkMessage and checksum kept in CheckedData, to preserve invariants fields of the struct are made non-public.

ACKs for top commit:
  apoelstra:
    ACK 5c8933001c
  tcharding:
    ACK 5c8933001c

Tree-SHA512: aca3c7ac13d2d71184288f7815449e72c4c04fc617a65effba592592ef4ec50f18b6f83dbff58e9c4237cb1fe8e7af52cd43db9036658bdaf7888c07011e46cc
This commit is contained in:
Riccardo Casatta 2023-07-27 08:54:04 +02:00
commit 0e6341d4c9
No known key found for this signature in database
GPG Key ID: FD986A969E450397
3 changed files with 125 additions and 81 deletions

View File

@ -28,10 +28,8 @@ fn main() {
let version_message = build_version_message(address); let version_message = build_version_message(address);
let first_message = message::RawNetworkMessage { let first_message =
magic: constants::Network::Bitcoin.magic(), message::RawNetworkMessage::new(constants::Network::Bitcoin.magic(), version_message);
payload: version_message,
};
if let Ok(mut stream) = TcpStream::connect(address) { if let Ok(mut stream) = TcpStream::connect(address) {
// Send the message // Send the message
@ -44,24 +42,24 @@ fn main() {
loop { loop {
// Loop an retrieve new messages // Loop an retrieve new messages
let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap(); let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap();
match reply.payload { match reply.payload() {
message::NetworkMessage::Version(_) => { message::NetworkMessage::Version(_) => {
println!("Received version message: {:?}", reply.payload); println!("Received version message: {:?}", reply.payload());
let second_message = message::RawNetworkMessage { let second_message = message::RawNetworkMessage::new(
magic: constants::Network::Bitcoin.magic(), constants::Network::Bitcoin.magic(),
payload: message::NetworkMessage::Verack, message::NetworkMessage::Verack,
}; );
let _ = stream.write_all(encode::serialize(&second_message).as_slice()); let _ = stream.write_all(encode::serialize(&second_message).as_slice());
println!("Sent verack message"); println!("Sent verack message");
} }
message::NetworkMessage::Verack => { message::NetworkMessage::Verack => {
println!("Received verack message: {:?}", reply.payload); println!("Received verack message: {:?}", reply.payload());
break; break;
} }
_ => { _ => {
println!("Received unknown message: {:?}", reply.payload); println!("Received unknown message: {:?}", reply.payload());
break; break;
} }
} }

View File

@ -15,7 +15,7 @@
//! typically big-endian decimals, etc.) //! typically big-endian decimals, etc.)
//! //!
use core::convert::From; use core::convert::{From, TryFrom};
use core::{fmt, mem, u32}; use core::{fmt, mem, u32};
use hashes::{sha256, sha256d, Hash}; use hashes::{sha256, sha256d, Hash};
@ -325,9 +325,29 @@ pub trait Decodable: Sized {
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
pub struct VarInt(pub u64); pub struct VarInt(pub u64);
/// Data which must be preceded by a 4-byte checksum. /// Data and a 4-byte checksum.
#[derive(PartialEq, Eq, Clone, Debug)] #[derive(PartialEq, Eq, Clone, Debug)]
pub struct CheckedData(pub Vec<u8>); pub struct CheckedData {
data: Vec<u8>,
checksum: [u8; 4],
}
impl CheckedData {
/// Creates a new `CheckedData` computing the checksum of given data.
pub fn new(data: Vec<u8>) -> Self {
let checksum = sha2_checksum(&data);
Self { data, checksum }
}
/// Returns a reference to the raw data without the checksum.
pub fn data(&self) -> &[u8] { &self.data }
/// Returns the raw data without the checksum.
pub fn into_data(self) -> Vec<u8> { self.data }
/// Returns the checksum of the data.
pub fn checksum(&self) -> [u8; 4] { self.checksum }
}
// Primitive types // Primitive types
macro_rules! impl_int_encodable { macro_rules! impl_int_encodable {
@ -686,10 +706,12 @@ fn sha2_checksum(data: &[u8]) -> [u8; 4] {
impl Encodable for CheckedData { impl Encodable for CheckedData {
#[inline] #[inline]
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> { fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
(self.0.len() as u32).consensus_encode(w)?; u32::try_from(self.data.len())
sha2_checksum(&self.0).consensus_encode(w)?; .expect("network message use u32 as length")
w.emit_slice(&self.0)?; .consensus_encode(w)?;
Ok(8 + self.0.len()) self.checksum().consensus_encode(w)?;
w.emit_slice(&self.data)?;
Ok(8 + self.data.len())
} }
} }
@ -700,12 +722,12 @@ impl Decodable for CheckedData {
let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?; let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?;
let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE }; let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE };
let ret = read_bytes_from_finite_reader(r, opts)?; let data = read_bytes_from_finite_reader(r, opts)?;
let expected_checksum = sha2_checksum(&ret); let expected_checksum = sha2_checksum(&data);
if expected_checksum != checksum { if expected_checksum != checksum {
Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum }) Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum })
} else { } else {
Ok(CheckedData(ret)) Ok(CheckedData { data, checksum })
} }
} }
} }
@ -978,7 +1000,7 @@ mod tests {
#[test] #[test]
fn serialize_checkeddata_test() { fn serialize_checkeddata_test() {
let cd = CheckedData(vec![1u8, 2, 3, 4, 5]); let cd = CheckedData::new(vec![1u8, 2, 3, 4, 5]);
assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]); assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
} }
@ -1137,7 +1159,7 @@ mod tests {
fn deserialize_checkeddata_test() { fn deserialize_checkeddata_test() {
let cd: Result<CheckedData, _> = let cd: Result<CheckedData, _> =
deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]); deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5]))); assert_eq!(cd.ok(), Some(CheckedData::new(vec![1u8, 2, 3, 4, 5])));
} }
#[test] #[test]

View File

@ -9,11 +9,11 @@
use core::convert::TryFrom; use core::convert::TryFrom;
use core::{fmt, iter}; use core::{fmt, iter};
use hashes::{sha256d, Hash};
use io::Read as _; use io::Read as _;
use crate::blockdata::{block, transaction}; use crate::blockdata::{block, transaction};
use crate::consensus::encode::{CheckedData, Decodable, Encodable, VarInt}; use crate::consensus::encode::{self, CheckedData, Decodable, Encodable, VarInt};
use crate::consensus::{encode, serialize};
use crate::io; use crate::io;
use crate::merkle_tree::MerkleBlock; use crate::merkle_tree::MerkleBlock;
use crate::network::address::{AddrV2Message, Address}; use crate::network::address::{AddrV2Message, Address};
@ -149,10 +149,10 @@ crate::error::impl_std_error!(CommandStringError);
/// A Network message /// A Network message
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct RawNetworkMessage { pub struct RawNetworkMessage {
/// Magic bytes to identify the network these messages are meant for magic: Magic,
pub magic: Magic, payload: NetworkMessage,
/// The actual message data payload_len: u32,
pub payload: NetworkMessage, checksum: [u8; 4],
} }
/// A Network message payload. Proper documentation is available on at /// A Network message payload. Proper documentation is available on at
@ -302,6 +302,22 @@ impl NetworkMessage {
} }
impl RawNetworkMessage { impl RawNetworkMessage {
/// Creates a [RawNetworkMessage]
pub fn new(magic: Magic, payload: NetworkMessage) -> Self {
let mut engine = sha256d::Hash::engine();
let payload_len = payload.consensus_encode(&mut engine).expect("engine doesn't error");
let payload_len = u32::try_from(payload_len).expect("network message use u32 as length");
let checksum = sha256d::Hash::from_engine(engine);
let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]];
Self { magic, payload, payload_len, checksum }
}
/// The actual message data
pub fn payload(&self) -> &NetworkMessage { &self.payload }
/// Magic bytes to identify the network these messages are meant for
pub fn magic(&self) -> &Magic { &self.magic }
/// Return the message command as a static string reference. /// Return the message command as a static string reference.
/// ///
/// This returns `"unknown"` for [NetworkMessage::Unknown], /// This returns `"unknown"` for [NetworkMessage::Unknown],
@ -328,51 +344,59 @@ impl<'a> Encodable for HeaderSerializationWrapper<'a> {
} }
} }
impl Encodable for RawNetworkMessage { impl Encodable for NetworkMessage {
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> { fn consensus_encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
let mut len = 0; match self {
len += self.magic.consensus_encode(w)?; NetworkMessage::Version(ref dat) => dat.consensus_encode(writer),
len += self.command().consensus_encode(w)?; NetworkMessage::Addr(ref dat) => dat.consensus_encode(writer),
len += CheckedData(match self.payload { NetworkMessage::Inv(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Version(ref dat) => serialize(dat), NetworkMessage::GetData(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Addr(ref dat) => serialize(dat), NetworkMessage::NotFound(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Inv(ref dat) => serialize(dat), NetworkMessage::GetBlocks(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetData(ref dat) => serialize(dat), NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::NotFound(ref dat) => serialize(dat), NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetBlocks(ref dat) => serialize(dat), NetworkMessage::Block(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetHeaders(ref dat) => serialize(dat), NetworkMessage::Headers(ref dat) =>
NetworkMessage::Tx(ref dat) => serialize(dat), HeaderSerializationWrapper(dat).consensus_encode(writer),
NetworkMessage::Block(ref dat) => serialize(dat), NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Headers(ref dat) => serialize(&HeaderSerializationWrapper(dat)), NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Ping(ref dat) => serialize(dat), NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Pong(ref dat) => serialize(dat), NetworkMessage::FilterLoad(ref dat) => dat.consensus_encode(writer),
NetworkMessage::MerkleBlock(ref dat) => serialize(dat), NetworkMessage::FilterAdd(ref dat) => dat.consensus_encode(writer),
NetworkMessage::FilterLoad(ref dat) => serialize(dat), NetworkMessage::GetCFilters(ref dat) => dat.consensus_encode(writer),
NetworkMessage::FilterAdd(ref dat) => serialize(dat), NetworkMessage::CFilter(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetCFilters(ref dat) => serialize(dat), NetworkMessage::GetCFHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CFilter(ref dat) => serialize(dat), NetworkMessage::CFHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetCFHeaders(ref dat) => serialize(dat), NetworkMessage::GetCFCheckpt(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CFHeaders(ref dat) => serialize(dat), NetworkMessage::CFCheckpt(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat), NetworkMessage::SendCmpct(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CFCheckpt(ref dat) => serialize(dat), NetworkMessage::CmpctBlock(ref dat) => dat.consensus_encode(writer),
NetworkMessage::SendCmpct(ref dat) => serialize(dat), NetworkMessage::GetBlockTxn(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CmpctBlock(ref dat) => serialize(dat), NetworkMessage::BlockTxn(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetBlockTxn(ref dat) => serialize(dat), NetworkMessage::Alert(ref dat) => dat.consensus_encode(writer),
NetworkMessage::BlockTxn(ref dat) => serialize(dat), NetworkMessage::Reject(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Alert(ref dat) => serialize(dat), NetworkMessage::FeeFilter(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Reject(ref dat) => serialize(dat), NetworkMessage::AddrV2(ref dat) => dat.consensus_encode(writer),
NetworkMessage::FeeFilter(ref data) => serialize(data),
NetworkMessage::AddrV2(ref dat) => serialize(dat),
NetworkMessage::Verack NetworkMessage::Verack
| NetworkMessage::SendHeaders | NetworkMessage::SendHeaders
| NetworkMessage::MemPool | NetworkMessage::MemPool
| NetworkMessage::GetAddr | NetworkMessage::GetAddr
| NetworkMessage::WtxidRelay | NetworkMessage::WtxidRelay
| NetworkMessage::FilterClear | NetworkMessage::FilterClear
| NetworkMessage::SendAddrV2 => vec![], | NetworkMessage::SendAddrV2 => Ok(0),
NetworkMessage::Unknown { payload: ref data, .. } => serialize(data), NetworkMessage::Unknown { payload: ref data, .. } => data.consensus_encode(writer),
}) }
.consensus_encode(w)?; }
}
impl Encodable for RawNetworkMessage {
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
let mut len = 0;
len += self.magic.consensus_encode(w)?;
len += self.command().consensus_encode(w)?;
len += self.payload_len.consensus_encode(w)?;
len += self.checksum.consensus_encode(w)?;
len += self.payload().consensus_encode(w)?;
Ok(len) Ok(len)
} }
} }
@ -411,7 +435,10 @@ impl Decodable for RawNetworkMessage {
) -> Result<Self, encode::Error> { ) -> Result<Self, encode::Error> {
let magic = Decodable::consensus_decode_from_finite_reader(r)?; let magic = Decodable::consensus_decode_from_finite_reader(r)?;
let cmd = CommandString::consensus_decode_from_finite_reader(r)?; let cmd = CommandString::consensus_decode_from_finite_reader(r)?;
let raw_payload = CheckedData::consensus_decode_from_finite_reader(r)?.0; let checked_data = CheckedData::consensus_decode_from_finite_reader(r)?;
let checksum = checked_data.checksum();
let raw_payload = checked_data.into_data();
let payload_len = raw_payload.len() as u32;
let mut mem_d = io::Cursor::new(raw_payload); let mut mem_d = io::Cursor::new(raw_payload);
let payload = match &cmd.0[..] { let payload = match &cmd.0[..] {
@ -498,7 +525,7 @@ impl Decodable for RawNetworkMessage {
"sendaddrv2" => NetworkMessage::SendAddrV2, "sendaddrv2" => NetworkMessage::SendAddrV2,
_ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() }, _ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() },
}; };
Ok(RawNetworkMessage { magic, payload }) Ok(RawNetworkMessage { magic, payload, payload_len, checksum })
} }
#[inline] #[inline]
@ -642,8 +669,7 @@ mod test {
]; ];
for msg in msgs { for msg in msgs {
let raw_msg = let raw_msg = RawNetworkMessage::new(Magic::from_bytes([57, 0, 0, 0]), msg);
RawNetworkMessage { magic: Magic::from_bytes([57, 0, 0, 0]), payload: msg };
assert_eq!(deserialize::<RawNetworkMessage>(&serialize(&raw_msg)).unwrap(), raw_msg); assert_eq!(deserialize::<RawNetworkMessage>(&serialize(&raw_msg)).unwrap(), raw_msg);
} }
} }
@ -676,7 +702,7 @@ mod test {
#[test] #[test]
#[rustfmt::skip] #[rustfmt::skip]
fn serialize_verack_test() { fn serialize_verack_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Verack }), assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Verack)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61, vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61,
0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -685,7 +711,7 @@ mod test {
#[test] #[test]
#[rustfmt::skip] #[rustfmt::skip]
fn serialize_ping_test() { fn serialize_ping_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Ping(100) }), assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Ping(100))),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67, vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d, 0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d,
@ -695,7 +721,7 @@ mod test {
#[test] #[test]
#[rustfmt::skip] #[rustfmt::skip]
fn serialize_mempool_test() { fn serialize_mempool_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::MemPool }), assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::MemPool)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70, vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70,
0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -704,7 +730,7 @@ mod test {
#[test] #[test]
#[rustfmt::skip] #[rustfmt::skip]
fn serialize_getaddr_test() { fn serialize_getaddr_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::GetAddr }), assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61, vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61,
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]); 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -718,10 +744,8 @@ mod test {
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2 0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2
]); ]);
let preimage = RawNetworkMessage { let preimage =
magic: Magic::from(Network::Bitcoin), RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr);
payload: NetworkMessage::GetAddr,
};
assert!(msg.is_ok()); assert!(msg.is_ok());
let msg: RawNetworkMessage = msg.unwrap(); let msg: RawNetworkMessage = msg.unwrap();
assert_eq!(preimage.magic, msg.magic); assert_eq!(preimage.magic, msg.magic);