diff --git a/bitcoin/src/psbt/mod.rs b/bitcoin/src/psbt/mod.rs index da9cd3efd..dd27611af 100644 --- a/bitcoin/src/psbt/mod.rs +++ b/bitcoin/src/psbt/mod.rs @@ -20,7 +20,7 @@ use std::collections::{HashMap, HashSet}; use internals::write_err; use secp256k1::{Keypair, Message, Secp256k1, Signing, Verification}; -use crate::bip32::{self, KeySource, Xpriv, Xpub}; +use crate::bip32::{self, DerivationPath, KeySource, Xpriv, Xpub}; use crate::crypto::key::{PrivateKey, PublicKey}; use crate::crypto::{ecdsa, taproot}; use crate::key::{TapTweak, XOnlyPublicKey}; @@ -365,9 +365,9 @@ impl Psbt { let mut used = vec![]; // List of pubkeys used to sign the input. for (pk, key_source) in input.bip32_derivation.iter() { - let sk = if let Ok(Some(sk)) = k.get_key(KeyRequest::Bip32(key_source.clone()), secp) { + let sk = if let Ok(Some(sk)) = k.get_key(&KeyRequest::Bip32(key_source.clone()), secp) { sk - } else if let Ok(Some(sk)) = k.get_key(KeyRequest::Pubkey(PublicKey::new(*pk)), secp) { + } else if let Ok(Some(sk)) = k.get_key(&KeyRequest::Pubkey(PublicKey::new(*pk)), secp) { sk } else { continue; @@ -419,7 +419,7 @@ impl Psbt { for (&xonly, (leaf_hashes, key_source)) in input.tap_key_origins.iter() { let sk = if let Ok(Some(secret_key)) = - k.get_key(KeyRequest::Bip32(key_source.clone()), secp) + k.get_key(&KeyRequest::Bip32(key_source.clone()), secp) { secret_key } else { @@ -745,7 +745,7 @@ pub trait GetKey { /// - `Err` if an error was encountered while looking for the key. fn get_key( &self, - key_request: KeyRequest, + key_request: &KeyRequest, secp: &Secp256k1, ) -> Result, Self::Error>; } @@ -755,13 +755,20 @@ impl GetKey for Xpriv { fn get_key( &self, - key_request: KeyRequest, + key_request: &KeyRequest, secp: &Secp256k1, ) -> Result, Self::Error> { match key_request { KeyRequest::Pubkey(_) => Err(GetKeyError::NotSupported), KeyRequest::Bip32((fingerprint, path)) => { - let key = if self.fingerprint(secp) == fingerprint { + let key = if self.fingerprint(secp) == *fingerprint { + let k = self.derive_priv(secp, &path); + Some(k.to_priv()) + } else if self.parent_fingerprint == *fingerprint + && !path.is_empty() + && path[0] == self.child_number + { + let path = DerivationPath::from_iter(path.into_iter().skip(1).copied()); let k = self.derive_priv(secp, &path); Some(k.to_priv()) } else { @@ -800,21 +807,14 @@ impl GetKey for $set { fn get_key( &self, - key_request: KeyRequest, + key_request: &KeyRequest, secp: &Secp256k1 ) -> Result, Self::Error> { - match key_request { - KeyRequest::Pubkey(_) => Err(GetKeyError::NotSupported), - KeyRequest::Bip32((fingerprint, path)) => { - for xpriv in self.iter() { - if xpriv.parent_fingerprint == fingerprint { - let k = xpriv.derive_priv(secp, &path); - return Ok(Some(k.to_priv())); - } - } - Ok(None) - } - } + // OK to stop at the first error because Xpriv::get_key() can only fail + // if this isn't a KeyRequest::Bip32, which would fail for all Xprivs. + self.iter() + .find_map(|xpriv| xpriv.get_key(key_request, secp).transpose()) + .transpose() } }}} impl_get_key_for_set!(BTreeSet); @@ -830,7 +830,7 @@ impl GetKey for $map { fn get_key( &self, - key_request: KeyRequest, + key_request: &KeyRequest, _: &Secp256k1, ) -> Result, Self::Error> { match key_request { @@ -2133,7 +2133,7 @@ mod tests { let mut key_map = BTreeMap::new(); key_map.insert(pk, priv_key); - let got = key_map.get_key(KeyRequest::Pubkey(pk), &secp).expect("failed to get key"); + let got = key_map.get_key(&KeyRequest::Pubkey(pk), &secp).expect("failed to get key"); assert_eq!(got.unwrap(), priv_key) } diff --git a/bitcoin/tests/psbt-sign-taproot.rs b/bitcoin/tests/psbt-sign-taproot.rs index e79a1fe1a..6f4563988 100644 --- a/bitcoin/tests/psbt-sign-taproot.rs +++ b/bitcoin/tests/psbt-sign-taproot.rs @@ -27,12 +27,12 @@ fn psbt_sign_taproot() { type Error = SignError; fn get_key( &self, - key_request: KeyRequest, + key_request: &KeyRequest, _secp: &Secp256k1, ) -> Result, Self::Error> { match key_request { KeyRequest::Bip32((mfp, _)) => - if mfp == self.mfp { + if *mfp == self.mfp { Ok(Some(self.sk)) } else { Err(SignError::KeyNotFound)