use std::{future::Future, marker::PhantomData, pin::Pin}; use bincode::{deserialize, serialize}; use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use tower::{Layer, Service}; pub struct BincodeLayer<'a, Request> { phantom: PhantomData<&'a ()>, phantom_request: PhantomData<&'a Request>, } impl<'a, Request> BincodeLayer<'a, Request> { pub fn new() -> Self { Self { phantom: PhantomData, phantom_request: PhantomData, } } } impl<'a, S: 'a, Request> Layer for BincodeLayer<'a, Request> { type Service = BincodeService; fn layer(&self, service: S) -> Self::Service { BincodeService { service, phantom_request: PhantomData, } } } #[derive(Clone)] pub struct BincodeService { service: S, phantom_request: PhantomData, } #[derive(Debug, Error)] pub enum BincodeServiceError { #[error("Error while polling: {0}")] Poll(String), #[error("Error while calling: {0}")] Call(String), #[error("Error while converting: {0}")] Convert(String), } impl Service> for BincodeService where S: Service, Request: DeserializeOwned, >::Response: Serialize, >::Error: std::error::Error, >::Future: Send + 'static, { type Response = Vec; type Error = BincodeServiceError; type Future = Pin> + Send>>; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.service .poll_ready(cx) .map_err(|e| BincodeServiceError::Poll(e.to_string())) } fn call(&mut self, req: Vec) -> Self::Future { let request: Request = match deserialize(&req) { Ok(r) => r, Err(e) => { return Box::pin(async move { Err(BincodeServiceError::Convert(e.to_string())) }) } }; let response = self.service.call(request); Box::pin(async move { let response = response .await .map_err(|e| BincodeServiceError::Call(e.to_string()))?; serialize(&response).map_err(|e| BincodeServiceError::Convert(e.to_string())) }) } } #[cfg(test)] mod tests { use super::*; use serde::{Deserialize, Serialize}; use std::{future::Future, pin::Pin, task::Poll}; use tower::{ServiceBuilder, ServiceExt}; #[derive(Serialize, Deserialize)] struct Test { field: String, } struct App; #[derive(Debug, thiserror::Error)] enum Infallible {} impl Service for App { type Response = Test; type Error = Infallible; type Future = Pin> + Send>>; fn poll_ready( &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Test) -> Self::Future { Box::pin(async { Ok(req) }) } } #[tokio::test] async fn can_serde_responses() { let content = serialize(&Test { field: "hello world!".to_string(), }) .unwrap(); let mut service = ServiceBuilder::new() .layer(BincodeLayer::::new()) .service(App); let result = service .ready() .await .unwrap() .call(content.clone()) .await .unwrap(); assert_eq!(result, content); } }