176 lines
5.1 KiB
Rust
176 lines
5.1 KiB
Rust
use crate::index::DerivationIndex;
|
|
use serde::{Deserialize, Serialize};
|
|
use thiserror::Error;
|
|
|
|
/// Errors associated with creating a [`DerivationPath`].
|
|
#[derive(Error, Debug)]
|
|
pub enum Error {
|
|
/// A [`DerivationIndex`] was not able to be created.
|
|
#[error("Unable to create index: {0}")]
|
|
UnableToCreateIndex(#[from] super::index::Error),
|
|
|
|
/// The path could not be parsed due to a bad prefix. Paths must be in the format:
|
|
///
|
|
/// m [/ index [']]+
|
|
///
|
|
/// The prefix for the path must be `m`, and all indices must be integers between 0 and
|
|
/// 2^31.
|
|
#[error("Unable to parse path due to bad path prefix")]
|
|
UnknownPathPrefix,
|
|
}
|
|
|
|
type Result<T, E = Error> = std::result::Result<T, E>;
|
|
|
|
const PREFIX: &str = "m";
|
|
|
|
/// A fully qualified path to derive a key.
|
|
#[derive(Serialize, Deserialize, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
|
|
pub struct DerivationPath {
|
|
pub(crate) path: Vec<DerivationIndex>,
|
|
}
|
|
|
|
impl DerivationPath {
|
|
/// Returns an iterator over the [`DerivationPath`].
|
|
pub fn iter(&self) -> impl Iterator<Item = &DerivationIndex> {
|
|
self.path.iter()
|
|
}
|
|
|
|
/// The amount of segments in the DerivationPath. For consistency, a [`usize`] is returned, but
|
|
/// BIP-0032 dictates that the depth should be no larger than `255`, [`u8::MAX`].
|
|
pub fn len(&self) -> usize {
|
|
self.path.len()
|
|
}
|
|
|
|
/// Returns true if there are no path segments.
|
|
pub fn is_empty(&self) -> bool {
|
|
self.path.is_empty()
|
|
}
|
|
|
|
/// Append an index to the path.
|
|
pub fn push(&mut self, index: DerivationIndex) {
|
|
self.path.push(index);
|
|
}
|
|
|
|
/// Return the inner path.
|
|
pub fn inner(&self) -> &Vec<DerivationIndex> {
|
|
&self.path
|
|
}
|
|
|
|
/// Append an index to the path, returning self to allow chaining method calls.
|
|
///
|
|
/// # Examples
|
|
/// ```rust
|
|
/// # use keyfork_derive_util::*;
|
|
/// # fn discover_wallet(_p: DerivationPath) -> Result<bool, std::io::Error> { Ok(true) }
|
|
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
/// let account = 0;
|
|
/// let path = DerivationPath::default()
|
|
/// .chain_push(DerivationIndex::new(44, true)?)
|
|
/// .chain_push(DerivationIndex::new(0, true)?)
|
|
/// .chain_push(DerivationIndex::new(account, true)?);
|
|
/// let mut has_wallet = false;
|
|
/// for index in (0..20).map(|i| DerivationIndex::new(i, true).unwrap()) {
|
|
/// has_wallet = has_wallet || discover_wallet(path.clone().chain_push(index))?;
|
|
/// }
|
|
/// # Ok(())
|
|
/// # }
|
|
/// ```
|
|
pub fn chain_push(mut self, index: DerivationIndex) -> Self {
|
|
self.path.push(index);
|
|
self
|
|
}
|
|
}
|
|
|
|
impl std::str::FromStr for DerivationPath {
|
|
type Err = Error;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
let mut iter = s.split('/');
|
|
if iter.next() != Some(PREFIX) {
|
|
return Err(Error::UnknownPathPrefix);
|
|
}
|
|
Ok(Self {
|
|
path: iter
|
|
.map(DerivationIndex::from_str)
|
|
.map(|maybe_err| maybe_err.map_err(From::from))
|
|
.collect::<Result<_>>()?,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for DerivationPath {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{PREFIX}")?;
|
|
for index in self.iter() {
|
|
write!(f, "/{index}")?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl std::ops::Add<DerivationIndex> for DerivationPath {
|
|
type Output = DerivationPath;
|
|
|
|
fn add(self, rhs: DerivationIndex) -> Self::Output {
|
|
let mut output = self.clone();
|
|
output.path.push(rhs);
|
|
output
|
|
}
|
|
}
|
|
|
|
impl std::ops::Add<&[DerivationIndex]> for DerivationPath {
|
|
type Output = DerivationPath;
|
|
|
|
fn add(self, rhs: &[DerivationIndex]) -> Self::Output {
|
|
let mut output = self.clone();
|
|
output.path.extend(rhs.iter().cloned());
|
|
output
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::str::FromStr;
|
|
|
|
#[test]
|
|
#[should_panic]
|
|
fn requires_master_path() {
|
|
DerivationPath::from_str("1234/5678'").unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn equivalency() -> Result<()> {
|
|
let paths = ["m/1234'/5678", "m/44'/0'/0'", "m"];
|
|
for path in paths {
|
|
assert_eq!(&DerivationPath::from_str(path)?.to_string(), path);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn add() -> Result<(), Box<dyn std::error::Error>> {
|
|
let path = DerivationPath::from_str("m")?;
|
|
let path = path + DerivationIndex::new(72, true)?;
|
|
let path = path + DerivationIndex::new(47, false)?;
|
|
let path = path + DerivationIndex::new((i32::MAX) as u32, false)?;
|
|
assert_eq!(path, DerivationPath::from_str("m/72'/47/2147483647")?);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn add_vec() -> Result<(), Box<dyn std::error::Error>> {
|
|
let path = DerivationPath::from_str("m")?;
|
|
let other_path = [
|
|
DerivationIndex::new(72, true)?,
|
|
DerivationIndex::new(47, false)?,
|
|
DerivationIndex::new((i32::MAX) as u32, false)?,
|
|
];
|
|
let path = path + &other_path[..];
|
|
assert_eq!(path, DerivationPath::from_str("m/72'/47/2147483647")?);
|
|
|
|
Ok(())
|
|
}
|
|
}
|