keyfork-mnemonic-util: rewrite to only process entropy on demand

This commit is contained in:
Ryan Heywood 2023-12-26 18:57:44 -05:00
parent 7eeb494819
commit 27e7aba901
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
1 changed files with 128 additions and 72 deletions

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr, sync::Arc}; use std::{error::Error, fmt::Display, str::FromStr, sync::Arc};
use hmac::Hmac; use hmac::Hmac;
use pbkdf2::pbkdf2; use pbkdf2::pbkdf2;
@ -74,9 +74,11 @@ impl Wordlist {
self.0.get(word) self.0.get(word)
} }
/*
fn inner(&self) -> &Vec<String> { fn inner(&self) -> &Vec<String> {
&self.0 &self.0
} }
*/
#[cfg(test)] #[cfg(test)]
fn into_inner(self) -> Vec<String> { fn into_inner(self) -> Vec<String> {
@ -87,18 +89,46 @@ impl Wordlist {
/// A BIP-0039 mnemonic with reference to a [`Wordlist`]. /// A BIP-0039 mnemonic with reference to a [`Wordlist`].
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Mnemonic { pub struct Mnemonic {
words: Vec<usize>, entropy: Vec<u8>,
// words: Vec<usize>,
wordlist: Arc<Wordlist>, wordlist: Arc<Wordlist>,
} }
impl Display for Mnemonic { impl Display for Mnemonic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut iter = self.words.iter().peekable(); let bit_count = self.entropy.len() * 8;
while let Some(word_index) = iter.next() { let mut bits = vec![false; bit_count + bit_count / 32];
let word = self.wordlist.get_word(*word_index).expect("word");
write!(f, "{word}")?; for byte_index in 0..bit_count / 8 {
for bit_index in 0..8 {
bits[byte_index * 8 + bit_index] =
(self.entropy[byte_index] & (1 << (7 - bit_index))) > 0;
}
}
let mut hasher = Sha256::new();
hasher.update(&self.entropy);
let hash = hasher.finalize().to_vec();
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
let mut iter = bits
.chunks_exact(11)
.peekable()
.map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
})
.filter_map(|word| self.wordlist.get_word(word))
.peekable();
while let Some(word) = iter.next() {
f.write_str(&word)?;
if iter.peek().is_some() { if iter.peek().is_some() {
write!(f, " ")?; f.write_str(" ")?;
} }
} }
Ok(()) Ok(())
@ -114,10 +144,14 @@ pub enum MnemonicFromStrError {
/// One of the words used to generate the mnemonic was not found in the default wordlist. /// One of the words used to generate the mnemonic was not found in the default wordlist.
InvalidWord(usize), InvalidWord(usize),
/// The checksum for the mnemonic did not match the given words.
InvalidChecksum,
} }
impl Display for MnemonicFromStrError { impl Display for MnemonicFromStrError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Mnemonic error: ")?;
match self { match self {
MnemonicFromStrError::InvalidWordCount(count) => { MnemonicFromStrError::InvalidWordCount(count) => {
write!(f, "Incorrect word count: {count}") write!(f, "Incorrect word count: {count}")
@ -125,6 +159,9 @@ impl Display for MnemonicFromStrError {
MnemonicFromStrError::InvalidWord(index) => { MnemonicFromStrError::InvalidWord(index) => {
write!(f, "Unknown word at index: {index}") write!(f, "Unknown word at index: {index}")
} }
MnemonicFromStrError::InvalidChecksum => {
f.write_str("Checksum of data did not match expected value")
}
} }
} }
} }
@ -135,25 +172,55 @@ impl FromStr for Mnemonic {
type Err = MnemonicFromStrError; type Err = MnemonicFromStrError;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let words: Vec<_> = s.split_whitespace().collect();
let mut usize_words = vec![];
let wordlist = Wordlist::default().arc(); let wordlist = Wordlist::default().arc();
let hm: HashMap<&str, usize> = wordlist let mut bits = vec![false; words.len() * 11];
.inner() for (index, word) in words.iter().enumerate() {
.iter() let word = wordlist
.enumerate() .0
.map(|(a, b)| (b.as_str(), a)) .iter()
.collect(); .position(|w| &w == word)
let mut words: Vec<usize> = Vec::with_capacity(24); .ok_or(MnemonicFromStrError::InvalidWord(index))?;
for (index, word) in s.split_whitespace().enumerate() { usize_words.push(word);
match hm.get(&word) { for bit in 0..11 {
Some(id) => words.push(*id), bits[index * 11 + bit] = (word & (1 << (10 - bit))) > 0;
None => return Err(MnemonicFromStrError::InvalidWord(index)),
} }
} }
// 3 words for every 32 bits
if words.len() % 3 != 0 { let mut checksum_bits = vec![false; bits.len() - (bits.len() * 32 / 33)];
return Err(MnemonicFromStrError::InvalidWordCount(words.len())); checksum_bits.copy_from_slice(&bits[bits.len() * 32 / 33..]);
// remove checksum bits
bits.truncate(bits.len() * 32 / 33);
// bits.truncate(bits.len() - bits.len() % 32);
let entropy: Vec<u8> = bits
.chunks_exact(8)
.map(|chunk| {
let mut num = 0u8;
for i in 0..8 {
num += u8::from(chunk[7 - i]) << i;
}
num
})
.collect();
let mut hasher = Sha256::new();
hasher.update(&entropy);
let hash = hasher.finalize().to_vec();
for (i, bit) in checksum_bits.iter().enumerate() {
if !hash[i / 8] & (1 << (7 - (i % 8))) == *bit as u8 {
return Err(MnemonicFromStrError::InvalidChecksum);
}
} }
Ok(Mnemonic { words, wordlist })
Ok(Mnemonic {
entropy,
// words: usize_words,
wordlist,
})
} }
} }
@ -176,61 +243,26 @@ impl Mnemonic {
return Err(MnemonicGenerationError::InvalidByteLength(bit_count)); return Err(MnemonicGenerationError::InvalidByteLength(bit_count));
} }
Ok(unsafe {Self::from_raw_entropy(bytes, wordlist)}) Ok(unsafe { Self::from_raw_entropy(bytes, wordlist) })
} }
pub unsafe fn from_raw_entropy(bytes: &[u8], wordlist: Arc<Wordlist>) -> Mnemonic { pub unsafe fn from_raw_entropy(bytes: &[u8], wordlist: Arc<Wordlist>) -> Mnemonic {
let bit_count = bytes.len() * 8; Mnemonic {
let mut bits = vec![false; bit_count + bit_count / 32]; entropy: bytes.to_vec(),
wordlist,
for byte_index in 0..bit_count / 8 {
for bit_index in 0..8 {
bits[byte_index * 8 + bit_index] = (bytes[byte_index] & (1 << (7 - bit_index))) > 0;
}
} }
}
let mut hasher = Sha256::new(); pub fn as_bytes(&self) -> &[u8] {
hasher.update(bytes); &self.entropy
let hash = hasher.finalize().to_vec(); }
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
let words = bits pub fn to_bytes(self) -> Vec<u8> {
// NOTE: Tested with all approved variants. Always divisible by 11. self.entropy
.chunks_exact(11)
.map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
})
.collect::<Vec<_>>();
Mnemonic { words, wordlist }
} }
pub fn entropy(&self) -> Vec<u8> { pub fn entropy(&self) -> Vec<u8> {
let mut bits = vec![false; self.words.len() * 11]; self.entropy.clone()
for (index, word) in self.words.iter().enumerate() {
for bit in 0..11 {
bits[index * 11 + bit] = (word & (1 << (10 - bit))) > 0;
}
}
// remove checksum bits
bits.truncate(bits.len() - bits.len() % 32);
bits.chunks_exact(8)
.map(|chunk| {
let mut num = 0u8;
for i in 0..8 {
num += u8::from(chunk[7 - i]) << i;
}
num
})
.collect()
} }
pub fn seed<'a>( pub fn seed<'a>(
@ -247,8 +279,32 @@ impl Mnemonic {
Ok(seed.to_vec()) Ok(seed.to_vec())
} }
pub fn into_inner(self) -> (Vec<usize>, Arc<Wordlist>) { pub fn words(self) -> (Vec<usize>, Arc<Wordlist>) {
(self.words, self.wordlist) let bit_count = self.entropy.len() * 8;
let mut bits = vec![false; bit_count + bit_count / 32];
for byte_index in 0..bit_count / 8 {
for bit_index in 0..8 {
bits[byte_index * 8 + bit_index] =
(self.entropy[byte_index] & (1 << (7 - bit_index))) > 0;
}
}
let mut hasher = Sha256::new();
hasher.update(&self.entropy);
let hash = hasher.finalize().to_vec();
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
let words = bits.chunks_exact(11).peekable().map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
});
(words.collect(), self.wordlist.clone())
} }
} }
@ -329,7 +385,7 @@ mod tests {
for _ in 0..tests { for _ in 0..tests {
random.read_exact(&mut entropy[..]).unwrap(); random.read_exact(&mut entropy[..]).unwrap();
let mnemonic = Mnemonic::from_entropy(&entropy[..256 / 8], wordlist.clone()).unwrap(); let mnemonic = Mnemonic::from_entropy(&entropy[..256 / 8], wordlist.clone()).unwrap();
let (words, _) = mnemonic.into_inner(); let (words, _) = mnemonic.words();
hs.clear(); hs.clear();
hs.extend(words); hs.extend(words);
if hs.len() != 24 { if hs.len() != 24 {
@ -361,7 +417,7 @@ mod tests {
let mut random = std::fs::File::open("/dev/urandom").unwrap(); let mut random = std::fs::File::open("/dev/urandom").unwrap();
random.read_exact(&mut entropy[..]).unwrap(); random.read_exact(&mut entropy[..]).unwrap();
let mnemonic = unsafe { Mnemonic::from_raw_entropy(&entropy[..], wordlist.clone()) }; let mnemonic = unsafe { Mnemonic::from_raw_entropy(&entropy[..], wordlist.clone()) };
let (words, _) = mnemonic.into_inner(); let (words, _) = mnemonic.words();
assert!(words.len() == 96); assert!(words.len() == 96);
} }
} }