progenitor/progenitor-impl/tests/test_specific.rs

379 lines
12 KiB
Rust

// Copyright 2023 Oxide Computer Company
use dropshot::{
endpoint, ApiDescription, ConfigDropshot, ConfigLogging,
ConfigLoggingLevel, EmptyScanParams, HttpError, HttpResponseOk,
HttpResponseUpdatedNoContent, HttpServerStarter, PaginationParams, Path,
Query, RequestContext, ResultsPage, TypedBody,
};
use futures::StreamExt;
use http::Response;
use hyper::Body;
use openapiv3::OpenAPI;
use progenitor_impl::{
space_out_items, GenerationSettings, Generator, InterfaceStyle,
};
use schemars::JsonSchema;
use serde::Deserialize;
use std::{
net::{Ipv4Addr, SocketAddr},
str::from_utf8,
sync::{Arc, Mutex},
};
fn generate_formatted(generator: &mut Generator, spec: &OpenAPI) -> String {
let content = generator.generate_tokens(&spec).unwrap();
let rustfmt_config = rustfmt_wrapper::config::Config {
normalize_doc_attributes: Some(true),
wrap_comments: Some(true),
..Default::default()
};
space_out_items(
rustfmt_wrapper::rustfmt_config(rustfmt_config, content).unwrap(),
)
.unwrap()
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct CursedPath {
#[serde(rename = "ref")]
reef: String,
#[serde(rename = "type")]
tripe: String,
#[serde(rename = "trait")]
trade: String,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct CursedQuery {
#[serde(rename = "if")]
iffy: String,
#[serde(rename = "in")]
inn: String,
#[serde(rename = "use")]
youse: String,
}
#[endpoint {
method = GET,
path = "/{ref}/{type}/{trait}",
}]
async fn renamed_parameters(
_rqctx: RequestContext<()>,
_path: Path<CursedPath>,
_query: Query<CursedQuery>,
) -> Result<HttpResponseUpdatedNoContent, HttpError> {
unreachable!();
}
/// Test parameters that conflict with Rust reserved words and therefore must
/// be renamed.
#[test]
fn test_renamed_parameters() {
let mut api = ApiDescription::new();
api.register(renamed_parameters).unwrap();
let mut out = Vec::new();
api.openapi("pagination-demo", "9000")
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator = Generator::default();
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_renamed_parameters"),
&output,
)
}
#[endpoint {
method = GET,
path = "/",
}]
async fn freeform_response(
_rqctx: RequestContext<()>,
) -> Result<Response<Body>, HttpError> {
unreachable!();
}
/// Test freeform responses.
#[test]
fn test_freeform_response() {
let mut api = ApiDescription::new();
api.register(freeform_response).unwrap();
let mut out = Vec::new();
api.openapi("pagination-demo", "9000")
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator = Generator::default();
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_freeform_response"),
&output,
)
}
#[derive(Deserialize, JsonSchema)]
#[allow(dead_code)]
struct BodyWithDefaults {
s: String,
#[serde(default)]
yes: bool,
#[serde(default = "forty_two", rename = "forty-two")]
forty_two: u32,
#[serde(default = "yes_yes")]
something: Option<bool>,
}
fn forty_two() -> u32 {
42
}
fn yes_yes() -> Option<bool> {
Some(true)
}
#[endpoint {
method = POST,
path = "/",
}]
async fn default_params(
_rqctx: RequestContext<()>,
_body: TypedBody<BodyWithDefaults>,
) -> Result<Response<Body>, HttpError> {
unreachable!();
}
/// Test default type values.
#[test]
fn test_default_params() {
let mut api = ApiDescription::new();
api.register(default_params).unwrap();
let mut out = Vec::new();
api.openapi("pagination-demo", "9000")
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator = Generator::default();
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_default_params_positional"),
&output,
);
let mut generator = Generator::new(
GenerationSettings::default().with_interface(InterfaceStyle::Builder),
);
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_default_params_builder"),
&output,
);
}
#[derive(Debug)]
struct PaginatedU32sContext {
all_values: std::ops::Range<u32>,
// Record of `(offset, limit)` pairs we received
page_pairs: Mutex<Vec<(usize, usize)>>,
}
#[endpoint {
method = GET,
path = "/",
}]
async fn paginated_u32s(
rqctx: RequestContext<Arc<PaginatedU32sContext>>,
query_params: Query<PaginationParams<EmptyScanParams, u32>>,
) -> Result<HttpResponseOk<ResultsPage<u32>>, HttpError> {
let ctx = rqctx.context();
let page_params = query_params.into_inner();
let limit = usize::try_from(
rqctx
.page_limit(&page_params)
.expect("invalid page limit")
.get(),
)
.expect("non-usize limit");
let offset = match page_params.page {
dropshot::WhichPage::First(EmptyScanParams {}) => 0,
dropshot::WhichPage::Next(offset) => {
usize::try_from(offset + 1).expect("non-usize offset")
}
};
ctx.page_pairs.lock().unwrap().push((offset, limit));
let values = ctx.all_values.clone().skip(offset).take(limit).collect();
let result =
ResultsPage::new(values, &(), |&x, &()| x).expect("bad results page");
Ok(HttpResponseOk(result))
}
#[tokio::test]
async fn test_stream_pagination() {
const TEST_NAME: &str = "test_stream_pagination";
let mut api = ApiDescription::new();
api.register(paginated_u32s).unwrap();
let mut out = Vec::new();
api.openapi(TEST_NAME, "1").write(&mut out).unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
// Test both interface styles.
let mut generator = Generator::new(
GenerationSettings::new().with_interface(InterfaceStyle::Positional),
);
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{TEST_NAME}_positional.rs"),
&output,
);
let mut generator = Generator::new(
GenerationSettings::new().with_interface(InterfaceStyle::Builder),
);
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{TEST_NAME}_builder.rs"),
&output,
);
// Run the Dropshot server.
let config_dropshot = ConfigDropshot {
bind_address: SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
..Default::default()
};
let config_logging = ConfigLogging::StderrTerminal {
level: ConfigLoggingLevel::Debug,
};
let log = config_logging
.to_logger(TEST_NAME)
.expect("failed to create logger");
let server_ctx = Arc::new(PaginatedU32sContext {
all_values: 0..35,
page_pairs: Mutex::default(),
});
let server = HttpServerStarter::new(
&config_dropshot,
api,
Arc::clone(&server_ctx),
&log,
)
.expect("failed to create server")
.start();
let server_addr = format!("http://{}", server.local_addr());
// Test the positional client.
#[allow(dead_code)]
mod gen_client_positional {
// This is weird: we're now `include!`ing the file we just used to
// confirm the generated code is what we expect. If changes are made to
// progenitor that affect this generated code, keep in mind that when
// this test executes, the above check is against what we _currently_
// produce, while this `include!` is what was on disk before the test
// ran. This can be surprising if you're running the test with
// `EXPECTORATE=overwrite`, because the above check will overwrite the
// file on disk, but then the test proceeds and gets to this point,
// where it uses what was on disk _before_ expectorate overwrote it.
include!("output/src/test_stream_pagination_positional.rs");
}
let client = gen_client_positional::Client::new(&server_addr);
let page_limit = 10.try_into().unwrap();
let mut stream = client.paginated_u32s_stream(Some(page_limit));
let mut all_values = Vec::new();
while let Some(result) = stream.next().await {
match result {
Ok(value) => {
all_values.push(value);
}
Err(err) => {
panic!("unexpected error: {err}");
}
}
}
// Ensure we got all the results we expected.
let expected_values = (0..35).collect::<Vec<_>>();
assert_eq!(expected_values, all_values);
// Ensure the server saw the page requests we expect: we should always see a
// limit of 10, and we should see offsets increasing by 10 until we get to
// (30, 10); that will return 5 items, so we should see one final (35, 10)
// for the client to confirm there are no more results.
let expected_pages = vec![(0, 10), (10, 10), (20, 10), (30, 10), (35, 10)];
assert_eq!(expected_pages, *server_ctx.page_pairs.lock().unwrap());
// Repeat the test with the builder client.
server_ctx.page_pairs.lock().unwrap().clear();
#[allow(dead_code, unused_imports)]
mod gen_client_builder {
// This is weird: we're now `include!`ing the file we just used to
// confirm the generated code is what we expect. If changes are made to
// progenitor that affect this generated code, keep in mind that when
// this test executes, the above check is against what we _currently_
// produce, while this `include!` is what was on disk before the test
// ran. This can be surprising if you're running the test with
// `EXPECTORATE=overwrite`, because the above check will overwrite the
// file on disk, but then the test proceeds and gets to this point,
// where it uses what was on disk _before_ expectorate overwrote it.
include!("output/src/test_stream_pagination_builder.rs");
}
let client = gen_client_builder::Client::new(&server_addr);
let mut stream = client.paginated_u32s().limit(page_limit).stream();
let mut all_values = Vec::new();
while let Some(result) = stream.next().await {
match result {
Ok(value) => {
all_values.push(value);
}
Err(err) => {
panic!("unexpected error: {err}");
}
}
}
// Ensure we got all the results we expected.
let expected_values = (0..35).collect::<Vec<_>>();
assert_eq!(expected_values, all_values);
// Ensure the server saw the page requests we expect: we should always see a
// limit of 10, and we should see offsets increasing by 10 until we get to
// (30, 10); that will return 5 items, so we should see one final (35, 10)
// for the client to confirm there are no more results.
let expected_pages = vec![(0, 10), (10, 10), (20, 10), (30, 10), (35, 10)];
assert_eq!(expected_pages, *server_ctx.page_pairs.lock().unwrap());
server.close().await.expect("failed to close server");
}