keyfork/crates/derive/keyfork-derive-util/src/index.rs

171 lines
5.0 KiB
Rust

use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Errors associated with creating a [`DerivationIndex`].
#[derive(Error, Debug)]
pub enum Error {
/// The index was too large and should be less than 2^31.
#[error("Index is too large, must be less than 0x80000000: {0}")]
IndexTooLarge(u32),
/// An integer could not be parsed from the string.
#[error("Unable to parse integer for index")]
IntParseError(#[from] std::num::ParseIntError),
}
type Result<T, E = Error> = std::result::Result<T, E>;
/// Index for a given extended private key.
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct DerivationIndex(pub(crate) u32);
impl DerivationIndex {
/// Creates a new [`DerivationIndex`].
///
/// # Errors
/// Returns an error if the index is larger than the hardened flag.
///
/// # Examples
/// ```rust
/// # use keyfork_derive_util::*;
/// let bip44 = DerivationIndex::new(44, true).unwrap();
/// ```
///
/// Using a derivation index that is higher than 2^31 returns an error:
///
/// ```rust,should_panic
/// # use keyfork_derive_util::*;
/// let too_high = DerivationIndex::new(u32::MAX, true).unwrap();
/// ```
pub const fn new(index: u32, hardened: bool) -> Result<Self> {
if index & (0b1 << 31) > 0 {
return Err(Error::IndexTooLarge(index));
}
Ok(Self(index | ((hardened as u32) << 31)))
}
#[doc(hidden)]
pub const fn new_unchecked(index: u32, hardened: bool) -> Self {
Self(index | ((hardened as u32) << 31))
}
/*
* Probably never used.
pub(crate) fn from_bytes(bytes: [u8; 4]) -> Self {
Self(u32::from_be_bytes(bytes))
}
*/
/// Return the internal derivation index. Note that if the derivation index is hardened, the
/// highest bit will be set, and the value can't be used to create a new derivation index.
///
/// # Examples
/// ```rust
/// # use keyfork_derive_util::*;
/// assert_eq!(DerivationIndex::new(44, true).unwrap().inner(), 2147483692);
/// assert_eq!(DerivationIndex::new(200, false).unwrap().inner(), 200);
/// ```
pub fn inner(&self) -> u32 {
self.0
}
pub(crate) fn to_bytes(&self) -> [u8; 4] {
self.0.to_be_bytes()
}
/// Whether or not the index is hardened, allowing deriving the key from a known parent public
/// key.
///
/// # Examples
/// ```rust
/// # use keyfork_derive_util::*;
/// assert_eq!(DerivationIndex::new(0, true).unwrap().is_hardened(), true);
/// assert_eq!(DerivationIndex::new(0, false).unwrap().is_hardened(), false);
/// ```
pub fn is_hardened(&self) -> bool {
self.0 & (0b1 << 31) != 0
}
}
impl std::str::FromStr for DerivationIndex {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
// Returns &str without suffix if suffix is found
let (s, is_hardened) = match s.strip_suffix('\'') {
Some(subslice) => (subslice, true),
None => (s, false),
};
let index: u32 = s.parse()?;
Self::new(index, is_hardened)
}
}
impl std::fmt::Display for DerivationIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0 & (u32::MAX >> 1))?;
if self.0 & (0b1 << 31) != 0 {
write!(f, "'")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
#[should_panic]
fn fails_on_high_index() {
DerivationIndex::new(0x8000_0001, false).unwrap();
}
#[test]
fn has_hardened_bit() {
assert_eq!(DerivationIndex::new(0x0, true).unwrap().0, 0b1 << 31);
}
#[test]
fn misc_values() -> Result<()> {
assert_eq!(DerivationIndex::new(0x8000_0000 - 1, true)?.0, u32::MAX);
assert_eq!(DerivationIndex::new(0x2, false)?.0, 2);
assert_eq!(DerivationIndex::new(0x00AB_CDEF, true)?.0, 0x80AB_CDEF);
assert_eq!(DerivationIndex::new(0x00AB_CDEF, false)?.0, 0x00AB_CDEF);
Ok(())
}
#[test]
fn from_str() -> Result<()> {
assert_eq!(DerivationIndex::from_str("100000")?.0, 100_000);
assert_eq!(
DerivationIndex::from_str("100000'")?.0,
(0b1 << 31) + 100_000
);
Ok(())
}
#[test]
fn display() -> Result<()> {
assert_eq!(&DerivationIndex::new(3232, false)?.to_string(), "3232");
assert_eq!(&DerivationIndex::new(3232, true)?.to_string(), "3232'");
Ok(())
}
#[test]
fn equivalency() -> Result<()> {
let values = ["123456'", "123456", "1726562", "0'", "0"];
for value in values {
assert_eq!(value, DerivationIndex::from_str(value)?.to_string());
}
Ok(())
}
#[test]
#[should_panic]
fn from_str_fails_on_high_index() {
DerivationIndex::from_str(&0x8000_0001u32.to_string()).unwrap();
}
}