keyforkd: require hardened derivation on two highest indexes #36

Manually merged
ryan merged 1 commits from ryan/harden-derivation-on-highest-level-keys into main 2024-04-10 19:36:05 +00:00
6 changed files with 32 additions and 5 deletions

2
Cargo.lock generated
View File

@ -1879,7 +1879,7 @@ dependencies = [
[[package]] [[package]]
name = "keyforkd" name = "keyforkd"
version = "0.1.0" version = "0.1.1"
dependencies = [ dependencies = [
"bincode", "bincode",
"hex-literal", "hex-literal",

View File

@ -25,6 +25,9 @@ fn secp256k1_test_suite() {
if chain_len < 2 { if chain_len < 2 {
continue; continue;
} }
if chain.iter().take(2).any(|index| !index.is_hardened()) {
continue;
}
// Consistency check: ensure the server and the client can each derive the same // Consistency check: ensure the server and the client can each derive the same
// key using an XPrv, for all but the last XPrv, which is verified after this // key using an XPrv, for all but the last XPrv, which is verified after this
for i in 2..chain_len { for i in 2..chain_len {

View File

@ -43,6 +43,10 @@ pub enum DerivationError {
#[error("Invalid derivation length: Expected at least 2, actual: {0}")] #[error("Invalid derivation length: Expected at least 2, actual: {0}")]
InvalidDerivationLength(usize), InvalidDerivationLength(usize),
/// The derivation request did not use hardened derivation on the 2 highest indexes.
#[error("Invalid derivation paths: expected index #{0} (1) to be hardened")]
InvalidDerivationPath(usize, u32),
/// An error occurred while deriving data. /// An error occurred while deriving data.
#[error("Derivation error: {0}")] #[error("Derivation error: {0}")]
Derivation(String), Derivation(String),

View File

@ -1,6 +1,6 @@
[package] [package]
name = "keyforkd" name = "keyforkd"
version = "0.1.0" version = "0.1.1"
edition = "2021" edition = "2021"
license = "AGPL-3.0-only" license = "AGPL-3.0-only"

View File

@ -69,6 +69,18 @@ impl Service<Request> for Keyforkd {
return Err(DerivationError::InvalidDerivationLength(len).into()); return Err(DerivationError::InvalidDerivationLength(len).into());
} }
if let Some((i, unhardened_index)) = req
.path()
.iter()
.take(2)
.enumerate()
.find(|(_, index)| {
!index.is_hardened()
})
{
return Err(DerivationError::InvalidDerivationPath(i, unhardened_index.inner()).into())
}
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
if let Some(target) = guess_target(req.path()) { if let Some(target) = guess_target(req.path()) {
info!("Deriving path: {target}"); info!("Deriving path: {target}");
@ -110,6 +122,9 @@ mod tests {
if chain.len() < 2 { if chain.len() < 2 {
continue; continue;
} }
if chain.iter().take(2).any(|index| !index.is_hardened()) {
continue;
}
let req = DerivationRequest::new(DerivationAlgorithm::Secp256k1, &chain); let req = DerivationRequest::new(DerivationAlgorithm::Secp256k1, &chain);
let response: DerivationResponse = keyforkd let response: DerivationResponse = keyforkd
.ready() .ready()

View File

@ -61,7 +61,7 @@ where
)); ));
let socket_dir = tempfile::tempdir().expect(bug!("can't create tempdir")); let socket_dir = tempfile::tempdir().expect(bug!("can't create tempdir"));
let socket_path = socket_dir.path().join("keyforkd.sock"); let socket_path = socket_dir.path().join("keyforkd.sock");
rt.block_on(async move { let result = rt.block_on(async move {
let (tx, mut rx) = tokio::sync::mpsc::channel(1); let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let server_handle = tokio::spawn({ let server_handle = tokio::spawn({
let socket_path = socket_path.clone(); let socket_path = socket_path.clone();
@ -87,8 +87,13 @@ where
let result = test_handle.await; let result = test_handle.await;
server_handle.abort(); server_handle.abort();
result result
}) });
.expect(bug!("runtime could not join all threads")) if let Err(e) = result {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
}
}
Ok(())
} }
#[cfg(test)] #[cfg(test)]