Merge rust-bitcoin/rust-bitcoin#1954: Avoid vector allocation in `RawNetworkMessage` encoding

5c8933001c Avoid serialize inner data in RawNetworkMessage (Riccardo Casatta)
bc66ed82b2 Impl Encodable for NetworkMessage (Riccardo Casatta)
8560baaca2 Make fields of RawNetworkMessage non public (Riccardo Casatta)

Pull request description:

  This PR removes the need to serialize the inner NetworkMessage in the RawNetworkMessage encoding, thus saving memory and reducing allocations.

  To achieve this payload_len and checksum are kept in the RawNetworkMessage and checksum kept in CheckedData, to preserve invariants fields of the struct are made non-public.

ACKs for top commit:
  apoelstra:
    ACK 5c8933001c
  tcharding:
    ACK 5c8933001c

Tree-SHA512: aca3c7ac13d2d71184288f7815449e72c4c04fc617a65effba592592ef4ec50f18b6f83dbff58e9c4237cb1fe8e7af52cd43db9036658bdaf7888c07011e46cc
This commit is contained in:
Riccardo Casatta 2023-07-27 08:54:04 +02:00
commit 0e6341d4c9
No known key found for this signature in database
GPG Key ID: FD986A969E450397
3 changed files with 125 additions and 81 deletions

View File

