keyforkd: allow performing multiple requests on the same socket

This commit is contained in:
Ryan Heywood 2024-02-12 02:36:54 -05:00
parent a24a0166cc
commit f1c24fb33e
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
2 changed files with 45 additions and 47 deletions

View File

@ -81,9 +81,6 @@ fn ed25519_test_suite() {
continue; continue;
} }
for i in 2..chain_len { for i in 2..chain_len {
// FIXME: Keyfork will only allow one request per session
let socket = UnixStream::connect(&socket_path).unwrap();
let mut client = Client::new(socket);
// Consistency check: ensure the server and the client can each derive the same // Consistency check: ensure the server and the client can each derive the same
// key using an XPrv, for all but the last XPrv, which is verified after this // key using an XPrv, for all but the last XPrv, which is verified after this
let path = DerivationPath::from_str(test.chain).unwrap(); let path = DerivationPath::from_str(test.chain).unwrap();
@ -95,8 +92,6 @@ fn ed25519_test_suite() {
.fold(DerivationPath::default(), |p, i| p.chain_push(i.clone())); .fold(DerivationPath::default(), |p, i| p.chain_push(i.clone()));
let xprv = dbg!(client.request_xprv::<SigningKey>(&left_path)).unwrap(); let xprv = dbg!(client.request_xprv::<SigningKey>(&left_path)).unwrap();
let derived_xprv = xprv.derive_path(&right_path).unwrap(); let derived_xprv = xprv.derive_path(&right_path).unwrap();
let socket = UnixStream::connect(&socket_path).unwrap();
let mut client = Client::new(socket);
let keyforkd_xprv = client.request_xprv::<SigningKey>(&path).unwrap(); let keyforkd_xprv = client.request_xprv::<SigningKey>(&path).unwrap();
assert_eq!( assert_eq!(
derived_xprv, keyforkd_xprv, derived_xprv, keyforkd_xprv,

View File

@ -68,55 +68,58 @@ impl UnixServer {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!("new socket connected"); debug!("new socket connected");
tokio::spawn(async move { tokio::spawn(async move {
let bytes = match try_decode_from(&mut socket).await { // Process requests until an error occurs or a client disconnects
Ok(bytes) => bytes, loop {
Err(e) => { let bytes = match try_decode_from(&mut socket).await {
#[cfg(feature = "tracing")] Ok(bytes) => bytes,
debug!(%e, "Error reading DerivationPath from socket"); Err(e) => {
let content = e.to_string().bytes().collect::<Vec<_>>(); #[cfg(feature = "tracing")]
let result = try_encode_to(&content[..], &mut socket).await; debug!(%e, "Error reading DerivationPath from socket");
#[cfg(feature = "tracing")] let content = e.to_string().bytes().collect::<Vec<_>>();
if let Err(error) = result { let result = try_encode_to(&content[..], &mut socket).await;
debug!(%error, "Error sending error to client"); #[cfg(feature = "tracing")]
if let Err(error) = result {
debug!(%error, "Error sending error to client");
}
return;
} }
return; };
}
};
let app = match app.ready().await { let app = match app.ready().await {
Ok(app) => app, Ok(app) => app,
Err(e) => { Err(e) => {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!(%e, "Could not poll ready"); debug!(%e, "Could not poll ready");
let content = e.to_string().bytes().collect::<Vec<_>>(); let content = e.to_string().bytes().collect::<Vec<_>>();
let result = try_encode_to(&content[..], &mut socket).await; let result = try_encode_to(&content[..], &mut socket).await;
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
if let Err(error) = result { if let Err(error) = result {
debug!(%error, "Error sending error to client"); debug!(%error, "Error sending error to client");
}
return;
} }
return; };
}
};
let response = match app.call(bytes.into()).await { let response = match app.call(bytes.into()).await {
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!(%e, "Error reading DerivationPath from socket"); debug!(%e, "Error reading DerivationPath from socket");
let content = e.to_string().bytes().collect::<Vec<_>>(); let content = e.to_string().bytes().collect::<Vec<_>>();
let result = try_encode_to(&content[..], &mut socket).await; let result = try_encode_to(&content[..], &mut socket).await;
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
if let Err(error) = result { if let Err(error) = result {
debug!(%error, "Error sending error to client"); debug!(%error, "Error sending error to client");
}
return;
} }
return;
} }
} .into();
.into();
if let Err(e) = try_encode_to(&response[..], &mut socket).await { if let Err(e) = try_encode_to(&response[..], &mut socket).await {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!(%e, "Error sending response to client"); debug!(%e, "Error sending response to client");
}
} }
}); });
} }