diff --git a/crates/daemon/keyforkd-client/src/tests.rs b/crates/daemon/keyforkd-client/src/tests.rs index e551b0d..5557d79 100644 --- a/crates/daemon/keyforkd-client/src/tests.rs +++ b/crates/daemon/keyforkd-client/src/tests.rs @@ -81,9 +81,6 @@ fn ed25519_test_suite() { continue; } for i in 2..chain_len { - // FIXME: Keyfork will only allow one request per session - let socket = UnixStream::connect(&socket_path).unwrap(); - let mut client = Client::new(socket); // Consistency check: ensure the server and the client can each derive the same // key using an XPrv, for all but the last XPrv, which is verified after this let path = DerivationPath::from_str(test.chain).unwrap(); @@ -95,8 +92,6 @@ fn ed25519_test_suite() { .fold(DerivationPath::default(), |p, i| p.chain_push(i.clone())); let xprv = dbg!(client.request_xprv::(&left_path)).unwrap(); let derived_xprv = xprv.derive_path(&right_path).unwrap(); - let socket = UnixStream::connect(&socket_path).unwrap(); - let mut client = Client::new(socket); let keyforkd_xprv = client.request_xprv::(&path).unwrap(); assert_eq!( derived_xprv, keyforkd_xprv, diff --git a/crates/daemon/keyforkd/src/server.rs b/crates/daemon/keyforkd/src/server.rs index 42e6e5f..3bfe8ab 100644 --- a/crates/daemon/keyforkd/src/server.rs +++ b/crates/daemon/keyforkd/src/server.rs @@ -68,55 +68,58 @@ impl UnixServer { #[cfg(feature = "tracing")] debug!("new socket connected"); tokio::spawn(async move { - let bytes = match try_decode_from(&mut socket).await { - Ok(bytes) => bytes, - Err(e) => { - #[cfg(feature = "tracing")] - debug!(%e, "Error reading DerivationPath from socket"); - let content = e.to_string().bytes().collect::>(); - let result = try_encode_to(&content[..], &mut socket).await; - #[cfg(feature = "tracing")] - if let Err(error) = result { - debug!(%error, "Error sending error to client"); + // Process requests until an error occurs or a client disconnects + loop { + let bytes = match try_decode_from(&mut socket).await { + Ok(bytes) => bytes, + Err(e) => { + #[cfg(feature = "tracing")] + debug!(%e, "Error reading DerivationPath from socket"); + let content = e.to_string().bytes().collect::>(); + let result = try_encode_to(&content[..], &mut socket).await; + #[cfg(feature = "tracing")] + if let Err(error) = result { + debug!(%error, "Error sending error to client"); + } + return; } - return; - } - }; + }; - let app = match app.ready().await { - Ok(app) => app, - Err(e) => { - #[cfg(feature = "tracing")] - debug!(%e, "Could not poll ready"); - let content = e.to_string().bytes().collect::>(); - let result = try_encode_to(&content[..], &mut socket).await; - #[cfg(feature = "tracing")] - if let Err(error) = result { - debug!(%error, "Error sending error to client"); + let app = match app.ready().await { + Ok(app) => app, + Err(e) => { + #[cfg(feature = "tracing")] + debug!(%e, "Could not poll ready"); + let content = e.to_string().bytes().collect::>(); + let result = try_encode_to(&content[..], &mut socket).await; + #[cfg(feature = "tracing")] + if let Err(error) = result { + debug!(%error, "Error sending error to client"); + } + return; } - return; - } - }; + }; - let response = match app.call(bytes.into()).await { - Ok(response) => response, - Err(e) => { - #[cfg(feature = "tracing")] - debug!(%e, "Error reading DerivationPath from socket"); - let content = e.to_string().bytes().collect::>(); - let result = try_encode_to(&content[..], &mut socket).await; - #[cfg(feature = "tracing")] - if let Err(error) = result { - debug!(%error, "Error sending error to client"); + let response = match app.call(bytes.into()).await { + Ok(response) => response, + Err(e) => { + #[cfg(feature = "tracing")] + debug!(%e, "Error reading DerivationPath from socket"); + let content = e.to_string().bytes().collect::>(); + let result = try_encode_to(&content[..], &mut socket).await; + #[cfg(feature = "tracing")] + if let Err(error) = result { + debug!(%error, "Error sending error to client"); + } + return; } - return; } - } - .into(); + .into(); - if let Err(e) = try_encode_to(&response[..], &mut socket).await { - #[cfg(feature = "tracing")] - debug!(%e, "Error sending response to client"); + if let Err(e) = try_encode_to(&response[..], &mut socket).await { + #[cfg(feature = "tracing")] + debug!(%e, "Error sending response to client"); + } } }); }