Merge rust-bitcoin/rust-bitcoin#1954: Avoid vector allocation in `RawNetworkMessage` encoding
5c8933001c
Avoid serialize inner data in RawNetworkMessage (Riccardo Casatta)bc66ed82b2
Impl Encodable for NetworkMessage (Riccardo Casatta)8560baaca2
Make fields of RawNetworkMessage non public (Riccardo Casatta) Pull request description: This PR removes the need to serialize the inner NetworkMessage in the RawNetworkMessage encoding, thus saving memory and reducing allocations. To achieve this payload_len and checksum are kept in the RawNetworkMessage and checksum kept in CheckedData, to preserve invariants fields of the struct are made non-public. ACKs for top commit: apoelstra: ACK5c8933001c
tcharding: ACK5c8933001c
Tree-SHA512: aca3c7ac13d2d71184288f7815449e72c4c04fc617a65effba592592ef4ec50f18b6f83dbff58e9c4237cb1fe8e7af52cd43db9036658bdaf7888c07011e46cc
This commit is contained in:
commit
0e6341d4c9
|
@ -28,10 +28,8 @@ fn main() {
|
|||
|
||||
let version_message = build_version_message(address);
|
||||
|
||||
let first_message = message::RawNetworkMessage {
|
||||
magic: constants::Network::Bitcoin.magic(),
|
||||
payload: version_message,
|
||||
};
|
||||
let first_message =
|
||||
message::RawNetworkMessage::new(constants::Network::Bitcoin.magic(), version_message);
|
||||
|
||||
if let Ok(mut stream) = TcpStream::connect(address) {
|
||||
// Send the message
|
||||
|
@ -44,24 +42,24 @@ fn main() {
|
|||
loop {
|
||||
// Loop an retrieve new messages
|
||||
let reply = message::RawNetworkMessage::consensus_decode(&mut stream_reader).unwrap();
|
||||
match reply.payload {
|
||||
match reply.payload() {
|
||||
message::NetworkMessage::Version(_) => {
|
||||
println!("Received version message: {:?}", reply.payload);
|
||||
println!("Received version message: {:?}", reply.payload());
|
||||
|
||||
let second_message = message::RawNetworkMessage {
|
||||
magic: constants::Network::Bitcoin.magic(),
|
||||
payload: message::NetworkMessage::Verack,
|
||||
};
|
||||
let second_message = message::RawNetworkMessage::new(
|
||||
constants::Network::Bitcoin.magic(),
|
||||
message::NetworkMessage::Verack,
|
||||
);
|
||||
|
||||
let _ = stream.write_all(encode::serialize(&second_message).as_slice());
|
||||
println!("Sent verack message");
|
||||
}
|
||||
message::NetworkMessage::Verack => {
|
||||
println!("Received verack message: {:?}", reply.payload);
|
||||
println!("Received verack message: {:?}", reply.payload());
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
println!("Received unknown message: {:?}", reply.payload);
|
||||
println!("Received unknown message: {:?}", reply.payload());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
//! typically big-endian decimals, etc.)
|
||||
//!
|
||||
|
||||
use core::convert::From;
|
||||
use core::convert::{From, TryFrom};
|
||||
use core::{fmt, mem, u32};
|
||||
|
||||
use hashes::{sha256, sha256d, Hash};
|
||||
|
@ -325,9 +325,29 @@ pub trait Decodable: Sized {
|
|||
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
|
||||
pub struct VarInt(pub u64);
|
||||
|
||||
/// Data which must be preceded by a 4-byte checksum.
|
||||
/// Data and a 4-byte checksum.
|
||||
#[derive(PartialEq, Eq, Clone, Debug)]
|
||||
pub struct CheckedData(pub Vec<u8>);
|
||||
pub struct CheckedData {
|
||||
data: Vec<u8>,
|
||||
checksum: [u8; 4],
|
||||
}
|
||||
|
||||
impl CheckedData {
|
||||
/// Creates a new `CheckedData` computing the checksum of given data.
|
||||
pub fn new(data: Vec<u8>) -> Self {
|
||||
let checksum = sha2_checksum(&data);
|
||||
Self { data, checksum }
|
||||
}
|
||||
|
||||
/// Returns a reference to the raw data without the checksum.
|
||||
pub fn data(&self) -> &[u8] { &self.data }
|
||||
|
||||
/// Returns the raw data without the checksum.
|
||||
pub fn into_data(self) -> Vec<u8> { self.data }
|
||||
|
||||
/// Returns the checksum of the data.
|
||||
pub fn checksum(&self) -> [u8; 4] { self.checksum }
|
||||
}
|
||||
|
||||
// Primitive types
|
||||
macro_rules! impl_int_encodable {
|
||||
|
@ -686,10 +706,12 @@ fn sha2_checksum(data: &[u8]) -> [u8; 4] {
|
|||
impl Encodable for CheckedData {
|
||||
#[inline]
|
||||
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
|
||||
(self.0.len() as u32).consensus_encode(w)?;
|
||||
sha2_checksum(&self.0).consensus_encode(w)?;
|
||||
w.emit_slice(&self.0)?;
|
||||
Ok(8 + self.0.len())
|
||||
u32::try_from(self.data.len())
|
||||
.expect("network message use u32 as length")
|
||||
.consensus_encode(w)?;
|
||||
self.checksum().consensus_encode(w)?;
|
||||
w.emit_slice(&self.data)?;
|
||||
Ok(8 + self.data.len())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -700,12 +722,12 @@ impl Decodable for CheckedData {
|
|||
|
||||
let checksum = <[u8; 4]>::consensus_decode_from_finite_reader(r)?;
|
||||
let opts = ReadBytesFromFiniteReaderOpts { len, chunk_size: MAX_VEC_SIZE };
|
||||
let ret = read_bytes_from_finite_reader(r, opts)?;
|
||||
let expected_checksum = sha2_checksum(&ret);
|
||||
let data = read_bytes_from_finite_reader(r, opts)?;
|
||||
let expected_checksum = sha2_checksum(&data);
|
||||
if expected_checksum != checksum {
|
||||
Err(self::Error::InvalidChecksum { expected: expected_checksum, actual: checksum })
|
||||
} else {
|
||||
Ok(CheckedData(ret))
|
||||
Ok(CheckedData { data, checksum })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -978,7 +1000,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn serialize_checkeddata_test() {
|
||||
let cd = CheckedData(vec![1u8, 2, 3, 4, 5]);
|
||||
let cd = CheckedData::new(vec![1u8, 2, 3, 4, 5]);
|
||||
assert_eq!(serialize(&cd), vec![5, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
|
||||
}
|
||||
|
||||
|
@ -1137,7 +1159,7 @@ mod tests {
|
|||
fn deserialize_checkeddata_test() {
|
||||
let cd: Result<CheckedData, _> =
|
||||
deserialize(&[5u8, 0, 0, 0, 162, 107, 175, 90, 1, 2, 3, 4, 5]);
|
||||
assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5])));
|
||||
assert_eq!(cd.ok(), Some(CheckedData::new(vec![1u8, 2, 3, 4, 5])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -9,11 +9,11 @@
|
|||
use core::convert::TryFrom;
|
||||
use core::{fmt, iter};
|
||||
|
||||
use hashes::{sha256d, Hash};
|
||||
use io::Read as _;
|
||||
|
||||
use crate::blockdata::{block, transaction};
|
||||
use crate::consensus::encode::{CheckedData, Decodable, Encodable, VarInt};
|
||||
use crate::consensus::{encode, serialize};
|
||||
use crate::consensus::encode::{self, CheckedData, Decodable, Encodable, VarInt};
|
||||
use crate::io;
|
||||
use crate::merkle_tree::MerkleBlock;
|
||||
use crate::network::address::{AddrV2Message, Address};
|
||||
|
@ -149,10 +149,10 @@ crate::error::impl_std_error!(CommandStringError);
|
|||
/// A Network message
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct RawNetworkMessage {
|
||||
/// Magic bytes to identify the network these messages are meant for
|
||||
pub magic: Magic,
|
||||
/// The actual message data
|
||||
pub payload: NetworkMessage,
|
||||
magic: Magic,
|
||||
payload: NetworkMessage,
|
||||
payload_len: u32,
|
||||
checksum: [u8; 4],
|
||||
}
|
||||
|
||||
/// A Network message payload. Proper documentation is available on at
|
||||
|
@ -302,6 +302,22 @@ impl NetworkMessage {
|
|||
}
|
||||
|
||||
impl RawNetworkMessage {
|
||||
/// Creates a [RawNetworkMessage]
|
||||
pub fn new(magic: Magic, payload: NetworkMessage) -> Self {
|
||||
let mut engine = sha256d::Hash::engine();
|
||||
let payload_len = payload.consensus_encode(&mut engine).expect("engine doesn't error");
|
||||
let payload_len = u32::try_from(payload_len).expect("network message use u32 as length");
|
||||
let checksum = sha256d::Hash::from_engine(engine);
|
||||
let checksum = [checksum[0], checksum[1], checksum[2], checksum[3]];
|
||||
Self { magic, payload, payload_len, checksum }
|
||||
}
|
||||
|
||||
/// The actual message data
|
||||
pub fn payload(&self) -> &NetworkMessage { &self.payload }
|
||||
|
||||
/// Magic bytes to identify the network these messages are meant for
|
||||
pub fn magic(&self) -> &Magic { &self.magic }
|
||||
|
||||
/// Return the message command as a static string reference.
|
||||
///
|
||||
/// This returns `"unknown"` for [NetworkMessage::Unknown],
|
||||
|
@ -328,51 +344,59 @@ impl<'a> Encodable for HeaderSerializationWrapper<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl Encodable for RawNetworkMessage {
|
||||
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
|
||||
let mut len = 0;
|
||||
len += self.magic.consensus_encode(w)?;
|
||||
len += self.command().consensus_encode(w)?;
|
||||
len += CheckedData(match self.payload {
|
||||
NetworkMessage::Version(ref dat) => serialize(dat),
|
||||
NetworkMessage::Addr(ref dat) => serialize(dat),
|
||||
NetworkMessage::Inv(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetData(ref dat) => serialize(dat),
|
||||
NetworkMessage::NotFound(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetBlocks(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetHeaders(ref dat) => serialize(dat),
|
||||
NetworkMessage::Tx(ref dat) => serialize(dat),
|
||||
NetworkMessage::Block(ref dat) => serialize(dat),
|
||||
NetworkMessage::Headers(ref dat) => serialize(&HeaderSerializationWrapper(dat)),
|
||||
NetworkMessage::Ping(ref dat) => serialize(dat),
|
||||
NetworkMessage::Pong(ref dat) => serialize(dat),
|
||||
NetworkMessage::MerkleBlock(ref dat) => serialize(dat),
|
||||
NetworkMessage::FilterLoad(ref dat) => serialize(dat),
|
||||
NetworkMessage::FilterAdd(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetCFilters(ref dat) => serialize(dat),
|
||||
NetworkMessage::CFilter(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetCFHeaders(ref dat) => serialize(dat),
|
||||
NetworkMessage::CFHeaders(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat),
|
||||
NetworkMessage::CFCheckpt(ref dat) => serialize(dat),
|
||||
NetworkMessage::SendCmpct(ref dat) => serialize(dat),
|
||||
NetworkMessage::CmpctBlock(ref dat) => serialize(dat),
|
||||
NetworkMessage::GetBlockTxn(ref dat) => serialize(dat),
|
||||
NetworkMessage::BlockTxn(ref dat) => serialize(dat),
|
||||
NetworkMessage::Alert(ref dat) => serialize(dat),
|
||||
NetworkMessage::Reject(ref dat) => serialize(dat),
|
||||
NetworkMessage::FeeFilter(ref data) => serialize(data),
|
||||
NetworkMessage::AddrV2(ref dat) => serialize(dat),
|
||||
impl Encodable for NetworkMessage {
|
||||
fn consensus_encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<usize, io::Error> {
|
||||
match self {
|
||||
NetworkMessage::Version(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Addr(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Inv(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetData(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::NotFound(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetBlocks(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetHeaders(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Tx(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Block(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Headers(ref dat) =>
|
||||
HeaderSerializationWrapper(dat).consensus_encode(writer),
|
||||
NetworkMessage::Ping(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Pong(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::MerkleBlock(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::FilterLoad(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::FilterAdd(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetCFilters(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::CFilter(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetCFHeaders(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::CFHeaders(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetCFCheckpt(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::CFCheckpt(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::SendCmpct(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::CmpctBlock(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::GetBlockTxn(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::BlockTxn(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Alert(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Reject(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::FeeFilter(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::AddrV2(ref dat) => dat.consensus_encode(writer),
|
||||
NetworkMessage::Verack
|
||||
| NetworkMessage::SendHeaders
|
||||
| NetworkMessage::MemPool
|
||||
| NetworkMessage::GetAddr
|
||||
| NetworkMessage::WtxidRelay
|
||||
| NetworkMessage::FilterClear
|
||||
| NetworkMessage::SendAddrV2 => vec![],
|
||||
NetworkMessage::Unknown { payload: ref data, .. } => serialize(data),
|
||||
})
|
||||
.consensus_encode(w)?;
|
||||
| NetworkMessage::SendAddrV2 => Ok(0),
|
||||
NetworkMessage::Unknown { payload: ref data, .. } => data.consensus_encode(writer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encodable for RawNetworkMessage {
|
||||
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
|
||||
let mut len = 0;
|
||||
len += self.magic.consensus_encode(w)?;
|
||||
len += self.command().consensus_encode(w)?;
|
||||
len += self.payload_len.consensus_encode(w)?;
|
||||
len += self.checksum.consensus_encode(w)?;
|
||||
len += self.payload().consensus_encode(w)?;
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
@ -411,7 +435,10 @@ impl Decodable for RawNetworkMessage {
|
|||
) -> Result<Self, encode::Error> {
|
||||
let magic = Decodable::consensus_decode_from_finite_reader(r)?;
|
||||
let cmd = CommandString::consensus_decode_from_finite_reader(r)?;
|
||||
let raw_payload = CheckedData::consensus_decode_from_finite_reader(r)?.0;
|
||||
let checked_data = CheckedData::consensus_decode_from_finite_reader(r)?;
|
||||
let checksum = checked_data.checksum();
|
||||
let raw_payload = checked_data.into_data();
|
||||
let payload_len = raw_payload.len() as u32;
|
||||
|
||||
let mut mem_d = io::Cursor::new(raw_payload);
|
||||
let payload = match &cmd.0[..] {
|
||||
|
@ -498,7 +525,7 @@ impl Decodable for RawNetworkMessage {
|
|||
"sendaddrv2" => NetworkMessage::SendAddrV2,
|
||||
_ => NetworkMessage::Unknown { command: cmd, payload: mem_d.into_inner() },
|
||||
};
|
||||
Ok(RawNetworkMessage { magic, payload })
|
||||
Ok(RawNetworkMessage { magic, payload, payload_len, checksum })
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -642,8 +669,7 @@ mod test {
|
|||
];
|
||||
|
||||
for msg in msgs {
|
||||
let raw_msg =
|
||||
RawNetworkMessage { magic: Magic::from_bytes([57, 0, 0, 0]), payload: msg };
|
||||
let raw_msg = RawNetworkMessage::new(Magic::from_bytes([57, 0, 0, 0]), msg);
|
||||
assert_eq!(deserialize::<RawNetworkMessage>(&serialize(&raw_msg)).unwrap(), raw_msg);
|
||||
}
|
||||
}
|
||||
|
@ -676,7 +702,7 @@ mod test {
|
|||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn serialize_verack_test() {
|
||||
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Verack }),
|
||||
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Verack)),
|
||||
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x76, 0x65, 0x72, 0x61,
|
||||
0x63, 0x6B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
|
||||
|
@ -685,7 +711,7 @@ mod test {
|
|||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn serialize_ping_test() {
|
||||
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::Ping(100) }),
|
||||
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::Ping(100))),
|
||||
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x70, 0x69, 0x6e, 0x67,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x08, 0x00, 0x00, 0x00, 0x24, 0x67, 0xf1, 0x1d,
|
||||
|
@ -695,7 +721,7 @@ mod test {
|
|||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn serialize_mempool_test() {
|
||||
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::MemPool }),
|
||||
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::MemPool)),
|
||||
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x6d, 0x65, 0x6d, 0x70,
|
||||
0x6f, 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
|
||||
|
@ -704,7 +730,7 @@ mod test {
|
|||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn serialize_getaddr_test() {
|
||||
assert_eq!(serialize(&RawNetworkMessage { magic: Magic::from(Network::Bitcoin), payload: NetworkMessage::GetAddr }),
|
||||
assert_eq!(serialize(&RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr)),
|
||||
vec![0xf9, 0xbe, 0xb4, 0xd9, 0x67, 0x65, 0x74, 0x61,
|
||||
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2]);
|
||||
|
@ -718,10 +744,8 @@ mod test {
|
|||
0x64, 0x64, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x5d, 0xf6, 0xe0, 0xe2
|
||||
]);
|
||||
let preimage = RawNetworkMessage {
|
||||
magic: Magic::from(Network::Bitcoin),
|
||||
payload: NetworkMessage::GetAddr,
|
||||
};
|
||||
let preimage =
|
||||
RawNetworkMessage::new(Magic::from(Network::Bitcoin), NetworkMessage::GetAddr);
|
||||
assert!(msg.is_ok());
|
||||
let msg: RawNetworkMessage = msg.unwrap();
|
||||
assert_eq!(preimage.magic, msg.magic);
|
||||
|
|
Loading…
Reference in New Issue