keyfork-mnemonic-util: major refactor of Mnemonic type, remove cloned Wordlist

This commit is contained in:
Ryan Heywood 2024-02-19 05:20:33 -05:00
parent ed61d0685a
commit 44d8cf2098
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
9 changed files with 175 additions and 187 deletions

View File

@ -7,7 +7,7 @@ use aes_gcm::{
Aes256Gcm, KeyInit, Aes256Gcm, KeyInit,
}; };
use hkdf::Hkdf; use hkdf::Hkdf;
use keyfork_mnemonic_util::{Mnemonic, Wordlist}; use keyfork_mnemonic_util::{English, Mnemonic};
use keyfork_prompt::{ use keyfork_prompt::{
validators::{mnemonic::MnemonicSetValidator, Validator}, validators::{mnemonic::MnemonicSetValidator, Validator},
Message as PromptMessage, PromptHandler, Terminal, Message as PromptMessage, PromptHandler, Terminal,
@ -63,7 +63,6 @@ const QRCODE_ERROR: &str = "Unable to scan a QR code. Falling back to text entry
/// incompatible with the currently running version. /// incompatible with the currently running version.
pub fn remote_decrypt(w: &mut impl Write) -> Result<(), Box<dyn std::error::Error>> { pub fn remote_decrypt(w: &mut impl Write) -> Result<(), Box<dyn std::error::Error>> {
let mut pm = Terminal::new(stdin(), stdout())?; let mut pm = Terminal::new(stdin(), stdout())?;
let wordlist = Wordlist::default();
let mut iter_count = None; let mut iter_count = None;
let mut shares = vec![]; let mut shares = vec![];
@ -74,11 +73,9 @@ pub fn remote_decrypt(w: &mut impl Write) -> Result<(), Box<dyn std::error::Erro
while iter_count.is_none() || iter_count.is_some_and(|i| i > 0) { while iter_count.is_none() || iter_count.is_some_and(|i| i > 0) {
iter += 1; iter += 1;
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let nonce_mnemonic = let nonce_mnemonic = unsafe { Mnemonic::from_raw_bytes(nonce.as_slice()) };
unsafe { Mnemonic::from_raw_bytes(nonce.as_slice(), Default::default()) };
let our_key = EphemeralSecret::random(); let our_key = EphemeralSecret::random();
let key_mnemonic = let key_mnemonic = Mnemonic::from_bytes(PublicKey::from(&our_key).as_bytes())?;
Mnemonic::from_bytes(PublicKey::from(&our_key).as_bytes(), Default::default())?;
#[cfg(feature = "qrcode")] #[cfg(feature = "qrcode")]
{ {
@ -132,12 +129,12 @@ pub fn remote_decrypt(w: &mut impl Write) -> Result<(), Box<dyn std::error::Erro
word_lengths: [24, 48], word_lengths: [24, 48],
}; };
let [pubkey_mnemonic, payload_mnemonic] = pm.prompt_validated_wordlist( let [pubkey_mnemonic, payload_mnemonic] = pm
QRCODE_COULDNT_READ, .prompt_validated_wordlist::<English, _, _, _>(
&wordlist, QRCODE_COULDNT_READ,
3, 3,
validator.to_fn(), validator.to_fn(),
)?; )?;
let pubkey = pubkey_mnemonic let pubkey = pubkey_mnemonic
.as_bytes() .as_bytes()
.try_into() .try_into()

View File

@ -17,7 +17,7 @@ use keyfork_derive_openpgp::{
derive_util::{DerivationPath, PathError, VariableLengthSeed}, derive_util::{DerivationPath, PathError, VariableLengthSeed},
XPrv, XPrv,
}; };
use keyfork_mnemonic_util::{Mnemonic, MnemonicFromStrError, MnemonicGenerationError, Wordlist}; use keyfork_mnemonic_util::{English, Mnemonic, MnemonicFromStrError, MnemonicGenerationError};
use keyfork_prompt::{ use keyfork_prompt::{
validators::{mnemonic::MnemonicSetValidator, Validator}, validators::{mnemonic::MnemonicSetValidator, Validator},
Error as PromptError, Message as PromptMessage, PromptHandler, Terminal, Error as PromptError, Message as PromptMessage, PromptHandler, Terminal,
@ -471,7 +471,6 @@ pub fn decrypt(
encrypted_messages: &[EncryptedMessage], encrypted_messages: &[EncryptedMessage],
) -> Result<()> { ) -> Result<()> {
let mut pm = Terminal::new(stdin(), stdout())?; let mut pm = Terminal::new(stdin(), stdout())?;
let wordlist = Wordlist::default();
let mut nonce_data: Option<[u8; 12]> = None; let mut nonce_data: Option<[u8; 12]> = None;
let mut pubkey_data: Option<[u8; 32]> = None; let mut pubkey_data: Option<[u8; 32]> = None;
@ -496,8 +495,11 @@ pub fn decrypt(
let validator = MnemonicSetValidator { let validator = MnemonicSetValidator {
word_lengths: [9, 24], word_lengths: [9, 24],
}; };
let [nonce_mnemonic, pubkey_mnemonic] = let [nonce_mnemonic, pubkey_mnemonic] = pm.prompt_validated_wordlist::<English, _, _, _>(
pm.prompt_validated_wordlist(QRCODE_COULDNT_READ, &wordlist, 3, validator.to_fn())?; QRCODE_COULDNT_READ,
3,
validator.to_fn(),
)?;
let nonce = nonce_mnemonic let nonce = nonce_mnemonic
.as_bytes() .as_bytes()
@ -514,8 +516,7 @@ pub fn decrypt(
let nonce = Nonce::<U12>::from_slice(&nonce); let nonce = Nonce::<U12>::from_slice(&nonce);
let our_key = EphemeralSecret::random(); let our_key = EphemeralSecret::random();
let our_pubkey_mnemonic = let our_pubkey_mnemonic = Mnemonic::from_bytes(PublicKey::from(&our_key).as_bytes())?;
Mnemonic::from_bytes(PublicKey::from(&our_key).as_bytes(), Default::default())?;
let shared_secret = our_key.diffie_hellman(&PublicKey::from(pubkey)).to_bytes(); let shared_secret = our_key.diffie_hellman(&PublicKey::from(pubkey)).to_bytes();
@ -560,7 +561,7 @@ pub fn decrypt(
} }
// safety: size of out_bytes is constant and always % 4 == 0 // safety: size of out_bytes is constant and always % 4 == 0
let payload_mnemonic = unsafe { Mnemonic::from_raw_bytes(&out_bytes, Default::default()) }; let payload_mnemonic = unsafe { Mnemonic::from_raw_bytes(&out_bytes) };
#[cfg(feature = "qrcode")] #[cfg(feature = "qrcode")]
{ {

View File

@ -109,7 +109,7 @@ impl MnemonicSeedSource {
MnemonicSeedSource::Tarot => todo!(), MnemonicSeedSource::Tarot => todo!(),
MnemonicSeedSource::Dice => todo!(), MnemonicSeedSource::Dice => todo!(),
}; };
let mnemonic = keyfork_mnemonic_util::Mnemonic::from_bytes(&seed, Default::default())?; let mnemonic = keyfork_mnemonic_util::Mnemonic::from_bytes(&seed)?;
Ok(mnemonic.to_string()) Ok(mnemonic.to_string())
} }
} }

View File

@ -2,7 +2,7 @@ use super::Keyfork;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::PathBuf; use std::path::PathBuf;
use keyfork_mnemonic_util::Mnemonic; use keyfork_mnemonic_util::{English, Mnemonic};
use keyfork_shard::{ use keyfork_shard::{
openpgp::{combine, discover_certs, parse_messages}, openpgp::{combine, discover_certs, parse_messages},
remote_decrypt, remote_decrypt,
@ -69,9 +69,8 @@ impl RecoverSubcommands {
let validator = MnemonicChoiceValidator { let validator = MnemonicChoiceValidator {
word_lengths: [WordLength::Count(12), WordLength::Count(24)], word_lengths: [WordLength::Count(12), WordLength::Count(24)],
}; };
let mnemonic = term.prompt_validated_wordlist( let mnemonic = term.prompt_validated_wordlist::<English, _, _, _>(
"Mnemonic: ", "Mnemonic: ",
&Default::default(),
3, 3,
validator.to_fn(), validator.to_fn(),
)?; )?;
@ -90,7 +89,7 @@ pub struct Recover {
impl Recover { impl Recover {
pub fn handle(&self, _k: &Keyfork) -> Result<()> { pub fn handle(&self, _k: &Keyfork) -> Result<()> {
let seed = self.command.handle()?; let seed = self.command.handle()?;
let mnemonic = Mnemonic::from_bytes(&seed, Default::default())?; let mnemonic = Mnemonic::from_bytes(&seed)?;
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()

View File

@ -8,7 +8,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
input.read_line(&mut line)?; input.read_line(&mut line)?;
let decoded = smex::decode(line.trim())?; let decoded = smex::decode(line.trim())?;
let mnemonic = unsafe { Mnemonic::from_raw_bytes(&decoded, Default::default()) }; let mnemonic = unsafe { Mnemonic::from_raw_bytes(&decoded) };
println!("{mnemonic}"); println!("{mnemonic}");

View File

@ -1,10 +1,59 @@
//! Zero-dependency Mnemonic encoding and decoding. //! Zero-dependency mnemonic encoding and decoding of data.
//!
//! Mnemonics can be used to safely encode data of 32, 48, and 64 bytes as a phrase:
//!
//! ```rust
//! use keyfork_mnemonic_util::Mnemonic;
//! let data = b"Hello, world! I am a mnemonic :)";
//! assert_eq!(data.len(), 32);
//! let mnemonic = Mnemonic::from_bytes(data).unwrap();
//! println!("Our mnemonic is: {mnemonic}");
//! ```
//!
//! A mnemonic can also be parsed from a string:
//!
//! ```rust
//! use keyfork_mnemonic_util::Mnemonic;
//! use std::str::FromStr;
//!
//! let data = b"Hello, world! I am a mnemonic :)";
//! let words = "embody clock brand tattoo search desert saddle eternal
//! goddess animal banner dolphin bitter mother loyal asset
//! hover clock forward system normal mosquito trim credit";
//! let mnemonic = Mnemonic::from_str(words).unwrap();
//! assert_eq!(&data[..], mnemonic.as_bytes());
//! ```
//!
//! Mnemonics can also be used to store data of other lengths, but such functionality is not
//! verified to be safe:
//!
//! ```rust
//! use keyfork_mnemonic_util::Mnemonic;
//! let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
//! let mnemonic = unsafe { Mnemonic::from_raw_bytes(data.as_slice()) };
//! let mnemonic_text = mnemonic.to_string();
//! ```
//!
//! If given an invalid length, undefined behavior may follow, or code may panic.
//!
//! ```rust,should_panic
//! use keyfork_mnemonic_util::Mnemonic;
//! use std::str::FromStr;
//!
//! // NOTE: Data is of invalid length, 31
//! let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
//! let mnemonic = unsafe { Mnemonic::from_raw_bytes(data.as_slice()) };
//! let mnemonic_text = mnemonic.to_string();
//! // NOTE: panic happens here
//! let new_mnemonic = Mnemonic::from_str(&mnemonic_text).unwrap();
//! ```
use std::{ use std::{
error::Error, error::Error,
fmt::Display, fmt::Display,
str::FromStr, str::FromStr,
sync::{Arc, OnceLock}, sync::OnceLock,
marker::PhantomData,
}; };
use hmac::Hmac; use hmac::Hmac;
@ -44,114 +93,65 @@ impl Display for MnemonicGenerationError {
impl Error for MnemonicGenerationError {} impl Error for MnemonicGenerationError {}
/// A BIP-0039 compatible list of words. /// A trait representing a BIP-0039 wordlist, of 2048 words, with each word having a unique first
#[derive(Debug, Clone)] /// three letters.
pub struct Wordlist(Vec<String>); pub trait Wordlist: std::fmt::Debug {
/// Get a reference to a [`std::sync::OnceLock`] Self.
fn get_singleton<'a>() -> &'a Self;
static ENGLISH: OnceLock<Wordlist> = OnceLock::new(); /// Return a representation of the words in the wordlist as an array of [`str`].
fn to_str_array(&self) -> [&str; 2048];
impl Default for Wordlist {
/// Returns the English wordlist in the Bitcoin BIP-0039 specification.
fn default() -> Self {
// TODO: English is the only supported language.
ENGLISH
.get_or_init(|| {
let wordlist_file = include_str!("data/wordlist.txt");
Wordlist(
wordlist_file
.lines()
// skip 1: comment at top of file to point to BIP-0039 source.
.skip(1)
.map(|x| x.trim().to_string())
.collect(),
)
.shrank()
})
.clone()
}
} }
impl Wordlist { /// A wordlist for the English language, from the BIP-0039 dataset.
/// Return an Arced version of the Wordlist #[derive(Debug)]
#[allow(clippy::must_use_candidate)] pub struct English {
pub fn arc(self) -> Arc<Self> { words: [String; 2048],
Arc::new(self) }
static ENGLISH: OnceLock<English> = OnceLock::new();
impl Wordlist for English {
fn get_singleton<'a>() -> &'a Self {
ENGLISH.get_or_init(|| {
let wordlist_file = include_str!("data/wordlist.txt");
let mut words = wordlist_file
.lines()
.skip(1)
.map(|x| x.trim().to_string());
English {
words: std::array::from_fn(|_| words.next().expect("wordlist has 2048 words")),
}
})
} }
/// Return a shrank version of the Wordlist fn to_str_array(&self) -> [&str; 2048] {
pub fn shrank(mut self) -> Self { std::array::from_fn(|i| self.words[i].as_str())
self.0.shrink_to_fit();
self
}
/// Determine whether the Wordlist contains a given word.
pub fn contains(&self, word: &str) -> bool {
self.0.iter().any(|w| w.as_str() == word)
}
/// Given an index, get a word from the wordlist.
pub fn get_word(&self, word: usize) -> Option<&String> {
self.0.get(word)
}
/*
fn inner(&self) -> &Vec<String> {
&self.0
}
*/
#[cfg(test)]
fn into_inner(self) -> Vec<String> {
self.0
} }
} }
/// A BIP-0039 mnemonic with reference to a [`Wordlist`]. /// A BIP-0039 mnemonic with reference to a [`Wordlist`].
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct Mnemonic { pub struct MnemonicBase<W: Wordlist> {
data: Vec<u8>, data: Vec<u8>,
// words: Vec<usize>, marker: PhantomData<W>,
wordlist: Arc<Wordlist>,
} }
impl PartialEq for Mnemonic { /// A default Mnemonic using the English language.
fn eq(&self, other: &Self) -> bool { pub type Mnemonic = MnemonicBase<English>;
self.data.eq(&other.data)
}
}
impl Eq for Mnemonic {} impl<W> Display for MnemonicBase<W>
where
impl Display for Mnemonic { W: Wordlist,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let bit_count = self.data.len() * 8; let wordlist = W::get_singleton();
let mut bits = vec![false; bit_count + bit_count / 32]; let words = wordlist.to_str_array();
for byte_index in 0..bit_count / 8 { let mut iter = self
for bit_index in 0..8 { .words()
bits[byte_index * 8 + bit_index] = .into_iter()
(self.data[byte_index] & (1 << (7 - bit_index))) > 0; .filter_map(|word| words.get(word))
}
}
let mut hasher = Sha256::new();
hasher.update(&self.data);
let hash = hasher.finalize().to_vec();
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
let mut iter = bits
.chunks_exact(11)
.peekable()
.map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
})
.filter_map(|word| self.wordlist.get_word(word))
.peekable(); .peekable();
while let Some(word) = iter.next() { while let Some(word) = iter.next() {
f.write_str(word)?; f.write_str(word)?;
@ -196,17 +196,20 @@ impl Display for MnemonicFromStrError {
impl Error for MnemonicFromStrError {} impl Error for MnemonicFromStrError {}
impl FromStr for Mnemonic { impl<W> FromStr for MnemonicBase<W>
where
W: Wordlist,
{
type Err = MnemonicFromStrError; type Err = MnemonicFromStrError;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let wordlist = W::get_singleton();
let wordlist_words = wordlist.to_str_array();
let words: Vec<_> = s.split_whitespace().collect(); let words: Vec<_> = s.split_whitespace().collect();
let mut usize_words = vec![]; let mut usize_words = vec![];
let wordlist = Wordlist::default().arc();
let mut bits = vec![false; words.len() * 11]; let mut bits = vec![false; words.len() * 11];
for (index, word) in words.iter().enumerate() { for (index, word) in words.iter().enumerate() {
let word = wordlist let word = wordlist_words
.0
.iter() .iter()
.position(|w| w == word) .position(|w| w == word)
.ok_or(MnemonicFromStrError::InvalidWord(index))?; .ok_or(MnemonicFromStrError::InvalidWord(index))?;
@ -244,15 +247,14 @@ impl FromStr for Mnemonic {
} }
} }
Ok(Mnemonic { Ok(MnemonicBase { data, marker: PhantomData })
data,
// words: usize_words,
wordlist,
})
} }
} }
impl Mnemonic { impl<W> MnemonicBase<W>
where
W: Wordlist,
{
/// Generate a [`Mnemonic`] from the provided data and [`Wordlist`]. The data is expected to be /// Generate a [`Mnemonic`] from the provided data and [`Wordlist`]. The data is expected to be
/// of 128, 192, or 256 bits, as per BIP-0039. /// of 128, 192, or 256 bits, as per BIP-0039.
/// ///
@ -263,12 +265,9 @@ impl Mnemonic {
/// ```rust /// ```rust
/// use keyfork_mnemonic_util::Mnemonic; /// use keyfork_mnemonic_util::Mnemonic;
/// let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; /// let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
/// let mnemonic = Mnemonic::from_bytes(data.as_slice(), Default::default()).unwrap(); /// let mnemonic = Mnemonic::from_bytes(data.as_slice()).unwrap();
/// ``` /// ```
pub fn from_bytes( pub fn from_bytes(bytes: &[u8]) -> Result<MnemonicBase<W>, MnemonicGenerationError> {
bytes: &[u8],
wordlist: Arc<Wordlist>,
) -> Result<Mnemonic, MnemonicGenerationError> {
let bit_count = bytes.len() * 8; let bit_count = bytes.len() * 8;
if bit_count % 32 != 0 { if bit_count % 32 != 0 {
@ -279,7 +278,7 @@ impl Mnemonic {
return Err(MnemonicGenerationError::InvalidByteLength(bit_count)); return Err(MnemonicGenerationError::InvalidByteLength(bit_count));
} }
Ok(unsafe { Self::from_raw_bytes(bytes, wordlist) }) Ok(unsafe { Self::from_raw_bytes(bytes) })
} }
/// Generate a [`Mnemonic`] from the provided data and [`Wordlist`]. The data is expected to be /// Generate a [`Mnemonic`] from the provided data and [`Wordlist`]. The data is expected to be
@ -288,11 +287,8 @@ impl Mnemonic {
/// # Errors /// # Errors
/// An error may be returned if the data is not within the expected lengths. /// An error may be returned if the data is not within the expected lengths.
#[deprecated = "use Mnemonic::from_bytes"] #[deprecated = "use Mnemonic::from_bytes"]
pub fn from_entropy( pub fn from_entropy(bytes: &[u8]) -> Result<MnemonicBase<W>, MnemonicGenerationError> {
bytes: &[u8], MnemonicBase::from_bytes(bytes)
wordlist: Arc<Wordlist>,
) -> Result<Mnemonic, MnemonicGenerationError> {
Mnemonic::from_bytes(bytes, wordlist)
} }
/// Create a Mnemonic using an arbitrary length of given data. The length does not need to /// Create a Mnemonic using an arbitrary length of given data. The length does not need to
@ -308,7 +304,7 @@ impl Mnemonic {
/// ```rust /// ```rust
/// use keyfork_mnemonic_util::Mnemonic; /// use keyfork_mnemonic_util::Mnemonic;
/// let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; /// let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
/// let mnemonic = unsafe { Mnemonic::from_raw_bytes(data.as_slice(), Default::default()) }; /// let mnemonic = unsafe { Mnemonic::from_raw_bytes(data.as_slice()) };
/// let mnemonic_text = mnemonic.to_string(); /// let mnemonic_text = mnemonic.to_string();
/// ``` /// ```
/// ///
@ -320,15 +316,15 @@ impl Mnemonic {
/// ///
/// // NOTE: Data is of invalid length, 31 /// // NOTE: Data is of invalid length, 31
/// let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; /// let data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
/// let mnemonic = unsafe { Mnemonic::from_raw_bytes(data.as_slice(), Default::default()) }; /// let mnemonic = unsafe { Mnemonic::from_raw_bytes(data.as_slice()) };
/// let mnemonic_text = mnemonic.to_string(); /// let mnemonic_text = mnemonic.to_string();
/// // NOTE: panic happens here /// // NOTE: panic happens here
/// let new_mnemonic = Mnemonic::from_str(&mnemonic_text).unwrap(); /// let new_mnemonic = Mnemonic::from_str(&mnemonic_text).unwrap();
/// ``` /// ```
pub unsafe fn from_raw_bytes(bytes: &[u8], wordlist: Arc<Wordlist>) -> Mnemonic { pub unsafe fn from_raw_bytes(bytes: &[u8]) -> MnemonicBase<W> {
Mnemonic { MnemonicBase {
data: bytes.to_vec(), data: bytes.to_vec(),
wordlist, marker: PhantomData,
} }
} }
@ -341,10 +337,10 @@ impl Mnemonic {
/// properly be encoded as a mnemonic. It is assumed the caller asserts the byte count is `% 4 /// properly be encoded as a mnemonic. It is assumed the caller asserts the byte count is `% 4
/// == 0`. If the assumption is incorrect, code may panic. /// == 0`. If the assumption is incorrect, code may panic.
#[deprecated = "use Mnemonic::from_raw_bytes"] #[deprecated = "use Mnemonic::from_raw_bytes"]
pub unsafe fn from_raw_entropy(bytes: &[u8], wordlist: Arc<Wordlist>) -> Mnemonic{ pub unsafe fn from_raw_entropy(bytes: &[u8]) -> MnemonicBase<W> {
Mnemonic { MnemonicBase {
data: bytes.to_vec(), data: bytes.to_vec(),
wordlist, marker: PhantomData,
} }
} }
@ -400,7 +396,7 @@ impl Mnemonic {
/// Encode the mnemonic into a list of integers 11 bits in length, matching the length of a /// Encode the mnemonic into a list of integers 11 bits in length, matching the length of a
/// BIP-0039 wordlist. /// BIP-0039 wordlist.
pub fn words(self) -> (Vec<usize>, Arc<Wordlist>) { pub fn words(&self) -> Vec<usize> {
let bit_count = self.data.len() * 8; let bit_count = self.data.len() * 8;
let mut bits = vec![false; bit_count + bit_count / 32]; let mut bits = vec![false; bit_count + bit_count / 32];
@ -418,14 +414,14 @@ impl Mnemonic {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0; bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
} }
let words = bits.chunks_exact(11).peekable().map(|chunk| { // TODO: find a way to not have to collect to vec
bits.chunks_exact(11).peekable().map(|chunk| {
let mut num = 0usize; let mut num = 0usize;
for i in 0..11 { for i in 0..11 {
num += usize::from(chunk[10 - i]) << i; num += usize::from(chunk[10 - i]) << i;
} }
num num
}); }).collect()
(words.collect(), self.wordlist.clone())
} }
} }
@ -436,13 +432,8 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn wordlist_word_count_correct() { fn can_load_wordlist() {
let wordlist = Wordlist::default().into_inner(); let _wordlist = English::get_singleton();
assert_eq!(
wordlist.len(),
2usize.pow(11),
"Wordlist did not include correct word count"
);
} }
#[test] #[test]
@ -450,8 +441,7 @@ mod tests {
let mut random_handle = File::open("/dev/random").unwrap(); let mut random_handle = File::open("/dev/random").unwrap();
let entropy = &mut [0u8; 256 / 8]; let entropy = &mut [0u8; 256 / 8];
random_handle.read_exact(&mut entropy[..]).unwrap(); random_handle.read_exact(&mut entropy[..]).unwrap();
let wordlist = Wordlist::default().arc(); let mnemonic = super::Mnemonic::from_bytes(&entropy[..256 / 8]).unwrap();
let mnemonic = super::Mnemonic::from_bytes(&entropy[..256 / 8], wordlist).unwrap();
let new_entropy = mnemonic.as_bytes(); let new_entropy = mnemonic.as_bytes();
assert_eq!(new_entropy, entropy); assert_eq!(new_entropy, entropy);
} }
@ -460,7 +450,6 @@ mod tests {
fn conforms_to_trezor_tests() { fn conforms_to_trezor_tests() {
let content = include_str!("data/vectors.json"); let content = include_str!("data/vectors.json");
let jsonobj: serde_json::Value = serde_json::from_str(content).unwrap(); let jsonobj: serde_json::Value = serde_json::from_str(content).unwrap();
let wordlist = Wordlist::default().arc();
for test in jsonobj["english"].as_array().unwrap() { for test in jsonobj["english"].as_array().unwrap() {
let [ref hex_, ref seed, ..] = test.as_array().unwrap()[..] else { let [ref hex_, ref seed, ..] = test.as_array().unwrap()[..] else {
@ -468,7 +457,7 @@ mod tests {
}; };
let hex = hex::decode(hex_.as_str().unwrap()).unwrap(); let hex = hex::decode(hex_.as_str().unwrap()).unwrap();
let mnemonic = Mnemonic::from_bytes(&hex, wordlist.clone()).unwrap(); let mnemonic = Mnemonic::from_bytes(&hex).unwrap();
assert_eq!(mnemonic.to_string(), seed.as_str().unwrap()); assert_eq!(mnemonic.to_string(), seed.as_str().unwrap());
} }
@ -479,8 +468,7 @@ mod tests {
let mut random_handle = File::open("/dev/random").unwrap(); let mut random_handle = File::open("/dev/random").unwrap();
let entropy = &mut [0u8; 256 / 8]; let entropy = &mut [0u8; 256 / 8];
random_handle.read_exact(&mut entropy[..]).unwrap(); random_handle.read_exact(&mut entropy[..]).unwrap();
let wordlist = Wordlist::default().arc(); let my_mnemonic = Mnemonic::from_bytes(&entropy[..256 / 8]).unwrap();
let my_mnemonic = super::Mnemonic::from_bytes(&entropy[..256 / 8], wordlist).unwrap();
let their_mnemonic = bip39::Mnemonic::from_entropy(&entropy[..256 / 8]).unwrap(); let their_mnemonic = bip39::Mnemonic::from_entropy(&entropy[..256 / 8]).unwrap();
assert_eq!(my_mnemonic.to_string(), their_mnemonic.to_string()); assert_eq!(my_mnemonic.to_string(), their_mnemonic.to_string());
assert_eq!(my_mnemonic.generate_seed(None), their_mnemonic.to_seed("")); assert_eq!(my_mnemonic.generate_seed(None), their_mnemonic.to_seed(""));
@ -499,14 +487,13 @@ mod tests {
let tests = 100_000; let tests = 100_000;
let mut count = 0.; let mut count = 0.;
let entropy = &mut [0u8; 256 / 8]; let entropy = &mut [0u8; 256 / 8];
let wordlist = Wordlist::default().arc();
let mut random = std::fs::File::open("/dev/urandom").unwrap(); let mut random = std::fs::File::open("/dev/urandom").unwrap();
let mut hs = HashSet::<usize>::with_capacity(24); let mut hs = HashSet::<usize>::with_capacity(24);
for _ in 0..tests { for _ in 0..tests {
random.read_exact(&mut entropy[..]).unwrap(); random.read_exact(&mut entropy[..]).unwrap();
let mnemonic = Mnemonic::from_bytes(&entropy[..256 / 8], wordlist.clone()).unwrap(); let mnemonic = Mnemonic::from_bytes(&entropy[..256 / 8]).unwrap();
let (words, _) = mnemonic.words(); let words = mnemonic.words();
hs.clear(); hs.clear();
hs.extend(words); hs.extend(words);
if hs.len() != 24 { if hs.len() != 24 {
@ -534,11 +521,10 @@ mod tests {
#[test] #[test]
fn can_do_up_to_1024_bits() { fn can_do_up_to_1024_bits() {
let entropy = &mut [0u8; 128]; let entropy = &mut [0u8; 128];
let wordlist = Wordlist::default().arc();
let mut random = std::fs::File::open("/dev/urandom").unwrap(); let mut random = std::fs::File::open("/dev/urandom").unwrap();
random.read_exact(&mut entropy[..]).unwrap(); random.read_exact(&mut entropy[..]).unwrap();
let mnemonic = unsafe { Mnemonic::from_raw_bytes(&entropy[..], wordlist.clone()) }; let mnemonic = unsafe { Mnemonic::from_raw_bytes(&entropy[..]) };
let (words, _) = mnemonic.words(); let words = mnemonic.words();
assert!(words.len() == 96); assert!(words.len() == 96);
} }
} }

View File

@ -7,6 +7,8 @@ use keyfork_prompt::{
Terminal, PromptHandler, Terminal, PromptHandler,
}; };
use keyfork_mnemonic_util::English;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut mgr = Terminal::new(stdin(), stdout())?; let mut mgr = Terminal::new(stdin(), stdout())?;
let transport_validator = mnemonic::MnemonicSetValidator { let transport_validator = mnemonic::MnemonicSetValidator {
@ -16,18 +18,16 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
word_lengths: [24, 48], word_lengths: [24, 48],
}; };
let mnemonics = mgr.prompt_validated_wordlist( let mnemonics = mgr.prompt_validated_wordlist::<English, _, _, _>(
"Enter a 9-word and 24-word mnemonic: ", "Enter a 9-word and 24-word mnemonic: ",
&Default::default(),
3, 3,
transport_validator.to_fn(), transport_validator.to_fn(),
)?; )?;
assert_eq!(mnemonics[0].as_bytes().len(), 12); assert_eq!(mnemonics[0].as_bytes().len(), 12);
assert_eq!(mnemonics[1].as_bytes().len(), 32); assert_eq!(mnemonics[1].as_bytes().len(), 32);
let mnemonics = mgr.prompt_validated_wordlist( let mnemonics = mgr.prompt_validated_wordlist::<English, _, _, _>(
"Enter a 24 and 48-word mnemonic: ", "Enter a 24 and 48-word mnemonic: ",
&Default::default(),
3, 3,
combine_validator.to_fn(), combine_validator.to_fn(),
)?; )?;

View File

@ -51,31 +51,33 @@ pub trait PromptHandler {
/// could not be read. /// could not be read.
fn prompt_input(&mut self, prompt: &str) -> Result<String>; fn prompt_input(&mut self, prompt: &str) -> Result<String>;
/// Prompt the user for input based on a wordlist. /// Prompt the user for input based on a wordlist. A language must be specified as the generic
/// parameter `X` (any type implementing [`Wordlist`]) when parsing a wordlist.
/// ///
/// # Errors /// # Errors
/// The method may return an error if the message was not able to be displayed or if the input /// The method may return an error if the message was not able to be displayed or if the input
/// could not be read. /// could not be read.
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
fn prompt_wordlist(&mut self, prompt: &str, wordlist: &Wordlist) -> Result<String>; fn prompt_wordlist<X>(&mut self, prompt: &str) -> Result<String> where X: Wordlist;
/// Prompt the user for input based on a wordlist, while validating the wordlist using a /// Prompt the user for input based on a wordlist, while validating the wordlist using a
/// provided parser function, returning the type from the parser. /// provided parser function, returning the type from the parser. A language must be specified
/// as the generic parameter `X` (any type implementing [`Wordlist`]) when parsing a wordlist.
/// ///
/// # Errors /// # Errors
/// The method may return an error if the message was not able to be displayed, if the input /// The method may return an error if the message was not able to be displayed, if the input
/// could not be read, or if the parser returned an error. /// could not be read, or if the parser returned an error.
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
fn prompt_validated_wordlist<V, F, E>( fn prompt_validated_wordlist<X, V, F, E>(
&mut self, &mut self,
prompt: &str, prompt: &str,
wordlist: &Wordlist,
retries: u8, retries: u8,
validator_fn: F, validator_fn: F,
) -> Result<V, Error> ) -> Result<V, Error>
where where
F: Fn(String) -> Result<V, E>, F: Fn(String) -> Result<V, E>,
E: std::error::Error; E: std::error::Error,
X: Wordlist;
/// Prompt the user for a passphrase, which is hidden while typing. /// Prompt the user for a passphrase, which is hidden while typing.
/// ///

View File

@ -182,20 +182,20 @@ impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + As
} }
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
fn prompt_validated_wordlist<V, F, E>( fn prompt_validated_wordlist<X, V, F, E>(
&mut self, &mut self,
prompt: &str, prompt: &str,
wordlist: &Wordlist,
retries: u8, retries: u8,
validator_fn: F, validator_fn: F,
) -> Result<V, Error> ) -> Result<V, Error>
where where
F: Fn(String) -> Result<V, E>, F: Fn(String) -> Result<V, E>,
E: std::error::Error, E: std::error::Error,
X: Wordlist,
{ {
let mut last_error = None; let mut last_error = None;
for _ in 0..retries { for _ in 0..retries {
let s = self.prompt_wordlist(prompt, wordlist)?; let s = self.prompt_wordlist::<X>(prompt)?;
match validator_fn(s) { match validator_fn(s) {
Ok(v) => return Ok(v), Ok(v) => return Ok(v),
Err(e) => { Err(e) => {
@ -214,7 +214,10 @@ impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + As
#[cfg(feature = "mnemonic")] #[cfg(feature = "mnemonic")]
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
fn prompt_wordlist(&mut self, prompt: &str, wordlist: &Wordlist) -> Result<String> { fn prompt_wordlist<X>(&mut self, prompt: &str) -> Result<String> where X: Wordlist {
let wordlist = X::get_singleton();
let words = wordlist.to_str_array();
let mut terminal = self let mut terminal = self
.lock() .lock()
.alternate_screen()? .alternate_screen()?
@ -316,7 +319,7 @@ impl<R, W> PromptHandler for Terminal<R, W> where R: Read + Sized, W: Write + As
let mut iter = printable_input.split_whitespace().peekable(); let mut iter = printable_input.split_whitespace().peekable();
while let Some(word) = iter.next() { while let Some(word) = iter.next() {
if wordlist.contains(word) { if words.contains(&word) {
terminal.queue(PrintStyledContent(word.green()))?; terminal.queue(PrintStyledContent(word.green()))?;
} else { } else {
terminal.queue(PrintStyledContent(word.red()))?; terminal.queue(PrintStyledContent(word.red()))?;