keyfork-mnemonic-util: reduce amount of generics for validated functions

This commit is contained in:
Ryan Heywood 2024-02-19 05:32:24 -05:00
parent 44d8cf2098
commit dfcf4b1740
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
7 changed files with 43 additions and 46 deletions

View File

@ -130,7 +130,7 @@ pub fn remote_decrypt(w: &mut impl Write) -> Result<(), Box<dyn std::error::Erro
}; };
let [pubkey_mnemonic, payload_mnemonic] = pm let [pubkey_mnemonic, payload_mnemonic] = pm
.prompt_validated_wordlist::<English, _, _, _>( .prompt_validated_wordlist::<English, _>(
QRCODE_COULDNT_READ, QRCODE_COULDNT_READ,
3, 3,
validator.to_fn(), validator.to_fn(),

View File

@ -495,7 +495,7 @@ pub fn decrypt(
let validator = MnemonicSetValidator { let validator = MnemonicSetValidator {
word_lengths: [9, 24], word_lengths: [9, 24],
}; };
let [nonce_mnemonic, pubkey_mnemonic] = pm.prompt_validated_wordlist::<English, _, _, _>( let [nonce_mnemonic, pubkey_mnemonic] = pm.prompt_validated_wordlist::<English, _>(
QRCODE_COULDNT_READ, QRCODE_COULDNT_READ,
3, 3,
validator.to_fn(), validator.to_fn(),

View File

@ -69,7 +69,7 @@ impl RecoverSubcommands {
let validator = MnemonicChoiceValidator { let validator = MnemonicChoiceValidator {
word_lengths: [WordLength::Count(12), WordLength::Count(24)], word_lengths: [WordLength::Count(12), WordLength::Count(24)],
}; };
let mnemonic = term.prompt_validated_wordlist::<English, _, _, _>( let mnemonic = term.prompt_validated_wordlist::<English, _>(
"Mnemonic: ", "Mnemonic: ",
3, 3,
validator.to_fn(), validator.to_fn(),

View File

@ -18,7 +18,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
word_lengths: [24, 48], word_lengths: [24, 48],
}; };
let mnemonics = mgr.prompt_validated_wordlist::<English, _, _, _>( let mnemonics = mgr.prompt_validated_wordlist::<English, _>(
"Enter a 9-word and 24-word mnemonic: ", "Enter a 9-word and 24-word mnemonic: ",
3, 3,
transport_validator.to_fn(), transport_validator.to_fn(),
@ -26,7 +26,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(mnemonics[0].as_bytes().len(), 12); assert_eq!(mnemonics[0].as_bytes().len(), 12);
assert_eq!(mnemonics[1].as_bytes().len(), 32); assert_eq!(mnemonics[1].as_bytes().len(), 32);
let mnemonics = mgr.prompt_validated_wordlist::<English, _, _, _>( let mnemonics = mgr.prompt_validated_wordlist::<English, _>(
"Enter a 24 and 48-word mnemonic: ", "Enter a 24 and 48-word mnemonic: ",
3, 3,
combine_validator.to_fn(), combine_validator.to_fn(),

View File

@ -68,15 +68,13 @@ pub trait PromptHandler {
/// The method may return an error if the message was not able to be displayed, if the input /// The method may return an error if the message was not able to be displayed, if the input
/// could not be read, or if the parser returned an error. /// could not be read, or if the parser returned an error.
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
fn prompt_validated_wordlist<X, V, F, E>( fn prompt_validated_wordlist<X, V>(
&mut self, &mut self,
prompt: &str, prompt: &str,
retries: u8, retries: u8,
validator_fn: F, validator_fn: impl Fn(String) -> Result<V, Box<dyn std::error::Error>>,
) -> Result<V, Error> ) -> Result<V, Error>
where where
F: Fn(String) -> Result<V, E>,
E: std::error::Error,
X: Wordlist; X: Wordlist;
/// Prompt the user for a passphrase, which is hidden while typing. /// Prompt the user for a passphrase, which is hidden while typing.
@ -92,15 +90,12 @@ pub trait PromptHandler {
/// # Errors /// # Errors
/// The method may return an error if the message was not able to be displayed, if the input /// The method may return an error if the message was not able to be displayed, if the input
/// could not be read, or if the parser returned an error. /// could not be read, or if the parser returned an error.
fn prompt_validated_passphrase<V, F, E>( fn prompt_validated_passphrase<V>(
&mut self, &mut self,
prompt: &str, prompt: &str,
retries: u8, retries: u8,
validator_fn: F, validator_fn: impl Fn(String) -> Result<V, Box<dyn std::error::Error>>,
) -> Result<V, Error> ) -> Result<V, Error>;
where
F: Fn(String) -> Result<V, E>,
E: std::error::Error;
/// Prompt the user with a [`Message`]. /// Prompt the user with a [`Message`].
/// ///

View File

@ -1,6 +1,7 @@
use std::{ use std::{
borrow::Borrow,
io::{stderr, stdin, BufRead, BufReader, Read, Stderr, Stdin, Write}, io::{stderr, stdin, BufRead, BufReader, Read, Stderr, Stdin, Write},
os::fd::AsRawFd, borrow::Borrow, os::fd::AsRawFd,
}; };
use keyfork_crossterm::{ use keyfork_crossterm::{
@ -12,7 +13,7 @@ use keyfork_crossterm::{
ExecutableCommand, QueueableCommand, ExecutableCommand, QueueableCommand,
}; };
use crate::{PromptHandler, Message, Wordlist, Error}; use crate::{Error, Message, PromptHandler, Wordlist};
#[allow(missing_docs)] #[allow(missing_docs)]
pub type Result<T, E = Error> = std::result::Result<T, E>; pub type Result<T, E = Error> = std::result::Result<T, E>;
@ -155,11 +156,13 @@ where
fn lock(&mut self) -> TerminalGuard<'_, R, W> { fn lock(&mut self) -> TerminalGuard<'_, R, W> {
TerminalGuard::new(&mut self.read, &mut self.write, &mut self.terminal) TerminalGuard::new(&mut self.read, &mut self.write, &mut self.terminal)
} }
} }
impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + AsRawFd + Sized { impl<R, W> PromptHandler for Terminal<R, W>
where
R: Read + Sized,
W: Write + AsRawFd + Sized,
{
fn prompt_input(&mut self, prompt: &str) -> Result<String> { fn prompt_input(&mut self, prompt: &str) -> Result<String> {
let mut terminal = self.lock().alternate_screen()?; let mut terminal = self.lock().alternate_screen()?;
terminal terminal
@ -182,15 +185,13 @@ impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + As
} }
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
fn prompt_validated_wordlist<X, V, F, E>( fn prompt_validated_wordlist<X, V>(
&mut self, &mut self,
prompt: &str, prompt: &str,
retries: u8, retries: u8,
validator_fn: F, validator_fn: impl Fn(String) -> Result<V, Box<dyn std::error::Error>>,
) -> Result<V, Error> ) -> Result<V, Error>
where where
F: Fn(String) -> Result<V, E>,
E: std::error::Error,
X: Wordlist, X: Wordlist,
{ {
let mut last_error = None; let mut last_error = None;
@ -214,7 +215,10 @@ impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + As
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
fn prompt_wordlist<X>(&mut self, prompt: &str) -> Result<String> where X: Wordlist { fn prompt_wordlist<X>(&mut self, prompt: &str) -> Result<String>
where
X: Wordlist,
{
let wordlist = X::get_singleton(); let wordlist = X::get_singleton();
let words = wordlist.to_str_array(); let words = wordlist.to_str_array();
@ -340,16 +344,12 @@ impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + As
Ok(input) Ok(input)
} }
fn prompt_validated_passphrase<V, F, E>( fn prompt_validated_passphrase<V>(
&mut self, &mut self,
prompt: &str, prompt: &str,
retries: u8, retries: u8,
validator_fn: F, validator_fn: impl Fn(String) -> Result<V, Box<dyn std::error::Error>>,
) -> Result<V, Error> ) -> Result<V, Error> {
where
F: Fn(String) -> Result<V, E>,
E: std::error::Error,
{
let mut last_error = None; let mut last_error = None;
for _ in 0..retries { for _ in 0..retries {
let s = self.prompt_passphrase(prompt)?; let s = self.prompt_passphrase(prompt)?;

View File

@ -12,7 +12,7 @@ pub trait Validator {
type Error; type Error;
/// Create a validator function from the given parameters. /// Create a validator function from the given parameters.
fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Self::Output, Self::Error>>; fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Self::Output, Box<dyn std::error::Error>>>;
} }
/// A PIN could not be validated from the given input. /// A PIN could not be validated from the given input.
@ -48,7 +48,7 @@ impl Validator for PinValidator {
type Output = String; type Output = String;
type Error = PinError; type Error = PinError;
fn to_fn(&self) -> Box<dyn Fn(String) -> Result<String, PinError>> { fn to_fn(&self) -> Box<dyn Fn(String) -> Result<String, Box<dyn std::error::Error>>> {
let min_len = self.min_length.unwrap_or(usize::MIN); let min_len = self.min_length.unwrap_or(usize::MIN);
let max_len = self.max_length.unwrap_or(usize::MAX); let max_len = self.max_length.unwrap_or(usize::MAX);
let range = self.range.clone().unwrap_or('0'..='9'); let range = self.range.clone().unwrap_or('0'..='9');
@ -56,14 +56,14 @@ impl Validator for PinValidator {
s.truncate(s.trim_end().len()); s.truncate(s.trim_end().len());
let len = s.len(); let len = s.len();
if len < min_len { if len < min_len {
return Err(PinError::TooShort(len, min_len)); return Err(Box::new(PinError::TooShort(len, min_len)));
} }
if len > max_len { if len > max_len {
return Err(PinError::TooLong(len, max_len)); return Err(Box::new(PinError::TooLong(len, max_len)));
} }
for (index, ch) in s.chars().enumerate() { for (index, ch) in s.chars().enumerate() {
if !range.contains(&ch) { if !range.contains(&ch) {
return Err(PinError::InvalidCharacters(ch, index)); return Err(Box::new(PinError::InvalidCharacters(ch, index)));
} }
} }
Ok(s) Ok(s)
@ -123,13 +123,13 @@ pub mod mnemonic {
type Output = Mnemonic; type Output = Mnemonic;
type Error = MnemonicValidationError; type Error = MnemonicValidationError;
fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Mnemonic, Self::Error>> { fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Mnemonic, Box<dyn std::error::Error>>> {
let word_length = self.word_length.clone(); let word_length = self.word_length.clone();
Box::new(move |s: String| match word_length.as_ref() { Box::new(move |s: String| match word_length.as_ref() {
Some(wl) => { Some(wl) => {
let count = s.split_whitespace().count(); let count = s.split_whitespace().count();
if !wl.matches(count) { if !wl.matches(count) {
return Err(Self::Error::InvalidLength(count, wl.clone())); return Err(Box::new(Self::Error::InvalidLength(count, wl.clone())));
} }
let m = Mnemonic::from_str(&s)?; let m = Mnemonic::from_str(&s)?;
Ok(m) Ok(m)
@ -165,7 +165,7 @@ pub mod mnemonic {
type Output = Mnemonic; type Output = Mnemonic;
type Error = MnemonicChoiceValidationError; type Error = MnemonicChoiceValidationError;
fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Self::Output, Self::Error>> { fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Self::Output, Box<dyn std::error::Error>>> {
let word_lengths = self.word_lengths.clone(); let word_lengths = self.word_lengths.clone();
Box::new(move |s: String| { Box::new(move |s: String| {
let count = s.split_whitespace().count(); let count = s.split_whitespace().count();
@ -175,10 +175,10 @@ pub mod mnemonic {
return Ok(m); return Ok(m);
} }
} }
Err(MnemonicChoiceValidationError::InvalidLength( Err(Box::new(MnemonicChoiceValidationError::InvalidLength(
count, count,
word_lengths.to_vec(), word_lengths.to_vec(),
)) )))
}) })
} }
} }
@ -207,7 +207,7 @@ pub mod mnemonic {
type Output = [Mnemonic; N]; type Output = [Mnemonic; N];
type Error = MnemonicSetValidationError; type Error = MnemonicSetValidationError;
fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Self::Output, Self::Error>> { fn to_fn(&self) -> Box<dyn Fn(String) -> Result<Self::Output, Box<dyn std::error::Error>>> {
let word_lengths = self.word_lengths; let word_lengths = self.word_lengths;
Box::new(move |s: String| { Box::new(move |s: String| {
let mut counter: usize = 0; let mut counter: usize = 0;
@ -219,15 +219,17 @@ pub mod mnemonic {
.take(word_length) .take(word_length)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if words.len() != word_length { if words.len() != word_length {
return Err(MnemonicSetValidationError::InvalidSetLength( return Err(Box::new(MnemonicSetValidationError::InvalidSetLength(
word_set, word_set,
words.len(), words.len(),
word_length, word_length,
)); )));
} }
let mnemonic = match Mnemonic::from_str(&words.join(" ")) { let mnemonic = match Mnemonic::from_str(&words.join(" ")) {
Ok(m) => m, Ok(m) => m,
Err(e) => return Err(Self::Error::MnemonicFromStrError(word_set, e)), Err(e) => {
return Err(Box::new(Self::Error::MnemonicFromStrError(word_set, e)))
}
}; };
output.push(mnemonic); output.push(mnemonic);
counter += word_length; counter += word_length;