191 lines
6.4 KiB
Rust
191 lines
6.4 KiB
Rust
#![allow(clippy::implicit_clone)]
|
|
|
|
use std::{future::Future, pin::Pin, sync::Arc, task::Poll};
|
|
|
|
use keyfork_derive_path_data::guess_target;
|
|
// use keyfork_derive_util::request::{DerivationError, DerivationRequest, DerivationResponse};
|
|
use keyforkd_models::{DerivationError, Error, Request, Response};
|
|
use tower::Service;
|
|
use tracing::info;
|
|
|
|
// NOTE: All values implemented in Keyforkd must implement Clone with low overhead, either by
|
|
// using an Arc or by having a small signature. This is because Service<T> takes &mut self.
|
|
//
|
|
#[derive(Clone, Debug)]
|
|
pub struct Keyforkd {
|
|
seed: Arc<Vec<u8>>,
|
|
}
|
|
|
|
impl Keyforkd {
|
|
pub fn new(seed: Vec<u8>) -> Self {
|
|
Self {
|
|
seed: Arc::new(seed),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Service<Request> for Keyforkd {
|
|
type Response = Response;
|
|
|
|
// TODO: indicate serialize in BincodeLayer
|
|
type Error = Error;
|
|
|
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
|
|
fn poll_ready(
|
|
&mut self,
|
|
_cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
|
|
fn call(&mut self, req: Request) -> Self::Future {
|
|
let seed = self.seed.clone();
|
|
match req {
|
|
Request::Derivation(req) => Box::pin(async move {
|
|
let len = req.path().len();
|
|
if len < 2 {
|
|
return Err(DerivationError::InvalidDerivationLength(len).into());
|
|
}
|
|
|
|
#[cfg(feature = "tracing")]
|
|
if let Some(target) = guess_target(req.path()) {
|
|
info!("Deriving path: {target}");
|
|
} else {
|
|
info!("Deriving path: {}", req.path());
|
|
}
|
|
|
|
req.derive_with_master_seed((*seed).clone())
|
|
.map(Response::Derivation)
|
|
.map_err(|e| DerivationError::Derivation(e.to_string()).into())
|
|
}),
|
|
Request::DerivationWithTTY(_, _) => todo!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use hex_literal::hex;
|
|
use keyfork_derive_util::{request::*, DerivationPath};
|
|
use keyfork_slip10_test_data::test_data;
|
|
use std::str::FromStr;
|
|
use tower::ServiceExt;
|
|
|
|
#[tokio::test]
|
|
async fn properly_derives_secp256k1() {
|
|
let tests = test_data()
|
|
.unwrap()
|
|
.remove(&"secp256k1".to_string())
|
|
.unwrap();
|
|
|
|
for per_seed in tests {
|
|
let seed = &per_seed.seed;
|
|
|
|
let mut keyforkd = Keyforkd::new(seed.to_vec());
|
|
for test in &per_seed.tests {
|
|
let chain = DerivationPath::from_str(test.chain).unwrap();
|
|
if chain.len() < 2 {
|
|
continue;
|
|
}
|
|
let req = DerivationRequest::new(DerivationAlgorithm::Secp256k1, &chain);
|
|
let response: DerivationResponse = keyforkd
|
|
.ready()
|
|
.await
|
|
.unwrap()
|
|
.call(Request::Derivation(req))
|
|
.await
|
|
.unwrap()
|
|
.try_into()
|
|
.unwrap();
|
|
assert_eq!(response.data, test.private_key);
|
|
assert_eq!(response.chain_code.as_slice(), test.chain_code);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn properly_derives_ed25519() {
|
|
let tests = test_data().unwrap().remove(&"ed25519".to_string()).unwrap();
|
|
|
|
for per_seed in tests {
|
|
let seed = &per_seed.seed;
|
|
|
|
// Test seed path
|
|
let mut keyforkd = Keyforkd::new(seed.to_vec());
|
|
for test in &per_seed.tests {
|
|
let chain = DerivationPath::from_str(test.chain).unwrap();
|
|
if chain.len() < 2 {
|
|
continue;
|
|
}
|
|
let req = DerivationRequest::new(DerivationAlgorithm::Ed25519, &chain);
|
|
let response: DerivationResponse = keyforkd
|
|
.ready()
|
|
.await
|
|
.unwrap()
|
|
.call(Request::Derivation(req))
|
|
.await
|
|
.unwrap()
|
|
.try_into()
|
|
.unwrap();
|
|
assert_eq!(response.data, test.private_key);
|
|
assert_eq!(response.chain_code.as_slice(), test.chain_code);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[should_panic]
|
|
#[tokio::test]
|
|
async fn errors_on_no_path() {
|
|
let tests = [(
|
|
&hex!("000102030405060708090a0b0c0d0e0f")[..],
|
|
DerivationPath::from_str("m").unwrap(),
|
|
hex!("90046a93de5380a72b5e45010748567d5ea02bbf6522f979e05c0d8d8ca9fffb"),
|
|
hex!("2b4be7f19ee27bbf30c667b642d5f4aa69fd169872f8fc3059c08ebae2eb19e7"),
|
|
hex!("00a4b2856bfec510abab89753fac1ac0e1112364e7d250545963f135f2a33188ed"),
|
|
)];
|
|
for (seed, path, _, private_key, _) in tests {
|
|
let req = DerivationRequest::new(DerivationAlgorithm::Ed25519, &path);
|
|
let mut keyforkd = Keyforkd::new(seed.to_vec());
|
|
let response: DerivationResponse = keyforkd
|
|
.ready()
|
|
.await
|
|
.unwrap()
|
|
.call(Request::Derivation(req))
|
|
.await
|
|
.unwrap()
|
|
.try_into()
|
|
.unwrap();
|
|
assert_eq!(response.data, private_key);
|
|
}
|
|
}
|
|
|
|
#[should_panic]
|
|
#[tokio::test]
|
|
async fn errors_on_short_path() {
|
|
let tests = [(
|
|
&hex!("000102030405060708090a0b0c0d0e0f")[..],
|
|
DerivationPath::from_str("m/0'").unwrap(),
|
|
hex!("8b59aa11380b624e81507a27fedda59fea6d0b779a778918a2fd3590e16e9c69"),
|
|
hex!("68e0fe46dfb67e368c75379acec591dad19df3cde26e63b93a8e704f1dade7a3"),
|
|
hex!("008c8a13df77a28f3445213a0f432fde644acaa215fc72dcdf300d5efaa85d350c"),
|
|
)];
|
|
for (seed, path, _, private_key, _) in tests {
|
|
let req = DerivationRequest::new(DerivationAlgorithm::Ed25519, &path);
|
|
let mut keyforkd = Keyforkd::new(seed.to_vec());
|
|
let response: DerivationResponse = keyforkd
|
|
.ready()
|
|
.await
|
|
.unwrap()
|
|
.call(Request::Derivation(req))
|
|
.await
|
|
.unwrap()
|
|
.try_into()
|
|
.unwrap();
|
|
assert_eq!(response.data, private_key);
|
|
}
|
|
}
|
|
}
|