bip32: Introduce DerivationPath type

Implements Display and FromStr for easy usage with serialized types.
This commit is contained in:
Steven Roose 2019-02-07 21:12:43 +00:00
parent a80cea270a
commit b23de17d55
1 changed files with 159 additions and 3 deletions

View File

@ -167,6 +167,21 @@ impl fmt::Display for ChildNumber {
} }
} }
impl FromStr for ChildNumber {
type Err = Error;
fn from_str(inp: &str) -> Result<ChildNumber, Error> {
Ok(match inp.chars().last().map_or(false, |l| l == '\'' || l == 'h') {
true => ChildNumber::from_hardened_idx(
inp[0..inp.len() - 1].parse().map_err(|_| Error::InvalidChildNumberFormat)?
)?,
false => ChildNumber::from_normal_idx(
inp.parse().map_err(|_| Error::InvalidChildNumberFormat)?
)?,
})
}
}
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for ChildNumber { impl<'de> serde::Deserialize<'de> for ChildNumber {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@ -187,6 +202,92 @@ impl serde::Serialize for ChildNumber {
} }
} }
/// A BIP-32 derivation path.
#[derive(Clone, PartialEq, Eq)]
pub struct DerivationPath(pub Vec<ChildNumber>);
impl From<Vec<ChildNumber>> for DerivationPath {
fn from(numbers: Vec<ChildNumber>) -> Self {
DerivationPath(numbers)
}
}
impl Into<Vec<ChildNumber>> for DerivationPath {
fn into(self) -> Vec<ChildNumber> {
self.0
}
}
impl FromStr for DerivationPath {
type Err = Error;
fn from_str(path: &str) -> Result<DerivationPath, Error> {
let mut parts = path.split("/");
// First parts must be `m`.
if parts.next().unwrap() != "m" {
return Err(Error::InvalidDerivationPathFormat);
}
let ret: Result<Vec<ChildNumber>, Error> = parts.map(str::parse).collect();
Ok(DerivationPath(ret?))
}
}
impl fmt::Display for DerivationPath {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("m")?;
for cn in self.0.iter() {
f.write_str("/")?;
fmt::Display::fmt(cn, f)?;
}
Ok(())
}
}
impl fmt::Debug for DerivationPath {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self, f)
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for DerivationPath {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use std::fmt;
use serde::de;
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = DerivationPath;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a Bitcoin address")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
DerivationPath::from_str(v).map_err(E::custom)
}
fn visit_borrowed_str<E: de::Error>(self, v: &'de str) -> Result<Self::Value, E> {
self.visit_str(v)
}
fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
self.visit_str(&v)
}
}
deserializer.deserialize_str(Visitor)
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for DerivationPath {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.to_string())
}
}
/// A BIP32 error /// A BIP32 error
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub enum Error { pub enum Error {
@ -197,7 +298,11 @@ pub enum Error {
/// A child number was provided that was out of range /// A child number was provided that was out of range
InvalidChildNumber(u32), InvalidChildNumber(u32),
/// Error creating a master seed --- for application use /// Error creating a master seed --- for application use
RngError(String) RngError(String),
/// Invalid childnumber format.
InvalidChildNumberFormat,
/// Invalid derivation path format.
InvalidDerivationPathFormat,
} }
impl fmt::Display for Error { impl fmt::Display for Error {
@ -207,6 +312,8 @@ impl fmt::Display for Error {
Error::Ecdsa(ref e) => fmt::Display::fmt(e, f), Error::Ecdsa(ref e) => fmt::Display::fmt(e, f),
Error::InvalidChildNumber(ref n) => write!(f, "child number {} is invalid (not within [0, 2^31 - 1])", n), Error::InvalidChildNumber(ref n) => write!(f, "child number {} is invalid (not within [0, 2^31 - 1])", n),
Error::RngError(ref s) => write!(f, "rng error {}", s), Error::RngError(ref s) => write!(f, "rng error {}", s),
Error::InvalidChildNumberFormat => f.write_str("invalid child number format"),
Error::InvalidDerivationPathFormat => f.write_str("invalid derivation path format"),
} }
} }
} }
@ -225,7 +332,9 @@ impl error::Error for Error {
Error::CannotDeriveFromHardenedKey => "cannot derive hardened key from public key", Error::CannotDeriveFromHardenedKey => "cannot derive hardened key from public key",
Error::Ecdsa(ref e) => error::Error::description(e), Error::Ecdsa(ref e) => error::Error::description(e),
Error::InvalidChildNumber(_) => "child number is invalid", Error::InvalidChildNumber(_) => "child number is invalid",
Error::RngError(_) => "rng error" Error::RngError(_) => "rng error",
Error::InvalidChildNumberFormat => "invalid child number format",
Error::InvalidDerivationPathFormat => "invalid derivation path format",
} }
} }
} }
@ -497,10 +606,57 @@ mod tests {
use network::constants::Network::{self, Bitcoin}; use network::constants::Network::{self, Bitcoin};
use super::{ChildNumber, ExtendedPrivKey, ExtendedPubKey}; use super::{ChildNumber, DerivationPath, ExtendedPrivKey, ExtendedPubKey};
use super::ChildNumber::{Hardened, Normal}; use super::ChildNumber::{Hardened, Normal};
use super::Error; use super::Error;
#[test]
fn test_parse_derivation_path() {
assert_eq!(DerivationPath::from_str("42"), Err(Error::InvalidDerivationPathFormat));
assert_eq!(DerivationPath::from_str("n/0'/0"), Err(Error::InvalidDerivationPathFormat));
assert_eq!(DerivationPath::from_str("4/m/5"), Err(Error::InvalidDerivationPathFormat));
assert_eq!(DerivationPath::from_str("m//3/0'"), Err(Error::InvalidChildNumberFormat));
assert_eq!(DerivationPath::from_str("m/0h/0x"), Err(Error::InvalidChildNumberFormat));
assert_eq!(DerivationPath::from_str("m/2147483648"), Err(Error::InvalidChildNumber(2147483648)));
assert_eq!(DerivationPath::from_str("m"), Ok(vec![].into()));
assert_eq!(
DerivationPath::from_str("m/0'"),
Ok(vec![ChildNumber::from_hardened_idx(0).unwrap()].into())
);
assert_eq!(
DerivationPath::from_str("m/0'/1"),
Ok(vec![ChildNumber::from_hardened_idx(0).unwrap(), ChildNumber::from_normal_idx(1).unwrap()].into())
);
assert_eq!(
DerivationPath::from_str("m/0h/1/2'"),
Ok(vec![
ChildNumber::from_hardened_idx(0).unwrap(),
ChildNumber::from_normal_idx(1).unwrap(),
ChildNumber::from_hardened_idx(2).unwrap(),
].into())
);
assert_eq!(
DerivationPath::from_str("m/0'/1/2h/2"),
Ok(vec![
ChildNumber::from_hardened_idx(0).unwrap(),
ChildNumber::from_normal_idx(1).unwrap(),
ChildNumber::from_hardened_idx(2).unwrap(),
ChildNumber::from_normal_idx(2).unwrap(),
].into())
);
assert_eq!(
DerivationPath::from_str("m/0'/1/2'/2/1000000000"),
Ok(vec![
ChildNumber::from_hardened_idx(0).unwrap(),
ChildNumber::from_normal_idx(1).unwrap(),
ChildNumber::from_hardened_idx(2).unwrap(),
ChildNumber::from_normal_idx(2).unwrap(),
ChildNumber::from_normal_idx(1000000000).unwrap(),
].into())
);
}
fn test_path<C: secp256k1::Signing + secp256k1::Verification>(secp: &Secp256k1<C>, fn test_path<C: secp256k1::Signing + secp256k1::Verification>(secp: &Secp256k1<C>,
network: Network, network: Network,
seed: &[u8], seed: &[u8],