keyforkd: appropriately handle or debug disconnects

This commit is contained in:
Ryan Heywood 2024-02-12 03:08:54 -05:00
parent f1c24fb33e
commit e441ef520f
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
1 changed files with 60 additions and 2 deletions

View File

@ -1,6 +1,9 @@
//! A UNIX socket server to run a Tower Service.
use keyfork_frame::asyncext::{try_decode_from, try_encode_to};
use keyfork_frame::{
asyncext::{try_decode_from, try_encode_to},
DecodeError, EncodeError,
};
use std::{
io::Error,
path::{Path, PathBuf},
@ -17,6 +20,34 @@ pub struct UnixServer {
listener: UnixListener,
}
/// This feels like a hack, but this is a convenient way to use the same method to quickly verify
/// something across two different error types.
trait IsDisconnect {
fn is_disconnect(&self) -> bool;
}
impl IsDisconnect for DecodeError {
fn is_disconnect(&self) -> bool {
if let Self::Io(e) = self {
if let std::io::ErrorKind::UnexpectedEof = e.kind() {
return true;
}
}
false
}
}
impl IsDisconnect for EncodeError {
fn is_disconnect(&self) -> bool {
if let Self::Io(e) = self {
if let std::io::ErrorKind::UnexpectedEof = e.kind() {
return true;
}
}
false
}
}
impl UnixServer {
/// Bind a socket to the given `address` and create a [`UnixServer`]. This function also creates a ctrl_c handler to automatically clean up the socket file.
///
@ -68,11 +99,19 @@ impl UnixServer {
#[cfg(feature = "tracing")]
debug!("new socket connected");
tokio::spawn(async move {
let mut has_processed_request = false;
// Process requests until an error occurs or a client disconnects
loop {
let bytes = match try_decode_from(&mut socket).await {
Ok(bytes) => bytes,
Err(e) => {
if e.is_disconnect() {
#[cfg(feature = "tracing")]
if !has_processed_request {
debug!("client disconnected before sending any response");
}
return;
}
#[cfg(feature = "tracing")]
debug!(%e, "Error reading DerivationPath from socket");
let content = e.to_string().bytes().collect::<Vec<_>>();
@ -109,17 +148,36 @@ impl UnixServer {
let result = try_encode_to(&content[..], &mut socket).await;
#[cfg(feature = "tracing")]
if let Err(error) = result {
debug!(%error, "Error sending error to client");
if error.is_disconnect() {
#[cfg(feature = "tracing")]
if has_processed_request {
debug!("client disconnected while sending error frame");
}
return;
}
debug!(%error, "Error sending error to client");
}
has_processed_request = true;
// The error has been successfully sent, the client may perform
// another request.
continue;
}
}
.into();
if let Err(e) = try_encode_to(&response[..], &mut socket).await {
if e.is_disconnect() {
#[cfg(feature = "tracing")]
if has_processed_request {
debug!("client disconnected while sending success frame");
}
return;
}
#[cfg(feature = "tracing")]
debug!(%e, "Error sending response to client");
}
has_processed_request = true;
}
});
}