From 1f08a313e5b210c55c2d79e9e69d62abc6b0d99b Mon Sep 17 00:00:00 2001 From: Elichai Turkel Date: Fri, 15 May 2020 15:37:08 +0300 Subject: [PATCH] Replace serde macros with generic visitor module Co-authored-by: Elichai Turkel Co-authored-by: Sebastian Geisler --- src/ecdh.rs | 2 +- src/key.rs | 66 +++++++++++++++++++++++++- src/lib.rs | 24 ++++++---- src/macros.rs | 118 ---------------------------------------------- src/schnorrsig.rs | 55 ++++++++++++++++++++- src/serde_util.rs | 76 +++++++++++++++++++++++++++++ 6 files changed, 210 insertions(+), 131 deletions(-) create mode 100644 src/serde_util.rs diff --git a/src/ecdh.rs b/src/ecdh.rs index 51ec005..a2aa121 100644 --- a/src/ecdh.rs +++ b/src/ecdh.rs @@ -194,7 +194,7 @@ mod tests { let s = Secp256k1::signing_only(); let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); let (sk2, pk2) = s.generate_keypair(&mut thread_rng()); - + let sec1 = SharedSecret::new_with_hash(&pk1, &sk2, |x,_| x.into()); let sec2 = SharedSecret::new_with_hash(&pk2, &sk1, |x,_| x.into()); let sec_odd = SharedSecret::new_with_hash(&pk1, &sk1, |x,_| x.into()); diff --git a/src/key.rs b/src/key.rs index 2492f04..b7f854f 100644 --- a/src/key.rs +++ b/src/key.rs @@ -213,7 +213,32 @@ impl SecretKey { } } -serde_impl!(SecretKey, constants::SECRET_KEY_SIZE); +#[cfg(feature = "serde")] +impl ::serde::Serialize for SecretKey { + fn serialize(&self, s: S) -> Result { + if s.is_human_readable() { + s.collect_str(self) + } else { + s.serialize_bytes(&self[..]) + } + } +} + +#[cfg(feature = "serde")] +impl<'de> ::serde::Deserialize<'de> for SecretKey { + fn deserialize>(d: D) -> Result { + if d.is_human_readable() { + d.deserialize_str(super::serde_util::HexVisitor::new( + "a hex string representing 32 byte SecretKey" + )) + } else { + d.deserialize_bytes(super::serde_util::BytesVisitor::new( + "raw 32 bytes SecretKey", + SecretKey::from_slice + )) + } + } +} impl PublicKey { /// Obtains a raw const pointer suitable for use with FFI functions @@ -402,7 +427,32 @@ impl From for PublicKey { } } -serde_impl_from_slice!(PublicKey); +#[cfg(feature = "serde")] +impl ::serde::Serialize for PublicKey { + fn serialize(&self, s: S) -> Result { + if s.is_human_readable() { + s.collect_str(self) + } else { + s.serialize_bytes(&self.serialize()) + } + } +} + +#[cfg(feature = "serde")] +impl<'de> ::serde::Deserialize<'de> for PublicKey { + fn deserialize>(d: D) -> Result { + if d.is_human_readable() { + d.deserialize_str(super::serde_util::HexVisitor::new( + "an ASCII hex string representing a public key" + )) + } else { + d.deserialize_bytes(super::serde_util::BytesVisitor::new( + "a bytestring representing a public key", + PublicKey::from_slice + )) + } + } +} impl PartialOrd for PublicKey { fn partial_cmp(&self, other: &PublicKey) -> Option<::core::cmp::Ordering> { @@ -846,8 +896,20 @@ mod test { let pk = PublicKey::from_secret_key(&s, &sk); assert_tokens(&sk.compact(), &[Token::BorrowedBytes(&SK_BYTES[..])]); + assert_tokens(&sk.compact(), &[Token::Bytes(&SK_BYTES)]); + assert_tokens(&sk.compact(), &[Token::ByteBuf(&SK_BYTES)]); + assert_tokens(&sk.readable(), &[Token::BorrowedStr(SK_STR)]); + assert_tokens(&sk.readable(), &[Token::Str(SK_STR)]); + assert_tokens(&sk.readable(), &[Token::String(SK_STR)]); + assert_tokens(&pk.compact(), &[Token::BorrowedBytes(&PK_BYTES[..])]); + assert_tokens(&pk.compact(), &[Token::Bytes(&PK_BYTES)]); + assert_tokens(&pk.compact(), &[Token::ByteBuf(&PK_BYTES)]); + assert_tokens(&pk.readable(), &[Token::BorrowedStr(PK_STR)]); + assert_tokens(&pk.readable(), &[Token::Str(PK_STR)]); + assert_tokens(&pk.readable(), &[Token::String(PK_STR)]); + } } diff --git a/src/lib.rs b/src/lib.rs index 675e956..c787eb2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,6 +147,8 @@ pub mod key; pub mod schnorrsig; #[cfg(feature = "recovery")] pub mod recovery; +#[cfg(feature = "serde")] +mod serde_util; pub use key::SecretKey; pub use key::PublicKey; @@ -435,21 +437,21 @@ impl ::serde::Serialize for Signature { } else { s.serialize_bytes(&self.serialize_der()) } - } } #[cfg(feature = "serde")] impl<'de> ::serde::Deserialize<'de> for Signature { - fn deserialize>(d: D) -> Result { - use ::serde::de::Error; - use str::FromStr; + fn deserialize>(d: D) -> Result { if d.is_human_readable() { - let sl: &str = ::serde::Deserialize::deserialize(d)?; - Signature::from_str(sl).map_err(D::Error::custom) + d.deserialize_str(serde_util::HexVisitor::new( + "a hex string representing a DER encoded Signature" + )) } else { - let sl: &[u8] = ::serde::Deserialize::deserialize(d)?; - Signature::from_der(sl).map_err(D::Error::custom) + d.deserialize_bytes(serde_util::BytesVisitor::new( + "raw byte stream, that represents a DER encoded Signature", + Signature::from_der + )) } } } @@ -1260,7 +1262,13 @@ mod tests { "; assert_tokens(&sig.compact(), &[Token::BorrowedBytes(&SIG_BYTES[..])]); + assert_tokens(&sig.compact(), &[Token::Bytes(&SIG_BYTES)]); + assert_tokens(&sig.compact(), &[Token::ByteBuf(&SIG_BYTES)]); + assert_tokens(&sig.readable(), &[Token::BorrowedStr(SIG_STR)]); + assert_tokens(&sig.readable(), &[Token::Str(SIG_STR)]); + assert_tokens(&sig.readable(), &[Token::String(SIG_STR)]); + } #[cfg(feature = "global-context")] diff --git a/src/macros.rs b/src/macros.rs index a8258be..bfd41b7 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -43,121 +43,3 @@ macro_rules! impl_from_array_len { )+ } } - -#[cfg(feature="serde")] -/// Implements `Serialize` and `Deserialize` for a type `$t` which represents -/// a newtype over a byte-slice over length `$len`. Type `$t` must implement -/// the `FromStr` and `Display` trait. -macro_rules! serde_impl( - ($t:ident, $len:expr) => ( - impl ::serde::Serialize for $t { - fn serialize(&self, s: S) -> Result { - if s.is_human_readable() { - s.collect_str(self) - } else { - s.serialize_bytes(&self[..]) - } - } - } - - impl<'de> ::serde::Deserialize<'de> for $t { - fn deserialize>(d: D) -> Result<$t, D::Error> { - use ::serde::de::Error; - use core::str::FromStr; - - if d.is_human_readable() { - let sl: &str = ::serde::Deserialize::deserialize(d)?; - $t::from_str(sl).map_err(D::Error::custom) - } else { - let sl: &[u8] = ::serde::Deserialize::deserialize(d)?; - if sl.len() != $len { - Err(D::Error::invalid_length(sl.len(), &stringify!($len))) - } else { - let mut ret = [0; $len]; - ret.copy_from_slice(sl); - Ok($t(ret)) - } - } - } - } - ) -); - -#[cfg(not(feature="serde"))] -macro_rules! serde_impl( - ($t:ident, $len:expr) => () -); - -#[cfg(feature = "serde")] -macro_rules! serde_impl_from_slice { - ($t: ident) => { - impl ::serde::Serialize for $t { - fn serialize(&self, s: S) -> Result { - if s.is_human_readable() { - s.collect_str(self) - } else { - s.serialize_bytes(&self.serialize()) - } - } - } - - impl<'de> ::serde::Deserialize<'de> for $t { - fn deserialize>(d: D) -> Result<$t, D::Error> { - if d.is_human_readable() { - struct HexVisitor; - - impl<'de> ::serde::de::Visitor<'de> for HexVisitor { - type Value = $t; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an ASCII hex string") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - if let Ok(hex) = str::from_utf8(v) { - str::FromStr::from_str(hex).map_err(E::custom) - } else { - Err(E::invalid_value(::serde::de::Unexpected::Bytes(v), &self)) - } - } - - fn visit_str(self, v: &str) -> Result - where - E: ::serde::de::Error, - { - str::FromStr::from_str(v).map_err(E::custom) - } - } - d.deserialize_str(HexVisitor) - } else { - struct BytesVisitor; - - impl<'de> ::serde::de::Visitor<'de> for BytesVisitor { - type Value = $t; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a bytestring") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: ::serde::de::Error, - { - $t::from_slice(v).map_err(E::custom) - } - } - - d.deserialize_bytes(BytesVisitor) - } - } - } - }; -} - -#[cfg(not(feature = "serde"))] -macro_rules! serde_impl_from_slice( - ($t:ident) => () -); diff --git a/src/schnorrsig.rs b/src/schnorrsig.rs index 12ab2f2..3cd80e1 100644 --- a/src/schnorrsig.rs +++ b/src/schnorrsig.rs @@ -18,7 +18,33 @@ use {Message, Signing, Verification}; pub struct Signature([u8; constants::SCHNORRSIG_SIGNATURE_SIZE]); impl_array_newtype!(Signature, u8, constants::SCHNORRSIG_SIGNATURE_SIZE); impl_pretty_debug!(Signature); -serde_impl!(Signature, constants::SCHNORRSIG_SIGNATURE_SIZE); + +#[cfg(feature = "serde")] +impl ::serde::Serialize for Signature { + fn serialize(&self, s: S) -> Result { + if s.is_human_readable() { + s.collect_str(self) + } else { + s.serialize_bytes(&self[..]) + } + } +} + +#[cfg(feature = "serde")] +impl<'de> ::serde::Deserialize<'de> for Signature { + fn deserialize>(d: D) -> Result { + if d.is_human_readable() { + d.deserialize_str(super::serde_util::HexVisitor::new( + "a hex string representing 64 byte schnorr signature" + )) + } else { + d.deserialize_bytes(super::serde_util::BytesVisitor::new( + "raw 64 bytes schnorr signature", + Signature::from_slice + )) + } + } +} impl fmt::LowerHex for Signature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -376,7 +402,32 @@ impl From<::key::PublicKey> for PublicKey { } } -serde_impl_from_slice!(PublicKey); +#[cfg(feature = "serde")] +impl ::serde::Serialize for PublicKey { + fn serialize(&self, s: S) -> Result { + if s.is_human_readable() { + s.collect_str(self) + } else { + s.serialize_bytes(&self.serialize()) + } + } +} + +#[cfg(feature = "serde")] +impl<'de> ::serde::Deserialize<'de> for PublicKey { + fn deserialize>(d: D) -> Result { + if d.is_human_readable() { + d.deserialize_str(super::serde_util::HexVisitor::new( + "a hex string representing 32 byte schnorr public key" + )) + } else { + d.deserialize_bytes(super::serde_util::BytesVisitor::new( + "raw 32 bytes schnorr public key", + PublicKey::from_slice + )) + } + } +} impl Secp256k1 { fn schnorrsig_sign_helper( diff --git a/src/serde_util.rs b/src/serde_util.rs new file mode 100644 index 0000000..5034416 --- /dev/null +++ b/src/serde_util.rs @@ -0,0 +1,76 @@ +use core::fmt; +use core::marker::PhantomData; +use core::str::{self, FromStr}; +use serde::de; + +pub struct HexVisitor { + expectation: &'static str, + _pd: PhantomData, +} + +impl HexVisitor { + pub fn new(expectation: &'static str) -> Self { + HexVisitor { + expectation, + _pd: PhantomData, + } + } +} + +impl<'de, T> de::Visitor<'de> for HexVisitor +where + T: FromStr, + ::Err: fmt::Display, +{ + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(self.expectation) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + if let Ok(hex) = str::from_utf8(v) { + FromStr::from_str(hex).map_err(E::custom) + } else { + Err(E::invalid_value(de::Unexpected::Bytes(v), &self)) + } + } + + fn visit_str(self, v: &str) -> Result { + FromStr::from_str(v).map_err(E::custom) + } +} + +pub struct BytesVisitor { + expectation: &'static str, + parse_fn: F, +} + +impl BytesVisitor +where + F: FnOnce(&[u8]) -> Result, + Err: fmt::Display, +{ + pub fn new(expectation: &'static str, parse_fn: F) -> Self { + BytesVisitor { + expectation, + parse_fn, + } + } +} + +impl<'de, F, T, Err> de::Visitor<'de> for BytesVisitor +where + F: FnOnce(&[u8]) -> Result, + Err: fmt::Display, +{ + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(self.expectation) + } + + fn visit_bytes(self, v: &[u8]) -> Result { + (self.parse_fn)(v).map_err(E::custom) + } +}