keyforkd: appropriately handle or debug disconnects

This commit is contained in:
Ryan Heywood 2024-02-12 03:08:54 -05:00
parent f1c24fb33e
commit e441ef520f
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
1 changed files with 60 additions and 2 deletions

View File

@ -1,6 +1,9 @@
//! A UNIX socket server to run a Tower Service. //! A UNIX socket server to run a Tower Service.
use keyfork_frame::asyncext::{try_decode_from, try_encode_to}; use keyfork_frame::{
asyncext::{try_decode_from, try_encode_to},
DecodeError, EncodeError,
};
use std::{ use std::{
io::Error, io::Error,
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -17,6 +20,34 @@ pub struct UnixServer {
listener: UnixListener, listener: UnixListener,
} }
/// This feels like a hack, but this is a convenient way to use the same method to quickly verify
/// something across two different error types.
trait IsDisconnect {
fn is_disconnect(&self) -> bool;
}
impl IsDisconnect for DecodeError {
fn is_disconnect(&self) -> bool {
if let Self::Io(e) = self {
if let std::io::ErrorKind::UnexpectedEof = e.kind() {
return true;
}
}
false
}
}
impl IsDisconnect for EncodeError {
fn is_disconnect(&self) -> bool {
if let Self::Io(e) = self {
if let std::io::ErrorKind::UnexpectedEof = e.kind() {
return true;
}
}
false
}
}
impl UnixServer { impl UnixServer {
/// Bind a socket to the given `address` and create a [`UnixServer`]. This function also creates a ctrl_c handler to automatically clean up the socket file. /// Bind a socket to the given `address` and create a [`UnixServer`]. This function also creates a ctrl_c handler to automatically clean up the socket file.
/// ///
@ -68,11 +99,19 @@ impl UnixServer {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!("new socket connected"); debug!("new socket connected");
tokio::spawn(async move { tokio::spawn(async move {
let mut has_processed_request = false;
// Process requests until an error occurs or a client disconnects // Process requests until an error occurs or a client disconnects
loop { loop {
let bytes = match try_decode_from(&mut socket).await { let bytes = match try_decode_from(&mut socket).await {
Ok(bytes) => bytes, Ok(bytes) => bytes,
Err(e) => { Err(e) => {
if e.is_disconnect() {
#[cfg(feature = "tracing")]
if !has_processed_request {
debug!("client disconnected before sending any response");
}
return;
}
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!(%e, "Error reading DerivationPath from socket"); debug!(%e, "Error reading DerivationPath from socket");
let content = e.to_string().bytes().collect::<Vec<_>>(); let content = e.to_string().bytes().collect::<Vec<_>>();
@ -109,17 +148,36 @@ impl UnixServer {
let result = try_encode_to(&content[..], &mut socket).await; let result = try_encode_to(&content[..], &mut socket).await;
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
if let Err(error) = result { if let Err(error) = result {
if error.is_disconnect() {
#[cfg(feature = "tracing")]
if has_processed_request {
debug!("client disconnected while sending error frame");
}
return;
}
debug!(%error, "Error sending error to client"); debug!(%error, "Error sending error to client");
} }
return; has_processed_request = true;
// The error has been successfully sent, the client may perform
// another request.
continue;
} }
} }
.into(); .into();
if let Err(e) = try_encode_to(&response[..], &mut socket).await { if let Err(e) = try_encode_to(&response[..], &mut socket).await {
if e.is_disconnect() {
#[cfg(feature = "tracing")]
if has_processed_request {
debug!("client disconnected while sending success frame");
}
return;
}
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
debug!(%e, "Error sending response to client"); debug!(%e, "Error sending response to client");
} }
has_processed_request = true;
} }
}); });
} }