diff --git a/mnemonic-hash-checker/src/main.rs b/mnemonic-hash-checker/src/main.rs index 6afd33a..c0a436a 100644 --- a/mnemonic-hash-checker/src/main.rs +++ b/mnemonic-hash-checker/src/main.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use std::sync::Arc; use axum::{ + http::StatusCode, extract::{Path, State}, routing::get, Json, Router, @@ -42,6 +43,18 @@ struct CliConfig { #[clap(long, env, default_value = "postgres")] postgres_user: String, + + #[clap(long, env, default_value = "mnemonics")] + database_table: String, + + #[clap(long, env, default_value = "hash")] + database_column: String, +} + +struct AppState { + pool: Pool, + database_table: String, + database_column: String, } fn setup_registry() { @@ -55,11 +68,14 @@ fn setup_registry() { .init(); } -#[tracing::instrument(skip(pool))] -async fn check_hash(hash: &str, pool: &Pool) -> Result { - let client = pool.get().await?; +#[tracing::instrument(skip(state))] +async fn check_hash(hash: &str, state: &AppState) -> Result { + let client = state.pool.get().await?; + let column = &state.database_column; + let table = &state.database_table; + let formatted_query = format!("SELECT {column} FROM {table} WHERE hash = $1"); let query = client - .prepare_cached("SELECT hash FROM testing WHERE hash = $1") + .prepare_cached(formatted_query.as_str()) .await?; let rows = client.query(&query, &[&hash]).await?; if let Some(row) = rows.get(0) { @@ -71,16 +87,21 @@ async fn check_hash(hash: &str, pool: &Pool) -> Result { } // Note: Exposes *zero* information of potential errors to clients. -#[tracing::instrument(skip(hash, pool))] +#[tracing::instrument(skip(hash, state))] async fn check_hash_slug( Path(hash): Path, - State(pool): State>, -) -> Json> { - let result = check_hash(&hash, &pool).await; - if let Err(e) = &result { - debug!(%e, "Error while performing lookup"); - } - Json(result.ok()) + State(state): State>, +) -> (StatusCode, Json>) { + let result = check_hash(&hash, &state).await; + let status_code = match result.as_ref() { + Ok(true) => StatusCode::OK, + Ok(false) => StatusCode::NOT_FOUND, + Err(e) => { + debug!(%e, "Error while performing lookup"); + StatusCode::INTERNAL_SERVER_ERROR + } + }; + (status_code, Json(result.ok())) } #[tokio::main] @@ -99,13 +120,17 @@ async fn main() -> Result<()> { recycling_method: RecyclingMethod::Fast, }); let pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?; - let addr = cli_config.bind_address.unwrap_or_else(|| { - (std::net::Ipv4Addr::new(127, 0, 0, 1), 8000).into() - }); + let addr = cli_config + .bind_address + .unwrap_or_else(|| (std::net::Ipv4Addr::new(127, 0, 0, 1), 8000).into()); let app = Router::new() .route("/check/:hash", get(check_hash_slug)) - .with_state(Arc::new(pool)) + .with_state(Arc::new(AppState { + pool, + database_table: cli_config.database_table, + database_column: cli_config.database_column, + })) .layer(CatchPanicLayer::new()) .layer(TraceLayer::new_for_http());