168 lines
4.6 KiB
Rust
168 lines
4.6 KiB
Rust
use std::{future::Future, marker::PhantomData, pin::Pin};
|
|
|
|
use bincode::{deserialize, serialize};
|
|
use serde::{de::DeserializeOwned, Serialize};
|
|
use thiserror::Error;
|
|
use tower::{Layer, Service};
|
|
|
|
/// Layer a [`BincodeService`] upon another Service.
|
|
pub struct BincodeLayer<'a, Request> {
|
|
phantom: PhantomData<&'a ()>,
|
|
phantom_request: PhantomData<&'a Request>,
|
|
}
|
|
|
|
impl<'a, Request> BincodeLayer<'a, Request> {
|
|
/// Create a new [`BincodeLayer`].
|
|
pub fn new() -> Self {
|
|
Self {
|
|
phantom: PhantomData,
|
|
phantom_request: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, Request> Default for BincodeLayer<'a, Request> {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl<'a, S: 'a, Request> Layer<S> for BincodeLayer<'a, Request> {
|
|
type Service = BincodeService<S, Request>;
|
|
|
|
fn layer(&self, service: S) -> Self::Service {
|
|
BincodeService {
|
|
service,
|
|
phantom_request: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Transform a Bincode-serialized type to a Rust type.
|
|
#[derive(Clone)]
|
|
pub struct BincodeService<S, Request> {
|
|
service: S,
|
|
phantom_request: PhantomData<Request>,
|
|
}
|
|
|
|
/// An error encountered either while transforming data or against the interior Service.
|
|
#[derive(Debug, Error)]
|
|
pub enum BincodeServiceError {
|
|
/// An error occurred while polling the internal service.
|
|
#[error("Error while polling: {0}")]
|
|
Poll(String),
|
|
|
|
/// An error occurred while calling the internal service.
|
|
#[error("Error while calling: {0}")]
|
|
Call(String),
|
|
|
|
/// An error occurred while converting to or from bincode.
|
|
#[error("Error while converting: {0}")]
|
|
Convert(String),
|
|
}
|
|
|
|
impl<S, Request> Service<Vec<u8>> for BincodeService<S, Request>
|
|
where
|
|
S: Service<Request>,
|
|
Request: DeserializeOwned,
|
|
<S as Service<Request>>::Response: Serialize,
|
|
<S as Service<Request>>::Error: std::error::Error + Serialize,
|
|
<S as Service<Request>>::Future: Send + 'static,
|
|
{
|
|
type Response = Vec<u8>;
|
|
type Error = BincodeServiceError;
|
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
|
|
fn poll_ready(
|
|
&mut self,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
self.service
|
|
.poll_ready(cx)
|
|
.map_err(|e| BincodeServiceError::Poll(e.to_string()))
|
|
}
|
|
|
|
fn call(&mut self, req: Vec<u8>) -> Self::Future {
|
|
let request: Request = match deserialize(&req) {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
return Box::pin(async move { Err(BincodeServiceError::Convert(e.to_string())) })
|
|
}
|
|
};
|
|
|
|
let response = self.service.call(request);
|
|
|
|
Box::pin(async move {
|
|
let response = response.await;
|
|
#[cfg(feature = "tracing")]
|
|
if let Err(e) = &response {
|
|
tracing::error!("Error performing derivation: {e}");
|
|
}
|
|
serialize(&response).map_err(|e| BincodeServiceError::Convert(e.to_string()))
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{future::Future, pin::Pin, task::Poll};
|
|
use tower::{ServiceBuilder, ServiceExt};
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct Test {
|
|
field: String,
|
|
}
|
|
|
|
impl Test {
|
|
fn new() -> Self {
|
|
Self {
|
|
field: "hello world!".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct App;
|
|
|
|
#[derive(Debug, thiserror::Error, Serialize)]
|
|
enum Infallible {}
|
|
|
|
impl Service<Test> for App {
|
|
type Response = Test;
|
|
type Error = Infallible;
|
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
|
|
fn poll_ready(
|
|
&mut self,
|
|
_cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn call(&mut self, req: Test) -> Self::Future {
|
|
Box::pin(async {
|
|
assert_eq!(req.field, Test::new().field);
|
|
Ok(req)
|
|
})
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn can_serde_responses() {
|
|
let test = Test::new();
|
|
let content = serialize(&test).unwrap();
|
|
let mut service = ServiceBuilder::new()
|
|
.layer(BincodeLayer::<Test>::default())
|
|
.service(App);
|
|
let result = service
|
|
.ready()
|
|
.await
|
|
.unwrap()
|
|
.call(content.clone())
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(result, serialize(&Result::<Test, Infallible>::Ok(test)).unwrap());
|
|
}
|
|
}
|