use crate::index::DerivationIndex; use serde::{Deserialize, Serialize}; use thiserror::Error; /// Errors associated with creating a [`DerivationPath`]. #[derive(Error, Debug)] pub enum Error { /// A [`DerivationIndex`] was not able to be created. #[error("Unable to create index: {0}")] UnableToCreateIndex(#[from] super::index::Error), /// The path could not be parsed due to a bad prefix. Paths must be in the format: /// /// m [/ index [']]+ /// /// The prefix for the path must be `m`, and all indices must be integers between 0 and /// 2^31. #[error("Unable to parse path due to bad path prefix")] UnknownPathPrefix, } type Result = std::result::Result; const PREFIX: &str = "m"; /// A fully qualified path to derive a key. #[derive(Serialize, Deserialize, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] pub struct DerivationPath { pub(crate) path: Vec, } impl DerivationPath { /// Returns an iterator over the [`DerivationPath`]. pub fn iter(&self) -> impl Iterator { self.path.iter() } pub fn len(&self) -> usize { self.path.len() } pub fn is_empty(&self) -> bool { self.path.is_empty() } pub fn push(&mut self, index: DerivationIndex) { self.path.push(index); } pub fn chain_push(mut self, index: DerivationIndex) -> Self { self.path.push(index); self } } impl std::str::FromStr for DerivationPath { type Err = Error; fn from_str(s: &str) -> Result { let mut iter = s.split('/'); if iter.next() != Some(PREFIX) { return Err(Error::UnknownPathPrefix); } Ok(Self { path: iter .map(DerivationIndex::from_str) .map(|maybe_err| maybe_err.map_err(From::from)) .collect::>()?, }) } } impl std::fmt::Display for DerivationPath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{PREFIX}")?; for index in self.iter() { write!(f, "/{index}")?; } Ok(()) } } impl std::ops::Add for DerivationPath { type Output = DerivationPath; fn add(self, rhs: DerivationIndex) -> Self::Output { let mut output = self.clone(); output.path.push(rhs); output } } impl std::ops::Add<&[DerivationIndex]> for DerivationPath { type Output = DerivationPath; fn add(self, rhs: &[DerivationIndex]) -> Self::Output { let mut output = self.clone(); output.path.extend(rhs.iter().cloned()); output } } #[cfg(test)] mod tests { use super::*; use std::str::FromStr; #[test] #[should_panic] fn requires_master_path() { DerivationPath::from_str("1234/5678'").unwrap(); } #[test] fn equivalency() -> Result<()> { let paths = ["m/1234'/5678", "m/44'/0'/0'", "m"]; for path in paths { assert_eq!(&DerivationPath::from_str(path)?.to_string(), path); } Ok(()) } #[test] fn add() -> Result<(), Box> { let path = DerivationPath::from_str("m")?; let path = path + DerivationIndex::new(72, true)?; let path = path + DerivationIndex::new(47, false)?; let path = path + DerivationIndex::new((i32::MAX) as u32, false)?; assert_eq!(path, DerivationPath::from_str("m/72'/47/2147483647")?); Ok(()) } #[test] fn add_vec() -> Result<(), Box> { let path = DerivationPath::from_str("m")?; let other_path = [ DerivationIndex::new(72, true)?, DerivationIndex::new(47, false)?, DerivationIndex::new((i32::MAX) as u32, false)?, ]; let path = path + &other_path[..]; assert_eq!(path, DerivationPath::from_str("m/72'/47/2147483647")?); Ok(()) } }