From e441ef520f09b0e371ab19890c069c0c17bca674 Mon Sep 17 00:00:00 2001 From: ryan Date: Mon, 12 Feb 2024 03:08:54 -0500 Subject: [PATCH] keyforkd: appropriately handle or debug disconnects --- crates/daemon/keyforkd/src/server.rs | 62 +++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/crates/daemon/keyforkd/src/server.rs b/crates/daemon/keyforkd/src/server.rs index 3bfe8ab..780b1aa 100644 --- a/crates/daemon/keyforkd/src/server.rs +++ b/crates/daemon/keyforkd/src/server.rs @@ -1,6 +1,9 @@ //! A UNIX socket server to run a Tower Service. -use keyfork_frame::asyncext::{try_decode_from, try_encode_to}; +use keyfork_frame::{ + asyncext::{try_decode_from, try_encode_to}, + DecodeError, EncodeError, +}; use std::{ io::Error, path::{Path, PathBuf}, @@ -17,6 +20,34 @@ pub struct UnixServer { listener: UnixListener, } +/// This feels like a hack, but this is a convenient way to use the same method to quickly verify +/// something across two different error types. +trait IsDisconnect { + fn is_disconnect(&self) -> bool; +} + +impl IsDisconnect for DecodeError { + fn is_disconnect(&self) -> bool { + if let Self::Io(e) = self { + if let std::io::ErrorKind::UnexpectedEof = e.kind() { + return true; + } + } + false + } +} + +impl IsDisconnect for EncodeError { + fn is_disconnect(&self) -> bool { + if let Self::Io(e) = self { + if let std::io::ErrorKind::UnexpectedEof = e.kind() { + return true; + } + } + false + } +} + impl UnixServer { /// Bind a socket to the given `address` and create a [`UnixServer`]. This function also creates a ctrl_c handler to automatically clean up the socket file. /// @@ -68,11 +99,19 @@ impl UnixServer { #[cfg(feature = "tracing")] debug!("new socket connected"); tokio::spawn(async move { + let mut has_processed_request = false; // Process requests until an error occurs or a client disconnects loop { let bytes = match try_decode_from(&mut socket).await { Ok(bytes) => bytes, Err(e) => { + if e.is_disconnect() { + #[cfg(feature = "tracing")] + if !has_processed_request { + debug!("client disconnected before sending any response"); + } + return; + } #[cfg(feature = "tracing")] debug!(%e, "Error reading DerivationPath from socket"); let content = e.to_string().bytes().collect::>(); @@ -109,17 +148,36 @@ impl UnixServer { let result = try_encode_to(&content[..], &mut socket).await; #[cfg(feature = "tracing")] if let Err(error) = result { + if error.is_disconnect() { + #[cfg(feature = "tracing")] + if has_processed_request { + debug!("client disconnected while sending error frame"); + } + return; + } debug!(%error, "Error sending error to client"); } - return; + has_processed_request = true; + // The error has been successfully sent, the client may perform + // another request. + continue; } } .into(); if let Err(e) = try_encode_to(&response[..], &mut socket).await { + if e.is_disconnect() { + #[cfg(feature = "tracing")] + if has_processed_request { + debug!("client disconnected while sending success frame"); + } + return; + } #[cfg(feature = "tracing")] debug!(%e, "Error sending response to client"); } + + has_processed_request = true; } }); }