diff --git a/src/ecdh.rs b/src/ecdh.rs index 4b7b95a..166907b 100644 --- a/src/ecdh.rs +++ b/src/ecdh.rs @@ -89,24 +89,26 @@ impl Deref for SharedSecret { } +unsafe fn callback_logic(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int + where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { + let callback: &mut F = &mut *(data as *mut F); + + let mut x_arr = [0; 32]; + let mut y_arr = [0; 32]; + ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32); + ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32); + + let secret = callback(x_arr, y_arr); + ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len()); + + secret.len() as c_int +} + #[cfg(feature = "std")] -unsafe extern "C" fn hash_callback(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int +unsafe extern "C" fn hash_callback_catch_unwind(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { - use std::panic::catch_unwind; - let res = catch_unwind(|| { - let callback: &mut F = &mut *(data as *mut F); - - let mut x_arr = [0; 32]; - let mut y_arr = [0; 32]; - ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32); - ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32); - - let secret = callback(x_arr, y_arr); - ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len()); - - secret.len() as c_int - }); + let res = ::std::panic::catch_unwind(||callback_logic::(output, x, y, data)); if let Ok(len) = res { len } else { @@ -114,6 +116,11 @@ unsafe extern "C" fn hash_callback(output: *mut c_uchar, x: *const c_uchar, y } } +unsafe extern "C" fn hash_callback_unsafe(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int + where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { + callback_logic::(output, x, y, data) +} + impl SharedSecret { /// Creates a new shared secret from a pubkey and secret key @@ -135,6 +142,29 @@ impl SharedSecret { ss } + fn new_with_callback_internal(point: &PublicKey, scalar: &SecretKey, mut closure: F, callback: ffi::EcdhHashFn) -> Result + where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { + let mut ss = SharedSecret::empty(); + + let res = unsafe { + ffi::secp256k1_ecdh( + ffi::secp256k1_context_no_precomp, + ss.get_data_mut_ptr(), + point.as_ptr(), + scalar.as_ptr(), + callback, + &mut closure as *mut F as *mut c_void, + ) + }; + if res == -1 { + return Err(Error::CallbackPanicked); + } + debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users. + ss.set_len(res as usize); + Ok(ss) + + } + /// Creates a new shared secret from a pubkey and secret key with applied custom hash function /// # Examples /// ``` @@ -153,28 +183,42 @@ impl SharedSecret { /// /// ``` #[cfg(feature = "std")] - pub fn new_with_hash(point: &PublicKey, scalar: &SecretKey, mut hash_function: F) -> Result - where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret - { - let mut ss = SharedSecret::empty(); - let hashfp: ffi::EcdhHashFn = hash_callback::; + pub fn new_with_hash(point: &PublicKey, scalar: &SecretKey, hash_function: F) -> Result + where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { + Self::new_with_callback_internal(point, scalar, hash_function, hash_callback_catch_unwind::) + } - let res = unsafe { - ffi::secp256k1_ecdh( - ffi::secp256k1_context_no_precomp, - ss.get_data_mut_ptr(), - point.as_ptr(), - scalar.as_ptr(), - hashfp, - &mut hash_function as *mut F as *mut c_void, - ) - }; - if res == -1 { - return Err(Error::CallbackPanicked); - } - debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users. - ss.set_len(res as usize); - Ok(ss) + /// Creates a new shared secret from a pubkey and secret key with applied custom hash function + /// Note that this function is the same as [`new_with_hash`] + /// + /// # Safety + /// The function doesn't wrap the callback with [`catch_unwind`] + /// so if the callback panics it will panic through an FFI boundray which is [`Undefined Behavior`] + /// If possible you should use [`new_with_hash`] which does wrap the callback with [`catch_unwind`] so is safe to use. + /// + /// [`catch_unwind`]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html + /// [`Undefined Behavior`]: https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-panics + /// [`new_with_hash`]: #method.new_with_hash + /// # Examples + /// ``` + /// # use secp256k1::ecdh::SharedSecret; + /// # use secp256k1::{Secp256k1, PublicKey, SecretKey}; + /// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]} + /// # let secp = Secp256k1::signing_only(); + /// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap(); + /// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap(); + /// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2); + // + /// let secret = unsafe { SharedSecret::new_with_hash_no_panic(&public_key, &secret_key, |x,y| { + /// let hash: [u8; 32] = sha2(&x,&y); + /// hash.into() + /// })}; + /// + /// + /// ``` + pub unsafe fn new_with_hash_no_panic(point: &PublicKey, scalar: &SecretKey, hash_function: F) -> Result + where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { + Self::new_with_callback_internal(point, scalar, hash_function, hash_callback_unsafe::) } } @@ -223,7 +267,13 @@ mod tests { y_out = y; expect_result.into() }).unwrap(); + let result_unsafe = unsafe {SharedSecret::new_with_hash_no_panic(&pk1, &sk1, | x, y | { + x_out = x; + y_out = y; + expect_result.into() + }).unwrap()}; assert_eq!(&expect_result[..], &result[..]); + assert_eq!(result, result_unsafe); assert_ne!(x_out, [0u8; 32]); assert_ne!(y_out, [0u8; 32]); }