use crate::service::{DerivablePath, DerivationError, Keyforkd}; use keyfork_frame::asyncext::{try_decode_from, try_encode_to}; use std::{io::Error, path::{Path, PathBuf}}; use tokio::net::{UnixListener, UnixStream}; use tower::{Service, ServiceExt}; #[cfg(feature = "tracing")] use tracing::debug; async fn read_path_from_socket( socket: &mut UnixStream, ) -> Result> { let data = try_decode_from(socket).await.unwrap(); let path: DerivablePath = minicbor::decode(&data[..]).unwrap(); Ok(path) } async fn wait_and_run(app: &mut Keyforkd, path: DerivablePath) -> Result, DerivationError> { app.ready().await?.call(path).await } #[allow(clippy::module_name_repetitions)] pub struct UnixServer { listener: UnixListener, } impl UnixServer { pub fn bind(address: impl AsRef) -> Result { let mut path = PathBuf::new(); path.extend(address.as_ref().components()); tokio::spawn(async move { let result = tokio::signal::ctrl_c().await; #[cfg(feature = "tracing")] debug!(?result, "encountered ctrl-c, performing cleanup and exiting"); let result = tokio::fs::remove_file(&path).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "unable to remove path: {}", path.display()); } std::process::exit(0x80); }); Ok(Self { listener: UnixListener::bind(address)?, }) } pub async fn run(&mut self, app: Keyforkd) -> Result<(), Box> { loop { let mut app = app.clone(); let (mut socket, _) = self.listener.accept().await?; #[cfg(feature = "tracing")] debug!("new socket connected"); tokio::spawn(async move { let path = match read_path_from_socket(&mut socket).await { Ok(path) => path, Err(e) => { #[cfg(feature = "tracing")] debug!(%e, "Error reading DerivablePath from socket"); let content = e.to_string().bytes().collect::>(); let result = try_encode_to(&content[..], &mut socket).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "Error sending error to client"); } return; } }; let response = match wait_and_run(&mut app, path).await { Ok(response) => response, Err(e) => { #[cfg(feature = "tracing")] debug!(%e, "Error reading DerivablePath from socket"); let content = e.to_string().bytes().collect::>(); let result = try_encode_to(&content[..], &mut socket).await; #[cfg(feature = "tracing")] if let Err(error) = result { debug!(%error, "Error sending error to client"); } return; } }; if let Err(e) = try_encode_to(&response[..], &mut socket).await { #[cfg(feature = "tracing")] debug!(%e, "Error sending response to client"); } }); } } }