keyfork/keyfork-derive-util/src/path.rs

139 lines
3.7 KiB
Rust

use crate::index::DerivationIndex;
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Errors associated with creating a [`DerivationPath`].
#[derive(Error, Debug)]
pub enum Error {
/// A [`DerivationIndex`] was not able to be created.
#[error("Unable to create index: {0}")]
UnableToCreateIndex(#[from] super::index::Error),
/// The path could not be parsed due to a bad prefix. Paths must be in the format:
///
/// m [/ index [']]+
///
/// The prefix for the path must be `m`, and all indices must be integers between 0 and
/// 2^31.
#[error("Unable to parse path due to bad path prefix")]
UnknownPathPrefix,
}
type Result<T, E = Error> = std::result::Result<T, E>;
const PREFIX: &str = "m";
/// A fully qualified path to derive a key.
#[derive(Serialize, Deserialize, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
pub struct DerivationPath {
pub(crate) path: Vec<DerivationIndex>,
}
impl DerivationPath {
/// Returns an iterator over the [`DerivationPath`].
pub fn iter(&self) -> impl Iterator<Item = &DerivationIndex> {
self.path.iter()
}
pub fn len(&self) -> usize {
self.path.len()
}
pub fn is_empty(&self) -> bool {
self.path.is_empty()
}
}
impl std::str::FromStr for DerivationPath {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut iter = s.split('/');
if iter.next() != Some(PREFIX) {
return Err(Error::UnknownPathPrefix);
}
Ok(Self {
path: iter
.map(DerivationIndex::from_str)
.map(|maybe_err| maybe_err.map_err(From::from))
.collect::<Result<_>>()?,
})
}
}
impl std::fmt::Display for DerivationPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{PREFIX}")?;
for index in self.iter() {
write!(f, "/{index}")?;
}
Ok(())
}
}
impl std::ops::Add<DerivationIndex> for DerivationPath {
type Output = DerivationPath;
fn add(self, rhs: DerivationIndex) -> Self::Output {
let mut output = self.clone();
output.path.push(rhs);
output
}
}
impl std::ops::Add<&[DerivationIndex]> for DerivationPath {
type Output = DerivationPath;
fn add(self, rhs: &[DerivationIndex]) -> Self::Output {
let mut output = self.clone();
output.path.extend(rhs.iter().cloned());
output
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
#[should_panic]
fn requires_master_path() {
DerivationPath::from_str("1234/5678'").unwrap();
}
#[test]
fn equivalency() -> Result<()> {
let paths = ["m/1234'/5678", "m/44'/0'/0'", "m"];
for path in paths {
assert_eq!(&DerivationPath::from_str(path)?.to_string(), path);
}
Ok(())
}
#[test]
fn add() -> Result<(), Box<dyn std::error::Error>> {
let path = DerivationPath::from_str("m")?;
let path = path + DerivationIndex::new(72, true)?;
let path = path + DerivationIndex::new(47, false)?;
let path = path + DerivationIndex::new((i32::MAX) as u32, false)?;
assert_eq!(path, DerivationPath::from_str("m/72'/47/2147483647")?);
Ok(())
}
#[test]
fn add_vec() -> Result<(), Box<dyn std::error::Error>> {
let path = DerivationPath::from_str("m")?;
let other_path = [
DerivationIndex::new(72, true)?,
DerivationIndex::new(47, false)?,
DerivationIndex::new((i32::MAX) as u32, false)?,
];
let path = path + &other_path[..];
assert_eq!(path, DerivationPath::from_str("m/72'/47/2147483647")?);
Ok(())
}
}