diff --git a/src/util/psbt/map/input.rs b/src/util/psbt/map/input.rs index 3bcbea1f..6d90d559 100644 --- a/src/util/psbt/map/input.rs +++ b/src/util/psbt/map/input.rs @@ -525,3 +525,71 @@ where btree_map::Entry::Occupied(_) => Err(psbt::Error::DuplicateKey(raw_key).into()), } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn psbt_sighash_type_ecdsa() { + for ecdsa in &[ + EcdsaSigHashType::All, + EcdsaSigHashType::None, + EcdsaSigHashType::Single, + EcdsaSigHashType::AllPlusAnyoneCanPay, + EcdsaSigHashType::NonePlusAnyoneCanPay, + EcdsaSigHashType::SinglePlusAnyoneCanPay, + ] { + let sighash = PsbtSigHashType::from(*ecdsa); + let s = format!("{}", sighash); + let back = PsbtSigHashType::from_str(&s).unwrap(); + assert_eq!(back, sighash); + assert_eq!(back.ecdsa_hash_ty().unwrap(), *ecdsa); + } + } + + #[test] + fn psbt_sighash_type_schnorr() { + for schnorr in &[ + SchnorrSigHashType::Default, + SchnorrSigHashType::All, + SchnorrSigHashType::None, + SchnorrSigHashType::Single, + SchnorrSigHashType::AllPlusAnyoneCanPay, + SchnorrSigHashType::NonePlusAnyoneCanPay, + SchnorrSigHashType::SinglePlusAnyoneCanPay, + ] { + let sighash = PsbtSigHashType::from(*schnorr); + let s = format!("{}", sighash); + let back = PsbtSigHashType::from_str(&s).unwrap(); + assert_eq!(back, sighash); + assert_eq!(back.schnorr_hash_ty().unwrap(), *schnorr); + } + } + + #[test] + fn psbt_sighash_type_schnorr_notstd() { + for (schnorr, schnorr_str) in &[ + (SchnorrSigHashType::Reserved, "0xff"), + ] { + let sighash = PsbtSigHashType::from(*schnorr); + let s = format!("{}", sighash); + assert_eq!(&s, schnorr_str); + let back = PsbtSigHashType::from_str(&s).unwrap(); + assert_eq!(back, sighash); + assert_eq!(back.schnorr_hash_ty().unwrap(), *schnorr); + } + } + + #[test] + fn psbt_sighash_type_notstd() { + let nonstd = 0xdddddddd; + let sighash = PsbtSigHashType { inner: nonstd }; + let s = format!("{}", sighash); + let back = PsbtSigHashType::from_str(&s).unwrap(); + + assert_eq!(back, sighash); + assert_eq!(back.ecdsa_hash_ty(), Err(NonStandardSigHashType(nonstd))); + assert_eq!(back.schnorr_hash_ty(), Err(sighash::Error::InvalidSigHashType(nonstd))); + } +}