Add support for serde (de)serialization; add unit tests

This commit is contained in:
Andrew Poelstra 2015-04-10 00:32:12 -05:00
parent 1b2858bc8a
commit ac61baf040
4 changed files with 167 additions and 1 deletions

View File

@ -16,4 +16,5 @@ git = "https://github.com/DaGenix/rust-crypto.git"
rand = "*" rand = "*"
libc = "*" libc = "*"
rustc-serialize = "*" rustc-serialize = "*"
serde = "*"

View File

@ -16,9 +16,10 @@
//! Public/Private keys //! Public/Private keys
use std::intrinsics::copy_nonoverlapping; use std::intrinsics::copy_nonoverlapping;
use std::{fmt, ops}; use std::{fmt, marker, ops};
use rand::Rng; use rand::Rng;
use serialize::{Decoder, Decodable, Encoder, Encodable}; use serialize::{Decoder, Decodable, Encoder, Encodable};
use serde::{Serialize, Deserialize, Serializer, Deserializer};
use super::init; use super::init;
use super::Error::{self, InvalidPublicKey, InvalidSecretKey, Unknown}; use super::Error::{self, InvalidPublicKey, InvalidSecretKey, Unknown};
@ -410,6 +411,64 @@ impl Encodable for PublicKey {
} }
} }
impl Deserialize for PublicKey {
fn deserialize<D>(d: &mut D) -> Result<PublicKey, D::Error>
where D: Deserializer
{
use serde::de;
struct Visitor {
marker: marker::PhantomData<PublicKey>,
}
impl de::Visitor for Visitor {
type Value = PublicKey;
#[inline]
fn visit_seq<V>(&mut self, mut v: V) -> Result<PublicKey, V::Error>
where V: de::SeqVisitor
{
assert!(constants::UNCOMPRESSED_PUBLIC_KEY_SIZE >= constants::COMPRESSED_PUBLIC_KEY_SIZE);
unsafe {
use std::mem;
let mut ret_u: [u8; constants::UNCOMPRESSED_PUBLIC_KEY_SIZE] = mem::uninitialized();
let mut ret_c: [u8; constants::COMPRESSED_PUBLIC_KEY_SIZE] = mem::uninitialized();
let mut read_len = 0;
while read_len < constants::UNCOMPRESSED_PUBLIC_KEY_SIZE {
let read_ch = match try!(v.visit()) {
Some(c) => c,
None => break
};
ret_u[read_len] = read_ch;
if read_len < constants::COMPRESSED_PUBLIC_KEY_SIZE { ret_c[read_len] = read_ch; }
read_len += 1;
}
try!(v.end());
if read_len == constants::UNCOMPRESSED_PUBLIC_KEY_SIZE {
Ok(PublicKey(PublicKeyData::Uncompressed(ret_u)))
} else if read_len == constants::COMPRESSED_PUBLIC_KEY_SIZE {
Ok(PublicKey(PublicKeyData::Compressed(ret_c)))
} else {
return Err(de::Error::syntax_error());
}
}
}
}
// Begin actual function
d.visit(Visitor { marker: ::std::marker::PhantomData })
}
}
impl Serialize for PublicKey {
fn serialize<S>(&self, s: &mut S) -> Result<(), S::Error>
where S: Serializer
{
(&self.0[..]).serialize(s)
}
}
impl fmt::Debug for SecretKey { impl fmt::Debug for SecretKey {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
(&self[..]).fmt(f) (&self[..]).fmt(f)
@ -478,6 +537,66 @@ mod test {
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x41]).is_err()); 0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x41]).is_err());
} }
#[test]
fn test_serialize() {
use std::io::Cursor;
use serialize::{json, Decodable, Encodable};
macro_rules! round_trip (
($var:ident) => ({
let start = $var;
let mut encoded = String::new();
{
let mut encoder = json::Encoder::new(&mut encoded);
start.encode(&mut encoder).unwrap();
}
let json = json::Json::from_reader(&mut Cursor::new(encoded.as_bytes())).unwrap();
let mut decoder = json::Decoder::new(json);
let decoded = Decodable::decode(&mut decoder);
assert_eq!(Some(start), decoded.ok());
})
);
let mut s = Secp256k1::new().unwrap();
for _ in 0..500 {
let (sk, pk) = s.generate_keypair(false);
round_trip!(sk);
round_trip!(pk);
let (sk, pk) = s.generate_keypair(true);
round_trip!(sk);
round_trip!(pk);
}
}
#[test]
fn test_serialize_serde() {
use serde::{json, Serialize, Deserialize};
macro_rules! round_trip (
($var:ident) => ({
let start = $var;
let mut encoded = Vec::new();
{
let mut serializer = json::ser::Serializer::new(&mut encoded);
start.serialize(&mut serializer).unwrap();
}
let mut deserializer = json::de::Deserializer::new(encoded.iter().map(|c| Ok(*c))).unwrap();
let decoded = Deserialize::deserialize(&mut deserializer);
assert_eq!(Some(start), decoded.ok());
})
);
let mut s = Secp256k1::new().unwrap();
for _ in 0..500 {
let (sk, pk) = s.generate_keypair(false);
round_trip!(sk);
round_trip!(pk);
let (sk, pk) = s.generate_keypair(true);
round_trip!(sk);
round_trip!(pk);
}
}
#[test] #[test]
fn test_addition() { fn test_addition() {
let mut s = Secp256k1::new().unwrap(); let mut s = Secp256k1::new().unwrap();

View File

@ -37,6 +37,7 @@
extern crate crypto; extern crate crypto;
extern crate rustc_serialize as serialize; extern crate rustc_serialize as serialize;
extern crate serde;
#[cfg(test)] extern crate test; #[cfg(test)] extern crate test;
extern crate libc; extern crate libc;

View File

@ -141,6 +141,51 @@ macro_rules! impl_array_newtype {
self[..].encode(s) self[..].encode(s)
} }
} }
impl ::serde::Deserialize for $thing {
fn deserialize<D>(d: &mut D) -> Result<$thing, D::Error>
where D: ::serde::Deserializer
{
// We have to define the Visitor struct inside the function
// to make it local ... all we really need is that it's
// local to the macro, but this works too :)
struct Visitor {
marker: ::std::marker::PhantomData<$thing>,
}
impl ::serde::de::Visitor for Visitor {
type Value = $thing;
#[inline]
fn visit_seq<V>(&mut self, mut v: V) -> Result<$thing, V::Error>
where V: ::serde::de::SeqVisitor
{
unsafe {
use std::mem;
let mut ret: [$ty; $len] = mem::uninitialized();
for i in 0..$len {
ret[i] = match try!(v.visit()) {
Some(c) => c,
None => return Err(::serde::de::Error::end_of_stream_error())
};
}
try!(v.end());
Ok($thing(ret))
}
}
}
// Begin actual function
d.visit(Visitor { marker: ::std::marker::PhantomData })
}
}
impl ::serde::Serialize for $thing {
fn serialize<S>(&self, s: &mut S) -> Result<(), S::Error>
where S: ::serde::Serializer
{
(&self.0[..]).serialize(s)
}
}
} }
} }