diff --git a/keyfork-mnemonic-generate/src/main.rs b/keyfork-mnemonic-generate/src/main.rs index 04766d7..e10e4ba 100644 --- a/keyfork-mnemonic-generate/src/main.rs +++ b/keyfork-mnemonic-generate/src/main.rs @@ -3,10 +3,13 @@ use std::{ error::Error, fs::{read_dir, read_to_string, File}, io::Read, + fmt::Display, }; use sha2::{Digest, Sha256}; +type Result> = std::result::Result; + /// Usage: keyfork-mnemonic-generate [bitsize: 128 or 256] /// CHECKS: /// * If the system is online @@ -19,6 +22,33 @@ use sha2::{Digest, Sha256}; static WARNING_LINKS: [&str; 1] = ["https://lore.kernel.org/lkml/20211223141113.1240679-2-Jason@zx2c4.com/"]; +#[derive(Debug)] +enum MnemonicGenerationError { + InvalidByteCount(usize), + InvalidByteLength(usize), +} + +impl MnemonicGenerationError { + fn boxed(self) -> Box { + Box::new(self) + } +} + +impl Display for MnemonicGenerationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MnemonicGenerationError::InvalidByteCount(count) => { + write!(f, "Invalid byte count: {count}, must be divisible by 8") + } + MnemonicGenerationError::InvalidByteLength(count) => { + write!(f, "Invalid byte length: {count}, must be 128 or 256") + } + } + } +} + +impl Error for MnemonicGenerationError {} + fn ensure_safe_kernel_version() { let kernel_version = read_to_string("/proc/version").expect("/proc/version"); let v = kernel_version @@ -61,6 +91,64 @@ fn ensure_offline() { } } +// TODO: Can a Mnemonic be formatted using a wordlist, without allocating or without storing the +// entire word list? +struct Mnemonic { + words: Vec, + wordlist: Vec, +} + +impl Mnemonic { + fn from_entropy(bytes: &[u8]) -> Result { + let bit_count = bytes.len() * 8; + let hash = generate_slice_hash(bytes); + + if bit_count % 32 != 0 { + return Err(MnemonicGenerationError::InvalidByteCount(bit_count).boxed()); + } + // 192 supported for test suite + if ![128, 192, 256].contains(&bit_count) { + return Err(MnemonicGenerationError::InvalidByteLength(bit_count).boxed()); + } + assert_eq!(bit_count % 32, 0, "bit count must be in 32 bit increments"); + + let mut bits = vec![false; bit_count + bit_count / 32]; + + for byte_index in 0..bit_count / 8 { + for bit_index in 0..8 { + bits[byte_index * 8 + bit_index] = (bytes[byte_index] & (1 << (7 - bit_index))) > 0; + } + } + for check_bit in 0..bit_count / 32 { + bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0; + } + + assert_eq!(bits.len() % 11, 0, "unstable bit count"); + + let words = bits + .chunks_exact(11) + .map(|chunk| bitslice_to_usize(chunk.try_into().expect("11 bit chunks"))) + .collect::>(); + + let wordlist = build_wordlist(); + Ok(Mnemonic { words, wordlist }) + } +} + +impl Display for Mnemonic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut iter = self.words.iter().peekable(); + while let Some(word_index) = iter.next() { + let word = &self.wordlist[*word_index]; + write!(f, "{word}")?; + if iter.peek().is_some() { + write!(f, " ")?; + } + } + Ok(()) + } +} + fn generate_slice_hash(data: &[u8]) -> Vec { let mut hasher = Sha256::new(); hasher.update(data); @@ -84,28 +172,6 @@ fn bitslice_to_usize(bitslice: [bool; 11]) -> usize { index } -fn entropy_to_bits(bytes: &[u8]) -> Vec { - let bit_count = bytes.len() * 8; - let hash = generate_slice_hash(bytes); - - assert_eq!(bit_count % 32, 0, "bit count must be in 32 bit increments"); - - let mut bits = vec![false; bit_count + bit_count / 32]; - - for byte_index in 0..bit_count / 8 { - for bit_index in 0..8 { - bits[byte_index * 8 + bit_index] = (bytes[byte_index] & (1 << (7 - bit_index))) > 0; - } - } - for check_bit in 0..bit_count / 32 { - bits[bit_count + check_bit] = (hash[check_bit / 8] & (1 << (7 - (check_bit % 8)))) > 0; - } - - assert_eq!(bits.len() % 11, 0, "unstable bit count"); - - bits -} - fn main() -> Result<(), Box> { if !env::vars() .any(|(name, _)| name == "SHOOT_SELF_IN_FOOT" || name == "INSECURE_HARDWARE_ALLOWED") @@ -114,13 +180,6 @@ fn main() -> Result<(), Box> { ensure_offline(); } - let wordlist = build_wordlist(); - assert_eq!( - wordlist.len(), - 2usize.pow(11), - "Wordlist did not include correct word count" - ); - let bit_size: usize = env::args() .nth(1) .unwrap_or(String::from("256")) @@ -135,13 +194,9 @@ fn main() -> Result<(), Box> { let entropy = &mut [0u8; 256 / 8]; random_handle.read_exact(&mut entropy[..])?; - let seed_bits = entropy_to_bits(entropy); + let mnemonic = Mnemonic::from_entropy(entropy)?; - let words = seed_bits - .chunks_exact(11) - .map(|chunk| wordlist[bitslice_to_usize(chunk.try_into().expect("11 bit chunks"))].clone()) - .collect::>(); - println!("{}", words.join(" ")); + println!("{mnemonic}"); Ok(()) } @@ -164,7 +219,6 @@ mod tests { fn loads_mnemonics() -> Result<(), Box> { let content = include_str!("test/vectors.json"); let jsonobj: serde_json::Value = serde_json::from_str(content)?; - let wordlist = build_wordlist(); for test in jsonobj["english"].as_array().unwrap() { let [ref hex_, ref seed, ..] = test.as_array().unwrap()[..] else { @@ -172,16 +226,9 @@ mod tests { }; let hex = hex::decode(hex_.as_str().unwrap()).unwrap(); - let seed_bits = entropy_to_bits(&hex); + let mnemonic = Mnemonic::from_entropy(&hex)?; - let words = seed_bits - .chunks_exact(11) - .map(|chunk| { - wordlist[bitslice_to_usize(chunk.try_into().expect("11 bit chunks"))].clone() - }) - .collect::>(); - - assert_eq!(words.join(" "), seed.as_str().unwrap()); + assert_eq!(mnemonic.to_string(), seed.as_str().unwrap()); } Ok(()) }