Implement transaction weight prediction

When creating a transaction one must know the the fee beforehand to set
appropriate amounts for outputs and to know the fee, weight is required.
So far we only had a method on an already-constructed transaction. This
method clearly wasn't helpful when constructing the transaction except
for hacks like temporarily adding an all-zeroes signature.

This change adds a function that can compute the transaction weight
without knowing individual bytes of the scripts, witnesses and other
elements. It only needs to know their sizes.

To make the API less error-prone a special, trivial, type is also added
for computing the lengths of witnesses.
This commit is contained in:
Martin Habovstiak 2023-02-09 13:59:06 +01:00
parent b6387db47f
commit cf068d16b0
1 changed files with 125 additions and 7 deletions

View File

@ -850,6 +850,19 @@ impl Transaction {
self.scaled_size(1) self.scaled_size(1)
} }
/// Computes the weight and checks that it matches the output of `predict_weight`.
#[cfg(test)]
fn check_weight(&self) -> Weight {
let weight1 = self.weight();
let inputs = self.input
.iter()
.map(|txin| InputWeightPrediction::new(txin.script_sig.len(), txin.witness.iter().map(|elem| elem.len())));
let outputs = self.output.iter().map(|txout| txout.script_pubkey.len());
let weight2 = predict_weight(inputs, outputs);
assert_eq!(weight1, weight2);
weight1
}
/// Returns the "virtual size" (vsize) of this transaction. /// Returns the "virtual size" (vsize) of this transaction.
/// ///
/// Will be `ceil(weight / 4.0)`. Note this implements the virtual size as per [`BIP141`], which /// Will be `ceil(weight / 4.0)`. Note this implements the virtual size as per [`BIP141`], which
@ -1148,6 +1161,111 @@ impl From<&Transaction> for Wtxid {
} }
} }
/// Predicts the weight of a to-be-constructed transaction.
///
/// This function computes the weight of a transaction which is not fully known. All that is needed
/// is the lengths of scripts and witness elements.
///
/// # Arguments
///
/// * `inputs` - an iterator which returns `InputWeightPrediction` for each input of the
/// to-be-constructed transaction.
/// * `output_script_lens` - an iterator which returns the length of `script_pubkey` of each output
/// of the to-be-constructed transaction.
///
/// Note that lengths of the scripts and witness elements must be non-serialized, IOW *without* the
/// preceding compact size. The lenght of preceding compact size is computed and added inside the
/// function for convenience.
///
/// # Usage
///
/// When signing a transaction one doesn't know the signature before knowing the transaction fee and
/// the transaction fee is not known before knowing the transaction size which is not known before
/// knowing the signature. This apparent dependency cycle can be broken by knowing the length of the
/// signature without knowing the contents of the signature e.g., we know all Schnorr signatures
/// are 64 bytes long.
///
/// Additionally, some protocols may require calculating the amounts before knowing various parts
/// of the transaction (assuming their length is known).
///
/// # Notes on integer overflow
///
/// Overflows are intentionally not checked because one of the following holds:
///
/// * The transaction is valid (obeys the block size limit) and the code feeds correct values to
/// this function - no overflow can happen.
/// * The transaction will be so large it doesn't fit in the memory - overflow will happen but
/// then the transaction will fail to construct and even if one serialized it on disk directly
/// it'd be invalid anyway so overflow doesn't matter.
/// * The values fed into this function are inconsistent with the actual lengths the transaction
/// will have - the code is already broken and checking overflows doesn't help. Unfortunately
/// this probably cannot be avoided.
pub fn predict_weight<I, O>(inputs: I, output_script_lens: O) -> Weight
where I: IntoIterator<Item = InputWeightPrediction>,
O: IntoIterator<Item = usize>,
{
let (input_count, partial_input_weight, inputs_with_witnesses) = inputs.into_iter()
.fold((0, 0, 0), |(count, partial_input_weight, inputs_with_witnesses), prediction| {
(count + 1, partial_input_weight + prediction.script_size * 4 + prediction.witness_size, inputs_with_witnesses + (prediction.witness_size > 0) as usize)
});
let (output_count, output_scripts_size) = output_script_lens.into_iter()
.fold((0, 0), |(output_count, total_scripts_size), script_len| {
let script_size = script_len + VarInt(script_len as u64).len();
(output_count + 1, total_scripts_size + script_size)
});
let input_weight = partial_input_weight + input_count * 4 * (32 + 4 + 4);
let output_size = 8 * output_count + output_scripts_size;
let non_input_size =
// version:
4 +
// count varints:
VarInt(input_count as u64).len() +
VarInt(output_count as u64).len() +
output_size +
// lock_time
4;
let weight = if inputs_with_witnesses == 0 {
non_input_size * 4 + input_weight
} else {
non_input_size * 4 + input_weight + input_count - inputs_with_witnesses + 2
};
Weight::from_wu(weight as u64)
}
/// Weight prediction of an individual input.
///
/// This helper type collects information about an input to be used in [`predict_weight`] function.
/// It can only be created using the [`new`](InputWeightPrediction::new) function.
#[derive(Copy, Clone, Debug)]
pub struct InputWeightPrediction {
script_size: usize,
witness_size: usize,
}
impl InputWeightPrediction {
/// Computes the prediction for a single input.
pub fn new<I>(input_script_len: usize, witness_element_lengths: I) -> Self
where I: IntoIterator<Item = usize>,
{
let (count, total_size) = witness_element_lengths.into_iter()
.fold((0, 0), |(count, total_size), elem_len| {
let elem_size = elem_len + VarInt(elem_len as u64).len();
(count + 1, total_size + elem_size)
});
let witness_size = if count > 0 {
total_size + VarInt(count as u64).len()
} else {
0
};
let script_size = input_script_len + VarInt(input_script_len as u64).len();
InputWeightPrediction {
script_size,
witness_size,
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -1260,7 +1378,7 @@ mod tests {
"a6eab3c14ab5272a58a5ba91505ba1a4b6d7a3a9fcbd187b6cd99a7b6d548cb7".to_string()); "a6eab3c14ab5272a58a5ba91505ba1a4b6d7a3a9fcbd187b6cd99a7b6d548cb7".to_string());
assert_eq!(format!("{:x}", realtx.wtxid()), assert_eq!(format!("{:x}", realtx.wtxid()),
"a6eab3c14ab5272a58a5ba91505ba1a4b6d7a3a9fcbd187b6cd99a7b6d548cb7".to_string()); "a6eab3c14ab5272a58a5ba91505ba1a4b6d7a3a9fcbd187b6cd99a7b6d548cb7".to_string());
assert_eq!(realtx.weight().to_wu() as usize, tx_bytes.len()*WITNESS_SCALE_FACTOR); assert_eq!(realtx.check_weight().to_wu() as usize, tx_bytes.len()*WITNESS_SCALE_FACTOR);
assert_eq!(realtx.size(), tx_bytes.len()); assert_eq!(realtx.size(), tx_bytes.len());
assert_eq!(realtx.vsize(), tx_bytes.len()); assert_eq!(realtx.vsize(), tx_bytes.len());
assert_eq!(realtx.strippedsize(), tx_bytes.len()); assert_eq!(realtx.strippedsize(), tx_bytes.len());
@ -1295,7 +1413,7 @@ mod tests {
assert_eq!(format!("{:x}", realtx.wtxid()), assert_eq!(format!("{:x}", realtx.wtxid()),
"80b7d8a82d5d5bf92905b06f2014dd699e03837ca172e3a59d51426ebbe3e7f5".to_string()); "80b7d8a82d5d5bf92905b06f2014dd699e03837ca172e3a59d51426ebbe3e7f5".to_string());
const EXPECTED_WEIGHT: Weight = Weight::from_wu(442); const EXPECTED_WEIGHT: Weight = Weight::from_wu(442);
assert_eq!(realtx.weight(), EXPECTED_WEIGHT); assert_eq!(realtx.check_weight(), EXPECTED_WEIGHT);
assert_eq!(realtx.size(), tx_bytes.len()); assert_eq!(realtx.size(), tx_bytes.len());
assert_eq!(realtx.vsize(), 111); assert_eq!(realtx.vsize(), 111);
// Since // Since
@ -1308,7 +1426,7 @@ mod tests {
// Construct a transaction without the witness data. // Construct a transaction without the witness data.
let mut tx_without_witness = realtx; let mut tx_without_witness = realtx;
tx_without_witness.input.iter_mut().for_each(|input| input.witness.clear()); tx_without_witness.input.iter_mut().for_each(|input| input.witness.clear());
assert_eq!(tx_without_witness.weight().to_wu() as usize, expected_strippedsize*WITNESS_SCALE_FACTOR); assert_eq!(tx_without_witness.check_weight().to_wu() as usize, expected_strippedsize*WITNESS_SCALE_FACTOR);
assert_eq!(tx_without_witness.size(), expected_strippedsize); assert_eq!(tx_without_witness.size(), expected_strippedsize);
assert_eq!(tx_without_witness.vsize(), expected_strippedsize); assert_eq!(tx_without_witness.vsize(), expected_strippedsize);
assert_eq!(tx_without_witness.strippedsize(), expected_strippedsize); assert_eq!(tx_without_witness.strippedsize(), expected_strippedsize);
@ -1413,7 +1531,7 @@ mod tests {
assert_eq!(format!("{:x}", tx.wtxid()), "d6ac4a5e61657c4c604dcde855a1db74ec6b3e54f32695d72c5e11c7761ea1b4"); assert_eq!(format!("{:x}", tx.wtxid()), "d6ac4a5e61657c4c604dcde855a1db74ec6b3e54f32695d72c5e11c7761ea1b4");
assert_eq!(format!("{:x}", tx.txid()), "9652aa62b0e748caeec40c4cb7bc17c6792435cc3dfe447dd1ca24f912a1c6ec"); assert_eq!(format!("{:x}", tx.txid()), "9652aa62b0e748caeec40c4cb7bc17c6792435cc3dfe447dd1ca24f912a1c6ec");
assert_eq!(tx.weight(), Weight::from_wu(2718)); assert_eq!(tx.check_weight(), Weight::from_wu(2718));
// non-segwit tx from my mempool // non-segwit tx from my mempool
let tx_bytes = hex!( let tx_bytes = hex!(
@ -1445,7 +1563,7 @@ mod tests {
fn test_segwit_tx_decode() { fn test_segwit_tx_decode() {
let tx_bytes = hex!("010000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff3603da1b0e00045503bd5704c7dd8a0d0ced13bb5785010800000000000a636b706f6f6c122f4e696e6a61506f6f6c2f5345475749542fffffffff02b4e5a212000000001976a914876fbb82ec05caa6af7a3b5e5a983aae6c6cc6d688ac0000000000000000266a24aa21a9edf91c46b49eb8a29089980f02ee6b57e7d63d33b18b4fddac2bcd7db2a39837040120000000000000000000000000000000000000000000000000000000000000000000000000"); let tx_bytes = hex!("010000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff3603da1b0e00045503bd5704c7dd8a0d0ced13bb5785010800000000000a636b706f6f6c122f4e696e6a61506f6f6c2f5345475749542fffffffff02b4e5a212000000001976a914876fbb82ec05caa6af7a3b5e5a983aae6c6cc6d688ac0000000000000000266a24aa21a9edf91c46b49eb8a29089980f02ee6b57e7d63d33b18b4fddac2bcd7db2a39837040120000000000000000000000000000000000000000000000000000000000000000000000000");
let tx: Transaction = deserialize(&tx_bytes).unwrap(); let tx: Transaction = deserialize(&tx_bytes).unwrap();
assert_eq!(tx.weight(), Weight::from_wu(780)); assert_eq!(tx.check_weight(), Weight::from_wu(780));
serde_round_trip!(tx); serde_round_trip!(tx);
let consensus_encoded = serialize(&tx); let consensus_encoded = serialize(&tx);
@ -1610,7 +1728,7 @@ mod tests {
input: vec![], input: vec![],
output: vec![], output: vec![],
} }
.weight(); .check_weight();
for (is_segwit, tx) in &txs { for (is_segwit, tx) in &txs {
let txin_weight = if *is_segwit { let txin_weight = if *is_segwit {
@ -1626,7 +1744,7 @@ mod tests {
+ segwit_marker_weight + segwit_marker_weight
+ tx.input.iter().fold(0, |sum, i| sum + txin_weight(i)) + tx.input.iter().fold(0, |sum, i| sum + txin_weight(i))
+ tx.output.iter().fold(0, |sum, o| sum + o.weight()); + tx.output.iter().fold(0, |sum, o| sum + o.weight());
assert_eq!(calculated_size, tx.weight().to_wu() as usize); assert_eq!(calculated_size, tx.check_weight().to_wu() as usize);
} }
} }
} }