diff --git a/keyfork-prompt/src/bin/test-basic-prompt.rs b/keyfork-prompt/src/bin/test-basic-prompt.rs index e02b68c..c12cd62 100644 --- a/keyfork-prompt/src/bin/test-basic-prompt.rs +++ b/keyfork-prompt/src/bin/test-basic-prompt.rs @@ -1,19 +1,36 @@ -use std::{ - io::{stdin, stdout}, - str::FromStr, -}; +use std::io::{stdin, stdout}; -use keyfork_mnemonic_util::Mnemonic; -use keyfork_prompt::{qrencode, Message, PromptManager}; +use keyfork_prompt::{ + validators::{mnemonic, Validator}, + PromptManager, +}; fn main() -> Result<(), Box> { let mut mgr = PromptManager::new(stdin(), stdout())?; - mgr.prompt_passphrase("Passphrase: ")?; - let string = mgr.prompt_wordlist("Mnemonic: ", &Default::default())?; - let mnemonic = Mnemonic::from_str(&string).unwrap(); - let entropy = mnemonic.entropy(); - mgr.prompt_message(&Message::Text(format!("Your entropy is: {entropy:X?}")))?; - let qrcode = qrencode::qrencode(&string)?; - mgr.prompt_message(&Message::Data(qrcode))?; + let transport_validator = mnemonic::MnemonicSetValidator { + word_lengths: [9, 24], + }; + let combine_validator = mnemonic::MnemonicSetValidator { + word_lengths: [24, 48], + }; + + let mnemonics = mgr.prompt_validated_wordlist( + "Enter a 9-word and 24-word mnemonic: ", + &Default::default(), + 3, + transport_validator.to_fn(), + )?; + assert_eq!(mnemonics[0].entropy().len(), 12); + assert_eq!(mnemonics[1].entropy().len(), 32); + + let mnemonics = mgr.prompt_validated_wordlist( + "Enter a 24 and 48-word mnemonic: ", + &Default::default(), + 3, + combine_validator.to_fn(), + )?; + assert_eq!(mnemonics[0].entropy().len(), 32); + assert_eq!(mnemonics[1].entropy().len(), 64); + Ok(()) } diff --git a/keyfork-prompt/src/lib.rs b/keyfork-prompt/src/lib.rs index 81e2ff1..af60c0a 100644 --- a/keyfork-prompt/src/lib.rs +++ b/keyfork-prompt/src/lib.rs @@ -155,7 +155,6 @@ where } Event::Key(k) => match k.code { KeyCode::Enter => { - input.push('\n'); break; } KeyCode::Backspace => { @@ -302,7 +301,6 @@ where } Event::Key(k) => match k.code { KeyCode::Enter => { - passphrase.push('\n'); break; } KeyCode::Backspace => { diff --git a/keyfork-prompt/src/validators.rs b/keyfork-prompt/src/validators.rs index d602a05..7fc4c9e 100644 --- a/keyfork-prompt/src/validators.rs +++ b/keyfork-prompt/src/validators.rs @@ -1,4 +1,5 @@ #![allow(clippy::type_complexity)] +use std::ops::RangeInclusive; pub trait Validator { type Output; @@ -19,11 +20,12 @@ pub enum PinError { InvalidCharacters(char, usize), } +/// Validate that a PIN is of a certain length and matches a range of characters. #[derive(Default, Clone)] pub struct PinValidator { pub min_length: Option, pub max_length: Option, - pub range: Option>, + pub range: Option>, } impl Validator for PinValidator { @@ -52,3 +54,117 @@ impl Validator for PinValidator { }) } } + +#[cfg(feature = "mnemonic")] +pub mod mnemonic { + use std::{mem::MaybeUninit, ops::Range, str::FromStr}; + + use super::Validator; + + use keyfork_mnemonic_util::{Mnemonic, MnemonicFromStrError}; + + #[derive(thiserror::Error, Debug)] + pub enum MnemonicValidationError { + #[error("Invalid word length: {0} does not match {1:?}")] + InvalidLength(usize, WordLength), + + #[error("{0}")] + MnemonicFromStrError(#[from] MnemonicFromStrError), + } + + #[derive(Clone, Debug)] + pub enum WordLength { + Range(Range), + Count(usize), + } + + impl WordLength { + fn matches(&self, word_count: usize) -> bool { + match self { + WordLength::Range(r) => r.contains(&word_count), + WordLength::Count(c) => c == &word_count, + } + } + } + + /// Validate a mnemonic of a range of word lengths or a specific length. + #[derive(Default, Clone)] + pub struct MnemonicValidator { + pub word_length: Option, + } + + impl Validator for MnemonicValidator { + type Output = Mnemonic; + type Error = MnemonicValidationError; + + fn to_fn(&self) -> Box Result> { + let word_length = self.word_length.clone(); + Box::new(move |s: String| match word_length.as_ref() { + Some(wl) => { + let count = s.split_whitespace().count(); + if !wl.matches(count) { + return Err(Self::Error::InvalidLength(count, wl.clone())); + } + let m = Mnemonic::from_str(&s)?; + Ok(m) + } + None => { + let m = Mnemonic::from_str(&s)?; + Ok(m) + } + }) + } + } + + #[derive(thiserror::Error, Debug)] + pub enum MnemonicSetValidationError { + #[error("Invalid word length in set {0}: {1} != expected {2}")] + InvalidSetLength(usize, usize, usize), + + #[error("Error parsing mnemonic set {0}: {1}")] + MnemonicFromStrError(usize, MnemonicFromStrError), + } + + /// Validate a set of mnemonics of a specific word length. + #[derive(Clone)] + pub struct MnemonicSetValidator { + pub word_lengths: [usize; N], + } + + impl Validator for MnemonicSetValidator { + type Output = [Mnemonic; N]; + type Error = MnemonicSetValidationError; + + fn to_fn(&self) -> Box Result> { + let word_lengths = self.word_lengths; + Box::new(move |s: String| { + let mut counter: usize = 0; + let mut output = Vec::with_capacity(N); + for (word_set, word_length) in word_lengths.into_iter().enumerate() { + let words = s + .split_whitespace() + .skip(counter) + .take(word_length) + .collect::>(); + if words.len() != word_length { + return Err(MnemonicSetValidationError::InvalidSetLength( + word_set, + words.len(), + word_length, + )); + } + let mnemonic = match Mnemonic::from_str(&words.join(" ")) { + Ok(m) => m, + Err(e) => return Err(Self::Error::MnemonicFromStrError(word_set, e)), + }; + output.push(mnemonic); + counter += word_length; + } + + Ok(output + .try_into() + .expect("vec with capacity of const N was not filled")) + }) + } + } +}