use std::{future::Future, marker::PhantomData, pin::Pin}; use bincode::{deserialize, serialize}; use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; use tower::{Layer, Service}; pub struct SerdeLayer<'a, Request> { phantom: PhantomData<&'a ()>, phantom_request: PhantomData<&'a Request>, } impl<'a, Request> SerdeLayer<'a, Request> { pub fn new() -> Self { Self { phantom: PhantomData, phantom_request: PhantomData, } } } impl<'a, S: 'a, Request> Layer for SerdeLayer<'a, Request> { type Service = SerdeService; fn layer(&self, service: S) -> Self::Service { SerdeService { service, phantom_request: PhantomData, } } } #[derive(Clone)] pub struct SerdeService { service: S, phantom_request: PhantomData, } #[derive(Debug, Error)] pub enum SerdeServiceError { #[error("Error while polling: {0}")] Poll(String), #[error("Error while calling: {0}")] Call(String), #[error("Error while converting: {0}")] Convert(String), } impl Service> for SerdeService where S: Service + Send + Sync, Request: DeserializeOwned + Send, >::Error: std::error::Error + Send, >::Response: Serialize + Send, >::Future: Send + 'static, { type Response = Vec; type Error = SerdeServiceError; type Future = Pin> + Send>>; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.service .poll_ready(cx) .map_err(|e| SerdeServiceError::Poll(e.to_string())) } fn call(&mut self, req: Vec) -> Self::Future { let request: Request = match deserialize(&req) { Ok(r) => r, Err(e) => { return Box::pin( async move { Err(SerdeServiceError::Convert(e.to_string())) }, ) } }; let response = self.service.call(request); Box::pin(async move { let response = response .await .map_err(|e| SerdeServiceError::Call(e.to_string()))?; serialize(&response).map_err(|e| SerdeServiceError::Convert(e.to_string())) }) } }