Clean up CommandString

- Add length invariant.
- Siimplify constructors.
This commit is contained in:
Steven Roose 2020-10-09 17:03:43 +02:00
parent a6264cfca6
commit 944371d6a2
No known key found for this signature in database
GPG Key ID: 2F2A88D7F8D68E87
1 changed files with 32 additions and 29 deletions

View File

@ -19,7 +19,7 @@
//! also defines (de)serialization routines for many primitives. //! also defines (de)serialization routines for many primitives.
//! //!
use std::{io, iter, mem, fmt}; use std::{fmt, io, iter, mem, str};
use std::borrow::Cow; use std::borrow::Cow;
use std::io::Cursor; use std::io::Cursor;
@ -42,24 +42,31 @@ pub const MAX_INV_SIZE: usize = 50_000;
#[derive(PartialEq, Eq, Clone, Debug)] #[derive(PartialEq, Eq, Clone, Debug)]
pub struct CommandString(Cow<'static, str>); pub struct CommandString(Cow<'static, str>);
impl CommandString {
/// Convert from various string types into a [CommandString].
///
/// Supported types are:
/// - `&'static str`
/// - `String`
///
/// Returns an empty error if and only if the string is
/// larger than 12 characters in length.
pub fn try_from<S: Into<Cow<'static, str>>>(s: S) -> Result<CommandString, ()> {
let cow = s.into();
if cow.as_ref().len() > 12 {
Err(())
} else {
Ok(CommandString(cow))
}
}
}
impl fmt::Display for CommandString { impl fmt::Display for CommandString {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.0.as_ref()) f.write_str(self.0.as_ref())
} }
} }
impl From<&'static str> for CommandString {
fn from(f: &'static str) -> Self {
CommandString(f.into())
}
}
impl From<String> for CommandString {
fn from(f: String) -> Self {
CommandString(f.into())
}
}
impl AsRef<str> for CommandString { impl AsRef<str> for CommandString {
fn as_ref(&self) -> &str { fn as_ref(&self) -> &str {
self.0.as_ref() self.0.as_ref()
@ -74,9 +81,7 @@ impl Encodable for CommandString {
) -> Result<usize, encode::Error> { ) -> Result<usize, encode::Error> {
let mut rawbytes = [0u8; 12]; let mut rawbytes = [0u8; 12];
let strbytes = self.0.as_bytes(); let strbytes = self.0.as_bytes();
if strbytes.len() > 12 { debug_assert!(strbytes.len() <= 12);
return Err(encode::Error::NetworkCommandTooLong(self.0.clone().into_owned()));
}
rawbytes[..strbytes.len()].clone_from_slice(&strbytes[..]); rawbytes[..strbytes.len()].clone_from_slice(&strbytes[..]);
rawbytes.consensus_encode(s) rawbytes.consensus_encode(s)
} }
@ -207,7 +212,7 @@ impl NetworkMessage {
/// Return the CommandString for the message command. /// Return the CommandString for the message command.
pub fn command(&self) -> CommandString { pub fn command(&self) -> CommandString {
self.cmd().into() CommandString::try_from(self.cmd()).expect("cmd returns valid commands")
} }
} }
@ -356,11 +361,10 @@ impl Decodable for RawNetworkMessage {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::io;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use super::{RawNetworkMessage, NetworkMessage, CommandString}; use super::{RawNetworkMessage, NetworkMessage, CommandString};
use network::constants::ServiceFlags; use network::constants::ServiceFlags;
use consensus::encode::{Encodable, deserialize, deserialize_partial, serialize}; use consensus::encode::{deserialize, deserialize_partial, serialize};
use hashes::hex::FromHex; use hashes::hex::FromHex;
use hashes::sha256d::Hash; use hashes::sha256d::Hash;
use hashes::Hash as HashTrait; use hashes::Hash as HashTrait;
@ -407,7 +411,7 @@ mod test {
NetworkMessage::GetCFCheckpt(GetCFCheckpt{filter_type: 17, stop_hash: hash([25u8; 32]).into()}), NetworkMessage::GetCFCheckpt(GetCFCheckpt{filter_type: 17, stop_hash: hash([25u8; 32]).into()}),
NetworkMessage::CFCheckpt(CFCheckpt{filter_type: 27, stop_hash: hash([77u8; 32]).into(), filter_headers: vec![hash([3u8; 32]).into(), hash([99u8; 32]).into()]}), NetworkMessage::CFCheckpt(CFCheckpt{filter_type: 27, stop_hash: hash([77u8; 32]).into(), filter_headers: vec![hash([3u8; 32]).into(), hash([99u8; 32]).into()]}),
NetworkMessage::Alert(vec![45,66,3,2,6,8,9,12,3,130]), NetworkMessage::Alert(vec![45,66,3,2,6,8,9,12,3,130]),
NetworkMessage::Reject(Reject{message: "Test reject".into(), ccode: RejectReason::Duplicate, reason: "Cause".into(), hash: hash([255u8; 32])}), NetworkMessage::Reject(Reject{message: CommandString::try_from("Test reject").unwrap(), ccode: RejectReason::Duplicate, reason: "Cause".into(), hash: hash([255u8; 32])}),
NetworkMessage::FeeFilter(1000), NetworkMessage::FeeFilter(1000),
NetworkMessage::WtxidRelay, NetworkMessage::WtxidRelay,
NetworkMessage::AddrV2(vec![AddrV2Message{ addr: AddrV2::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), port: 0, services: ServiceFlags::NONE, time: 0 }]), NetworkMessage::AddrV2(vec![AddrV2Message{ addr: AddrV2::Ipv4(Ipv4Addr::new(127, 0, 0, 1)), port: 0, services: ServiceFlags::NONE, time: 0 }]),
@ -422,21 +426,20 @@ mod test {
} }
#[test] #[test]
fn serialize_commandstring_test() { fn commandstring_test() {
// Test converting.
assert_eq!(CommandString::try_from("AndrewAndrew").unwrap().as_ref(), "AndrewAndrew");
assert!(CommandString::try_from("AndrewAndrewA").is_err());
// Test serializing.
let cs = CommandString("Andrew".into()); let cs = CommandString("Andrew".into());
assert_eq!(serialize(&cs), vec![0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0, 0]); assert_eq!(serialize(&cs), vec![0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0, 0]);
// Test oversized one. // Test deserializing
let mut encoder = io::Cursor::new(vec![]);
assert!(CommandString("AndrewAndrewA".into()).consensus_encode(&mut encoder).is_err());
}
#[test]
fn deserialize_commandstring_test() {
let cs: Result<CommandString, _> = deserialize(&[0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0, 0]); let cs: Result<CommandString, _> = deserialize(&[0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0, 0]);
assert!(cs.is_ok()); assert!(cs.is_ok());
assert_eq!(cs.as_ref().unwrap().to_string(), "Andrew".to_owned()); assert_eq!(cs.as_ref().unwrap().to_string(), "Andrew".to_owned());
assert_eq!(cs.unwrap(), "Andrew".into()); assert_eq!(cs.unwrap(), CommandString::try_from("Andrew").unwrap());
let short_cs: Result<CommandString, _> = deserialize(&[0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0]); let short_cs: Result<CommandString, _> = deserialize(&[0x41u8, 0x6e, 0x64, 0x72, 0x65, 0x77, 0, 0, 0, 0, 0]);
assert!(short_cs.is_err()); assert!(short_cs.is_err());