diff --git a/src/util/sighash.rs b/src/util/sighash.rs index 43a70a37..33cb5bb0 100644 --- a/src/util/sighash.rs +++ b/src/util/sighash.rs @@ -20,10 +20,12 @@ //! and legacy (before Bip143). //! -pub use blockdata::transaction::EcdsaSigHashType; +use prelude::*; + +pub use blockdata::transaction::{EcdsaSigHashType, SigHashTypeParseError}; use blockdata::witness::Witness; use consensus::{encode, Encodable}; -use core::fmt; +use core::{str, fmt}; use core::ops::{Deref, DerefMut}; use core::borrow::Borrow; use hashes::{sha256, sha256d, Hash}; @@ -105,7 +107,6 @@ pub struct ScriptPath<'s> { /// Hashtype of an input's signature, encoded in the last byte of the signature /// Fixed values so they can be casted as integer types for encoding #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum SchnorrSigHashType { /// 0x0: Used when not explicitly specified, defaulting to [`SchnorrSigHashType::All`] Default = 0x00, @@ -128,6 +129,41 @@ pub enum SchnorrSigHashType { /// Reserved for future use, `#[non_exhaustive]` is not available with current MSRV Reserved = 0xFF, } +serde_string_impl!(SchnorrSigHashType, "a SchnorrSigHashType data"); + +impl fmt::Display for SchnorrSigHashType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + SchnorrSigHashType::Default => "SIGHASH_DEFAULT", + SchnorrSigHashType::All => "SIGHASH_ALL", + SchnorrSigHashType::None => "SIGHASH_NONE", + SchnorrSigHashType::Single => "SIGHASH_SINGLE", + SchnorrSigHashType::AllPlusAnyoneCanPay => "SIGHASH_ALL|SIGHASH_ANYONECANPAY", + SchnorrSigHashType::NonePlusAnyoneCanPay => "SIGHASH_NONE|SIGHASH_ANYONECANPAY", + SchnorrSigHashType::SinglePlusAnyoneCanPay => "SIGHASH_SINGLE|SIGHASH_ANYONECANPAY", + SchnorrSigHashType::Reserved => "SIGHASH_RESERVED", + }; + f.write_str(s) + } +} + +impl str::FromStr for SchnorrSigHashType { + type Err = SigHashTypeParseError; + + fn from_str(s: &str) -> Result { + match s { + "SIGHASH_DEFAULT" => Ok(SchnorrSigHashType::Default), + "SIGHASH_ALL" => Ok(SchnorrSigHashType::All), + "SIGHASH_NONE" => Ok(SchnorrSigHashType::None), + "SIGHASH_SINGLE" => Ok(SchnorrSigHashType::Single), + "SIGHASH_ALL|SIGHASH_ANYONECANPAY" => Ok(SchnorrSigHashType::AllPlusAnyoneCanPay), + "SIGHASH_NONE|SIGHASH_ANYONECANPAY" => Ok(SchnorrSigHashType::NonePlusAnyoneCanPay), + "SIGHASH_SINGLE|SIGHASH_ANYONECANPAY" => Ok(SchnorrSigHashType::SinglePlusAnyoneCanPay), + "SIGHASH_RESERVED" => Ok(SchnorrSigHashType::Reserved), + _ => Err(SigHashTypeParseError{ unrecognized: s.to_owned() }), + } + } +} /// Possible errors in computing the signature message #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -1107,4 +1143,39 @@ mod tests { let json_str = include_str!("../../test_data/bip341_tests.json"); serde_json::from_str(json_str).expect("JSON was not well-formatted") } + + #[test] + fn sighashtype_fromstr_display() { + let sighashtypes = vec![ + ("SIGHASH_DEFAULT", SchnorrSigHashType::Default), + ("SIGHASH_ALL", SchnorrSigHashType::All), + ("SIGHASH_NONE", SchnorrSigHashType::None), + ("SIGHASH_SINGLE", SchnorrSigHashType::Single), + ("SIGHASH_ALL|SIGHASH_ANYONECANPAY", SchnorrSigHashType::AllPlusAnyoneCanPay), + ("SIGHASH_NONE|SIGHASH_ANYONECANPAY", SchnorrSigHashType::NonePlusAnyoneCanPay), + ("SIGHASH_SINGLE|SIGHASH_ANYONECANPAY", SchnorrSigHashType::SinglePlusAnyoneCanPay), + ("SIGHASH_RESERVED", SchnorrSigHashType::Reserved), + ]; + for (s, sht) in sighashtypes { + assert_eq!(sht.to_string(), s); + assert_eq!(SchnorrSigHashType::from_str(s).unwrap(), sht); + } + let sht_mistakes = vec![ + "SIGHASH_ALL | SIGHASH_ANYONECANPAY", + "SIGHASH_NONE |SIGHASH_ANYONECANPAY", + "SIGHASH_SINGLE| SIGHASH_ANYONECANPAY", + "SIGHASH_ALL SIGHASH_ANYONECANPAY", + "SIGHASH_NONE |", + "SIGHASH_SIGNLE", + "DEFAULT", + "ALL", + "sighash_none", + "Sighash_none", + "SigHash_None", + "SigHash_NONE", + ]; + for s in sht_mistakes { + assert_eq!(SchnorrSigHashType::from_str(s).unwrap_err().to_string(), format!("Unrecognized SIGHASH string '{}'", s)); + } + } }