use std::{future::Future, marker::PhantomData, pin::Pin}; use bincode::{deserialize, serialize}; use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use tower::{Layer, Service}; /// Layer a [`BincodeService`] upon another Service. pub struct BincodeLayer<'a, Request> { phantom: PhantomData<&'a ()>, phantom_request: PhantomData<&'a Request>, } impl<'a, Request> BincodeLayer<'a, Request> { /// Create a new [`BincodeLayer`]. pub fn new() -> Self { Self { phantom: PhantomData, phantom_request: PhantomData, } } } impl<'a, Request> Default for BincodeLayer<'a, Request> { fn default() -> Self { Self::new() } } impl<'a, S: 'a, Request> Layer for BincodeLayer<'a, Request> { type Service = BincodeService; fn layer(&self, service: S) -> Self::Service { BincodeService { service, phantom_request: PhantomData, } } } /// Transform a Bincode-serialized type to a Rust type. #[derive(Clone)] pub struct BincodeService { service: S, phantom_request: PhantomData, } /// An error encountered either while transforming data or against the interior Service. #[derive(Debug, Error)] pub enum BincodeServiceError { /// An error occurred while polling the internal service. #[error("Error while polling: {0}")] Poll(String), /// An error occurred while calling the internal service. #[error("Error while calling: {0}")] Call(String), /// An error occurred while converting to or from bincode. #[error("Error while converting: {0}")] Convert(String), } impl Service> for BincodeService where S: Service, Request: DeserializeOwned, >::Response: Serialize, >::Error: std::error::Error + Serialize, >::Future: Send + 'static, { type Response = Vec; type Error = BincodeServiceError; type Future = Pin> + Send>>; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.service .poll_ready(cx) .map_err(|e| BincodeServiceError::Poll(e.to_string())) } fn call(&mut self, req: Vec) -> Self::Future { let request: Request = match deserialize(&req) { Ok(r) => r, Err(e) => { return Box::pin(async move { Err(BincodeServiceError::Convert(e.to_string())) }) } }; let response = self.service.call(request); Box::pin(async move { let response = response.await; #[cfg(feature = "tracing")] if let Err(e) = &response { tracing::error!("Error performing derivation: {e}"); } serialize(&response).map_err(|e| BincodeServiceError::Convert(e.to_string())) }) } } #[cfg(test)] mod tests { use super::*; use serde::{Deserialize, Serialize}; use std::{future::Future, pin::Pin, task::Poll}; use tower::{ServiceBuilder, ServiceExt}; #[derive(Serialize, Deserialize)] struct Test { field: String, } impl Test { fn new() -> Self { Self { field: "hello world!".to_string(), } } } struct App; #[derive(Debug, thiserror::Error, Serialize)] enum Infallible {} impl Service for App { type Response = Test; type Error = Infallible; type Future = Pin> + Send>>; fn poll_ready( &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Test) -> Self::Future { Box::pin(async { assert_eq!(req.field, Test::new().field); Ok(req) }) } } #[tokio::test] async fn can_serde_responses() { let test = Test::new(); let content = serialize(&test).unwrap(); let mut service = ServiceBuilder::new() .layer(BincodeLayer::::default()) .service(App); let result = service .ready() .await .unwrap() .call(content.clone()) .await .unwrap(); assert_eq!(result, serialize(&Result::::Ok(test)).unwrap()); } }