Add an unsafe variant of new_with_has called new_with_hash_no_panic

This commit is contained in:
Elichai Turkel 2019-11-28 00:21:51 +02:00
parent 124c1f3c7c
commit 5619f2a5df
No known key found for this signature in database
GPG Key ID: 9383CDE9E8E66A7F
1 changed files with 86 additions and 36 deletions

View File

@ -89,24 +89,26 @@ impl Deref for SharedSecret {
} }
unsafe fn callback_logic<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
let callback: &mut F = &mut *(data as *mut F);
let mut x_arr = [0; 32];
let mut y_arr = [0; 32];
ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32);
ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32);
let secret = callback(x_arr, y_arr);
ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len());
secret.len() as c_int
}
#[cfg(feature = "std")] #[cfg(feature = "std")]
unsafe extern "C" fn hash_callback<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int unsafe extern "C" fn hash_callback_catch_unwind<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret { where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
use std::panic::catch_unwind; let res = ::std::panic::catch_unwind(||callback_logic::<F>(output, x, y, data));
let res = catch_unwind(|| {
let callback: &mut F = &mut *(data as *mut F);
let mut x_arr = [0; 32];
let mut y_arr = [0; 32];
ptr::copy_nonoverlapping(x, x_arr.as_mut_ptr(), 32);
ptr::copy_nonoverlapping(y, y_arr.as_mut_ptr(), 32);
let secret = callback(x_arr, y_arr);
ptr::copy_nonoverlapping(secret.as_ptr(), output as *mut u8, secret.len());
secret.len() as c_int
});
if let Ok(len) = res { if let Ok(len) = res {
len len
} else { } else {
@ -114,6 +116,11 @@ unsafe extern "C" fn hash_callback<F>(output: *mut c_uchar, x: *const c_uchar, y
} }
} }
unsafe extern "C" fn hash_callback_unsafe<F>(output: *mut c_uchar, x: *const c_uchar, y: *const c_uchar, data: *mut c_void) -> c_int
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
callback_logic::<F>(output, x, y, data)
}
impl SharedSecret { impl SharedSecret {
/// Creates a new shared secret from a pubkey and secret key /// Creates a new shared secret from a pubkey and secret key
@ -135,6 +142,29 @@ impl SharedSecret {
ss ss
} }
fn new_with_callback_internal<F>(point: &PublicKey, scalar: &SecretKey, mut closure: F, callback: ffi::EcdhHashFn) -> Result<SharedSecret, Error>
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
let mut ss = SharedSecret::empty();
let res = unsafe {
ffi::secp256k1_ecdh(
ffi::secp256k1_context_no_precomp,
ss.get_data_mut_ptr(),
point.as_ptr(),
scalar.as_ptr(),
callback,
&mut closure as *mut F as *mut c_void,
)
};
if res == -1 {
return Err(Error::CallbackPanicked);
}
debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users.
ss.set_len(res as usize);
Ok(ss)
}
/// Creates a new shared secret from a pubkey and secret key with applied custom hash function /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
/// # Examples /// # Examples
/// ``` /// ```
@ -153,28 +183,42 @@ impl SharedSecret {
/// ///
/// ``` /// ```
#[cfg(feature = "std")] #[cfg(feature = "std")]
pub fn new_with_hash<F>(point: &PublicKey, scalar: &SecretKey, mut hash_function: F) -> Result<SharedSecret, Error> pub fn new_with_hash<F>(point: &PublicKey, scalar: &SecretKey, hash_function: F) -> Result<SharedSecret, Error>
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
{ Self::new_with_callback_internal(point, scalar, hash_function, hash_callback_catch_unwind::<F>)
let mut ss = SharedSecret::empty(); }
let hashfp: ffi::EcdhHashFn = hash_callback::<F>;
let res = unsafe { /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
ffi::secp256k1_ecdh( /// Note that this function is the same as [`new_with_hash`]
ffi::secp256k1_context_no_precomp, ///
ss.get_data_mut_ptr(), /// # Safety
point.as_ptr(), /// The function doesn't wrap the callback with [`catch_unwind`]
scalar.as_ptr(), /// so if the callback panics it will panic through an FFI boundray which is [`Undefined Behavior`]
hashfp, /// If possible you should use [`new_with_hash`] which does wrap the callback with [`catch_unwind`] so is safe to use.
&mut hash_function as *mut F as *mut c_void, ///
) /// [`catch_unwind`]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html
}; /// [`Undefined Behavior`]: https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-panics
if res == -1 { /// [`new_with_hash`]: #method.new_with_hash
return Err(Error::CallbackPanicked); /// # Examples
} /// ```
debug_assert!(res >= 16); // 128 bit is the minimum for a secure hash function and the minimum we let users. /// # use secp256k1::ecdh::SharedSecret;
ss.set_len(res as usize); /// # use secp256k1::{Secp256k1, PublicKey, SecretKey};
Ok(ss) /// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]}
/// # let secp = Secp256k1::signing_only();
/// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap();
/// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap();
/// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2);
//
/// let secret = unsafe { SharedSecret::new_with_hash_no_panic(&public_key, &secret_key, |x,y| {
/// let hash: [u8; 32] = sha2(&x,&y);
/// hash.into()
/// })};
///
///
/// ```
pub unsafe fn new_with_hash_no_panic<F>(point: &PublicKey, scalar: &SecretKey, hash_function: F) -> Result<SharedSecret, Error>
where F: FnMut([u8; 32], [u8; 32]) -> SharedSecret {
Self::new_with_callback_internal(point, scalar, hash_function, hash_callback_unsafe::<F>)
} }
} }
@ -223,7 +267,13 @@ mod tests {
y_out = y; y_out = y;
expect_result.into() expect_result.into()
}).unwrap(); }).unwrap();
let result_unsafe = unsafe {SharedSecret::new_with_hash_no_panic(&pk1, &sk1, | x, y | {
x_out = x;
y_out = y;
expect_result.into()
}).unwrap()};
assert_eq!(&expect_result[..], &result[..]); assert_eq!(&expect_result[..], &result[..]);
assert_eq!(result, result_unsafe);
assert_ne!(x_out, [0u8; 32]); assert_ne!(x_out, [0u8; 32]);
assert_ne!(y_out, [0u8; 32]); assert_ne!(y_out, [0u8; 32]);
} }