keyfork-mnemonic-util: impl FromStr for Mnemonic

This changes the actual structure of Mnemonic since it requires
exclusively owned types when implementing FromStr. Now, Mnemonic
contains an Arc. Thread safety is required because of the Tokio
multithreaded runtime, hence an Arc instead of an Rc.

This does add some level of burden for people instantiating Mnemonics,
but `Wordlist::arc(self) -> Arc<Self>` has been provided as a
convenience method to make working with mnemonics easier.
This commit is contained in:
Ryan Heywood 2023-08-24 21:56:35 -05:00
parent ee15145662
commit 76c9214d73
Signed by: ryan
GPG Key ID: 8E401478A3FBEF72
2 changed files with 78 additions and 15 deletions

View File

@ -100,8 +100,8 @@ fn main() -> Result<()> {
let entropy = &mut [0u8; 256 / 8];
rng.read_into(&mut entropy[..])?;
let wordlist = Wordlist::default();
let mnemonic = Mnemonic::from_entropy(&entropy[..bit_size / 8], &wordlist)?;
let wordlist = Wordlist::default().arc();
let mnemonic = Mnemonic::from_entropy(&entropy[..bit_size / 8], wordlist)?;
println!("{mnemonic}");
@ -119,13 +119,13 @@ mod tests {
let tests = 100_000;
let mut count = 0.;
let entropy = &mut [0u8; 256 / 8];
let wordlist = Wordlist::default();
let wordlist = Wordlist::default().arc();
let mut rng = Entropy::new().unwrap();
let mut hs = HashSet::<usize>::with_capacity(24);
for _ in 0..tests {
rng.read_into(&mut entropy[..]).unwrap();
let mnemonic = Mnemonic::from_entropy(&entropy[..256 / 8], &wordlist).unwrap();
let mnemonic = Mnemonic::from_entropy(&entropy[..256 / 8], wordlist.clone()).unwrap();
let (words, _) = mnemonic.into_inner();
hs.clear();
hs.extend(words);

View File

@ -1,3 +1,5 @@
use std::{collections::HashMap, str::FromStr, sync::Arc};
use sha2::{Digest, Sha256};
use std::{error::Error, fmt::Display};
@ -49,11 +51,21 @@ impl Default for Wordlist {
}
impl Wordlist {
/// Return an Arced version of the Wordlist
#[allow(clippy::must_use_candidate)]
pub fn arc(self) -> Arc<Self> {
Arc::new(self)
}
/// Given an index, get a word from the wordlist.
fn get_word(&self, word: usize) -> Option<&String> {
self.0.get(word)
}
fn inner(&self) -> &Vec<String> {
&self.0
}
#[cfg(test)]
fn into_inner(self) -> Vec<String> {
self.0
@ -62,12 +74,12 @@ impl Wordlist {
/// A BIP-0039 mnemonic with reference to a [`Wordlist`].
#[derive(Debug, Clone)]
pub struct Mnemonic<'a> {
pub struct Mnemonic {
words: Vec<usize>,
wordlist: &'a Wordlist,
wordlist: Arc<Wordlist>,
}
impl<'a> Display for Mnemonic<'a> {
impl Display for Mnemonic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut iter = self.words.iter().peekable();
while let Some(word_index) = iter.next() {
@ -81,21 +93,72 @@ impl<'a> Display for Mnemonic<'a> {
}
}
/// The error type representing a failure to parse a [`Mnemonic`]. These errors only occur during
/// [`Mnemonic`] creation.
#[derive(Debug, Clone)]
pub enum MnemonicFromStrError {
/// The amount of words used to parse a mnemonic was not correct.
InvalidWordCount(usize),
/// One of the words used to generate the mnemonic was not found in the default wordlist.
InvalidWord(usize),
}
impl Display for MnemonicFromStrError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MnemonicFromStrError::InvalidWordCount(count) => {
write!(f, "Incorrect word count: {count}")
}
MnemonicFromStrError::InvalidWord(index) => {
write!(f, "Unknown word at index: {index}")
}
}
}
}
impl Error for MnemonicFromStrError {}
impl FromStr for Mnemonic {
type Err = MnemonicFromStrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let wordlist = Wordlist::default().arc();
let hm: HashMap<&str, usize> = wordlist
.inner()
.iter()
.enumerate()
.map(|(a, b)| (b.as_str(), a))
.collect();
let mut words: Vec<usize> = Vec::with_capacity(24);
for (index, word) in s.split_whitespace().enumerate() {
match hm.get(&word) {
Some(id) => words.push(*id),
None => return Err(MnemonicFromStrError::InvalidWord(index)),
}
}
if ![12, 24].contains(&words.len()) {
return Err(MnemonicFromStrError::InvalidWordCount(words.len()));
}
Ok(Mnemonic { words, wordlist })
}
}
fn generate_slice_hash(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
impl<'a> Mnemonic<'a> {
impl Mnemonic {
/// Generate a [`Mnemonic`] from the provided entropy and [`Wordlist`].
///
/// # Errors
/// An error may be returned if the entropy is not within the acceptable lengths.
pub fn from_entropy(
bytes: &[u8],
wordlist: &'a Wordlist,
) -> Result<Mnemonic<'a>, MnemonicGenerationError> {
wordlist: Arc<Wordlist>,
) -> Result<Mnemonic, MnemonicGenerationError> {
let bit_count = bytes.len() * 8;
let hash = generate_slice_hash(bytes);
@ -135,7 +198,7 @@ impl<'a> Mnemonic<'a> {
}
#[must_use]
pub fn into_inner(self) -> (Vec<usize>, &'a Wordlist) {
pub fn into_inner(self) -> (Vec<usize>, Arc<Wordlist>) {
(self.words, self.wordlist)
}
}
@ -160,6 +223,7 @@ mod tests {
fn conforms_to_trezor_tests() {
let content = include_str!("data/vectors.json");
let jsonobj: serde_json::Value = serde_json::from_str(content).unwrap();
let wordlist = Wordlist::default().arc();
for test in jsonobj["english"].as_array().unwrap() {
let [ref hex_, ref seed, ..] = test.as_array().unwrap()[..] else {
@ -167,8 +231,7 @@ mod tests {
};
let hex = hex::decode(hex_.as_str().unwrap()).unwrap();
let wordlist = Wordlist::default();
let mnemonic = Mnemonic::from_entropy(&hex, &wordlist).unwrap();
let mnemonic = Mnemonic::from_entropy(&hex, wordlist.clone()).unwrap();
assert_eq!(mnemonic.to_string(), seed.as_str().unwrap());
}
@ -179,8 +242,8 @@ mod tests {
let mut random_handle = File::open("/dev/random").unwrap();
let entropy = &mut [0u8; 256 / 8];
random_handle.read_exact(&mut entropy[..]).unwrap();
let wordlist = Wordlist::default();
let my_mnemonic = super::Mnemonic::from_entropy(&entropy[..256 / 8], &wordlist).unwrap();
let wordlist = Wordlist::default().arc();
let my_mnemonic = super::Mnemonic::from_entropy(&entropy[..256 / 8], wordlist).unwrap();
let their_mnemonic = bip39::Mnemonic::from_entropy(&entropy[..256 / 8]).unwrap();
assert_eq!(my_mnemonic.to_string(), their_mnemonic.to_string());
}