Compare commits

...

3 Commits

5 changed files with 76 additions and 23 deletions

View File

@ -45,7 +45,6 @@ pub trait PrivateKey: Sized {
fn requires_hardened_derivation() -> bool { fn requires_hardened_derivation() -> bool {
false false
} }
} }
/// Errors associated with creating and arithmetic on private keys. This specific error is only /// Errors associated with creating and arithmetic on private keys. This specific error is only

View File

@ -1,4 +1,6 @@
use crate::{extended_key::private_key::Error as XPrvError, DerivationPath, ExtendedPrivateKey, PrivateKey}; use crate::{
extended_key::private_key::Error as XPrvError, DerivationPath, ExtendedPrivateKey, PrivateKey,
};
use keyfork_mnemonic_util::Mnemonic; use keyfork_mnemonic_util::Mnemonic;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -214,7 +214,7 @@ impl Mnemonic {
} }
num num
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
#[must_use] #[must_use]

View File

@ -16,7 +16,7 @@ use service::Keyforkd;
pub async fn start_and_run_server(mnemonic: Mnemonic) -> Result<(), Box<dyn std::error::Error>> { pub async fn start_and_run_server(mnemonic: Mnemonic) -> Result<(), Box<dyn std::error::Error>> {
let service = ServiceBuilder::new() let service = ServiceBuilder::new()
.layer(middleware::SerdeLayer::new()) .layer(middleware::BincodeLayer::new())
.service(Keyforkd::new(mnemonic)); .service(Keyforkd::new(mnemonic));
let runtime_vars = std::env::vars() let runtime_vars = std::env::vars()

View File

@ -5,12 +5,12 @@ use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error; use thiserror::Error;
use tower::{Layer, Service}; use tower::{Layer, Service};
pub struct SerdeLayer<'a, Request> { pub struct BincodeLayer<'a, Request> {
phantom: PhantomData<&'a ()>, phantom: PhantomData<&'a ()>,
phantom_request: PhantomData<&'a Request>, phantom_request: PhantomData<&'a Request>,
} }
impl<'a, Request> SerdeLayer<'a, Request> { impl<'a, Request> BincodeLayer<'a, Request> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
phantom: PhantomData, phantom: PhantomData,
@ -19,11 +19,11 @@ impl<'a, Request> SerdeLayer<'a, Request> {
} }
} }
impl<'a, S: 'a, Request> Layer<S> for SerdeLayer<'a, Request> { impl<'a, S: 'a, Request> Layer<S> for BincodeLayer<'a, Request> {
type Service = SerdeService<S, Request>; type Service = BincodeService<S, Request>;
fn layer(&self, service: S) -> Self::Service { fn layer(&self, service: S) -> Self::Service {
SerdeService { BincodeService {
service, service,
phantom_request: PhantomData, phantom_request: PhantomData,
} }
@ -31,13 +31,13 @@ impl<'a, S: 'a, Request> Layer<S> for SerdeLayer<'a, Request> {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct SerdeService<S, Request> { pub struct BincodeService<S, Request> {
service: S, service: S,
phantom_request: PhantomData<Request>, phantom_request: PhantomData<Request>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum SerdeServiceError { pub enum BincodeServiceError {
#[error("Error while polling: {0}")] #[error("Error while polling: {0}")]
Poll(String), Poll(String),
@ -48,16 +48,16 @@ pub enum SerdeServiceError {
Convert(String), Convert(String),
} }
impl<S, Request> Service<Vec<u8>> for SerdeService<S, Request> impl<S, Request> Service<Vec<u8>> for BincodeService<S, Request>
where where
S: Service<Request> + Send + Sync, S: Service<Request>,
Request: DeserializeOwned + Send, Request: DeserializeOwned,
<S as Service<Request>>::Error: std::error::Error + Send, <S as Service<Request>>::Response: Serialize,
<S as Service<Request>>::Response: Serialize + Send, <S as Service<Request>>::Error: std::error::Error,
<S as Service<Request>>::Future: Send + 'static, <S as Service<Request>>::Future: Send + 'static,
{ {
type Response = Vec<u8>; type Response = Vec<u8>;
type Error = SerdeServiceError; type Error = BincodeServiceError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready( fn poll_ready(
@ -66,16 +66,14 @@ where
) -> std::task::Poll<Result<(), Self::Error>> { ) -> std::task::Poll<Result<(), Self::Error>> {
self.service self.service
.poll_ready(cx) .poll_ready(cx)
.map_err(|e| SerdeServiceError::Poll(e.to_string())) .map_err(|e| BincodeServiceError::Poll(e.to_string()))
} }
fn call(&mut self, req: Vec<u8>) -> Self::Future { fn call(&mut self, req: Vec<u8>) -> Self::Future {
let request: Request = match deserialize(&req) { let request: Request = match deserialize(&req) {
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
return Box::pin( return Box::pin(async move { Err(BincodeServiceError::Convert(e.to_string())) })
async move { Err(SerdeServiceError::Convert(e.to_string())) },
)
} }
}; };
@ -84,8 +82,62 @@ where
Box::pin(async move { Box::pin(async move {
let response = response let response = response
.await .await
.map_err(|e| SerdeServiceError::Call(e.to_string()))?; .map_err(|e| BincodeServiceError::Call(e.to_string()))?;
serialize(&response).map_err(|e| SerdeServiceError::Convert(e.to_string())) serialize(&response).map_err(|e| BincodeServiceError::Convert(e.to_string()))
}) })
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::{future::Future, pin::Pin, task::Poll};
use tower::{ServiceBuilder, ServiceExt};
#[derive(Serialize, Deserialize)]
struct Test {
field: String,
}
struct App;
#[derive(Debug, thiserror::Error)]
enum Infallible {}
impl Service<Test> for App {
type Response = Test;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Test) -> Self::Future {
Box::pin(async { Ok(req) })
}
}
#[tokio::test]
async fn can_serde_responses() {
let content = serialize(&Test {
field: "hello world!".to_string(),
})
.unwrap();
let mut service = ServiceBuilder::new()
.layer(BincodeLayer::<Test>::new())
.service(App);
let result = service
.ready()
.await
.unwrap()
.call(content.clone())
.await
.unwrap();
assert_eq!(result, content);
}
}