diff --git a/src/ecdh.rs b/src/ecdh.rs index 3f97990..4b7b95a 100644 --- a/src/ecdh.rs +++ b/src/ecdh.rs @@ -22,6 +22,7 @@ use core::ops::{FnMut, Deref}; use key::{SecretKey, PublicKey}; use ffi::{self, CPtr}; use secp256k1_sys::types::{c_int, c_uchar, c_void}; +use Error; /// A tag used for recovering the public key from a compact signature #[derive(Copy, Clone)] @@ -63,6 +64,7 @@ impl SharedSecret { /// Set the length of the object. pub(crate) fn set_len(&mut self, len: usize) { + debug_assert!(len <= self.data.len()); self.len = len; } } @@ -87,19 +89,29 @@ impl Deref for SharedSecret { } +#[cfg(feature = "std")] unsafe extern "C" fn hash_callback(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { - let callback: &mut F = &mut *(data as *mut F); - let mut x_arr = [0; 32]; - let mut y_arr = [0; 32]; - ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32); - ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32); + use std::panic::catch_unwind; + let res = catch_unwind(|| { + let callback: &mut F = &mut *(data as *mut F); - let secret = callback(x_arr, y_arr); - ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len()); + let mut x_arr = [0; 32]; + let mut y_arr = [0; 32]; + ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32); + ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32); - secret.len() as c_int + let secret = callback(x_arr, y_arr); + ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len()); + + secret.len() as c_int + }); + if let Ok(len) = res { + len + } else { + -1 + } } @@ -140,7 +152,8 @@ impl SharedSecret { /// }); /// /// ``` - pub fn new_with_hash(point: &PublicKey, scalar: &SecretKey, mut hash_function: F) -> SharedSecret + #[cfg(feature = "std")] + pub fn new_with_hash(point: &PublicKey, scalar: &SecretKey, mut hash_function: F) -> Result where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { let mut ss = SharedSecret::empty(); @@ -156,9 +169,12 @@ impl SharedSecret { &mut hash_function as *mut F as *mut c_void, ) }; + if res == -1 { + return Err(Error::CallbackPanicked); + } debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users. ss.set_len(res as usize); - ss + Ok(ss) } } @@ -167,6 +183,7 @@ mod tests { use rand::thread_rng; use super::SharedSecret; use super::super::Secp256k1; + use Error; #[test] fn ecdh() { @@ -187,9 +204,9 @@ mod tests { let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); let (sk2, pk2) = s.generate_keypair(&mut thread_rng()); - let sec1 = SharedSecret::new_with_hash(&pk1, &sk2, |x,_| x.into()); - let sec2 = SharedSecret::new_with_hash(&pk2, &sk1, |x,_| x.into()); - let sec_odd = SharedSecret::new_with_hash(&pk1, &sk1, |x,_| x.into()); + let sec1 = SharedSecret::new_with_hash(&pk1, &sk2, |x,_| x.into()).unwrap(); + let sec2 = SharedSecret::new_with_hash(&pk2, &sk1, |x,_| x.into()).unwrap(); + let sec_odd = SharedSecret::new_with_hash(&pk1, &sk1, |x,_| x.into()).unwrap(); assert_eq!(sec1, sec2); assert_ne!(sec_odd, sec2); } @@ -205,11 +222,23 @@ mod tests { x_out = x; y_out = y; expect_result.into() - }); + }).unwrap(); assert_eq!(&expect_result[..], &result[..]); assert_ne!(x_out, [0u8; 32]); assert_ne!(y_out, [0u8; 32]); } + + #[test] + fn ecdh_with_hash_callback_panic() { + let s = Secp256k1::signing_only(); + let (sk1, pk1) = s.generate_keypair(&mut thread_rng()); + let mut res = [0u8; 48]; + let result = SharedSecret::new_with_hash(&pk1, &sk1, | x, _ | { + res.copy_from_slice(&x); // res.len() != x.len(). this will panic. + res.into() + }); + assert_eq!(result, Err(Error::CallbackPanicked)); + } } #[cfg(all(test, feature = "unstable"))] diff --git a/src/lib.rs b/src/lib.rs index 74254f2..8b61372 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -496,7 +496,8 @@ pub enum Error { InvalidTweak, /// Didn't pass enough memory to context creation with preallocated memory NotEnoughMemory, - + /// The callback has panicked. + CallbackPanicked, } impl Error { @@ -510,6 +511,7 @@ impl Error { Error::InvalidRecoveryId => "secp: bad recovery id", Error::InvalidTweak => "secp: bad tweak", Error::NotEnoughMemory => "secp: not enough memory allocated", + Error::CallbackPanicked => "secp: a callback passed has panicked", } } }