From 27e7aba901c766a2d93fbea921bbc9a374871f4d Mon Sep 17 00:00:00 2001 From: ryan Date: Tue, 26 Dec 2023 18:57:44 -0500 Subject: [PATCH] keyfork-mnemonic-util: rewrite to only process entropy on demand --- keyfork-mnemonic-util/src/lib.rs | 200 ++++++++++++++++++++----------- 1 file changed, 128 insertions(+), 72 deletions(-) diff --git a/keyfork-mnemonic-util/src/lib.rs b/keyfork-mnemonic-util/src/lib.rs index 3d535e2..b6a1bb2 100644 --- a/keyfork-mnemonic-util/src/lib.rs +++ b/keyfork-mnemonic-util/src/lib.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr, sync::Arc}; +use std::{error::Error, fmt::Display, str::FromStr, sync::Arc}; use hmac::Hmac; use pbkdf2::pbkdf2; @@ -74,9 +74,11 @@ impl Wordlist { self.0.get(word) } + /* fn inner(&self) -> &Vec { &self.0 } + */ #[cfg(test)] fn into_inner(self) -> Vec { @@ -87,18 +89,46 @@ impl Wordlist { /// A BIP-0039 mnemonic with reference to a [`Wordlist`]. #[derive(Debug, Clone)] pub struct Mnemonic { - words: Vec, + entropy: Vec, + // words: Vec, wordlist: Arc, } impl Display for Mnemonic { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut iter = self.words.iter().peekable(); - while let Some(word_index) = iter.next() { - let word = self.wordlist.get_word(*word_index).expect("word"); - write!(f, "{word}")?; + let bit_count = self.entropy.len() * 8; + let mut bits = vec![false; bit_count + bit_count / 32]; + + for byte_index in 0..bit_count / 8 { + for bit_index in 0..8 { + bits[byte_index * 8 + bit_index] = + (self.entropy[byte_index] & (1 << (7 - bit_index))) > 0; + } + } + + let mut hasher = Sha256::new(); + hasher.update(&self.entropy); + let hash = hasher.finalize().to_vec(); + for check_bit in 0..bit_count / 32 { + bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0; + } + + let mut iter = bits + .chunks_exact(11) + .peekable() + .map(|chunk| { + let mut num = 0usize; + for i in 0..11 { + num += usize::from(chunk[10 - i]) << i; + } + num + }) + .filter_map(|word| self.wordlist.get_word(word)) + .peekable(); + while let Some(word) = iter.next() { + f.write_str(&word)?; if iter.peek().is_some() { - write!(f, " ")?; + f.write_str(" ")?; } } Ok(()) @@ -114,10 +144,14 @@ pub enum MnemonicFromStrError { /// One of the words used to generate the mnemonic was not found in the default wordlist. InvalidWord(usize), + + /// The checksum for the mnemonic did not match the given words. + InvalidChecksum, } impl Display for MnemonicFromStrError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Mnemonic error: ")?; match self { MnemonicFromStrError::InvalidWordCount(count) => { write!(f, "Incorrect word count: {count}") @@ -125,6 +159,9 @@ impl Display for MnemonicFromStrError { MnemonicFromStrError::InvalidWord(index) => { write!(f, "Unknown word at index: {index}") } + MnemonicFromStrError::InvalidChecksum => { + f.write_str("Checksum of data did not match expected value") + } } } } @@ -135,25 +172,55 @@ impl FromStr for Mnemonic { type Err = MnemonicFromStrError; fn from_str(s: &str) -> Result { + let words: Vec<_> = s.split_whitespace().collect(); + let mut usize_words = vec![]; let wordlist = Wordlist::default().arc(); - let hm: HashMap<&str, usize> = wordlist - .inner() - .iter() - .enumerate() - .map(|(a, b)| (b.as_str(), a)) - .collect(); - let mut words: Vec = Vec::with_capacity(24); - for (index, word) in s.split_whitespace().enumerate() { - match hm.get(&word) { - Some(id) => words.push(*id), - None => return Err(MnemonicFromStrError::InvalidWord(index)), + let mut bits = vec![false; words.len() * 11]; + for (index, word) in words.iter().enumerate() { + let word = wordlist + .0 + .iter() + .position(|w| &w == word) + .ok_or(MnemonicFromStrError::InvalidWord(index))?; + usize_words.push(word); + for bit in 0..11 { + bits[index * 11 + bit] = (word & (1 << (10 - bit))) > 0; } } - // 3 words for every 32 bits - if words.len() % 3 != 0 { - return Err(MnemonicFromStrError::InvalidWordCount(words.len())); + + let mut checksum_bits = vec![false; bits.len() - (bits.len() * 32 / 33)]; + checksum_bits.copy_from_slice(&bits[bits.len() * 32 / 33..]); + + // remove checksum bits + bits.truncate(bits.len() * 32 / 33); + // bits.truncate(bits.len() - bits.len() % 32); + + let entropy: Vec = bits + .chunks_exact(8) + .map(|chunk| { + let mut num = 0u8; + for i in 0..8 { + num += u8::from(chunk[7 - i]) << i; + } + num + }) + .collect(); + + let mut hasher = Sha256::new(); + hasher.update(&entropy); + let hash = hasher.finalize().to_vec(); + + for (i, bit) in checksum_bits.iter().enumerate() { + if !hash[i / 8] & (1 << (7 - (i % 8))) == *bit as u8 { + return Err(MnemonicFromStrError::InvalidChecksum); + } } - Ok(Mnemonic { words, wordlist }) + + Ok(Mnemonic { + entropy, + // words: usize_words, + wordlist, + }) } } @@ -176,61 +243,26 @@ impl Mnemonic { return Err(MnemonicGenerationError::InvalidByteLength(bit_count)); } - Ok(unsafe {Self::from_raw_entropy(bytes, wordlist)}) + Ok(unsafe { Self::from_raw_entropy(bytes, wordlist) }) } pub unsafe fn from_raw_entropy(bytes: &[u8], wordlist: Arc) -> Mnemonic { - let bit_count = bytes.len() * 8; - let mut bits = vec![false; bit_count + bit_count / 32]; - - for byte_index in 0..bit_count / 8 { - for bit_index in 0..8 { - bits[byte_index * 8 + bit_index] = (bytes[byte_index] & (1 << (7 - bit_index))) > 0; - } + Mnemonic { + entropy: bytes.to_vec(), + wordlist, } + } - let mut hasher = Sha256::new(); - hasher.update(bytes); - let hash = hasher.finalize().to_vec(); - for check_bit in 0..bit_count / 32 { - bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0; - } + pub fn as_bytes(&self) -> &[u8] { + &self.entropy + } - let words = bits - // NOTE: Tested with all approved variants. Always divisible by 11. - .chunks_exact(11) - .map(|chunk| { - let mut num = 0usize; - for i in 0..11 { - num += usize::from(chunk[10 - i]) << i; - } - num - }) - .collect::>(); - - Mnemonic { words, wordlist } + pub fn to_bytes(self) -> Vec { + self.entropy } pub fn entropy(&self) -> Vec { - let mut bits = vec![false; self.words.len() * 11]; - for (index, word) in self.words.iter().enumerate() { - for bit in 0..11 { - bits[index * 11 + bit] = (word & (1 << (10 - bit))) > 0; - } - } - - // remove checksum bits - bits.truncate(bits.len() - bits.len() % 32); - - bits.chunks_exact(8) - .map(|chunk| { - let mut num = 0u8; - for i in 0..8 { - num += u8::from(chunk[7 - i]) << i; - } - num - }) - .collect() + self.entropy.clone() } pub fn seed<'a>( @@ -247,8 +279,32 @@ impl Mnemonic { Ok(seed.to_vec()) } - pub fn into_inner(self) -> (Vec, Arc) { - (self.words, self.wordlist) + pub fn words(self) -> (Vec, Arc) { + let bit_count = self.entropy.len() * 8; + let mut bits = vec![false; bit_count + bit_count / 32]; + + for byte_index in 0..bit_count / 8 { + for bit_index in 0..8 { + bits[byte_index * 8 + bit_index] = + (self.entropy[byte_index] & (1 << (7 - bit_index))) > 0; + } + } + + let mut hasher = Sha256::new(); + hasher.update(&self.entropy); + let hash = hasher.finalize().to_vec(); + for check_bit in 0..bit_count / 32 { + bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0; + } + + let words = bits.chunks_exact(11).peekable().map(|chunk| { + let mut num = 0usize; + for i in 0..11 { + num += usize::from(chunk[10 - i]) << i; + } + num + }); + (words.collect(), self.wordlist.clone()) } } @@ -329,7 +385,7 @@ mod tests { for _ in 0..tests { random.read_exact(&mut entropy[..]).unwrap(); let mnemonic = Mnemonic::from_entropy(&entropy[..256 / 8], wordlist.clone()).unwrap(); - let (words, _) = mnemonic.into_inner(); + let (words, _) = mnemonic.words(); hs.clear(); hs.extend(words); if hs.len() != 24 { @@ -361,7 +417,7 @@ mod tests { let mut random = std::fs::File::open("/dev/urandom").unwrap(); random.read_exact(&mut entropy[..]).unwrap(); let mnemonic = unsafe { Mnemonic::from_raw_entropy(&entropy[..], wordlist.clone()) }; - let (words, _) = mnemonic.into_inner(); + let (words, _) = mnemonic.words(); assert!(words.len() == 96); } }