diff --git a/bitcoin/src/crypto/key.rs b/bitcoin/src/crypto/key.rs index 9e9f1aa8c..9ba23615a 100644 --- a/bitcoin/src/crypto/key.rs +++ b/bitcoin/src/crypto/key.rs @@ -497,7 +497,12 @@ impl PrivateKey { let compressed = match data.len() { 33 => false, - 34 => true, + 34 => { + if data[33] != 1 { + return Err(InvalidWifCompressionFlagError{ invalid: data[33] }.into()); + } + true + }, length => { return Err(InvalidBase58PayloadLengthError { length }.into()); } @@ -963,6 +968,8 @@ pub enum FromWifError { InvalidAddressVersion(InvalidAddressVersionError), /// A secp256k1 error. Secp256k1(secp256k1::Error), + /// Invalid WIF compression flag. + InvalidWifCompressionFlag(InvalidWifCompressionFlagError), } impl From for FromWifError { @@ -980,6 +987,8 @@ impl fmt::Display for FromWifError { InvalidAddressVersion(ref e) => write_err!(f, "decoded base58 data contained an invalid address version btye"; e), Secp256k1(ref e) => write_err!(f, "private key validation failed"; e), + InvalidWifCompressionFlag(ref e) => + write_err!(f, "invalid WIF compression flag";e), } } } @@ -994,6 +1003,7 @@ impl std::error::Error for FromWifError { InvalidBase58PayloadLength(ref e) => Some(e), InvalidAddressVersion(ref e) => Some(e), Secp256k1(ref e) => Some(e), + InvalidWifCompressionFlag(ref e) => Some(e), } } } @@ -1016,6 +1026,12 @@ impl From for FromWifError { fn from(e: InvalidAddressVersionError) -> FromWifError { Self::InvalidAddressVersion(e) } } +impl From for FromWifError { + fn from(e: InvalidWifCompressionFlagError) -> FromWifError { + Self::InvalidWifCompressionFlag(e) + } +} + /// Error returned while constructing public key from string. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ParsePublicKeyError { @@ -1161,6 +1177,27 @@ impl fmt::Display for InvalidAddressVersionError { #[cfg(feature = "std")] impl std::error::Error for InvalidAddressVersionError {} +/// Invalid compression flag for a WIF key +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InvalidWifCompressionFlagError{ + /// The invalid compression flag. + pub(crate) invalid: u8, +} + +impl InvalidWifCompressionFlagError { + /// Returns the invalid compression flag. + pub fn invalid_compression_flag(&self) -> u8 { self.invalid } +} + +impl fmt::Display for InvalidWifCompressionFlagError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "invalid WIF compression flag. Expected a 0x01 byte at the end of the key but found: {}", self.invalid) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for InvalidWifCompressionFlagError {} + #[cfg(test)] mod tests { use super::*; @@ -1168,6 +1205,11 @@ mod tests { #[test] fn key_derivation() { + // mainnet compressed WIF with invalid compression flag. + let sk = + PrivateKey::from_wif("L2x4uC2YgfFWZm9tF4pjDnVR6nJkheizFhEr2KvDNnTEmEqVzPJY"); + assert!(matches!(sk, Err(FromWifError::InvalidWifCompressionFlag(InvalidWifCompressionFlagError { invalid: 49 })))); + // testnet compressed let sk = PrivateKey::from_wif("cVt4o7BGAig1UXywgGSmARhxMdzP5qvQsxKkSsc1XEkw3tDTQFpy").unwrap();