diff --git a/Cargo.lock b/Cargo.lock index f830002..66e9d0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -244,7 +244,7 @@ dependencies = [ "futures-lite 2.2.0", "parking", "polling 3.3.2", - "rustix 0.38.30", + "rustix 0.38.31", "slab", "tracing", "windows-sys 0.52.0", @@ -1610,7 +1610,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ "hermit-abi", - "rustix 0.38.30", + "rustix 0.38.31", "windows-sys 0.52.0", ] @@ -1870,6 +1870,7 @@ dependencies = [ "keyfork-slip10-test-data", "keyforkd-models", "serde", + "tempfile", "thiserror", "tokio", "tower", @@ -2533,7 +2534,7 @@ dependencies = [ "cfg-if", "concurrent-queue", "pin-project-lite", - "rustix 0.38.30", + "rustix 0.38.31", "tracing", "windows-sys 0.52.0", ] @@ -2806,9 +2807,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.30" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ "bitflags 2.4.2", "errno", @@ -3184,14 +3185,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.9.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ "cfg-if", "fastrand 2.0.1", - "redox_syscall", - "rustix 0.38.30", + "rustix 0.38.31", "windows-sys 0.52.0", ] @@ -3212,7 +3212,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21bebf2b7c9e0a515f6e0f8c51dc0f8e4696391e6f1ff30379559f8365fb0df7" dependencies = [ - "rustix 0.38.30", + "rustix 0.38.31", "windows-sys 0.48.0", ] @@ -3606,7 +3606,7 @@ dependencies = [ "either", "home", "once_cell", - "rustix 0.38.30", + "rustix 0.38.31", ] [[package]] diff --git a/crates/daemon/keyforkd-client/src/tests.rs b/crates/daemon/keyforkd-client/src/tests.rs index d12f528..7dae50c 100644 --- a/crates/daemon/keyforkd-client/src/tests.rs +++ b/crates/daemon/keyforkd-client/src/tests.rs @@ -1,100 +1,65 @@ use crate::Client; +use keyforkd::test_util::{run_test, Infallible}; use keyfork_derive_util::{request::*, DerivationPath}; use keyfork_slip10_test_data::test_data; -use std::sync::mpsc::channel; use std::{os::unix::net::UnixStream, str::FromStr}; -use tokio::runtime::Builder; #[test] -fn secp256k1() { +fn secp256k1_test_suite() { let tests = test_data() .unwrap() .remove(&"secp256k1".to_string()) .unwrap(); - // note: since client is non async, can't be single threaded - let rt = Builder::new_multi_thread().enable_io().build().unwrap(); - let tempdir = tempfile::tempdir().unwrap(); - for (i, per_seed) in tests.into_iter().enumerate() { - let mut socket_name = i.to_string(); - socket_name.push_str("-keyforkd.sock"); - let socket_path = tempdir.path().join(socket_name); - let (tx, rx) = channel(); - let handle = rt.spawn({ - let socket_path = socket_path.clone(); - async move { - let seed = per_seed.seed.clone(); - let mut server = keyforkd::UnixServer::bind(&socket_path).unwrap(); - tx.send(()).unwrap(); - let service = keyforkd::ServiceBuilder::new() - .layer(keyforkd::middleware::BincodeLayer::new()) - .service(keyforkd::Keyforkd::new(seed)); - server.run(service).await.unwrap(); + for seed_test in tests { + let seed = seed_test.seed; + run_test(&seed, move |socket_path| { + for test in seed_test.tests { + let socket = UnixStream::connect(&socket_path).unwrap(); + let mut client = Client::new(socket); + let chain = DerivationPath::from_str(test.chain).unwrap(); + if chain.len() < 2 { + continue; + } + let req = DerivationRequest::new( + DerivationAlgorithm::Secp256k1, + &DerivationPath::from_str(test.chain).unwrap(), + ); + let response = + DerivationResponse::try_from(client.request(&req.into()).unwrap()).unwrap(); + assert_eq!(&response.data, test.private_key.as_slice()); } - }); - rx.recv().unwrap(); - - for test in &per_seed.tests { - let socket = UnixStream::connect(&socket_path).unwrap(); - let mut client = Client::new(socket); - let chain = DerivationPath::from_str(test.chain).unwrap(); - if chain.len() < 2 { - continue; - } - let req = DerivationRequest::new( - DerivationAlgorithm::Secp256k1, - &DerivationPath::from_str(test.chain).unwrap(), - ); - let response = - DerivationResponse::try_from(client.request(&req.into()).unwrap()).unwrap(); - assert_eq!(&response.data, test.private_key.as_slice()); - } - - handle.abort(); + Infallible::Ok(()) + }).unwrap(); } } #[test] -fn ed25519() { - let tests = test_data().unwrap().remove(&"ed25519".to_string()).unwrap(); +fn ed25519_test_suite() { + let tests = test_data() + .unwrap() + .remove(&"ed25519".to_string()) + .unwrap(); - let rt = Builder::new_multi_thread().enable_io().build().unwrap(); - let tempdir = tempfile::tempdir().unwrap(); - for (i, per_seed) in tests.into_iter().enumerate() { - let mut socket_name = i.to_string(); - socket_name.push_str("-keyforkd.sock"); - let socket_path = tempdir.path().join(socket_name); - let (tx, rx) = channel(); - let handle = rt.spawn({ - let socket_path = socket_path.clone(); - async move { - let seed = per_seed.seed.clone(); - let mut server = keyforkd::UnixServer::bind(&socket_path).unwrap(); - tx.send(()).unwrap(); - let service = keyforkd::ServiceBuilder::new() - .layer(keyforkd::middleware::BincodeLayer::new()) - .service(keyforkd::Keyforkd::new(seed)); - server.run(service).await.unwrap(); + for seed_test in tests { + let seed = seed_test.seed; + run_test(&seed, move |socket_path| { + for test in seed_test.tests { + let socket = UnixStream::connect(&socket_path).unwrap(); + let mut client = Client::new(socket); + let chain = DerivationPath::from_str(test.chain).unwrap(); + if chain.len() < 2 { + continue; + } + let req = DerivationRequest::new( + DerivationAlgorithm::Ed25519, + &DerivationPath::from_str(test.chain).unwrap(), + ); + let response = + DerivationResponse::try_from(client.request(&req.into()).unwrap()).unwrap(); + assert_eq!(&response.data, test.private_key.as_slice()); } - }); - rx.recv().unwrap(); - - for test in &per_seed.tests { - let socket = UnixStream::connect(&socket_path).unwrap(); - let mut client = Client::new(socket); - let chain = DerivationPath::from_str(test.chain).unwrap(); - if chain.len() < 2 { - continue; - } - let req = DerivationRequest::new( - DerivationAlgorithm::Ed25519, - &DerivationPath::from_str(test.chain).unwrap(), - ); - let response = - DerivationResponse::try_from(client.request(&req.into()).unwrap()).unwrap(); - assert_eq!(&response.data, test.private_key.as_slice()); - } - - handle.abort(); + Infallible::Ok(()) + }).unwrap(); } } diff --git a/crates/daemon/keyforkd/Cargo.toml b/crates/daemon/keyforkd/Cargo.toml index 158f810..7813fd0 100644 --- a/crates/daemon/keyforkd/Cargo.toml +++ b/crates/daemon/keyforkd/Cargo.toml @@ -31,6 +31,7 @@ tower = { version = "0.4.13", features = ["tokio", "util"] } # Personally audited thiserror = "1.0.47" serde = { version = "1.0.186", features = ["derive"] } +tempfile = { version = "3.10.0", default-features = false } [dev-dependencies] hex-literal = "0.4.1" diff --git a/crates/daemon/keyforkd/src/lib.rs b/crates/daemon/keyforkd/src/lib.rs index 4ea3a45..2fa0b20 100644 --- a/crates/daemon/keyforkd/src/lib.rs +++ b/crates/daemon/keyforkd/src/lib.rs @@ -30,6 +30,8 @@ pub use error::Keyforkd as KeyforkdError; pub use server::UnixServer; pub use service::Keyforkd; +pub mod test_util; + /// Set up a Tracing subscriber, defaulting to debug mode. #[cfg(feature = "tracing")] pub fn setup_registry() { diff --git a/crates/daemon/keyforkd/src/test_util.rs b/crates/daemon/keyforkd/src/test_util.rs new file mode 100644 index 0000000..5e7fdb7 --- /dev/null +++ b/crates/daemon/keyforkd/src/test_util.rs @@ -0,0 +1,85 @@ +//! # Keyforkd Test Utilities +//! +//! This module adds a helper to set up a Tokio runtime, start a Tokio runtime with a given seed, +//! start a Keyfork server on that runtime, and run a given test closure. + +use crate::{middleware, Keyforkd, ServiceBuilder, UnixServer}; + +use tokio::runtime::Builder; + +#[derive(Debug, thiserror::Error)] +#[error("This error can never be instantiated")] +#[doc(hidden)] +pub struct InfallibleError { + protected: (), +} + +/// An infallible result. This type can be used to represent a function that should never error. +/// +/// ```rust +/// use keyforkd::test_util::Infallible; +/// let closure = || { +/// Infallible::Ok(()) +/// }; +/// assert!(closure().is_ok()); +/// ``` +pub type Infallible = std::result::Result; + +/// Run a test making use of a Keyforkd server. The path to the socket of the Keyforkd server is +/// provided as the only argument to the closure. The closure is expected to return a Result; the +/// Error field of the Result may be an error returned by a test. +/// +/// # Panics +/// +/// The function is not expected to run in production; therefore, the function plays "fast and +/// loose" wih the usage of [`Result::expect`]. In normal usage, these should never be an issue. +#[allow(clippy::missing_errors_doc)] +pub fn run_test(seed: &[u8], closure: F) -> Result<(), E> +where + F: FnOnce(&std::path::Path) -> Result<(), E> + Send + 'static, + E: Send + 'static, +{ + let rt = Builder::new_multi_thread() + .worker_threads(2) + .enable_io() + .build() + .expect("tokio threaded IO runtime"); + let socket_dir = tempfile::tempdir().expect("can't create tempdir"); + let socket_path = socket_dir.path().join("keyforkd.sock"); + rt.block_on(async move { + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let server_handle = tokio::spawn({ + let socket_path = socket_path.clone(); + let seed = seed.to_vec(); + async move { + let mut server = UnixServer::bind(&socket_path).expect("can't bind unix socket"); + tx.send(()).await.expect("couldn't send server start signal"); + let service = ServiceBuilder::new() + .layer(middleware::BincodeLayer::new()) + .service(Keyforkd::new(seed.to_vec())); + server.run(service).await.unwrap(); + } + }); + + rx.recv() + .await + .expect("can't receive server start signal from channel"); + let test_handle = tokio::task::spawn_blocking(move || closure(&socket_path)); + + let result = test_handle.await; + server_handle.abort(); + result + }) + .expect("runtime could not join all threads") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_run_test() { + let seed = b"beefbeef"; + run_test(seed, |_path| Infallible::Ok(())).expect("infallible"); + } +}