keyfork/crates/daemon/keyforkd/src/middleware.rs

168 lines
4.6 KiB
Rust

use std::{future::Future, marker::PhantomData, pin::Pin};
use bincode::{deserialize, serialize};
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tower::{Layer, Service};
/// Layer a [`BincodeService`] upon another Service.
pub struct BincodeLayer<'a, Request> {
phantom: PhantomData<&'a ()>,
phantom_request: PhantomData<&'a Request>,
}
impl<'a, Request> BincodeLayer<'a, Request> {
/// Create a new [`BincodeLayer`].
pub fn new() -> Self {
Self {
phantom: PhantomData,
phantom_request: PhantomData,
}
}
}
impl<'a, Request> Default for BincodeLayer<'a, Request> {
fn default() -> Self {
Self::new()
}
}
impl<'a, S: 'a, Request> Layer<S> for BincodeLayer<'a, Request> {
type Service = BincodeService<S, Request>;
fn layer(&self, service: S) -> Self::Service {
BincodeService {
service,
phantom_request: PhantomData,
}
}
}
/// Transform a Bincode-serialized type to a Rust type.
#[derive(Clone)]
pub struct BincodeService<S, Request> {
service: S,
phantom_request: PhantomData<Request>,
}
/// An error encountered either while transforming data or against the interior Service.
#[derive(Debug, Error)]
pub enum BincodeServiceError {
/// An error occurred while polling the internal service.
#[error("Error while polling: {0}")]
Poll(String),
/// An error occurred while calling the internal service.
#[error("Error while calling: {0}")]
Call(String),
/// An error occurred while converting to or from bincode.
#[error("Error while converting: {0}")]
Convert(String),
}
impl<S, Request> Service<Vec<u8>> for BincodeService<S, Request>
where
S: Service<Request>,
Request: DeserializeOwned,
<S as Service<Request>>::Response: Serialize,
<S as Service<Request>>::Error: std::error::Error + Serialize,
<S as Service<Request>>::Future: Send + 'static,
{
type Response = Vec<u8>;
type Error = BincodeServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service
.poll_ready(cx)
.map_err(|e| BincodeServiceError::Poll(e.to_string()))
}
fn call(&mut self, req: Vec<u8>) -> Self::Future {
let request: Request = match deserialize(&req) {
Ok(r) => r,
Err(e) => {
return Box::pin(async move { Err(BincodeServiceError::Convert(e.to_string())) })
}
};
let response = self.service.call(request);
Box::pin(async move {
let response = response.await;
#[cfg(feature = "tracing")]
if let Err(e) = &response {
tracing::error!("Error performing derivation: {e}");
}
serialize(&response).map_err(|e| BincodeServiceError::Convert(e.to_string()))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::{future::Future, pin::Pin, task::Poll};
use tower::{ServiceBuilder, ServiceExt};
#[derive(Serialize, Deserialize)]
struct Test {
field: String,
}
impl Test {
fn new() -> Self {
Self {
field: "hello world!".to_string(),
}
}
}
struct App;
#[derive(Debug, thiserror::Error, Serialize)]
enum Infallible {}
impl Service<Test> for App {
type Response = Test;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Test) -> Self::Future {
Box::pin(async {
assert_eq!(req.field, Test::new().field);
Ok(req)
})
}
}
#[tokio::test]
async fn can_serde_responses() {
let test = Test::new();
let content = serialize(&test).unwrap();
let mut service = ServiceBuilder::new()
.layer(BincodeLayer::<Test>::default())
.service(App);
let result = service
.ready()
.await
.unwrap()
.call(content.clone())
.await
.unwrap();
assert_eq!(result, serialize(&Result::<Test, Infallible>::Ok(test)).unwrap());
}
}