diff --git a/keyfork-prompt/src/bin/test-basic-prompt.rs b/keyfork-prompt/src/bin/test-basic-prompt.rs index 5d529e8..8bf1f3b 100644 --- a/keyfork-prompt/src/bin/test-basic-prompt.rs +++ b/keyfork-prompt/src/bin/test-basic-prompt.rs @@ -2,7 +2,7 @@ use std::io::{stdin, stdout}; use keyfork_prompt::{ validators::{mnemonic, Validator}, - Terminal, + Terminal, PromptHandler, }; fn main() -> Result<(), Box> { diff --git a/keyfork-prompt/src/lib.rs b/keyfork-prompt/src/lib.rs index 4fc00e2..69fa38a 100644 --- a/keyfork-prompt/src/lib.rs +++ b/keyfork-prompt/src/lib.rs @@ -1,6 +1,6 @@ use std::{ io::{stderr, stdin, BufRead, BufReader, Read, Stderr, Stdin, Write}, - os::fd::AsRawFd, + os::fd::AsRawFd, borrow::Borrow, }; #[cfg(feature = "mnemonic")] @@ -10,9 +10,9 @@ use keyfork_crossterm::{ cursor, event::{read, DisableBracketedPaste, EnableBracketedPaste, Event, KeyCode, KeyModifiers}, style::{Print, PrintStyledContent, Stylize}, - terminal::{self, TerminalIoctl, FdTerminal, EnterAlternateScreen, LeaveAlternateScreen}, + terminal::{self, EnterAlternateScreen, FdTerminal, LeaveAlternateScreen, TerminalIoctl}, tty::IsTty, - QueueableCommand, ExecutableCommand + ExecutableCommand, QueueableCommand, }; pub mod validators; @@ -39,18 +39,57 @@ pub enum Message { Data(String), } -struct TerminalGuard<'a, R, W> where W: Write + AsRawFd { +pub trait PromptHandler { + fn prompt_input(&mut self, prompt: &str) -> Result; + + fn prompt_wordlist(&mut self, prompt: &str, wordlist: &Wordlist) -> Result; + + #[cfg(feature = "mnemonic")] + fn prompt_validated_wordlist( + &mut self, + prompt: &str, + wordlist: &Wordlist, + retries: u8, + validator_fn: F, + ) -> Result + where + F: Fn(String) -> Result, + E: std::error::Error; + + fn prompt_passphrase(&mut self, prompt: &str) -> Result; + + fn prompt_validated_passphrase( + &mut self, + prompt: &str, + retries: u8, + validator_fn: F, + ) -> Result + where + F: Fn(String) -> Result, + E: std::error::Error; + + fn prompt_message(&mut self, prompt: impl Borrow) -> Result<()>; +} + +struct TerminalGuard<'a, R, W> +where + W: Write + AsRawFd, +{ read: &'a mut BufReader, write: &'a mut W, terminal: &'a mut FdTerminal, } -impl<'a, R, W> TerminalGuard<'a, R, W> where W: Write + AsRawFd, R: Read { +impl<'a, R, W> TerminalGuard<'a, R, W> +where + W: Write + AsRawFd, + R: Read, +{ fn new(read: &'a mut BufReader, write: &'a mut W, terminal: &'a mut FdTerminal) -> Self { Self { read, write, - terminal + terminal, } } @@ -70,7 +109,11 @@ impl<'a, R, W> TerminalGuard<'a, R, W> where W: Write + AsRawFd, R: Read { } } -impl TerminalIoctl for TerminalGuard<'_, R, W> where R: Read, W: Write + AsRawFd { +impl TerminalIoctl for TerminalGuard<'_, R, W> +where + R: Read, + W: Write + AsRawFd, +{ fn enable_raw_mode(&mut self) -> std::io::Result<()> { self.terminal.enable_raw_mode() } @@ -88,13 +131,21 @@ impl TerminalIoctl for TerminalGuard<'_, R, W> where R: Read, W: Write + A } } -impl Read for TerminalGuard<'_, R, W> where R: Read, W: Write + AsRawFd { +impl Read for TerminalGuard<'_, R, W> +where + R: Read, + W: Write + AsRawFd, +{ fn read(&mut self, buf: &mut [u8]) -> std::io::Result { self.read.read(buf) } } -impl BufRead for TerminalGuard<'_, R, W> where R: Read, W: Write + AsRawFd { +impl BufRead for TerminalGuard<'_, R, W> +where + R: Read, + W: Write + AsRawFd, +{ fn fill_buf(&mut self) -> std::io::Result<&[u8]> { self.read.fill_buf() } @@ -104,7 +155,10 @@ impl BufRead for TerminalGuard<'_, R, W> where R: Read, W: Write + AsRawFd } } -impl Write for TerminalGuard<'_, R, W> where W: Write + AsRawFd { +impl Write for TerminalGuard<'_, R, W> +where + W: Write + AsRawFd, +{ fn write(&mut self, buf: &[u8]) -> std::io::Result { self.write.write(buf) } @@ -114,7 +168,10 @@ impl Write for TerminalGuard<'_, R, W> where W: Write + AsRawFd { } } -impl Drop for TerminalGuard<'_, R, W> where W: Write + AsRawFd { +impl Drop for TerminalGuard<'_, R, W> +where + W: Write + AsRawFd, +{ fn drop(&mut self) { self.write.execute(DisableBracketedPaste).unwrap(); self.write.execute(LeaveAlternateScreen).unwrap(); @@ -148,7 +205,11 @@ where TerminalGuard::new(&mut self.read, &mut self.write, &mut self.terminal) } - pub fn prompt_input(&mut self, prompt: &str) -> Result { +} + +impl PromptHandler for Terminal where R: Read + Sized, W: Write + AsRawFd + Sized { + + fn prompt_input(&mut self, prompt: &str) -> Result { let mut terminal = self.lock().alternate_screen()?; terminal .queue(terminal::Clear(terminal::ClearType::All))? @@ -170,7 +231,7 @@ where } #[cfg(feature = "mnemonic")] - pub fn prompt_validated_wordlist( + fn prompt_validated_wordlist( &mut self, prompt: &str, wordlist: &Wordlist, @@ -202,8 +263,12 @@ where #[cfg(feature = "mnemonic")] #[allow(clippy::too_many_lines)] - pub fn prompt_wordlist(&mut self, prompt: &str, wordlist: &Wordlist) -> Result { - let mut terminal = self.lock().alternate_screen()?.raw_mode()?.bracketed_paste()?; + fn prompt_wordlist(&mut self, prompt: &str, wordlist: &Wordlist) -> Result { + let mut terminal = self + .lock() + .alternate_screen()? + .raw_mode()? + .bracketed_paste()?; terminal .queue(terminal::Clear(terminal::ClearType::All))? @@ -318,8 +383,7 @@ where Ok(input) } - #[cfg(feature = "mnemonic")] - pub fn prompt_validated_passphrase( + fn prompt_validated_passphrase( &mut self, prompt: &str, retries: u8, @@ -335,7 +399,9 @@ where match validator_fn(s) { Ok(v) => return Ok(v), Err(e) => { - self.prompt_message(&Message::Text(format!("Error validating passphrase: {e}")))?; + self.prompt_message(&Message::Text(format!( + "Error validating passphrase: {e}" + )))?; let _ = last_error.insert(e); } } @@ -349,7 +415,7 @@ where } // TODO: return secrecy::Secret - pub fn prompt_passphrase(&mut self, prompt: &str) -> Result { + fn prompt_passphrase(&mut self, prompt: &str) -> Result { let mut terminal = self.lock().alternate_screen()?.raw_mode()?; terminal @@ -407,7 +473,7 @@ where Ok(passphrase) } - pub fn prompt_message(&mut self, prompt: &Message) -> Result<()> { + fn prompt_message(&mut self, prompt: impl Borrow) -> Result<()> { let mut terminal = self.lock().alternate_screen()?.raw_mode()?; loop { @@ -417,7 +483,7 @@ where .queue(terminal::Clear(terminal::ClearType::All))? .queue(cursor::MoveTo(0, 0))?; - match &prompt { + match prompt.borrow() { Message::Text(text) => { for line in text.lines() { let mut written_chars = 0; diff --git a/keyfork-shard/src/lib.rs b/keyfork-shard/src/lib.rs index 3c00078..16a1774 100644 --- a/keyfork-shard/src/lib.rs +++ b/keyfork-shard/src/lib.rs @@ -9,7 +9,7 @@ use keyfork_mnemonic_util::{Mnemonic, Wordlist}; use keyfork_prompt::{ qrencode, validators::{mnemonic::MnemonicSetValidator, Validator}, - Message as PromptMessage, Terminal, + Message as PromptMessage, Terminal, PromptHandler }; use sha2::Sha256; use sharks::{Share, Sharks}; @@ -59,12 +59,12 @@ pub fn remote_decrypt(w: &mut impl Write) -> Result<(), Box Result { let card_backend = loop { - self.pm.prompt_message(&Message::Text( + self.pm.prompt_message(Message::Text( "Please plug in a smart card and press enter".to_string(), ))?; if let Some(c) = PcscBackend::cards(None)?.next().transpose()? { break c; } self.pm - .prompt_message(&Message::Text("No smart card was found".to_string()))?; + .prompt_message(Message::Text("No smart card was found".to_string()))?; }; let mut card = Card::::new(card_backend).map_err(Error::OpenSmartCard)?; let transaction = card.transaction().map_err(Error::Transaction)?; @@ -154,7 +154,7 @@ impl SmartcardManager { } } - self.pm.prompt_message(&Message::Text( + self.pm.prompt_message(Message::Text( "Please plug in a smart card and press enter".to_string(), ))?; } @@ -266,7 +266,7 @@ impl DecryptionHelper for &mut SmartcardManager { } // NOTE: This should not be hit, because of the above validator. Err(CardError::CardStatus(StatusBytes::IncorrectParametersCommandDataField)) => { - self.pm.prompt_message(&Message::Text( + self.pm.prompt_message(Message::Text( "Invalid PIN length entered.".to_string(), ))?; } diff --git a/keyfork/src/cli/wizard.rs b/keyfork/src/cli/wizard.rs index b19dda8..eefe02d 100644 --- a/keyfork/src/cli/wizard.rs +++ b/keyfork/src/cli/wizard.rs @@ -12,7 +12,7 @@ use keyfork_derive_util::{ }; use keyfork_prompt::{ validators::{PinValidator, Validator}, - Message, Terminal, + Message, Terminal, PromptHandler, }; #[derive(thiserror::Error, Debug)] @@ -125,7 +125,7 @@ fn generate_shard_secret(threshold: u8, max: u8, keys_per_shard: u8) -> Result<( for index in 0..max { let cert = derive_key(&seed, index)?; for i in 0..keys_per_shard { - pm.prompt_message(&Message::Text(format!( + pm.prompt_message(Message::Text(format!( "Please insert key #{} for user #{}", i + 1, index + 1,