From a7feed1bcc98f0eb88b77bed50c3dfa8e7b05043 Mon Sep 17 00:00:00 2001 From: ryan Date: Thu, 7 Sep 2023 08:05:38 -0500 Subject: [PATCH] keyforkd: extract serialization logic into middleware --- keyforkd/src/main.rs | 6 ++- keyforkd/src/middleware.rs | 91 ++++++++++++++++++++++++++++++++++++++ keyforkd/src/server.rs | 49 +++++++++++--------- 3 files changed, 125 insertions(+), 21 deletions(-) create mode 100644 keyforkd/src/middleware.rs diff --git a/keyforkd/src/main.rs b/keyforkd/src/main.rs index 082809f..7d40281 100644 --- a/keyforkd/src/main.rs +++ b/keyforkd/src/main.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, path::PathBuf}; use keyfork_mnemonic_util::Mnemonic; use tokio::io::{self, AsyncBufReadExt, BufReader}; +use tower::ServiceBuilder; #[cfg(feature = "tracing")] use tracing::debug; @@ -17,6 +18,7 @@ use tracing_subscriber::{ mod error; mod server; mod service; +mod middleware; use error::KeyforkdError; use server::UnixServer; use service::Keyforkd; @@ -52,7 +54,9 @@ async fn main() -> Result<(), Box> { debug!("reading mnemonic from standard input"); let mnemonic = load_mnemonic().await?; - let service = Keyforkd::new(mnemonic); + let service = ServiceBuilder::new() + .layer(middleware::SerdeLayer::new()) + .service(Keyforkd::new(mnemonic)); let runtime_vars = std::env::vars() .filter(|(key, _)| ["XDG_RUNTIME_DIR", "KEYFORKD_SOCKET_PATH"].contains(&key.as_str())) diff --git a/keyforkd/src/middleware.rs b/keyforkd/src/middleware.rs new file mode 100644 index 0000000..9aa7578 --- /dev/null +++ b/keyforkd/src/middleware.rs @@ -0,0 +1,91 @@ +use std::{future::Future, marker::PhantomData, pin::Pin}; + +use bincode::{deserialize, serialize}; +use serde::{de::DeserializeOwned, Serialize}; +use thiserror::Error; +use tower::{Layer, Service}; + +pub struct SerdeLayer<'a, Request> { + phantom: PhantomData<&'a ()>, + phantom_request: PhantomData<&'a Request>, +} + +impl<'a, Request> SerdeLayer<'a, Request> { + pub fn new() -> Self { + Self { + phantom: PhantomData, + phantom_request: PhantomData, + } + } +} + +impl<'a, S: 'a, Request> Layer for SerdeLayer<'a, Request> { + type Service = SerdeService; + + fn layer(&self, service: S) -> Self::Service { + SerdeService { + service, + phantom_request: PhantomData, + } + } +} + +#[derive(Clone)] +pub struct SerdeService { + service: S, + phantom_request: PhantomData, +} + +#[derive(Debug, Error)] +pub enum SerdeServiceError { + #[error("Error while polling: {0}")] + Poll(String), + + #[error("Error while calling: {0}")] + Call(String), + + #[error("Error while converting: {0}")] + Convert(String), +} + +impl Service> for SerdeService +where + S: Service + Send + Sync, + Request: DeserializeOwned + Send, + >::Error: std::error::Error + Send, + >::Response: Serialize + Send, + >::Future: Send + 'static, +{ + type Response = Vec; + type Error = SerdeServiceError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.service + .poll_ready(cx) + .map_err(|e| SerdeServiceError::Poll(e.to_string())) + } + + fn call(&mut self, req: Vec) -> Self::Future { + let request: Request = match deserialize(&req) { + Ok(r) => r, + Err(e) => { + return Box::pin( + async move { Err(SerdeServiceError::Convert(e.to_string())) }, + ) + } + }; + + let response = self.service.call(request); + + Box::pin(async move { + let response = response + .await + .map_err(|e| SerdeServiceError::Call(e.to_string()))?; + serialize(&response).map_err(|e| SerdeServiceError::Convert(e.to_string())) + }) + } +} diff --git a/keyforkd/src/server.rs b/keyforkd/src/server.rs index 35e0fdf..9811672 100644 --- a/keyforkd/src/server.rs +++ b/keyforkd/src/server.rs @@ -1,28 +1,14 @@ -use crate::service::{DerivationError, Keyforkd}; -use keyfork_derive_util::DerivationPath; use keyfork_frame::asyncext::{try_decode_from, try_encode_to}; use std::{ io::Error, path::{Path, PathBuf}, }; -use tokio::net::{UnixListener, UnixStream}; +use tokio::net::UnixListener; use tower::{Service, ServiceExt}; #[cfg(feature = "tracing")] use tracing::debug; -async fn read_path_from_socket( - socket: &mut UnixStream, -) -> Result> { - let data = try_decode_from(socket).await.unwrap(); - let path: DerivationPath = bincode::deserialize(&data[..]).unwrap(); - Ok(path) -} - -async fn wait_and_run(app: &mut Keyforkd, path: DerivationPath) -> Result, DerivationError> { - app.ready().await?.call(path).await -} - #[allow(clippy::module_name_repetitions)] pub struct UnixServer { listener: UnixListener, @@ -54,7 +40,14 @@ impl UnixServer { }) } - pub async fn run(&mut self, app: Keyforkd) -> Result<(), Box> { + pub async fn run(&mut self, app: S) -> Result<(), Box> + where + S: Service + Clone + Send + 'static, + R: From> + Send, + >::Error: std::error::Error + Send, + >::Response: std::convert::Into> + Send, + >::Future: Send, + { #[cfg(feature = "tracing")] debug!("Listening for clients"); loop { @@ -63,8 +56,8 @@ impl UnixServer { #[cfg(feature = "tracing")] debug!("new socket connected"); tokio::spawn(async move { - let path = match read_path_from_socket(&mut socket).await { - Ok(path) => path, + let bytes = match try_decode_from(&mut socket).await { + Ok(bytes) => bytes, Err(e) => { #[cfg(feature = "tracing")] debug!(%e, "Error reading DerivationPath from socket"); @@ -78,7 +71,22 @@ impl UnixServer { } }; - let response = match wait_and_run(&mut app, path).await { + let app = match app.ready().await { + Ok(app) => app, + Err(e) => { + #[cfg(feature = "tracing")] + debug!(%e, "Could not poll ready"); + let content = e.to_string().bytes().collect::>(); + let result = try_encode_to(&content[..], &mut socket).await; + #[cfg(feature = "tracing")] + if let Err(error) = result { + debug!(%error, "Error sending error to client"); + } + return; + } + }; + + let response = match app.call(bytes.into()).await { Ok(response) => response, Err(e) => { #[cfg(feature = "tracing")] @@ -91,7 +99,8 @@ impl UnixServer { } return; } - }; + } + .into(); if let Err(e) = try_encode_to(&response[..], &mut socket).await { #[cfg(feature = "tracing")]