@ -28,10 +28,8 @@ fn main() {
let version_message = build_version_message(address);
let first_message = message::RawNetworkMessage {
magic: constants::Network::Bitcoin.magic(),
payload: version_message,
};
let first_message =
message::RawNetworkMessage::new(constants::Network::Bitcoin.magic(), version_message);
if let Ok(mut stream) = TcpStream::connect(address) {
// Send the message
@ -44,24 +42,24 @@ fn main() {
loop {
// Loop an retrieve new messages
let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap();
match reply.payload {
match reply.payload() {
message::NetworkMessage::Version(_) => {
println!("Received version message: {:?}", reply.payload);
println!("Received version message: {:?}", reply.payload());
let second_message = message::RawNetworkMessage {
magic: constants::Network::Bitcoin.magic(),
payload: message::NetworkMessage::Verack,
};
let second_message = message::RawNetworkMessage::new(
constants::Network::Bitcoin.magic(),
message::NetworkMessage::Verack,
);
let _ = stream.write_all(encode::serialize(&second_message).as_slice());
println!("Sent verack message");
}
message::NetworkMessage::Verack => {
println!("Received verack message: {:?}", reply.payload);
println!("Received verack message: {:?}", reply.payload());
break;
}
_ => {
println!("Received unknown message: {:?}", reply.payload);
println!("Received unknown message: {:?}", reply.payload());
break;
}
}

View File

@ -15,7 +15,7 @@
//! typically big-endian decimals, etc.)
//!
use core::convert::From;
use core::convert::{From, TryFrom};
use core::{fmt, mem, u32};
use hashes::{sha256, sha256d, Hash};
@ -325,9 +325,29 @@ pub trait Decodable: Sized {
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
pub struct VarInt(pub u64);
/// Data which must be preceded by a 4-byte checksum.
/// Data and a 4-byte checksum.
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct CheckedData(pub Vec<u8>);
pub struct CheckedData {
data: Vec<u8>,
checksum: [u8; 4],
}
impl CheckedData {
/// Creates a new `CheckedData` computing the checksum of given data.
pub fn new(data: Vec<u8>) -> Self {
let checksum = sha2_checksum(&data);
Self { data, checksum }
}
/// Returns a reference to the raw data without the checksum.
pub fn data(&self) -> &[u8] { &self.data }
/// Returns the raw data without the checksum.
pub fn into_data(self) -> Vec<u8> { self.data }
/// Returns the checksum of the data.
pub fn checksum(&self) -> [u8; 4] { self.checksum }
}
// Primitive types
macro_rules! impl_int_encodable {
@ -686,10 +706,12 @@ fn sha2_checksum(data: &[u8]) -> [u8; 4] {
impl Encodable for CheckedData {
#[inline]
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
(self.0.len() as u32).consensus_encode(w)?;
sha2_checksum(&self.0).consensus_encode(w)?;
w.emit_slice(&self.0)?;
Ok(8 + self.0.len())
u32::try_from(self.data.len())
.expect("network message use u32 as length")
.consensus_encode(w)?;
self.checksum().consensus_encode(w)?;
w.emit_slice(&self.data)?;
Ok(8 + self.data.len())
}
}
@ -700,12 +722,12 @@ impl Decodable for CheckedData {
let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?;
let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE };
let ret = read_bytes_from_finite_reader(r, opts)?;
let expected_checksum = sha2_checksum(&ret);
let data = read_bytes_from_finite_reader(r, opts)?;
let expected_checksum = sha2_checksum(&data);
if expected_checksum != checksum {
Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum })
} else {
Ok(CheckedData(ret))
Ok(CheckedData { data, checksum })
}
}
}
@ -978,7 +1000,7 @@ mod tests {
#[test]
fn serialize_checkeddata_test() {
let cd = CheckedData(vec![1u8, 2, 3, 4, 5]);
let cd = CheckedData::new(vec![1u8, 2, 3, 4, 5]);
assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
}
@ -1137,7 +1159,7 @@ mod tests {
fn deserialize_checkeddata_test() {
let cd: Result<CheckedData, _> =
deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5])));
assert_eq!(cd.ok(), Some(CheckedData::new(vec![1u8, 2, 3, 4, 5])));
}
#[test]

View File

@ -9,11 +9,11 @@
use core::convert::TryFrom;
use core::{fmt, iter};
use hashes::{sha256d, Hash};
use io::Read as _;
use crate::blockdata::{block, transaction};
use crate::consensus::encode::{CheckedData, Decodable, Encodable, VarInt};
use crate::consensus::{encode, serialize};
use crate::consensus::encode::{self, CheckedData, Decodable, Encodable, VarInt};
use crate::io;
use crate::merkle_tree::MerkleBlock;
use crate::network::address::{AddrV2Message, Address};
@ -149,10 +149,10 @@ crate::error::impl_std_error!(CommandStringError);
/// A Network message
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RawNetworkMessage {
/// Magic bytes to identify the network these messages are meant for
pub magic: Magic,
/// The actual message data
pub payload: NetworkMessage,
magic: Magic,
payload: NetworkMessage,
payload_len: u32,
checksum: [u8; 4],
}
/// A Network message payload. Proper documentation is available on at
@ -302,6 +302,22 @@ impl NetworkMessage {
}
impl RawNetworkMessage {
/// Creates a [RawNetworkMessage]
pub fn new(magic: Magic, payload: NetworkMessage) -> Self {
let mut engine = sha256d::Hash::engine();
let payload_len = payload.consensus_encode(&mut engine).expect("engine doesn't error");
let payload_len = u32::try_from(payload_len).expect("network message use u32 as length");
let checksum = sha256d::Hash::from_engine(engine);
let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]];
Self { magic, payload, payload_len, checksum }
}
/// The actual message data
pub fn payload(&self) -> &NetworkMessage { &self.payload }
/// Magic bytes to identify the network these messages are meant for
pub fn magic(&self) -> &Magic { &self.magic }
/// Return the message command as a static string reference.
///
/// This returns `"unknown"` for [NetworkMessage::Unknown],
@ -328,51 +344,59 @@ impl<'a> Encodable for HeaderSerializationWrapper<'a> {
}
}
impl Encodable for RawNetworkMessage {
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
let mut len = 0;
len += self.magic.consensus_encode(w)?;
len += self.command().consensus_encode(w)?;
len += CheckedData(match self.payload {
NetworkMessage::Version(ref dat) => serialize(dat),
NetworkMessage::Addr(ref dat) => serialize(dat),
NetworkMessage::Inv(ref dat) => serialize(dat),
NetworkMessage::GetData(ref dat) => serialize(dat),
NetworkMessage::NotFound(ref dat) => serialize(dat),
NetworkMessage::GetBlocks(ref dat) => serialize(dat),
NetworkMessage::GetHeaders(ref dat) => serialize(dat),
NetworkMessage::Tx(ref dat) => serialize(dat),
NetworkMessage::Block(ref dat) => serialize(dat),
NetworkMessage::Headers(ref dat) => serialize(&HeaderSerializationWrapper(dat)),
NetworkMessage::Ping(ref dat) => serialize(dat),
NetworkMessage::Pong(ref dat) => serialize(dat),
NetworkMessage::MerkleBlock(ref dat) => serialize(dat),
NetworkMessage::FilterLoad(ref dat) => serialize(dat),
NetworkMessage::FilterAdd(ref dat) => serialize(dat),
NetworkMessage::GetCFilters(ref dat) => serialize(dat),
NetworkMessage::CFilter(ref dat) => serialize(dat),
NetworkMessage::GetCFHeaders(ref dat) => serialize(dat),
NetworkMessage::CFHeaders(ref dat) => serialize(dat),
NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat),
NetworkMessage::CFCheckpt(ref dat) => serialize(dat),
NetworkMessage::SendCmpct(ref dat) => serialize(dat),
NetworkMessage::CmpctBlock(ref dat) => serialize(dat),
NetworkMessage::GetBlockTxn(ref dat) => serialize(dat),
NetworkMessage::BlockTxn(ref dat) => serialize(dat),
NetworkMessage::Alert(ref dat) => serialize(dat),
NetworkMessage::Reject(ref dat) => serialize(dat),
NetworkMessage::FeeFilter(ref data) => serialize(data),
NetworkMessage::AddrV2(ref dat) => serialize(dat),
impl Encodable for NetworkMessage {
fn consensus_encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
match self {
NetworkMessage::Version(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Addr(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Inv(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetData(ref dat) => dat.consensus_encode(writer),
NetworkMessage::NotFound(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetBlocks(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Block(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Headers(ref dat) =>
HeaderSerializationWrapper(dat).consensus_encode(writer),
NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer),
NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer),
NetworkMessage::FilterLoad(ref dat) => dat.consensus_encode(writer),
NetworkMessage::FilterAdd(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetCFilters(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CFilter(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetCFHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CFHeaders(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetCFCheckpt(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CFCheckpt(ref dat) => dat.consensus_encode(writer),
NetworkMessage::SendCmpct(ref dat) => dat.consensus_encode(writer),
NetworkMessage::CmpctBlock(ref dat) => dat.consensus_encode(writer),
NetworkMessage::GetBlockTxn(ref dat) => dat.consensus_encode(writer),
NetworkMessage::BlockTxn(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Alert(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Reject(ref dat) => dat.consensus_encode(writer),
NetworkMessage::FeeFilter(ref dat) => dat.consensus_encode(writer),
NetworkMessage::AddrV2(ref dat) => dat.consensus_encode(writer),
NetworkMessage::Verack
| NetworkMessage::SendHeaders
| NetworkMessage::MemPool
| NetworkMessage::GetAddr
| NetworkMessage::WtxidRelay
| NetworkMessage::FilterClear
| NetworkMessage::SendAddrV2 => vec![],
NetworkMessage::Unknown { payload: ref data, .. } => serialize(data),
})
.consensus_encode(w)?;
| NetworkMessage::SendAddrV2 => Ok(0),
NetworkMessage::Unknown { payload: ref data, .. } => data.consensus_encode(writer),
}
}
}
impl Encodable for RawNetworkMessage {
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
let mut len = 0;
len += self.magic.consensus_encode(w)?;
len += self.command().consensus_encode(w)?;
len += self.payload_len.consensus_encode(w)?;
len += self.checksum.consensus_encode(w)?;
len += self.payload().consensus_encode(w)?;
Ok(len)
}
}
@ -411,7 +435,10 @@ impl Decodable for RawNetworkMessage {
) -> Result<Self, encode::Error> {
let magic = Decodable::consensus_decode_from_finite_reader(r)?;
let cmd = CommandString::consensus_decode_from_finite_reader(r)?;
let raw_payload = CheckedData::consensus_decode_from_finite_reader(r)?.0;
let checked_data = CheckedData::consensus_decode_from_finite_reader(r)?;
let checksum = checked_data.checksum();
let raw_payload = checked_data.into_data();
let payload_len = raw_payload.len() as u32;
let mut mem_d = io::Cursor::new(raw_payload);
let payload = match &cmd.0[..] {
@ -498,7 +525,7 @@ impl Decodable for RawNetworkMessage {
"sendaddrv2" => NetworkMessage::SendAddrV2,
_ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() },
};
Ok(RawNetworkMessage { magic, payload })
Ok(RawNetworkMessage { magic, payload, payload_len, checksum })
}
#[inline]
@ -642,8 +669,7 @@ mod test {
];
for msg in msgs {
let raw_msg =
RawNetworkMessage { magic: Magic::from_bytes([57, 0, 0, 0]), payload: msg };
let raw_msg = RawNetworkMessage::new(Magic::from_bytes([57, 0, 0, 0]), msg);
assert_eq!(deserialize::<RawNetworkMessage>(&serialize(&raw_msg)).unwrap(), raw_msg);
}
}
@ -676,7 +702,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_verack_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Verack }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Verack)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61,
0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -685,7 +711,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_ping_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Ping(100) }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Ping(100))),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d,
@ -695,7 +721,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_mempool_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::MemPool }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::MemPool)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70,
0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -704,7 +730,7 @@ mod test {
#[test]
#[rustfmt::skip]
fn serialize_getaddr_test() {
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::GetAddr }),
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr)),
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61,
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
@ -718,10 +744,8 @@ mod test {
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2
]);
let preimage = RawNetworkMessage {
magic: Magic::from(Network::Bitcoin),
payload: NetworkMessage::GetAddr,
};
let preimage =
RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr);
assert!(msg.is_ok());
let msg: RawNetworkMessage = msg.unwrap();
assert_eq!(preimage.magic, msg.magic);