From b23de17d5554613e44b75b75e776b0dab82f9c62 Mon Sep 17 00:00:00 2001 From: Steven Roose Date: Thu, 7 Feb 2019 21:12:43 +0000 Subject: [PATCH] bip32: Introduce DerivationPath type Implements Display and FromStr for easy usage with serialized types. --- src/util/bip32.rs | 162 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 3 deletions(-) diff --git a/src/util/bip32.rs b/src/util/bip32.rs index 1062f6de..107426e3 100644 --- a/src/util/bip32.rs +++ b/src/util/bip32.rs @@ -167,6 +167,21 @@ impl fmt::Display for ChildNumber { } } +impl FromStr for ChildNumber { + type Err = Error; + + fn from_str(inp: &str) -> Result { + Ok(match inp.chars().last().map_or(false, |l| l == '\'' || l == 'h') { + true => ChildNumber::from_hardened_idx( + inp[0..inp.len() - 1].parse().map_err(|_| Error::InvalidChildNumberFormat)? + )?, + false => ChildNumber::from_normal_idx( + inp.parse().map_err(|_| Error::InvalidChildNumberFormat)? + )?, + }) + } +} + #[cfg(feature = "serde")] impl<'de> serde::Deserialize<'de> for ChildNumber { fn deserialize(deserializer: D) -> Result @@ -187,6 +202,92 @@ impl serde::Serialize for ChildNumber { } } +/// A BIP-32 derivation path. +#[derive(Clone, PartialEq, Eq)] +pub struct DerivationPath(pub Vec); + +impl From> for DerivationPath { + fn from(numbers: Vec) -> Self { + DerivationPath(numbers) + } +} + +impl Into> for DerivationPath { + fn into(self) -> Vec { + self.0 + } +} + +impl FromStr for DerivationPath { + type Err = Error; + + fn from_str(path: &str) -> Result { + let mut parts = path.split("/"); + // First parts must be `m`. + if parts.next().unwrap() != "m" { + return Err(Error::InvalidDerivationPathFormat); + } + + let ret: Result, Error> = parts.map(str::parse).collect(); + Ok(DerivationPath(ret?)) + } +} + +impl fmt::Display for DerivationPath { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("m")?; + for cn in self.0.iter() { + f.write_str("/")?; + fmt::Display::fmt(cn, f)?; + } + Ok(()) + } +} + +impl fmt::Debug for DerivationPath { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self, f) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for DerivationPath { + fn deserialize>(deserializer: D) -> Result { + use std::fmt; + use serde::de; + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = DerivationPath; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a Bitcoin address") + } + + fn visit_str(self, v: &str) -> Result { + DerivationPath::from_str(v).map_err(E::custom) + } + + fn visit_borrowed_str(self, v: &'de str) -> Result { + self.visit_str(v) + } + + fn visit_string(self, v: String) -> Result { + self.visit_str(&v) + } + } + + deserializer.deserialize_str(Visitor) + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for DerivationPath { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.to_string()) + } +} + /// A BIP32 error #[derive(Clone, PartialEq, Eq, Debug)] pub enum Error { @@ -197,7 +298,11 @@ pub enum Error { /// A child number was provided that was out of range InvalidChildNumber(u32), /// Error creating a master seed --- for application use - RngError(String) + RngError(String), + /// Invalid childnumber format. + InvalidChildNumberFormat, + /// Invalid derivation path format. + InvalidDerivationPathFormat, } impl fmt::Display for Error { @@ -207,6 +312,8 @@ impl fmt::Display for Error { Error::Ecdsa(ref e) => fmt::Display::fmt(e, f), Error::InvalidChildNumber(ref n) => write!(f, "child number {} is invalid (not within [0, 2^31 - 1])", n), Error::RngError(ref s) => write!(f, "rng error {}", s), + Error::InvalidChildNumberFormat => f.write_str("invalid child number format"), + Error::InvalidDerivationPathFormat => f.write_str("invalid derivation path format"), } } } @@ -225,7 +332,9 @@ impl error::Error for Error { Error::CannotDeriveFromHardenedKey => "cannot derive hardened key from public key", Error::Ecdsa(ref e) => error::Error::description(e), Error::InvalidChildNumber(_) => "child number is invalid", - Error::RngError(_) => "rng error" + Error::RngError(_) => "rng error", + Error::InvalidChildNumberFormat => "invalid child number format", + Error::InvalidDerivationPathFormat => "invalid derivation path format", } } } @@ -497,10 +606,57 @@ mod tests { use network::constants::Network::{self, Bitcoin}; - use super::{ChildNumber, ExtendedPrivKey, ExtendedPubKey}; + use super::{ChildNumber, DerivationPath, ExtendedPrivKey, ExtendedPubKey}; use super::ChildNumber::{Hardened, Normal}; use super::Error; + #[test] + fn test_parse_derivation_path() { + assert_eq!(DerivationPath::from_str("42"), Err(Error::InvalidDerivationPathFormat)); + assert_eq!(DerivationPath::from_str("n/0'/0"), Err(Error::InvalidDerivationPathFormat)); + assert_eq!(DerivationPath::from_str("4/m/5"), Err(Error::InvalidDerivationPathFormat)); + assert_eq!(DerivationPath::from_str("m//3/0'"), Err(Error::InvalidChildNumberFormat)); + assert_eq!(DerivationPath::from_str("m/0h/0x"), Err(Error::InvalidChildNumberFormat)); + assert_eq!(DerivationPath::from_str("m/2147483648"), Err(Error::InvalidChildNumber(2147483648))); + + assert_eq!(DerivationPath::from_str("m"), Ok(vec![].into())); + assert_eq!( + DerivationPath::from_str("m/0'"), + Ok(vec![ChildNumber::from_hardened_idx(0).unwrap()].into()) + ); + assert_eq!( + DerivationPath::from_str("m/0'/1"), + Ok(vec![ChildNumber::from_hardened_idx(0).unwrap(), ChildNumber::from_normal_idx(1).unwrap()].into()) + ); + assert_eq!( + DerivationPath::from_str("m/0h/1/2'"), + Ok(vec![ + ChildNumber::from_hardened_idx(0).unwrap(), + ChildNumber::from_normal_idx(1).unwrap(), + ChildNumber::from_hardened_idx(2).unwrap(), + ].into()) + ); + assert_eq!( + DerivationPath::from_str("m/0'/1/2h/2"), + Ok(vec![ + ChildNumber::from_hardened_idx(0).unwrap(), + ChildNumber::from_normal_idx(1).unwrap(), + ChildNumber::from_hardened_idx(2).unwrap(), + ChildNumber::from_normal_idx(2).unwrap(), + ].into()) + ); + assert_eq!( + DerivationPath::from_str("m/0'/1/2'/2/1000000000"), + Ok(vec![ + ChildNumber::from_hardened_idx(0).unwrap(), + ChildNumber::from_normal_idx(1).unwrap(), + ChildNumber::from_hardened_idx(2).unwrap(), + ChildNumber::from_normal_idx(2).unwrap(), + ChildNumber::from_normal_idx(1000000000).unwrap(), + ].into()) + ); + } + fn test_path(secp: &Secp256k1, network: Network, seed: &[u8],