diff --git a/src/network/message.rs b/src/network/message.rs index 3ed4c25c..ca954685 100644 --- a/src/network/message.rs +++ b/src/network/message.rs @@ -19,7 +19,7 @@ //! also defines (de)serialization routines for many primitives. //! -use std::{io, iter, mem, fmt}; +use std::{fmt, io, iter, mem, str}; use std::borrow::Cow; use std::io::Cursor; @@ -42,24 +42,31 @@ pub const MAX_INV_SIZE: usize = 50_000; #[derive(PartialEq, Eq, Clone, Debug)] pub struct CommandString(Cow<'static, str>); +impl CommandString { + /// Convert from various string types into a [CommandString]. + /// + /// Supported types are: + /// - `&'static str` + /// - `String` + /// + /// Returns an empty error if and only if the string is + /// larger than 12 characters in length. + pub fn try_from>>(s: S) -> Result { + let cow = s.into(); + if cow.as_ref().len() > 12 { + Err(()) + } else { + Ok(CommandString(cow)) + } + } +} + impl fmt::Display for CommandString { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str(self.0.as_ref()) } } -impl From<&'static str> for CommandString { - fn from(f: &'static str) -> Self { - CommandString(f.into()) - } -} - -impl From for CommandString { - fn from(f: String) -> Self { - CommandString(f.into()) - } -} - impl AsRef for CommandString { fn as_ref(&self) -> &str { self.0.as_ref() @@ -74,9 +81,7 @@ impl Encodable for CommandString { ) -> Result { let mut rawbytes = [0u8; 12]; let strbytes = self.0.as_bytes(); - if strbytes.len() > 12 { - return Err(encode::Error::NetworkCommandTooLong(self.0.clone().into_owned())); - } + debug_assert!(strbytes.len() <= 12); rawbytes[..strbytes.len()].clone_from_slice(&strbytes[..]); rawbytes.consensus_encode(s) } @@ -207,7 +212,7 @@ impl NetworkMessage { /// Return the CommandString for the message command. pub fn command(&self) -> CommandString { - self.cmd().into() + CommandString::try_from(self.cmd()).expect("cmd returns valid commands") } } @@ -356,11 +361,10 @@ impl Decodable for RawNetworkMessage { #[cfg(test)] mod test { - use std::io; use std::net::Ipv4Addr; use super::{RawNetworkMessage, NetworkMessage, CommandString}; use network::constants::ServiceFlags; - use consensus::encode::{Encodable, deserialize, deserialize_partial, serialize}; + use consensus::encode::{deserialize, deserialize_partial, serialize}; use hashes::hex::FromHex; use hashes::sha256d::Hash; use hashes::Hash as HashTrait; @@ -407,7 +411,7 @@ mod test { NetworkMessage::GetCFCheckpt(GetCFCheckpt{filter_type: 17, stop_hash: hash([25u8; 32]).into()}), NetworkMessage::CFCheckpt(CFCheckpt{filter_type: 27, stop_hash: hash([77u8; 32]).into(), filter_headers: vec![hash([3u8; 32]).into(), hash([99u8; 32]).into()]}), NetworkMessage::Alert(vec![45,66,3,2,6,8,9,12,3,130]), - NetworkMessage::Reject(Reject{message: "Test reject".into(), ccode: RejectReason::Duplicate, reason: "Cause".into(), hash: hash([255u8; 32])}), + NetworkMessage::Reject(Reject{message: CommandString::try_from("Test reject").unwrap(), ccode: RejectReason::Duplicate, reason: "Cause".into(), hash: hash([255u8; 32])}), NetworkMessage::FeeFilter(1000), NetworkMessage::WtxidRelay, NetworkMessage::AddrV2(vec![AddrV2Message{ addr: AddrV2::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), port: 0, services: ServiceFlags::NONE, time: 0 }]), @@ -422,21 +426,20 @@ mod test { } #[test] - fn serialize_commandstring_test() { + fn commandstring_test() { + // Test converting. + assert_eq!(CommandString::try_from("AndrewAndrew").unwrap().as_ref(), "AndrewAndrew"); + assert!(CommandString::try_from("AndrewAndrewA").is_err()); + + // Test serializing. let cs = CommandString("Andrew".into()); assert_eq!(serialize(&cs), vec![0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0, 0]); - // Test oversized one. - let mut encoder = io::Cursor::new(vec![]); - assert!(CommandString("AndrewAndrewA".into()).consensus_encode(&mut encoder).is_err()); - } - - #[test] - fn deserialize_commandstring_test() { + // Test deserializing let cs: Result = deserialize(&[0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0, 0]); assert!(cs.is_ok()); assert_eq!(cs.as_ref().unwrap().to_string(), "Andrew".to_owned()); - assert_eq!(cs.unwrap(), "Andrew".into()); + assert_eq!(cs.unwrap(), CommandString::try_from("Andrew").unwrap()); let short_cs: Result = deserialize(&[0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0]); assert!(short_cs.is_err());