use serde::{Deserialize, Serialize}; use thiserror::Error; /// Errors associated with creating a [`DerivationIndex`]. #[derive(Error, Debug)] pub enum Error { /// The index was too large and should be less than 2^31. #[error("Index is too large, must be less than 0x80000000: {0}")] IndexTooLarge(u32), /// An integer could not be parsed from the string. #[error("Unable to parse integer for index")] IntParseError(#[from] std::num::ParseIntError), } type Result = std::result::Result; /// Index for a given extended private key. #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct DerivationIndex(pub(crate) u32); impl DerivationIndex { /// Creates a new [`DerivationIndex`]. /// /// # Errors /// /// Returns an error if the index is larger than the hardened flag. pub const fn new(index: u32, hardened: bool) -> Result { if index & (0b1 << 31) > 0 { return Err(Error::IndexTooLarge(index)); } Ok(Self(index | ((hardened as u32) << 31))) } #[doc(hidden)] pub const fn new_unchecked(index: u32, hardened: bool) -> Self { Self(index | ((hardened as u32) << 31)) } /* * Probably never used. pub(crate) fn from_bytes(bytes: [u8; 4]) -> Self { Self(u32::from_be_bytes(bytes)) } */ /// Return the internal derivation index. Note that if the derivation index is hardened, the /// highest bit will be set, and the value can't be used to create a new derivation index. pub fn inner(&self) -> u32 { self.0 } pub(crate) fn to_bytes(&self) -> [u8; 4] { self.0.to_be_bytes() } /// Whether or not the index is hardened, allowing deriving the key from a known parent key. pub fn is_hardened(&self) -> bool { self.0 & (0b1 << 31) != 0 } } impl std::str::FromStr for DerivationIndex { type Err = Error; fn from_str(s: &str) -> std::result::Result { // Returns &str without suffix if suffix is found let (s, is_hardened) = match s.strip_suffix('\'') { Some(subslice) => (subslice, true), None => (s, false), }; let index: u32 = s.parse()?; Self::new(index, is_hardened) } } impl std::fmt::Display for DerivationIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0 & (u32::MAX >> 1))?; if self.0 & (0b1 << 31) != 0 { write!(f, "'")?; } Ok(()) } } #[cfg(test)] mod tests { use super::*; use std::str::FromStr; #[test] #[should_panic] fn fails_on_high_index() { DerivationIndex::new(0x8000_0001, false).unwrap(); } #[test] fn has_hardened_bit() { assert_eq!(DerivationIndex::new(0x0, true).unwrap().0, 0b1 << 31); } #[test] fn misc_values() -> Result<()> { assert_eq!(DerivationIndex::new(0x8000_0000 - 1, true)?.0, u32::MAX); assert_eq!(DerivationIndex::new(0x2, false)?.0, 2); assert_eq!(DerivationIndex::new(0x00AB_CDEF, true)?.0, 0x80AB_CDEF); assert_eq!(DerivationIndex::new(0x00AB_CDEF, false)?.0, 0x00AB_CDEF); Ok(()) } #[test] fn from_str() -> Result<()> { assert_eq!(DerivationIndex::from_str("100000")?.0, 100_000); assert_eq!( DerivationIndex::from_str("100000'")?.0, (0b1 << 31) + 100_000 ); Ok(()) } #[test] fn display() -> Result<()> { assert_eq!(&DerivationIndex::new(3232, false)?.to_string(), "3232"); assert_eq!(&DerivationIndex::new(3232, true)?.to_string(), "3232'"); Ok(()) } #[test] fn equivalency() -> Result<()> { let values = ["123456'", "123456", "1726562", "0'", "0"]; for value in values { assert_eq!(value, DerivationIndex::from_str(value)?.to_string()); } Ok(()) } #[test] #[should_panic] fn from_str_fails_on_high_index() { DerivationIndex::from_str(&0x8000_0001u32.to_string()).unwrap(); } }