Factor out `io::Error` from sighash errors

The hadnling of `io::Error` in sighash had a few problems:

* It used `io::ErrorKind` instead of `io::Error` losing inforation
* Changing `io::ErrorKind` to `io::Error` would disable `PartialEq`&co
* The `Io` error wariants were duplicated

It turns out all of these can be solved by moving the `Io` variant into
a separate error.
This commit is contained in:
Martin Habovstiak 2024-01-19 17:03:32 +01:00
parent 111094ca9e
commit a4d01d0b6c
1 changed files with 72 additions and 46 deletions

View File

@ -582,8 +582,8 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
annex: Option<Annex>,
leaf_hash_code_separator: Option<(TapLeafHash, u32)>,
sighash_type: TapSighashType,
) -> Result<(), TaprootError> {
prevouts.check_all(self.tx.borrow())?;
) -> Result<(), SigningDataError<TaprootError>> {
prevouts.check_all(self.tx.borrow()).map_err(SigningDataError::sighash)?;
let (sighash, anyone_can_pay) = sighash_type.split_anyonecanpay_flag();
@ -608,8 +608,8 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
// sha_sequences (32): the SHA256 of the serialization of all input nSequence.
if !anyone_can_pay {
self.common_cache().prevouts.consensus_encode(writer)?;
self.taproot_cache(prevouts.get_all()?).amounts.consensus_encode(writer)?;
self.taproot_cache(prevouts.get_all()?).script_pubkeys.consensus_encode(writer)?;
self.taproot_cache(prevouts.get_all().map_err(SigningDataError::sighash)?).amounts.consensus_encode(writer)?;
self.taproot_cache(prevouts.get_all().map_err(SigningDataError::sighash)?).script_pubkeys.consensus_encode(writer)?;
self.common_cache().sequences.consensus_encode(writer)?;
}
@ -637,8 +637,8 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
// scriptPubKey (35): scriptPubKey of the previous output spent by this input, serialized as script inside CTxOut. Its size is always 35 bytes.
// nSequence (4): nSequence of this input.
if anyone_can_pay {
let txin = &self.tx.borrow().tx_in(input_index)?;
let previous_output = prevouts.get(input_index)?;
let txin = &self.tx.borrow().tx_in(input_index).map_err(SigningDataError::sighash)?;
let previous_output = prevouts.get(input_index).map_err(SigningDataError::sighash)?;
txin.previous_output.consensus_encode(writer)?;
previous_output.value.consensus_encode(writer)?;
previous_output.script_pubkey.consensus_encode(writer)?;
@ -669,7 +669,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
.ok_or(TaprootError::SingleMissingOutput(SingleMissingOutputError {
input_index,
outputs_length: self.tx.borrow().output.len(),
}))?
})).map_err(SigningDataError::Sighash)?
.consensus_encode(&mut enc)?;
let hash = sha256::Hash::from_engine(enc);
hash.consensus_encode(writer)?;
@ -705,7 +705,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
annex,
leaf_hash_code_separator,
sighash_type,
)?;
).map_err(SigningDataError::unwrap_sighash)?;
Ok(TapSighash::from_engine(enc))
}
@ -724,7 +724,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
None,
None,
sighash_type,
)?;
).map_err(SigningDataError::unwrap_sighash)?;
Ok(TapSighash::from_engine(enc))
}
@ -747,7 +747,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
None,
Some((leaf_hash.into(), 0xFFFFFFFF)),
sighash_type,
)?;
).map_err(SigningDataError::unwrap_sighash)?;
Ok(TapSighash::from_engine(enc))
}
@ -764,7 +764,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
script_code: &Script,
value: Amount,
sighash_type: EcdsaSighashType,
) -> Result<(), SegwitV0Error> {
) -> Result<(), SigningDataError<SegwitV0Error>> {
let zero_hash = sha256d::Hash::all_zeros();
let (sighash, anyone_can_pay) = sighash_type.split_anyonecanpay_flag();
@ -787,7 +787,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
}
{
let txin = &self.tx.borrow().tx_in(input_index)?;
let txin = &self.tx.borrow().tx_in(input_index).map_err(SigningDataError::sighash)?;
txin.previous_output.consensus_encode(writer)?;
script_code.consensus_encode(writer)?;
value.consensus_encode(writer)?;
@ -831,7 +831,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
&script_code,
value,
sighash_type,
)?;
).map_err(SigningDataError::unwrap_sighash)?;
Ok(SegwitV0Sighash::from_engine(enc))
}
@ -850,7 +850,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
witness_script,
value,
sighash_type,
)?;
).map_err(SigningDataError::unwrap_sighash)?;
Ok(SegwitV0Sighash::from_engine(enc))
}
@ -882,10 +882,10 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
input_index: usize,
script_pubkey: &Script,
sighash_type: U,
) -> EncodeSigningDataResult<LegacyError> {
) -> EncodeSigningDataResult<SigningDataError<LegacyError>> {
// Validate input_index.
if let Err(e) = self.tx.borrow().tx_in(input_index) {
return EncodeSigningDataResult::WriteResult(Err(e.into()));
return EncodeSigningDataResult::WriteResult(Err(SigningDataError::Sighash(e.into())));
}
let sighash_type: u32 = sighash_type.into();
@ -977,7 +977,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
script_pubkey,
sighash_type,
)
.map_err(|e| LegacyError::Io(e.kind())),
.map_err(Into::into),
)
}
@ -1013,10 +1013,7 @@ impl<R: Borrow<Transaction>> SighashCache<R> {
{
Ok(true) => Ok(LegacySighash::from_byte_array(UINT256_ONE)),
Ok(false) => Ok(LegacySighash::from_engine(engine)),
Err(e) => match e {
LegacyError::InputsIndex(e) => Err(e),
LegacyError::Io(_) => unreachable!("engines don't error"),
},
Err(e) => Err(match e.unwrap_sighash() { LegacyError::InputsIndex(error) => error }),
}
}
@ -1145,9 +1142,6 @@ impl<'a> Encodable for Annex<'a> {
pub enum TaprootError {
/// Index out of bounds when accessing transaction input vector.
InputsIndex(transaction::InputsIndexError),
/// Can happen only when using `*_encode_signing_*` methods with custom writers, engines
/// like those used in `*_signature_hash` methods do not error.
Io(io::ErrorKind),
/// Using `SIGHASH_SINGLE` requires an output at the same index is the input.
SingleMissingOutput(SingleMissingOutputError),
/// Prevouts size error.
@ -1166,7 +1160,6 @@ impl fmt::Display for TaprootError {
match *self {
InputsIndex(ref e) => write_err!(f, "inputs index"; e),
Io(error_kind) => write!(f, "write failed: {:?}", error_kind),
SingleMissingOutput(ref e) => write_err!(f, "sighash single"; e),
PrevoutsSize(ref e) => write_err!(f, "prevouts size"; e),
PrevoutsIndex(ref e) => write_err!(f, "prevouts index"; e),
@ -1187,7 +1180,7 @@ impl std::error::Error for TaprootError {
PrevoutsSize(ref e) => Some(e),
PrevoutsIndex(ref e) => Some(e),
PrevoutsKind(ref e) => Some(e),
Io(_) | InvalidSighashType(_) => None,
InvalidSighashType(_) => None,
}
}
}
@ -1196,10 +1189,6 @@ impl From<transaction::InputsIndexError> for TaprootError {
fn from(e: transaction::InputsIndexError) -> Self { Self::InputsIndex(e) }
}
impl From<io::Error> for TaprootError {
fn from(e: io::Error) -> Self { Self::Io(e.kind()) }
}
impl From<PrevoutsSizeError> for TaprootError {
fn from(e: PrevoutsSizeError) -> Self { Self::PrevoutsSize(e) }
}
@ -1255,8 +1244,6 @@ impl From<SegwitV0Error> for P2wpkhError {
pub enum SegwitV0Error {
/// Index out of bounds when accessing transaction input vector.
InputsIndex(transaction::InputsIndexError),
/// Writer errored during consensus encoding.
Io(io::ErrorKind),
}
impl fmt::Display for SegwitV0Error {
@ -1265,7 +1252,6 @@ impl fmt::Display for SegwitV0Error {
match *self {
InputsIndex(ref e) => write_err!(f, "inputs index"; e),
Io(error_kind) => write!(f, "write failed: {:?}", error_kind),
}
}
}
@ -1278,7 +1264,6 @@ impl std::error::Error for SegwitV0Error {
match *self {
InputsIndex(ref e) => Some(e),
Io(_) => None,
}
}
}
@ -1287,10 +1272,6 @@ impl From<transaction::InputsIndexError> for SegwitV0Error {
fn from(e: transaction::InputsIndexError) -> Self { Self::InputsIndex(e) }
}
impl From<io::Error> for SegwitV0Error {
fn from(e: io::Error) -> Self { Self::Io(e.kind()) }
}
/// Using `SIGHASH_SINGLE` requires an output at the same index as the input.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
@ -1356,8 +1337,6 @@ impl std::error::Error for AnnexError {
pub enum LegacyError {
/// Index out of bounds when accessing transaction input vector.
InputsIndex(transaction::InputsIndexError),
/// Writer errored during consensus encoding.
Io(io::ErrorKind),
}
impl fmt::Display for LegacyError {
@ -1366,7 +1345,6 @@ impl fmt::Display for LegacyError {
match *self {
InputsIndex(ref e) => write_err!(f, "inputs index"; e),
Io(error_kind) => write!(f, "write failed: {:?}", error_kind),
}
}
}
@ -1378,7 +1356,6 @@ impl std::error::Error for LegacyError {
match *self {
InputsIndex(ref e) => Some(e),
Io(_) => None,
}
}
}
@ -1387,10 +1364,6 @@ impl From<transaction::InputsIndexError> for LegacyError {
fn from(e: transaction::InputsIndexError) -> Self { Self::InputsIndex(e) }
}
impl From<io::Error> for LegacyError {
fn from(e: io::Error) -> Self { Self::Io(e.kind()) }
}
fn is_invalid_use_of_sighash_single(sighash: u32, input_index: usize, outputs_len: usize) -> bool {
let ty = EcdsaSighashType::from_consensus(sighash);
ty == EcdsaSighashType::Single && input_index >= outputs_len
@ -1464,6 +1437,59 @@ impl<E> EncodeSigningDataResult<E> {
}
}
/// Error returned when writing signing data fails.
#[derive(Debug)]
pub enum SigningDataError<E> {
/// Can happen only when using `*_encode_signing_*` methods with custom writers, engines
/// like those used in `*_signature_hash` methods do not error.
Io(io::Error),
/// An argument to the called sighash function was invalid.
Sighash(E),
}
impl<E> SigningDataError<E> {
/// Returns the sighash variant, panicking if it's IO.
///
/// This is used when encoding to hash engine when we know that IO doesn't fail.
fn unwrap_sighash(self) -> E {
match self {
Self::Sighash(error) => error,
Self::Io(error) => panic!("hash engine error {}", error),
}
}
fn sighash<E2: Into<E>>(error: E2) -> Self {
Self::Sighash(error.into())
}
}
// We cannot simultaneously impl `From<E>`. it was determined that this alternative requires less
// manual `map_err` calls.
impl<E> From<io::Error> for SigningDataError<E> {
fn from(value: io::Error) -> Self {
Self::Io(value)
}
}
impl<E: fmt::Display> fmt::Display for SigningDataError<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Io(error) => write_err!(f, "failed to write sighash data"; error),
Self::Sighash(error) => write_err!(f, "failed to compute sighash data"; error),
}
}
}
#[cfg(feature = "std")]
impl<E: std::error::Error + 'static> std::error::Error for SigningDataError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
SigningDataError::Io(error) => Some(error),
SigningDataError::Sighash(error) => Some(error),
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;