use std::{future::Future, pin::Pin, sync::Arc, task::Poll}; use keyfork_derive_util::request::{DerivationError, DerivationRequest, DerivationResponse}; use keyfork_mnemonic_util::Mnemonic; use tower::Service; // NOTE: All values implemented in Keyforkd must implement Clone with low overhead, either by // using an Arc or by having a small signature. This is because Service takes &mut self. #[derive(thiserror::Error, Debug)] pub enum KeyforkdRequestError { #[error("Invalid derivation length: Expected: 2, actual: {0}")] InvalidDerivationLength(usize), #[error("Derivation error: {0}")] Derivation(#[from] DerivationError), } #[derive(Clone, Debug)] pub struct Keyforkd { mnemonic: Arc, } impl Keyforkd { pub fn new(mnemonic: Mnemonic) -> Self { Self { mnemonic: Arc::new(mnemonic), } } } impl Service for Keyforkd { type Response = DerivationResponse; type Error = KeyforkdRequestError; type Future = Pin> + Send>>; fn poll_ready( &mut self, _cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { Poll::Ready(Ok(())) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))] fn call(&mut self, req: DerivationRequest) -> Self::Future { let mnemonic = self.mnemonic.clone(); Box::pin(async move { let len = req.path().len(); if len < 2 { return Err(KeyforkdRequestError::InvalidDerivationLength(len)); } req.derive_with_mnemonic(&mnemonic) .map_err(KeyforkdRequestError::from) }) } } #[cfg(test)] mod tests { use super::*; use hex_literal::hex; use keyfork_derive_util::{request::*, DerivationPath}; use keyfork_mnemonic_util::Wordlist; use std::str::FromStr; use tower::ServiceExt; #[tokio::test] async fn properly_derives_data() { // Pulled from keyfork-derive-util's tests, which is more extensively tested. let tests = [ /* * Note: Tests excluded because the derivation path is not deep enough * for the API's preferences. ( &hex!("000102030405060708090a0b0c0d0e0f")[..], DerivationPath::from_str("m").unwrap(), hex!("90046a93de5380a72b5e45010748567d5ea02bbf6522f979e05c0d8d8ca9fffb"), hex!("2b4be7f19ee27bbf30c667b642d5f4aa69fd169872f8fc3059c08ebae2eb19e7"), hex!("00a4b2856bfec510abab89753fac1ac0e1112364e7d250545963f135f2a33188ed"), ), ( &hex!("000102030405060708090a0b0c0d0e0f")[..], DerivationPath::from_str("m/0'").unwrap(), hex!("8b59aa11380b624e81507a27fedda59fea6d0b779a778918a2fd3590e16e9c69"), hex!("68e0fe46dfb67e368c75379acec591dad19df3cde26e63b93a8e704f1dade7a3"), hex!("008c8a13df77a28f3445213a0f432fde644acaa215fc72dcdf300d5efaa85d350c"), ), */ ( &hex!("000102030405060708090a0b0c0d0e0f")[..], DerivationPath::from_str("m/0'/1'/2'/2'/1000000000'").unwrap(), hex!("68789923a0cac2cd5a29172a475fe9e0fb14cd6adb5ad98a3fa70333e7afa230"), hex!("8f94d394a8e8fd6b1bc2f3f49f5c47e385281d5c17e65324b0f62483e37e8793"), hex!("003c24da049451555d51a7014a37337aa4e12d41e485abccfa46b47dfb2af54b7a"), ), ]; let wordlist = Wordlist::default().arc(); for (seed, path, _, private_key, _) in tests { let mnemonic = Mnemonic::from_entropy(&seed[..], wordlist.clone()).unwrap(); assert_eq!(mnemonic.seed(), seed); let req = DerivationRequest::new(DerivationAlgorithm::Ed25519, path); let mut keyforkd = Keyforkd::new(mnemonic); let response = keyforkd.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(response.data, private_key) } } #[should_panic] #[tokio::test] async fn errors_on_no_path() { let tests = [( &hex!("000102030405060708090a0b0c0d0e0f")[..], DerivationPath::from_str("m").unwrap(), hex!("90046a93de5380a72b5e45010748567d5ea02bbf6522f979e05c0d8d8ca9fffb"), hex!("2b4be7f19ee27bbf30c667b642d5f4aa69fd169872f8fc3059c08ebae2eb19e7"), hex!("00a4b2856bfec510abab89753fac1ac0e1112364e7d250545963f135f2a33188ed"), )]; let wordlist = Wordlist::default().arc(); for (seed, path, _, private_key, _) in tests { let mnemonic = Mnemonic::from_entropy(&seed[..], wordlist.clone()).unwrap(); assert_eq!(mnemonic.seed(), seed); let req = DerivationRequest::new(DerivationAlgorithm::Ed25519, path); let mut keyforkd = Keyforkd::new(mnemonic); let response = keyforkd.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(response.data, private_key) } } #[should_panic] #[tokio::test] async fn errors_on_short_path() { let tests = [( &hex!("000102030405060708090a0b0c0d0e0f")[..], DerivationPath::from_str("m/0'").unwrap(), hex!("8b59aa11380b624e81507a27fedda59fea6d0b779a778918a2fd3590e16e9c69"), hex!("68e0fe46dfb67e368c75379acec591dad19df3cde26e63b93a8e704f1dade7a3"), hex!("008c8a13df77a28f3445213a0f432fde644acaa215fc72dcdf300d5efaa85d350c"), )]; let wordlist = Wordlist::default().arc(); for (seed, path, _, private_key, _) in tests { let mnemonic = Mnemonic::from_entropy(&seed[..], wordlist.clone()).unwrap(); assert_eq!(mnemonic.seed(), seed); let req = DerivationRequest::new(DerivationAlgorithm::Ed25519, path); let mut keyforkd = Keyforkd::new(mnemonic); let response = keyforkd.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(response.data, private_key) } } }