psbt: Add IndexOutOfBounds error

We currently have a bunch of functions that are infallible if the
`index` argument is within-bounds however we return a `SignError`, this
obfuscates the code.

Add an `IndexOutOfBoundsErorr`. While we are at it make it an enum so
users can differentiate between which vector the out of bounds access
was attempted against.
This commit is contained in:
Tobin C. Harding 2023-08-17 13:45:07 +10:00
parent 1991b7af40
commit 66d5800ac0
No known key found for this signature in database
GPG Key ID: 40BF9E4C269D6607
1 changed files with 63 additions and 11 deletions

View File

@ -394,20 +394,29 @@ impl Psbt {
} }
/// Gets the input at `input_index` after checking that it is a valid index. /// Gets the input at `input_index` after checking that it is a valid index.
fn checked_input(&self, input_index: usize) -> Result<&Input, SignError> { fn checked_input(&self, input_index: usize) -> Result<&Input, IndexOutOfBoundsError> {
self.check_index_is_within_bounds(input_index)?; self.check_index_is_within_bounds(input_index)?;
Ok(&self.inputs[input_index]) Ok(&self.inputs[input_index])
} }
/// Checks `input_index` is within bounds for the PSBT `inputs` array and /// Checks `input_index` is within bounds for the PSBT `inputs` array and
/// for the PSBT `unsigned_tx` `input` array. /// for the PSBT `unsigned_tx` `input` array.
fn check_index_is_within_bounds(&self, input_index: usize) -> Result<(), SignError> { fn check_index_is_within_bounds(
&self,
input_index: usize,
) -> Result<(), IndexOutOfBoundsError> {
if input_index >= self.inputs.len() { if input_index >= self.inputs.len() {
return Err(SignError::IndexOutOfBounds(input_index, self.inputs.len())); return Err(IndexOutOfBoundsError::Inputs {
index: input_index,
length: self.inputs.len(),
});
} }
if input_index >= self.unsigned_tx.input.len() { if input_index >= self.unsigned_tx.input.len() {
return Err(SignError::IndexOutOfBounds(input_index, self.unsigned_tx.input.len())); return Err(IndexOutOfBoundsError::TxInput {
index: input_index,
length: self.unsigned_tx.input.len(),
});
} }
Ok(()) Ok(())
@ -676,8 +685,8 @@ pub enum SigningAlgorithm {
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive] #[non_exhaustive]
pub enum SignError { pub enum SignError {
/// Input index out of bounds (actual index, maximum index allowed). /// Input index out of bounds.
IndexOutOfBounds(usize, usize), IndexOutOfBounds(IndexOutOfBoundsError),
/// Invalid Sighash type. /// Invalid Sighash type.
InvalidSighashType, InvalidSighashType,
/// Missing input utxo. /// Missing input utxo.
@ -711,9 +720,7 @@ impl fmt::Display for SignError {
use self::SignError::*; use self::SignError::*;
match *self { match *self {
IndexOutOfBounds(ref ind, ref len) => { IndexOutOfBounds(ref e) => write_err!(f, "index out of bounds"; e),
write!(f, "index {}, psbt input len: {}", ind, len)
}
InvalidSighashType => write!(f, "invalid sighash type"), InvalidSighashType => write!(f, "invalid sighash type"),
MissingInputUtxo => write!(f, "missing input utxo in PBST"), MissingInputUtxo => write!(f, "missing input utxo in PBST"),
MissingRedeemScript => write!(f, "missing redeem script"), MissingRedeemScript => write!(f, "missing redeem script"),
@ -738,8 +745,7 @@ impl std::error::Error for SignError {
use self::SignError::*; use self::SignError::*;
match *self { match *self {
IndexOutOfBounds(_, _) InvalidSighashType
| InvalidSighashType
| MissingInputUtxo | MissingInputUtxo
| MissingRedeemScript | MissingRedeemScript
| MissingSpendUtxo | MissingSpendUtxo
@ -752,6 +758,7 @@ impl std::error::Error for SignError {
| WrongSigningAlgorithm | WrongSigningAlgorithm
| Unsupported => None, | Unsupported => None,
SighashComputation(ref e) => Some(e), SighashComputation(ref e) => Some(e),
IndexOutOfBounds(ref e) => Some(e),
} }
} }
} }
@ -760,6 +767,51 @@ impl From<sighash::Error> for SignError {
fn from(e: sighash::Error) -> Self { SignError::SighashComputation(e) } fn from(e: sighash::Error) -> Self { SignError::SighashComputation(e) }
} }
impl From<IndexOutOfBoundsError> for SignError {
fn from(e: IndexOutOfBoundsError) -> Self { SignError::IndexOutOfBounds(e) }
}
/// Input index out of bounds (actual index, maximum index allowed).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexOutOfBoundsError {
/// The index is out of bounds for the `psbt.inputs` vector.
Inputs {
/// Attempted index access.
index: usize,
/// Length of the PBST inputs vector.
length: usize,
},
/// The index is out of bounds for the `psbt.unsigned_tx.input` vector.
TxInput {
/// Attempted index access.
index: usize,
/// Length of the PBST's unsigned transaction input vector.
length: usize,
},
}
impl fmt::Display for IndexOutOfBoundsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use IndexOutOfBoundsError::*;
match *self {
Inputs { ref index, ref length } => write!(
f,
"index {} is out-of-bounds for PSBT inputs vector length {}",
length, index
),
TxInput { ref index, ref length } => write!(
f,
"index {} is out-of-bounds for PSBT unsigned tx input vector length {}",
length, index
),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for IndexOutOfBoundsError {}
#[cfg(feature = "base64")] #[cfg(feature = "base64")]
mod display_from_str { mod display_from_str {
use core::fmt::{self, Display, Formatter}; use core::fmt::{self, Display, Formatter};