diff --git a/bitcoin/src/blockdata/transaction.rs b/bitcoin/src/blockdata/transaction.rs index cc385066..8cb13064 100644 --- a/bitcoin/src/blockdata/transaction.rs +++ b/bitcoin/src/blockdata/transaction.rs @@ -850,6 +850,19 @@ impl Transaction { self.scaled_size(1) } + /// Computes the weight and checks that it matches the output of `predict_weight`. + #[cfg(test)] + fn check_weight(&self) -> Weight { + let weight1 = self.weight(); + let inputs = self.input + .iter() + .map(|txin| InputWeightPrediction::new(txin.script_sig.len(), txin.witness.iter().map(|elem| elem.len()))); + let outputs = self.output.iter().map(|txout| txout.script_pubkey.len()); + let weight2 = predict_weight(inputs, outputs); + assert_eq!(weight1, weight2); + weight1 + } + /// Returns the "virtual size" (vsize) of this transaction. /// /// Will be `ceil(weight / 4.0)`. Note this implements the virtual size as per [`BIP141`], which @@ -1148,6 +1161,111 @@ impl From<&Transaction> for Wtxid { } } +/// Predicts the weight of a to-be-constructed transaction. +/// +/// This function computes the weight of a transaction which is not fully known. All that is needed +/// is the lengths of scripts and witness elements. +/// +/// # Arguments +/// +/// * `inputs` - an iterator which returns `InputWeightPrediction` for each input of the +/// to-be-constructed transaction. +/// * `output_script_lens` - an iterator which returns the length of `script_pubkey` of each output +/// of the to-be-constructed transaction. +/// +/// Note that lengths of the scripts and witness elements must be non-serialized, IOW *without* the +/// preceding compact size. The lenght of preceding compact size is computed and added inside the +/// function for convenience. +/// +/// # Usage +/// +/// When signing a transaction one doesn't know the signature before knowing the transaction fee and +/// the transaction fee is not known before knowing the transaction size which is not known before +/// knowing the signature. This apparent dependency cycle can be broken by knowing the length of the +/// signature without knowing the contents of the signature e.g., we know all Schnorr signatures +/// are 64 bytes long. +/// +/// Additionally, some protocols may require calculating the amounts before knowing various parts +/// of the transaction (assuming their length is known). +/// +/// # Notes on integer overflow +/// +/// Overflows are intentionally not checked because one of the following holds: +/// +/// * The transaction is valid (obeys the block size limit) and the code feeds correct values to +/// this function - no overflow can happen. +/// * The transaction will be so large it doesn't fit in the memory - overflow will happen but +/// then the transaction will fail to construct and even if one serialized it on disk directly +/// it'd be invalid anyway so overflow doesn't matter. +/// * The values fed into this function are inconsistent with the actual lengths the transaction +/// will have - the code is already broken and checking overflows doesn't help. Unfortunately +/// this probably cannot be avoided. +pub fn predict_weight(inputs: I, output_script_lens: O) -> Weight + where I: IntoIterator, + O: IntoIterator, +{ + let (input_count, partial_input_weight, inputs_with_witnesses) = inputs.into_iter() + .fold((0, 0, 0), |(count, partial_input_weight, inputs_with_witnesses), prediction| { + (count + 1, partial_input_weight + prediction.script_size * 4 + prediction.witness_size, inputs_with_witnesses + (prediction.witness_size > 0) as usize) + }); + let (output_count, output_scripts_size) = output_script_lens.into_iter() + .fold((0, 0), |(output_count, total_scripts_size), script_len| { + let script_size = script_len + VarInt(script_len as u64).len(); + (output_count + 1, total_scripts_size + script_size) + }); + let input_weight = partial_input_weight + input_count * 4 * (32 + 4 + 4); + let output_size = 8 * output_count + output_scripts_size; + let non_input_size = + // version: + 4 + + // count varints: + VarInt(input_count as u64).len() + + VarInt(output_count as u64).len() + + output_size + + // lock_time + 4; + let weight = if inputs_with_witnesses == 0 { + non_input_size * 4 + input_weight + } else { + non_input_size * 4 + input_weight + input_count - inputs_with_witnesses + 2 + }; + Weight::from_wu(weight as u64) +} + +/// Weight prediction of an individual input. +/// +/// This helper type collects information about an input to be used in [`predict_weight`] function. +/// It can only be created using the [`new`](InputWeightPrediction::new) function. +#[derive(Copy, Clone, Debug)] +pub struct InputWeightPrediction { + script_size: usize, + witness_size: usize, +} + +impl InputWeightPrediction { + /// Computes the prediction for a single input. + pub fn new(input_script_len: usize, witness_element_lengths: I) -> Self + where I: IntoIterator, + { + let (count, total_size) = witness_element_lengths.into_iter() + .fold((0, 0), |(count, total_size), elem_len| { + let elem_size = elem_len + VarInt(elem_len as u64).len(); + (count + 1, total_size + elem_size) + }); + let witness_size = if count > 0 { + total_size + VarInt(count as u64).len() + } else { + 0 + }; + let script_size = input_script_len + VarInt(input_script_len as u64).len(); + + InputWeightPrediction { + script_size, + witness_size, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1260,7 +1378,7 @@ mod tests { "a6eab3c14ab5272a58a5ba91505ba1a4b6d7a3a9fcbd187b6cd99a7b6d548cb7".to_string()); assert_eq!(format!("{:x}", realtx.wtxid()), "a6eab3c14ab5272a58a5ba91505ba1a4b6d7a3a9fcbd187b6cd99a7b6d548cb7".to_string()); - assert_eq!(realtx.weight().to_wu() as usize, tx_bytes.len()*WITNESS_SCALE_FACTOR); + assert_eq!(realtx.check_weight().to_wu() as usize, tx_bytes.len()*WITNESS_SCALE_FACTOR); assert_eq!(realtx.size(), tx_bytes.len()); assert_eq!(realtx.vsize(), tx_bytes.len()); assert_eq!(realtx.strippedsize(), tx_bytes.len()); @@ -1295,7 +1413,7 @@ mod tests { assert_eq!(format!("{:x}", realtx.wtxid()), "80b7d8a82d5d5bf92905b06f2014dd699e03837ca172e3a59d51426ebbe3e7f5".to_string()); const EXPECTED_WEIGHT: Weight = Weight::from_wu(442); - assert_eq!(realtx.weight(), EXPECTED_WEIGHT); + assert_eq!(realtx.check_weight(), EXPECTED_WEIGHT); assert_eq!(realtx.size(), tx_bytes.len()); assert_eq!(realtx.vsize(), 111); // Since @@ -1308,7 +1426,7 @@ mod tests { // Construct a transaction without the witness data. let mut tx_without_witness = realtx; tx_without_witness.input.iter_mut().for_each(|input| input.witness.clear()); - assert_eq!(tx_without_witness.weight().to_wu() as usize, expected_strippedsize*WITNESS_SCALE_FACTOR); + assert_eq!(tx_without_witness.check_weight().to_wu() as usize, expected_strippedsize*WITNESS_SCALE_FACTOR); assert_eq!(tx_without_witness.size(), expected_strippedsize); assert_eq!(tx_without_witness.vsize(), expected_strippedsize); assert_eq!(tx_without_witness.strippedsize(), expected_strippedsize); @@ -1413,7 +1531,7 @@ mod tests { assert_eq!(format!("{:x}", tx.wtxid()), "d6ac4a5e61657c4c604dcde855a1db74ec6b3e54f32695d72c5e11c7761ea1b4"); assert_eq!(format!("{:x}", tx.txid()), "9652aa62b0e748caeec40c4cb7bc17c6792435cc3dfe447dd1ca24f912a1c6ec"); - assert_eq!(tx.weight(), Weight::from_wu(2718)); + assert_eq!(tx.check_weight(), Weight::from_wu(2718)); // non-segwit tx from my mempool let tx_bytes = hex!( @@ -1445,7 +1563,7 @@ mod tests { fn test_segwit_tx_decode() { let tx_bytes = hex!("010000000001010000000000000000000000000000000000000000000000000000000000000000ffffffff3603da1b0e00045503bd5704c7dd8a0d0ced13bb5785010800000000000a636b706f6f6c122f4e696e6a61506f6f6c2f5345475749542fffffffff02b4e5a212000000001976a914876fbb82ec05caa6af7a3b5e5a983aae6c6cc6d688ac0000000000000000266a24aa21a9edf91c46b49eb8a29089980f02ee6b57e7d63d33b18b4fddac2bcd7db2a39837040120000000000000000000000000000000000000000000000000000000000000000000000000"); let tx: Transaction = deserialize(&tx_bytes).unwrap(); - assert_eq!(tx.weight(), Weight::from_wu(780)); + assert_eq!(tx.check_weight(), Weight::from_wu(780)); serde_round_trip!(tx); let consensus_encoded = serialize(&tx); @@ -1610,7 +1728,7 @@ mod tests { input: vec![], output: vec![], } - .weight(); + .check_weight(); for (is_segwit, tx) in &txs { let txin_weight = if *is_segwit { @@ -1626,7 +1744,7 @@ mod tests { + segwit_marker_weight + tx.input.iter().fold(0, |sum, i| sum + txin_weight(i)) + tx.output.iter().fold(0, |sum, o| sum + o.weight()); - assert_eq!(calculated_size, tx.weight().to_wu() as usize); + assert_eq!(calculated_size, tx.check_weight().to_wu() as usize); } } }