diff --git a/bitcoin/src/bip32.rs b/bitcoin/src/bip32.rs index 5dd747fbf..1a32428f4 100644 --- a/bitcoin/src/bip32.rs +++ b/bitcoin/src/bip32.rs @@ -148,11 +148,11 @@ impl ChildNumber { /// [0, 2^31 - 1]. /// /// [`Normal`]: #variant.Normal - pub fn from_normal_idx(index: u32) -> Result { + pub fn from_normal_idx(index: u32) -> Result { if index & (1 << 31) == 0 { Ok(ChildNumber::Normal { index }) } else { - Err(Error::InvalidChildNumber(index)) + Err(IndexOutOfRangeError { index }) } } @@ -160,11 +160,11 @@ impl ChildNumber { /// [0, 2^31 - 1]. /// /// [`Hardened`]: #variant.Hardened - pub fn from_hardened_idx(index: u32) -> Result { + pub fn from_hardened_idx(index: u32) -> Result { if index & (1 << 31) == 0 { Ok(ChildNumber::Hardened { index }) } else { - Err(Error::InvalidChildNumber(index)) + Err(IndexOutOfRangeError { index }) } } @@ -184,7 +184,7 @@ impl ChildNumber { } /// Returns the child number that is a single increment from this one. - pub fn increment(self) -> Result { + pub fn increment(self) -> Result { match self { ChildNumber::Normal { index: idx } => ChildNumber::from_normal_idx(idx + 1), ChildNumber::Hardened { index: idx } => ChildNumber::from_hardened_idx(idx + 1), @@ -225,16 +225,18 @@ impl fmt::Display for ChildNumber { } impl FromStr for ChildNumber { - type Err = Error; + type Err = ParseChildNumberError; - fn from_str(inp: &str) -> Result { + fn from_str(inp: &str) -> Result { let is_hardened = inp.chars().last().map_or(false, |l| l == '\'' || l == 'h'); Ok(if is_hardened { ChildNumber::from_hardened_idx( - inp[0..inp.len() - 1].parse().map_err(|_| Error::InvalidChildNumberFormat)?, - )? + inp[0..inp.len() - 1].parse().map_err(ParseChildNumberError::ParseInt)?, + ) + .map_err(ParseChildNumberError::IndexOutOfRange)? } else { - ChildNumber::from_normal_idx(inp.parse().map_err(|_| Error::InvalidChildNumberFormat)?)? + ChildNumber::from_normal_idx(inp.parse().map_err(ParseChildNumberError::ParseInt)?) + .map_err(ParseChildNumberError::IndexOutOfRange)? }) } } @@ -267,7 +269,7 @@ impl serde::Serialize for ChildNumber { /// derivation path pub trait IntoDerivationPath { /// Converts a given type into a [`DerivationPath`] with possible error - fn into_derivation_path(self) -> Result; + fn into_derivation_path(self) -> Result; } /// A BIP-32 derivation path. @@ -295,15 +297,17 @@ impl IntoDerivationPath for T where T: Into, { - fn into_derivation_path(self) -> Result { Ok(self.into()) } + fn into_derivation_path(self) -> Result { + Ok(self.into()) + } } impl IntoDerivationPath for String { - fn into_derivation_path(self) -> Result { self.parse() } + fn into_derivation_path(self) -> Result { self.parse() } } impl IntoDerivationPath for &'_ str { - fn into_derivation_path(self) -> Result { self.parse() } + fn into_derivation_path(self) -> Result { self.parse() } } impl From> for DerivationPath { @@ -338,9 +342,9 @@ impl AsRef<[ChildNumber]> for DerivationPath { } impl FromStr for DerivationPath { - type Err = Error; + type Err = ParseChildNumberError; - fn from_str(path: &str) -> Result { + fn from_str(path: &str) -> Result { if path.is_empty() || path == "m" || path == "m/" { return Ok(vec![].into()); } @@ -348,7 +352,7 @@ impl FromStr for DerivationPath { let path = path.strip_prefix("m/").unwrap_or(path); let parts = path.split('/'); - let ret: Result, Error> = parts.map(str::parse).collect(); + let ret: Result, _> = parts.map(str::parse).collect(); Ok(DerivationPath(ret?)) } } @@ -499,10 +503,6 @@ pub type KeySource = (Fingerprint, DerivationPath); pub enum Error { /// A secp256k1 error occurred Secp256k1(secp256k1::Error), - /// A child number was provided that was out of range - InvalidChildNumber(u32), - /// Invalid childnumber format. - InvalidChildNumberFormat, /// Unknown version magic bytes UnknownVersion([u8; 4]), /// Encoded extended key data has wrong length @@ -529,9 +529,6 @@ impl fmt::Display for Error { match *self { Secp256k1(ref e) => write_err!(f, "secp256k1 error"; e), - InvalidChildNumber(ref n) => - write!(f, "child number {} is invalid (not within [0, 2^31 - 1])", n), - InvalidChildNumberFormat => f.write_str("invalid child number format"), UnknownVersion(ref bytes) => write!(f, "unknown version magic bytes: {:?}", bytes), WrongExtendedKeyLength(ref len) => write!(f, "encoded extended key data has wrong length {}", len), @@ -555,10 +552,7 @@ impl std::error::Error for Error { Secp256k1(ref e) => Some(e), Base58(ref e) => Some(e), InvalidBase58PayloadLength(ref e) => Some(e), - InvalidChildNumber(_) - | InvalidChildNumberFormat - | UnknownVersion(_) - | WrongExtendedKeyLength(_) => None, + UnknownVersion(_) | WrongExtendedKeyLength(_) => None, InvalidPrivateKeyPrefix => None, NonZeroParentFingerprintForMasterKey => None, NonZeroChildNumberForMasterKey => None, @@ -605,6 +599,62 @@ impl fmt::Display for DerivationError { } } +/// Out-of-range index when constructing a child number. +/// +/// *Indices* are always in the range [0, 2^31 - 1]. Normal child numbers have the +/// same range, while hardened child numbers lie in the range [2^31, 2^32 - 1]. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub struct IndexOutOfRangeError { + /// The index that was out of range for a child number. + pub index: u32, +} + +#[cfg(feature = "std")] +impl std::error::Error for IndexOutOfRangeError {} + +impl From for IndexOutOfRangeError { + fn from(never: Infallible) -> Self { match never {} } +} + +impl fmt::Display for IndexOutOfRangeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "index {} out of range [0, 2^31 - 1] (do you have an hardened child number, rather than an index?)", self.index) + } +} + +/// Error parsing a child number. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ParseChildNumberError { + /// Parsed the child number as an integer, but the integer was out of range. + IndexOutOfRange(IndexOutOfRangeError), + /// Failed to parse the child number as an integer. + ParseInt(core::num::ParseIntError), +} + +impl From for ParseChildNumberError { + fn from(never: Infallible) -> Self { match never {} } +} + +#[cfg(feature = "std")] +impl std::error::Error for ParseChildNumberError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match *self { + Self::IndexOutOfRange(ref e) => Some(e), + Self::ParseInt(ref e) => Some(e), + } + } +} + +impl fmt::Display for ParseChildNumberError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Self::IndexOutOfRange(ref e) => e.fmt(f), + Self::ParseInt(ref e) => e.fmt(f), + } + } +} + impl Xpriv { /// Constructs a new master key from a seed value pub fn new_master(network: impl Into, seed: &[u8]) -> Xpriv { @@ -1040,13 +1090,25 @@ mod tests { #[test] fn parse_derivation_path() { - assert_eq!("n/0'/0".parse::(), Err(Error::InvalidChildNumberFormat)); - assert_eq!("4/m/5".parse::(), Err(Error::InvalidChildNumberFormat)); - assert_eq!("//3/0'".parse::(), Err(Error::InvalidChildNumberFormat)); - assert_eq!("0h/0x".parse::(), Err(Error::InvalidChildNumberFormat)); + assert!(matches!( + "n/0'/0".parse::(), + Err(ParseChildNumberError::ParseInt(..)), + )); + assert!(matches!( + "4/m/5".parse::(), + Err(ParseChildNumberError::ParseInt(..)), + )); + assert!(matches!( + "//3/0'".parse::(), + Err(ParseChildNumberError::ParseInt(..)), + )); + assert!(matches!( + "0h/0x".parse::(), + Err(ParseChildNumberError::ParseInt(..)), + )); assert_eq!( "2147483648".parse::(), - Err(Error::InvalidChildNumber(2147483648)) + Err(ParseChildNumberError::IndexOutOfRange(IndexOutOfRangeError { index: 2147483648 })), ); assert_eq!(DerivationPath::master(), "".parse::().unwrap()); @@ -1176,9 +1238,9 @@ mod tests { let max = (1 << 31) - 1; let cn = ChildNumber::from_normal_idx(max).unwrap(); - assert_eq!(cn.increment().err(), Some(Error::InvalidChildNumber(1 << 31))); + assert_eq!(cn.increment(), Err(IndexOutOfRangeError { index: 1 << 31 }),); let cn = ChildNumber::from_hardened_idx(max).unwrap(); - assert_eq!(cn.increment().err(), Some(Error::InvalidChildNumber(1 << 31))); + assert_eq!(cn.increment(), Err(IndexOutOfRangeError { index: 1 << 31 }),); let cn = ChildNumber::from_normal_idx(350).unwrap(); let path = "42'".parse::().unwrap();