keyfork-mnemonic-util: rewrite to only process entropy on demand

This commit is contained in:
Ryan Heywood 2023-12-26 18:57:44 -05:00
parent 7eeb494819
commit 27e7aba901
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
1 changed files with 128 additions and 72 deletions

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr, sync::Arc};
use std::{error::Error, fmt::Display, str::FromStr, sync::Arc};
use hmac::Hmac;
use pbkdf2::pbkdf2;
@ -74,9 +74,11 @@ impl Wordlist {
self.0.get(word)
}
/*
fn inner(&self) -> &Vec<String> {
&self.0
}
*/
#[cfg(test)]
fn into_inner(self) -> Vec<String> {
@ -87,18 +89,46 @@ impl Wordlist {
/// A BIP-0039 mnemonic with reference to a [`Wordlist`].
#[derive(Debug, Clone)]
pub struct Mnemonic {
words: Vec<usize>,
entropy: Vec<u8>,
// words: Vec<usize>,
wordlist: Arc<Wordlist>,
}
impl Display for Mnemonic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut iter = self.words.iter().peekable();
while let Some(word_index) = iter.next() {
let word = self.wordlist.get_word(*word_index).expect("word");
write!(f, "{word}")?;
let bit_count = self.entropy.len() * 8;
let mut bits = vec![false; bit_count + bit_count / 32];
for byte_index in 0..bit_count / 8 {
for bit_index in 0..8 {
bits[byte_index * 8 + bit_index] =
(self.entropy[byte_index] & (1 << (7 - bit_index))) > 0;
}
}
let mut hasher = Sha256::new();
hasher.update(&self.entropy);
let hash = hasher.finalize().to_vec();
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
let mut iter = bits
.chunks_exact(11)
.peekable()
.map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
})
.filter_map(|word| self.wordlist.get_word(word))
.peekable();
while let Some(word) = iter.next() {
f.write_str(&word)?;
if iter.peek().is_some() {
write!(f, " ")?;
f.write_str(" ")?;
}
}
Ok(())
@ -114,10 +144,14 @@ pub enum MnemonicFromStrError {
/// One of the words used to generate the mnemonic was not found in the default wordlist.
InvalidWord(usize),
/// The checksum for the mnemonic did not match the given words.
InvalidChecksum,
}
impl Display for MnemonicFromStrError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Mnemonic error: ")?;
match self {
MnemonicFromStrError::InvalidWordCount(count) => {
write!(f, "Incorrect word count: {count}")
@ -125,6 +159,9 @@ impl Display for MnemonicFromStrError {
MnemonicFromStrError::InvalidWord(index) => {
write!(f, "Unknown word at index: {index}")
}
MnemonicFromStrError::InvalidChecksum => {
f.write_str("Checksum of data did not match expected value")
}
}
}
}
@ -135,25 +172,55 @@ impl FromStr for Mnemonic {
type Err = MnemonicFromStrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let words: Vec<_> = s.split_whitespace().collect();
let mut usize_words = vec![];
let wordlist = Wordlist::default().arc();
let hm: HashMap<&str, usize> = wordlist
.inner()
.iter()
.enumerate()
.map(|(a, b)| (b.as_str(), a))
.collect();
let mut words: Vec<usize> = Vec::with_capacity(24);
for (index, word) in s.split_whitespace().enumerate() {
match hm.get(&word) {
Some(id) => words.push(*id),
None => return Err(MnemonicFromStrError::InvalidWord(index)),
let mut bits = vec![false; words.len() * 11];
for (index, word) in words.iter().enumerate() {
let word = wordlist
.0
.iter()
.position(|w| &w == word)
.ok_or(MnemonicFromStrError::InvalidWord(index))?;
usize_words.push(word);
for bit in 0..11 {
bits[index * 11 + bit] = (word & (1 << (10 - bit))) > 0;
}
}
// 3 words for every 32 bits
if words.len() % 3 != 0 {
return Err(MnemonicFromStrError::InvalidWordCount(words.len()));
let mut checksum_bits = vec![false; bits.len() - (bits.len() * 32 / 33)];
checksum_bits.copy_from_slice(&bits[bits.len() * 32 / 33..]);
// remove checksum bits
bits.truncate(bits.len() * 32 / 33);
// bits.truncate(bits.len() - bits.len() % 32);
let entropy: Vec<u8> = bits
.chunks_exact(8)
.map(|chunk| {
let mut num = 0u8;
for i in 0..8 {
num += u8::from(chunk[7 - i]) << i;
}
num
})
.collect();
let mut hasher = Sha256::new();
hasher.update(&entropy);
let hash = hasher.finalize().to_vec();
for (i, bit) in checksum_bits.iter().enumerate() {
if !hash[i / 8] & (1 << (7 - (i % 8))) == *bit as u8 {
return Err(MnemonicFromStrError::InvalidChecksum);
}
}
Ok(Mnemonic { words, wordlist })
Ok(Mnemonic {
entropy,
// words: usize_words,
wordlist,
})
}
}
@ -176,61 +243,26 @@ impl Mnemonic {
return Err(MnemonicGenerationError::InvalidByteLength(bit_count));
}
Ok(unsafe {Self::from_raw_entropy(bytes, wordlist)})
Ok(unsafe { Self::from_raw_entropy(bytes, wordlist) })
}
pub unsafe fn from_raw_entropy(bytes: &[u8], wordlist: Arc<Wordlist>) -> Mnemonic {
let bit_count = bytes.len() * 8;
let mut bits = vec![false; bit_count + bit_count / 32];
for byte_index in 0..bit_count / 8 {
for bit_index in 0..8 {
bits[byte_index * 8 + bit_index] = (bytes[byte_index] & (1 << (7 - bit_index))) > 0;
}
Mnemonic {
entropy: bytes.to_vec(),
wordlist,
}
}
let mut hasher = Sha256::new();
hasher.update(bytes);
let hash = hasher.finalize().to_vec();
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
pub fn as_bytes(&self) -> &[u8] {
&self.entropy
}
let words = bits
// NOTE: Tested with all approved variants. Always divisible by 11.
.chunks_exact(11)
.map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
})
.collect::<Vec<_>>();
Mnemonic { words, wordlist }
pub fn to_bytes(self) -> Vec<u8> {
self.entropy
}
pub fn entropy(&self) -> Vec<u8> {
let mut bits = vec![false; self.words.len() * 11];
for (index, word) in self.words.iter().enumerate() {
for bit in 0..11 {
bits[index * 11 + bit] = (word & (1 << (10 - bit))) > 0;
}
}
// remove checksum bits
bits.truncate(bits.len() - bits.len() % 32);
bits.chunks_exact(8)
.map(|chunk| {
let mut num = 0u8;
for i in 0..8 {
num += u8::from(chunk[7 - i]) << i;
}
num
})
.collect()
self.entropy.clone()
}
pub fn seed<'a>(
@ -247,8 +279,32 @@ impl Mnemonic {
Ok(seed.to_vec())
}
pub fn into_inner(self) -> (Vec<usize>, Arc<Wordlist>) {
(self.words, self.wordlist)
pub fn words(self) -> (Vec<usize>, Arc<Wordlist>) {
let bit_count = self.entropy.len() * 8;
let mut bits = vec![false; bit_count + bit_count / 32];
for byte_index in 0..bit_count / 8 {
for bit_index in 0..8 {
bits[byte_index * 8 + bit_index] =
(self.entropy[byte_index] & (1 << (7 - bit_index))) > 0;
}
}
let mut hasher = Sha256::new();
hasher.update(&self.entropy);
let hash = hasher.finalize().to_vec();
for check_bit in 0..bit_count / 32 {
bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0;
}
let words = bits.chunks_exact(11).peekable().map(|chunk| {
let mut num = 0usize;
for i in 0..11 {
num += usize::from(chunk[10 - i]) << i;
}
num
});
(words.collect(), self.wordlist.clone())
}
}
@ -329,7 +385,7 @@ mod tests {
for _ in 0..tests {
random.read_exact(&mut entropy[..]).unwrap();
let mnemonic = Mnemonic::from_entropy(&entropy[..256 / 8], wordlist.clone()).unwrap();
let (words, _) = mnemonic.into_inner();
let (words, _) = mnemonic.words();
hs.clear();
hs.extend(words);
if hs.len() != 24 {
@ -361,7 +417,7 @@ mod tests {
let mut random = std::fs::File::open("/dev/urandom").unwrap();
random.read_exact(&mut entropy[..]).unwrap();
let mnemonic = unsafe { Mnemonic::from_raw_entropy(&entropy[..], wordlist.clone()) };
let (words, _) = mnemonic.into_inner();
let (words, _) = mnemonic.words();
assert!(words.len() == 96);
}
}