keyfork/keyforkd/src/middleware.rs

92 lines
2.4 KiB
Rust
Raw Normal View History

use std::{future::Future, marker::PhantomData, pin::Pin};
use bincode::{deserialize, serialize};
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tower::{Layer, Service};
pub struct SerdeLayer<'a, Request> {
phantom: PhantomData<&'a ()>,
phantom_request: PhantomData<&'a Request>,
}
impl<'a, Request> SerdeLayer<'a, Request> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
phantom_request: PhantomData,
}
}
}
impl<'a, S: 'a, Request> Layer<S> for SerdeLayer<'a, Request> {
type Service = SerdeService<S, Request>;
fn layer(&self, service: S) -> Self::Service {
SerdeService {
service,
phantom_request: PhantomData,
}
}
}
#[derive(Clone)]
pub struct SerdeService<S, Request> {
service: S,
phantom_request: PhantomData<Request>,
}
#[derive(Debug, Error)]
pub enum SerdeServiceError {
#[error("Error while polling: {0}")]
Poll(String),
#[error("Error while calling: {0}")]
Call(String),
#[error("Error while converting: {0}")]
Convert(String),
}
impl<S, Request> Service<Vec<u8>> for SerdeService<S, Request>
where
S: Service<Request> + Send + Sync,
Request: DeserializeOwned + Send,
<S as Service<Request>>::Error: std::error::Error + Send,
<S as Service<Request>>::Response: Serialize + Send,
<S as Service<Request>>::Future: Send + 'static,
{
type Response = Vec<u8>;
type Error = SerdeServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service
.poll_ready(cx)
.map_err(|e| SerdeServiceError::Poll(e.to_string()))
}
fn call(&mut self, req: Vec<u8>) -> Self::Future {
let request: Request = match deserialize(&req) {
Ok(r) => r,
Err(e) => {
return Box::pin(
async move { Err(SerdeServiceError::Convert(e.to_string())) },
)
}
};
let response = self.service.call(request);
Box::pin(async move {
let response = response
.await
.map_err(|e| SerdeServiceError::Call(e.to_string()))?;
serialize(&response).map_err(|e| SerdeServiceError::Convert(e.to_string()))
})
}
}