//! A UNIX socket server to run a Tower Service. use keyfork_frame::asyncext::{try_decode_from, try_encode_to}; use std::{ io::Error, path::{Path, PathBuf}, }; use tokio::net::UnixListener; use tower::{Service, ServiceExt}; #[cfg(feature = "tracing")] use tracing::debug; /// A UNIX Socket Server. #[allow(clippy::module_name_repetitions)] pub struct UnixServer { listener: UnixListener, } impl UnixServer { /// Bind a socket to the given `address` and create a [`UnixServer`]. This function also creates a ctrl_c handler to automatically clean up the socket file. /// /// # Errors /// This function may return an error if the socket can't be bound. pub fn bind(address: impl AsRef) -> Result { let mut path = PathBuf::new(); path.extend(address.as_ref().components()); tokio::spawn(async move { #[cfg(feature = "tracing")] debug!("Binding tokio ctrl-c handler"); let result = tokio::signal::ctrl_c().await; #[cfg(feature = "tracing")] debug!( ?result, "encountered ctrl-c, performing cleanup and exiting" ); let result = tokio::fs::remove_file(&path).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "unable to remove path: {}", path.display()); } std::process::exit(0x80); }); Ok(Self { listener: UnixListener::bind(address)?, }) } /// Given a Service, accept clients and use their input to call the Service. /// /// # Errors /// The method may return an error if the server becomes unable to accept new connections. /// Errors while the server is running are logged using the `tracing` crate. pub async fn run(&mut self, app: S) -> Result<(), Box> where S: Service + Clone + Send + 'static, R: From> + Send, >::Error: std::error::Error + Send, >::Response: std::convert::Into> + Send, >::Future: Send, { #[cfg(feature = "tracing")] debug!("Listening for clients"); loop { let mut app = app.clone(); let (mut socket, _) = self.listener.accept().await?; #[cfg(feature = "tracing")] debug!("new socket connected"); tokio::spawn(async move { // Process requests until an error occurs or a client disconnects loop { let bytes = match try_decode_from(&mut socket).await { Ok(bytes) => bytes, Err(e) => { #[cfg(feature = "tracing")] debug!(%e, "Error reading DerivationPath from socket"); let content = e.to_string().bytes().collect::>(); let result = try_encode_to(&content[..], &mut socket).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "Error sending error to client"); } return; } }; let app = match app.ready().await { Ok(app) => app, Err(e) => { #[cfg(feature = "tracing")] debug!(%e, "Could not poll ready"); let content = e.to_string().bytes().collect::>(); let result = try_encode_to(&content[..], &mut socket).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "Error sending error to client"); } return; } }; let response = match app.call(bytes.into()).await { Ok(response) => response, Err(e) => { #[cfg(feature = "tracing")] debug!(%e, "Error reading DerivationPath from socket"); let content = e.to_string().bytes().collect::>(); let result = try_encode_to(&content[..], &mut socket).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "Error sending error to client"); } return; } } .into(); if let Err(e) = try_encode_to(&response[..], &mut socket).await { #[cfg(feature = "tracing")] debug!(%e, "Error sending response to client"); } } }); } } }