diff --git a/src/util/bip32.rs b/src/util/bip32.rs index 7566be89..34c2067f 100644 --- a/src/util/bip32.rs +++ b/src/util/bip32.rs @@ -137,6 +137,13 @@ impl ChildNumber { } } + /// Returns the child number that is a single increment from this one. + pub fn increment(self) -> Result { + match self { + ChildNumber::Normal{ index: idx } => ChildNumber::from_normal_idx(idx+1), + ChildNumber::Hardened{ index: idx } => ChildNumber::from_hardened_idx(idx+1), + } + } } impl From for ChildNumber { @@ -260,20 +267,69 @@ impl FromStr for DerivationPath { } } +/// An iterator over children of a [DerivationPath]. +/// +/// It is returned by the methods [DerivationPath::children_since], +/// [DerivationPath::normal_children] and [DerivationPath::hardened_children]. +pub struct DerivationPathIterator<'a> { + base: &'a DerivationPath, + next_child: Option, +} + +impl<'a> DerivationPathIterator<'a> { + /// Start a new [DerivationPathIterator] at the given child. + pub fn start_from(path: &'a DerivationPath, start: ChildNumber) -> DerivationPathIterator<'a> { + DerivationPathIterator { + base: path, + next_child: Some(start), + } + } +} + +impl<'a> Iterator for DerivationPathIterator<'a> { + type Item = DerivationPath; + + fn next(&mut self) -> Option { + if self.next_child.is_none() { + return None; + } + + let ret = self.next_child.unwrap(); + self.next_child = ret.increment().ok(); + Some(self.base.child(ret)) + } +} + impl DerivationPath { - /// Create a new DerivationPath that is a child of this one. + /// Create a new [DerivationPath] that is a child of this one. pub fn child(&self, cn: ChildNumber) -> DerivationPath { let mut path = self.0.clone(); path.push(cn); DerivationPath(path) } - /// Convert into a DerivationPath that is a child of this one. + /// Convert into a [DerivationPath] that is a child of this one. pub fn into_child(self, cn: ChildNumber) -> DerivationPath { let mut path = self.0; path.push(cn); DerivationPath(path) } + + /// Get an [Iterator] over the children of this [DerivationPath] + /// starting with the given [ChildNumber]. + pub fn children_from(&self, cn: ChildNumber) -> DerivationPathIterator { + DerivationPathIterator::start_from(&self, cn) + } + + /// Get an [Iterator] over the unhardened children of this [DerivationPath]. + pub fn normal_children(&self) -> DerivationPathIterator { + DerivationPathIterator::start_from(&self, ChildNumber::Normal{ index: 0 }) + } + + /// Get an [Iterator] over the hardened children of this [DerivationPath]. + pub fn hardened_children(&self) -> DerivationPathIterator { + DerivationPathIterator::start_from(&self, ChildNumber::Hardened{ index: 0 }) + } } impl fmt::Display for DerivationPath { @@ -799,6 +855,49 @@ mod tests { assert_eq!(Ok(pk), decoded_pk); } + #[test] + fn test_increment() { + let idx = 9345497; // randomly generated, I promise + let cn = ChildNumber::from_normal_idx(idx).unwrap(); + assert_eq!(cn.increment().ok(), Some(ChildNumber::from_normal_idx(idx+1).unwrap())); + let cn = ChildNumber::from_hardened_idx(idx).unwrap(); + assert_eq!(cn.increment().ok(), Some(ChildNumber::from_hardened_idx(idx+1).unwrap())); + + let max = (1<<31)-1; + let cn = ChildNumber::from_normal_idx(max).unwrap(); + assert_eq!(cn.increment().err(), Some(Error::InvalidChildNumber(1<<31))); + let cn = ChildNumber::from_hardened_idx(max).unwrap(); + assert_eq!(cn.increment().err(), Some(Error::InvalidChildNumber(1<<31))); + + let cn = ChildNumber::from_normal_idx(350).unwrap(); + let path = DerivationPath::from_str("m/42'").unwrap(); + let mut iter = path.children_from(cn); + assert_eq!(iter.next(), Some("m/42'/350".parse().unwrap())); + assert_eq!(iter.next(), Some("m/42'/351".parse().unwrap())); + + let path = DerivationPath::from_str("m/42'/350'").unwrap(); + let mut iter = path.normal_children(); + assert_eq!(iter.next(), Some("m/42'/350'/0".parse().unwrap())); + assert_eq!(iter.next(), Some("m/42'/350'/1".parse().unwrap())); + + let path = DerivationPath::from_str("m/42'/350'").unwrap(); + let mut iter = path.hardened_children(); + assert_eq!(iter.next(), Some("m/42'/350'/0'".parse().unwrap())); + assert_eq!(iter.next(), Some("m/42'/350'/1'".parse().unwrap())); + + let cn = ChildNumber::from_hardened_idx(42350).unwrap(); + let path = DerivationPath::from_str("m/42'").unwrap(); + let mut iter = path.children_from(cn); + assert_eq!(iter.next(), Some("m/42'/42350'".parse().unwrap())); + assert_eq!(iter.next(), Some("m/42'/42351'".parse().unwrap())); + + let cn = ChildNumber::from_hardened_idx(max).unwrap(); + let path = DerivationPath::from_str("m/42'").unwrap(); + let mut iter = path.children_from(cn); + assert!(iter.next().is_some()); + assert!(iter.next().is_none()); + } + #[test] fn test_vector_1() { let secp = Secp256k1::new();