Refactor GetKey to take the KeyRequest by reference

To avoid cloning when looking it up in sets.
This commit is contained in:
Nadav Ivgi 2024-09-13 11:38:38 +03:00
parent d15c57bd1f
commit 055aa9d4dc
No known key found for this signature in database
GPG Key ID: 81F6104CD0F150FC
2 changed files with 12 additions and 12 deletions

View File

@ -365,9 +365,9 @@ impl Psbt {
let mut used = vec![]; // List of pubkeys used to sign the input. let mut used = vec![]; // List of pubkeys used to sign the input.
for (pk, key_source) in input.bip32_derivation.iter() { for (pk, key_source) in input.bip32_derivation.iter() {
let sk = if let Ok(Some(sk)) = k.get_key(KeyRequest::Bip32(key_source.clone()), secp) { let sk = if let Ok(Some(sk)) = k.get_key(&KeyRequest::Bip32(key_source.clone()), secp) {
sk sk
} else if let Ok(Some(sk)) = k.get_key(KeyRequest::Pubkey(PublicKey::new(*pk)), secp) { } else if let Ok(Some(sk)) = k.get_key(&KeyRequest::Pubkey(PublicKey::new(*pk)), secp) {
sk sk
} else { } else {
continue; continue;
@ -419,7 +419,7 @@ impl Psbt {
for (&xonly, (leaf_hashes, key_source)) in input.tap_key_origins.iter() { for (&xonly, (leaf_hashes, key_source)) in input.tap_key_origins.iter() {
let sk = if let Ok(Some(secret_key)) = let sk = if let Ok(Some(secret_key)) =
k.get_key(KeyRequest::Bip32(key_source.clone()), secp) k.get_key(&KeyRequest::Bip32(key_source.clone()), secp)
{ {
secret_key secret_key
} else { } else {
@ -745,7 +745,7 @@ pub trait GetKey {
/// - `Err` if an error was encountered while looking for the key. /// - `Err` if an error was encountered while looking for the key.
fn get_key<C: Signing>( fn get_key<C: Signing>(
&self, &self,
key_request: KeyRequest, key_request: &KeyRequest,
secp: &Secp256k1<C>, secp: &Secp256k1<C>,
) -> Result<Option<PrivateKey>, Self::Error>; ) -> Result<Option<PrivateKey>, Self::Error>;
} }
@ -755,13 +755,13 @@ impl GetKey for Xpriv {
fn get_key<C: Signing>( fn get_key<C: Signing>(
&self, &self,
key_request: KeyRequest, key_request: &KeyRequest,
secp: &Secp256k1<C>, secp: &Secp256k1<C>,
) -> Result<Option<PrivateKey>, Self::Error> { ) -> Result<Option<PrivateKey>, Self::Error> {
match key_request { match key_request {
KeyRequest::Pubkey(_) => Err(GetKeyError::NotSupported), KeyRequest::Pubkey(_) => Err(GetKeyError::NotSupported),
KeyRequest::Bip32((fingerprint, path)) => { KeyRequest::Bip32((fingerprint, path)) => {
let key = if self.fingerprint(secp) == fingerprint { let key = if self.fingerprint(secp) == *fingerprint {
let k = self.derive_priv(secp, &path); let k = self.derive_priv(secp, &path);
Some(k.to_priv()) Some(k.to_priv())
} else { } else {
@ -800,13 +800,13 @@ impl GetKey for $set<Xpriv> {
fn get_key<C: Signing>( fn get_key<C: Signing>(
&self, &self,
key_request: KeyRequest, key_request: &KeyRequest,
secp: &Secp256k1<C> secp: &Secp256k1<C>
) -> Result<Option<PrivateKey>, Self::Error> { ) -> Result<Option<PrivateKey>, Self::Error> {
// OK to stop at the first error because Xpriv::get_key() can only fail // OK to stop at the first error because Xpriv::get_key() can only fail
// if this isn't a KeyRequest::Bip32, which would fail for all Xprivs. // if this isn't a KeyRequest::Bip32, which would fail for all Xprivs.
self.iter() self.iter()
.find_map(|xpriv| xpriv.get_key(key_request.clone(), secp).transpose()) .find_map(|xpriv| xpriv.get_key(key_request, secp).transpose())
.transpose() .transpose()
} }
}}} }}}
@ -823,7 +823,7 @@ impl GetKey for $map<PublicKey, PrivateKey> {
fn get_key<C: Signing>( fn get_key<C: Signing>(
&self, &self,
key_request: KeyRequest, key_request: &KeyRequest,
_: &Secp256k1<C>, _: &Secp256k1<C>,
) -> Result<Option<PrivateKey>, Self::Error> { ) -> Result<Option<PrivateKey>, Self::Error> {
match key_request { match key_request {
@ -2126,7 +2126,7 @@ mod tests {
let mut key_map = BTreeMap::new(); let mut key_map = BTreeMap::new();
key_map.insert(pk, priv_key); key_map.insert(pk, priv_key);
let got = key_map.get_key(KeyRequest::Pubkey(pk), &secp).expect("failed to get key"); let got = key_map.get_key(&KeyRequest::Pubkey(pk), &secp).expect("failed to get key");
assert_eq!(got.unwrap(), priv_key) assert_eq!(got.unwrap(), priv_key)
} }

View File

@ -27,12 +27,12 @@ fn psbt_sign_taproot() {
type Error = SignError; type Error = SignError;
fn get_key<C: Signing>( fn get_key<C: Signing>(
&self, &self,
key_request: KeyRequest, key_request: &KeyRequest,
_secp: &Secp256k1<C>, _secp: &Secp256k1<C>,
) -> Result<Option<PrivateKey>, Self::Error> { ) -> Result<Option<PrivateKey>, Self::Error> {
match key_request { match key_request {
KeyRequest::Bip32((mfp, _)) => KeyRequest::Bip32((mfp, _)) =>
if mfp == self.mfp { if *mfp == self.mfp {
Ok(Some(self.sk)) Ok(Some(self.sk))
} else { } else {
Err(SignError::KeyNotFound) Err(SignError::KeyNotFound)