186 lines
7.2 KiB
Rust
186 lines
7.2 KiB
Rust
//! A UNIX socket server to run a Tower Service.
|
|
|
|
use keyfork_frame::{
|
|
asyncext::{try_decode_from, try_encode_to},
|
|
DecodeError, EncodeError,
|
|
};
|
|
use std::{
|
|
io::Error,
|
|
path::{Path, PathBuf},
|
|
};
|
|
use tokio::net::UnixListener;
|
|
use tower::{Service, ServiceExt};
|
|
|
|
#[cfg(feature = "tracing")]
|
|
use tracing::debug;
|
|
|
|
/// A UNIX Socket Server.
|
|
#[allow(clippy::module_name_repetitions)]
|
|
pub struct UnixServer {
|
|
listener: UnixListener,
|
|
}
|
|
|
|
/// This feels like a hack, but this is a convenient way to use the same method to quickly verify
|
|
/// something across two different error types.
|
|
trait IsDisconnect {
|
|
fn is_disconnect(&self) -> bool;
|
|
}
|
|
|
|
impl IsDisconnect for DecodeError {
|
|
fn is_disconnect(&self) -> bool {
|
|
if let Self::Io(e) = self {
|
|
if let std::io::ErrorKind::UnexpectedEof = e.kind() {
|
|
return true;
|
|
}
|
|
}
|
|
false
|
|
}
|
|
}
|
|
|
|
impl IsDisconnect for EncodeError {
|
|
fn is_disconnect(&self) -> bool {
|
|
if let Self::Io(e) = self {
|
|
if let std::io::ErrorKind::UnexpectedEof = e.kind() {
|
|
return true;
|
|
}
|
|
}
|
|
false
|
|
}
|
|
}
|
|
|
|
impl UnixServer {
|
|
/// Bind a socket to the given `address` and create a [`UnixServer`]. This function also creates a ctrl_c handler to automatically clean up the socket file.
|
|
///
|
|
/// # Errors
|
|
/// This function may return an error if the socket can't be bound.
|
|
pub fn bind(address: impl AsRef<Path>) -> Result<Self, Error> {
|
|
let mut path = PathBuf::new();
|
|
path.extend(address.as_ref().components());
|
|
tokio::spawn(async move {
|
|
#[cfg(feature = "tracing")]
|
|
debug!("Binding tokio ctrl-c handler");
|
|
let result = tokio::signal::ctrl_c().await;
|
|
#[cfg(feature = "tracing")]
|
|
debug!(
|
|
?result,
|
|
"encountered ctrl-c, performing cleanup and exiting"
|
|
);
|
|
let result = tokio::fs::remove_file(&path).await;
|
|
#[cfg(feature = "tracing")]
|
|
if let Err(error) = result {
|
|
debug!(%error, "unable to remove path: {}", path.display());
|
|
}
|
|
std::process::exit(0x80);
|
|
});
|
|
|
|
Ok(Self {
|
|
listener: UnixListener::bind(address)?,
|
|
})
|
|
}
|
|
|
|
/// Given a Service, accept clients and use their input to call the Service.
|
|
///
|
|
/// # Errors
|
|
/// The method may return an error if the server becomes unable to accept new connections.
|
|
/// Errors while the server is running are logged using the `tracing` crate.
|
|
pub async fn run<S, R>(&mut self, app: S) -> Result<(), Box<dyn std::error::Error>>
|
|
where
|
|
S: Service<R> + Clone + Send + 'static,
|
|
R: From<Vec<u8>> + Send,
|
|
<S as Service<R>>::Error: std::error::Error + Send,
|
|
<S as Service<R>>::Response: std::convert::Into<Vec<u8>> + Send,
|
|
<S as Service<R>>::Future: Send,
|
|
{
|
|
#[cfg(feature = "tracing")]
|
|
debug!("Listening for clients");
|
|
loop {
|
|
let mut app = app.clone();
|
|
let (mut socket, _) = self.listener.accept().await?;
|
|
#[cfg(feature = "tracing")]
|
|
debug!("new socket connected");
|
|
tokio::spawn(async move {
|
|
let mut has_processed_request = false;
|
|
// Process requests until an error occurs or a client disconnects
|
|
loop {
|
|
let bytes = match try_decode_from(&mut socket).await {
|
|
Ok(bytes) => bytes,
|
|
Err(e) => {
|
|
if e.is_disconnect() {
|
|
#[cfg(feature = "tracing")]
|
|
if !has_processed_request {
|
|
debug!("client disconnected before sending any response");
|
|
}
|
|
return;
|
|
}
|
|
#[cfg(feature = "tracing")]
|
|
debug!(%e, "Error reading DerivationPath from socket");
|
|
let content = e.to_string().bytes().collect::<Vec<_>>();
|
|
let result = try_encode_to(&content[..], &mut socket).await;
|
|
#[cfg(feature = "tracing")]
|
|
if let Err(error) = result {
|
|
debug!(%error, "Error sending error to client");
|
|
}
|
|
return;
|
|
}
|
|
};
|
|
|
|
let app = match app.ready().await {
|
|
Ok(app) => app,
|
|
Err(e) => {
|
|
#[cfg(feature = "tracing")]
|
|
debug!(%e, "Could not poll ready");
|
|
let content = e.to_string().bytes().collect::<Vec<_>>();
|
|
let result = try_encode_to(&content[..], &mut socket).await;
|
|
#[cfg(feature = "tracing")]
|
|
if let Err(error) = result {
|
|
debug!(%error, "Error sending error to client");
|
|
}
|
|
return;
|
|
}
|
|
};
|
|
|
|
let response = match app.call(bytes.into()).await {
|
|
Ok(response) => response,
|
|
Err(e) => {
|
|
#[cfg(feature = "tracing")]
|
|
debug!(%e, "Error reading DerivationPath from socket");
|
|
let content = e.to_string().bytes().collect::<Vec<_>>();
|
|
let result = try_encode_to(&content[..], &mut socket).await;
|
|
#[cfg(feature = "tracing")]
|
|
if let Err(error) = result {
|
|
if error.is_disconnect() {
|
|
#[cfg(feature = "tracing")]
|
|
if has_processed_request {
|
|
debug!("client disconnected while sending error frame");
|
|
}
|
|
return;
|
|
}
|
|
debug!(%error, "Error sending error to client");
|
|
}
|
|
has_processed_request = true;
|
|
// The error has been successfully sent, the client may perform
|
|
// another request.
|
|
continue;
|
|
}
|
|
}
|
|
.into();
|
|
|
|
if let Err(e) = try_encode_to(&response[..], &mut socket).await {
|
|
if e.is_disconnect() {
|
|
#[cfg(feature = "tracing")]
|
|
if has_processed_request {
|
|
debug!("client disconnected while sending success frame");
|
|
}
|
|
return;
|
|
}
|
|
#[cfg(feature = "tracing")]
|
|
debug!(%e, "Error sending response to client");
|
|
}
|
|
|
|
has_processed_request = true;
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|