From 7934be90b9738e4ea98159d80077739312d9ced0 Mon Sep 17 00:00:00 2001 From: Adam Leventhal Date: Fri, 29 Oct 2021 07:16:39 -0700 Subject: [PATCH] a variety of improvements to support omicron clients (#12) - add the start of a client support crate - add support for pre/post request hooks with consumer-specific data - suggest type names for parameter and response types in case those types are unnamed - handle more reference types by resolving them properly - improve optional parameter generation --- Cargo.lock | 27 +- Cargo.toml | 6 +- example-build/Cargo.toml | 14 +- example-macro/Cargo.toml | 12 +- example-macro/src/main.rs | 11 +- progenitor-client/Cargo.toml | 11 + progenitor-client/src/lib.rs | 43 ++ progenitor-impl/Cargo.toml | 23 +- progenitor-impl/src/lib.rs | 757 +++++++++++++-------- progenitor-impl/tests/output/buildomat.out | 128 ++-- progenitor-impl/tests/output/keeper.out | 56 +- progenitor-macro/Cargo.toml | 9 +- progenitor-macro/src/lib.rs | 88 ++- progenitor/Cargo.toml | 13 +- 14 files changed, 740 insertions(+), 458 deletions(-) create mode 100644 progenitor-client/Cargo.toml create mode 100644 progenitor-client/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 275668f..72bf6a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -609,9 +609,9 @@ checksum = "c3ca011bd0129ff4ae15cd04c4eef202cadf6c51c21e47aba319b4e0501db741" [[package]] name = "proc-macro2" -version = "1.0.30" +version = "1.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc3358ebc67bc8b7fa0c007f945b0b18226f78437d61bec735a9eb96b61ee70" +checksum = "ba508cc11742c0dc5c1659771673afbab7a0efab23aa17e854cbab0837ed0b43" dependencies = [ "unicode-xid", ] @@ -622,21 +622,27 @@ version = "0.0.0" dependencies = [ "anyhow", "getopts", - "indexmap", "openapiv3", "progenitor-impl", "progenitor-macro", - "regex", - "rustfmt-wrapper", "serde", "serde_json", ] +[[package]] +name = "progenitor-client" +version = "0.0.0" +dependencies = [ + "reqwest", + "serde_json", +] + [[package]] name = "progenitor-impl" version = "0.0.0" dependencies = [ "anyhow", + "convert_case", "expectorate", "getopts", "indexmap", @@ -657,6 +663,7 @@ name = "progenitor-macro" version = "0.0.0" dependencies = [ "openapiv3", + "proc-macro2", "progenitor-impl", "quote", "serde_json", @@ -749,9 +756,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51c732d463dd300362ffb44b7b125f299c23d2990411a4253824630ebc7467fb" +checksum = "66d2927ca2f685faf0fc620ac4834690d29e7abb153add10f5812eef20b5e280" dependencies = [ "base64", "bytes", @@ -1123,7 +1130,7 @@ checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" [[package]] name = "typify" version = "0.0.1" -source = "git+https://github.com/oxidecomputer/typify#3cf8c1bb87b29cec4615d124362fb8f0872ffebd" +source = "git+https://github.com/oxidecomputer/typify#6de8074425c1c0090a85efc6115dfa4605d123e6" dependencies = [ "typify-impl", "typify-macro", @@ -1132,7 +1139,7 @@ dependencies = [ [[package]] name = "typify-impl" version = "0.0.1" -source = "git+https://github.com/oxidecomputer/typify#3cf8c1bb87b29cec4615d124362fb8f0872ffebd" +source = "git+https://github.com/oxidecomputer/typify#6de8074425c1c0090a85efc6115dfa4605d123e6" dependencies = [ "convert_case", "proc-macro2", @@ -1146,7 +1153,7 @@ dependencies = [ [[package]] name = "typify-macro" version = "0.0.1" -source = "git+https://github.com/oxidecomputer/typify#3cf8c1bb87b29cec4615d124362fb8f0872ffebd" +source = "git+https://github.com/oxidecomputer/typify#6de8074425c1c0090a85efc6115dfa4605d123e6" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 5a483a0..9eb513d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,16 @@ [workspace] members = [ "progenitor", - "progenitor-macro", + "progenitor-client", "progenitor-impl", + "progenitor-macro", "example-build", "example-macro", ] default-members = [ "progenitor", - "progenitor-macro", + "progenitor-client", "progenitor-impl", + "progenitor-macro", ] diff --git a/example-build/Cargo.toml b/example-build/Cargo.toml index caad60b..07c06b5 100644 --- a/example-build/Cargo.toml +++ b/example-build/Cargo.toml @@ -5,13 +5,13 @@ authors = ["Adam H. Leventhal "] edition = "2018" [dependencies] -anyhow = "1.0.44" -percent-encoding = "2.1.0" -serde = { version = "1.0.130", features = ["derive"] } -reqwest = { version = "0.11.5", features = ["json", "stream"] } -uuid = { version = "0.8.2", features = ["serde", "v4"] } -chrono = { version = "0.4.19", features = ["serde"] } +anyhow = "1.0" +percent-encoding = "2.1" +serde = { version = "1.0", features = ["derive"] } +reqwest = { version = "0.11", features = ["json", "stream"] } +uuid = { version = "0.8", features = ["serde", "v4"] } +chrono = { version = "0.4", features = ["serde"] } [build-dependencies] progenitor = { path = "../progenitor" } -serde_json = "1.0.68" +serde_json = "1.0" diff --git a/example-macro/Cargo.toml b/example-macro/Cargo.toml index 2a9634e..5a8df34 100644 --- a/example-macro/Cargo.toml +++ b/example-macro/Cargo.toml @@ -6,9 +6,9 @@ edition = "2018" [dependencies] progenitor = { path = "../progenitor" } -anyhow = "1.0.44" -percent-encoding = "2.1.0" -serde = { version = "1.0.130", features = ["derive"] } -reqwest = { version = "0.11.5", features = ["json", "stream"] } -uuid = { version = "0.8.2", features = ["serde", "v4"] } -chrono = { version = "0.4.19", features = ["serde"] } \ No newline at end of file +anyhow = "1.0" +percent-encoding = "2.1" +serde = { version = "1.0", features = ["derive"] } +reqwest = { version = "0.11", features = ["json", "stream"] } +uuid = { version = "0.8", features = ["serde", "v4"] } +chrono = { version = "0.4", features = ["serde"] } diff --git a/example-macro/src/main.rs b/example-macro/src/main.rs index ec19fd5..61e9b84 100644 --- a/example-macro/src/main.rs +++ b/example-macro/src/main.rs @@ -2,6 +2,15 @@ use progenitor::generate_api; -generate_api!("../sample_openapi/keeper.json"); +generate_api!( + "../sample_openapi/keeper.json", + (), + |_, request| { + println!("doing this {:?}", request); + }, + crate::all_done +); + +fn all_done(_: &(), _result: &reqwest::Result) {} fn main() {} diff --git a/progenitor-client/Cargo.toml b/progenitor-client/Cargo.toml new file mode 100644 index 0000000..bfd7bfd --- /dev/null +++ b/progenitor-client/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "progenitor-client" +version = "0.0.0" +edition = "2018" +license = "MPL-2.0" +repository = "https://github.com/oxidecomputer/progenitor.git" +description = "An OpenAPI client generator - client support" + +[dependencies] +reqwest = "0.11" +serde_json = "1.0" diff --git a/progenitor-client/src/lib.rs b/progenitor-client/src/lib.rs new file mode 100644 index 0000000..636fad4 --- /dev/null +++ b/progenitor-client/src/lib.rs @@ -0,0 +1,43 @@ +// Copyright 2021 Oxide Computer Company + +//! Support code for generated clients. + +use std::ops::Deref; + +/// Error produced by generated client methods. +pub enum Error { + /// Indicates an error from the server, with the data, or with the + /// connection. + CommunicationError(reqwest::Error), + + /// A documented error response. + ErrorResponse(ResponseValue), + + /// A response not listed in the API description. This may represent a + /// success or failure response; check `status()::is_success()`. + UnexpectedResponse(reqwest::Response), +} + +pub struct ResponseValue { + inner: T, + response: reqwest::Response, +} + +impl ResponseValue { + #[doc(hidden)] + pub fn new(inner: T, response: reqwest::Response) -> Self { + Self { inner, response } + } + + pub fn request(&self) -> &reqwest::Response { + &self.response + } +} + +impl Deref for ResponseValue { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} diff --git a/progenitor-impl/Cargo.toml b/progenitor-impl/Cargo.toml index 412b73f..906a3e5 100644 --- a/progenitor-impl/Cargo.toml +++ b/progenitor-impl/Cargo.toml @@ -7,19 +7,20 @@ repository = "https://github.com/oxidecomputer/progenitor.git" description = "An OpenAPI client generator - core implementation" [dependencies] -anyhow = "1" +anyhow = "1.0" getopts = "0.2" -indexmap = "1.7.0" +indexmap = "1.7" openapiv3 = "1.0.0-beta.2" -proc-macro2 = "1.0.29" -quote = "1.0.9" -regex = "1.5.4" -rustfmt-wrapper = "0.1.0" -schemars = "0.8.5" -serde = { version = "1", features = [ "derive" ] } -serde_json = "1.0.68" +proc-macro2 = "1.0" +quote = "1.0" +regex = "1.5" +rustfmt-wrapper = "0.1" +schemars = "0.8" +serde = { version = "1.0", features = [ "derive" ] } +serde_json = "1.0" +convert_case = "0.4" typify = { git = "https://github.com/oxidecomputer/typify" } -thiserror = "1.0.30" +thiserror = "1.0" [dev-dependencies] -expectorate = "1.0.4" \ No newline at end of file +expectorate = "1.0" diff --git a/progenitor-impl/src/lib.rs b/progenitor-impl/src/lib.rs index 0252941..2aa7e25 100644 --- a/progenitor-impl/src/lib.rs +++ b/progenitor-impl/src/lib.rs @@ -2,7 +2,11 @@ use std::cmp::Ordering; -use openapiv3::{OpenAPI, ReferenceOr}; +use convert_case::{Case, Casing}; +use indexmap::IndexMap; +use openapiv3::{ + Components, OpenAPI, Parameter, ReferenceOr, RequestBody, Response, Schema, +}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; @@ -33,6 +37,9 @@ pub type Result = std::result::Result; #[derive(Default)] pub struct Generator { type_space: TypeSpace, + inner_type: Option, + pre_hook: Option, + post_hook: Option, } impl Generator { @@ -40,6 +47,21 @@ impl Generator { Self::default() } + pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self { + self.inner_type = Some(inner_type); + self + } + + pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self { + self.pre_hook = Some(pre_hook); + self + } + + pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self { + self.post_hook = Some(post_hook); + self + } + pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result { // Convert our components dictionary to schemars let schemas = spec @@ -55,283 +77,23 @@ impl Generator { self.type_space.set_type_mod("types"); self.type_space.add_ref_types(schemas)?; - enum ParamType { - Path, - Query, - Body, - } - let methods = spec - .operations() + .paths + .iter() + .flat_map(|(path, ref_or_item)| { + let item = ref_or_item.as_item().unwrap(); + assert!(item.parameters.is_empty()); + item.iter().map(move |(method, operation)| { + (path.as_str(), method, operation) + }) + }) .map(|(path, method, operation)| { - let mut query: Vec<(String, bool)> = Vec::new(); - let mut raw_params = operation - .parameters - .iter() - .map(|parameter| { - match parameter.item()? { - openapiv3::Parameter::Path { - parameter_data, - style: openapiv3::PathStyle::Simple, - } => { - // Path parameters MUST be required. - assert!(parameter_data.required); - - let nam = parameter_data.name.clone(); - let schema = - parameter_data.schema()?.to_schema(); - let typ = self - .type_space - .add_type_details(&schema)? - .parameter; - - Ok((ParamType::Path, nam, typ)) - } - openapiv3::Parameter::Query { - parameter_data, - allow_reserved: _, - style: openapiv3::QueryStyle::Form, - allow_empty_value, - } => { - if let Some(aev) = allow_empty_value { - if *aev { - todo!("allow empty value is a no go"); - } - } - - let nam = parameter_data.name.clone(); - let schema = - parameter_data.schema()?.to_schema(); - let mut typ = self - .type_space - .add_type_details(&schema)? - .parameter; - if !parameter_data.required { - typ = quote! { Option<#typ> }; - } - query.push(( - nam.to_string(), - !parameter_data.required, - )); - Ok((ParamType::Query, nam, typ)) - } - x => todo!("unhandled parameter type: {:#?}", x), - } - }) - .collect::>>()?; - - let mut bounds = Vec::new(); - - let (body_param, body_func) = if let Some(b) = - &operation.request_body - { - let b = b.item()?; - if b.is_binary()? { - bounds.push(quote! {B: Into}); - (Some(quote! {B}), Some(quote! { .body(body) })) - } else { - let mt = b.content_json()?; - if !mt.encoding.is_empty() { - todo!("media type encoding not empty: {:#?}", mt); - } - - if let Some(s) = &mt.schema { - let schema = s.to_schema(); - let typ = self - .type_space - .add_type_details(&schema)? - .parameter; - (Some(typ), Some(quote! { .json(body) })) - } else { - todo!("media type encoding, no schema: {:#?}", mt); - } - } - } else { - (None, None) - }; - - if let Some(body) = body_param { - raw_params.push(( - ParamType::Body, - "body".to_string(), - body, - )); - } - - let tmp = template::parse(path)?; - let names = tmp.names(); - let url_path = tmp.compile(); - - // Put parameters in a deterministic order. - raw_params.sort_by(|a, b| match (&a.0, &b.0) { - // Path params are first and are in positional order. - (ParamType::Path, ParamType::Path) => { - let aa = names.iter().position(|x| x == &a.1).unwrap(); - let bb = names.iter().position(|x| x == &b.1).unwrap(); - aa.cmp(&bb) - } - (ParamType::Path, ParamType::Query) => Ordering::Less, - (ParamType::Path, ParamType::Body) => Ordering::Less, - - // Query params are in lexicographic order. - (ParamType::Query, ParamType::Body) => Ordering::Less, - (ParamType::Query, ParamType::Query) => a.1.cmp(&b.1), - (ParamType::Query, ParamType::Path) => Ordering::Greater, - - // Body params are last and should be unique - (ParamType::Body, ParamType::Path) => Ordering::Greater, - (ParamType::Body, ParamType::Query) => Ordering::Greater, - (ParamType::Body, ParamType::Body) => { - panic!("should only be one body") - } - }); - - let (response_type, decode_response) = if operation - .responses - .responses - .len() - == 1 - { - let only = operation.responses.responses.first().unwrap(); - if !matches!(only.0, openapiv3::StatusCode::Code(200..=299)) - { - todo!("code? {:#?}", only); - } - - let i = only.1.item()?; - if !i.headers.is_empty() { - todo!("no response headers for now"); - } - - if !i.links.is_empty() { - todo!("no response links for now"); - } - - // Look at the response content. For now, support a - // single JSON-formatted response. - let typ = match ( - i.content.len(), - i.content.get("application/json"), - ) { - (0, _) => quote! { () }, - (1, Some(mt)) => { - if !mt.encoding.is_empty() { - todo!( - "media type encoding not empty: {:#?}", - mt - ); - } - - if let Some(schema) = &mt.schema { - let schema = schema.to_schema(); - self.type_space.add_type_details(&schema)?.ident - } else { - todo!( - "media type encoding, no schema: {:#?}", - mt - ); - } - } - (1, None) => { - todo!( - "response content not JSON: {:#?}", - i.content - ); - } - (_, _) => { - todo!( - "too many response contents: {:#?}", - i.content - ); - } - }; - (typ, quote! { res.json().await? }) - } else if operation.responses.responses.is_empty() { - (quote! { reqwest::Response }, quote! { res }) - } else { - todo!("responses? {:#?}", operation.responses); - }; - - let operation_id = format_ident!( - "{}", - operation.operation_id.as_deref().unwrap() - ); - - let bounds = if bounds.is_empty() { - quote! {} - } else { - quote! { - < #(#bounds),* > - } - }; - - let params = raw_params.into_iter().map(|(_, name, typ)| { - let name = format_ident!("{}", name); - quote! { - #name: #typ - } - }); - - let (query_build, query_use) = if query.is_empty() { - (quote! {}, quote! {}) - } else { - let query_items = query.iter().map(|(qn, opt)| { - if *opt { - let qn_ident = format_ident!("{}", qn); - quote! { - if let Some(v) = & #qn_ident { - query.push((#qn, v.to_string())); - } - } - } else { - quote! { - query.push((#qn, #qn.to_string())); - } - } - }); - - let query_build = quote! { - let mut query = Vec::new(); - #(#query_items)* - }; - let query_use = quote! { - .query(&query) - }; - - (query_build, query_use) - }; - - let doc_comment = format!( - "{}: {} {}", - operation.operation_id.as_deref().unwrap(), - method.to_ascii_uppercase(), - path - ); - - let method_func = format_ident!("{}", method); - - let method = quote! { - #[doc = #doc_comment] - pub async fn #operation_id #bounds ( - &self, - #(#params),* - ) -> Result<#response_type> { - #url_path - #query_build - - let res = self.client - . #method_func (url) - #body_func - #query_use - .send() - .await? - .error_for_status()?; - - Ok(#decode_response) - } - }; - - Ok(method) + self.process_operation( + operation, + &spec.components, + path, + method, + ) }) .collect::>>()?; @@ -345,9 +107,20 @@ impl Generator { ) }) .collect::>(); - types.sort_by(|a, b| a.0.cmp(&b.0)); + types.sort_by(|(a_name, _), (b_name, _)| a_name.cmp(b_name)); let types = types.into_iter().map(|(_, def)| def); + let inner_property = self.inner_type.as_ref().map(|inner| { + quote! { + inner: #inner, + } + }); + let inner_value = self.inner_type.as_ref().map(|_| { + quote! { + inner + } + }); + let file = quote! { use anyhow::Result; @@ -381,27 +154,32 @@ impl Generator { pub struct Client { baseurl: String, client: reqwest::Client, + #inner_property } impl Client { - pub fn new(baseurl: &str) -> Client { + pub fn new( + baseurl: &str, + #inner_property + ) -> Self { let dur = std::time::Duration::from_secs(15); let client = reqwest::ClientBuilder::new() .connect_timeout(dur) .timeout(dur) .build() .unwrap(); - - Client::new_with_client(baseurl, client) + Self::new_with_client(baseurl, client, #inner_value) } pub fn new_with_client( baseurl: &str, client: reqwest::Client, - ) -> Client { - Client { + #inner_property + ) -> Self { + Self { baseurl: baseurl.to_string(), client, + #inner_value } } @@ -412,6 +190,314 @@ impl Generator { Ok(file) } + fn process_operation( + &mut self, + operation: &openapiv3::Operation, + components: &Option, + path: &str, + method: &str, + ) -> Result { + enum ParamType { + Path, + Query, + Body, + } + + let mut query: Vec<(String, bool)> = Vec::new(); + let mut raw_params = operation + .parameters + .iter() + .map(|parameter| { + match parameter.item(components)? { + openapiv3::Parameter::Path { + parameter_data, + style: openapiv3::PathStyle::Simple, + } => { + // Path parameters MUST be required. + assert!(parameter_data.required); + + let nam = parameter_data.name.clone(); + let schema = parameter_data.schema()?.to_schema(); + let name = format!( + "{}{}", + sanitize( + operation.operation_id.as_ref().unwrap(), + Case::Pascal + ), + sanitize(&nam, Case::Pascal), + ); + let typ = self + .type_space + .add_type_details_with_name(&schema, Some(name))? + .parameter; + + Ok((ParamType::Path, nam, typ)) + } + openapiv3::Parameter::Query { + parameter_data, + allow_reserved: _, + style: openapiv3::QueryStyle::Form, + allow_empty_value, + } => { + if let Some(true) = allow_empty_value { + todo!("allow empty value is a no go"); + } + + let nam = parameter_data.name.clone(); + let mut schema = parameter_data.schema()?.to_schema(); + let name = format!( + "{}{}", + sanitize( + operation.operation_id.as_ref().unwrap(), + Case::Pascal + ), + sanitize(&nam, Case::Pascal), + ); + + if !parameter_data.required { + schema = make_optional(schema); + } + + let typ = self + .type_space + .add_type_details_with_name(&schema, Some(name))? + .parameter; + + query.push((nam.to_string(), !parameter_data.required)); + Ok((ParamType::Query, nam, typ)) + } + x => todo!("unhandled parameter type: {:#?}", x), + } + }) + .collect::>>()?; + let mut bounds = Vec::new(); + let (body_param, body_func) = if let Some(b) = &operation.request_body { + let b = b.item(components)?; + if b.is_binary(components)? { + bounds.push(quote! {B: Into}); + (Some(quote! {B}), Some(quote! { .body(body) })) + } else { + let mt = b.content_json()?; + if !mt.encoding.is_empty() { + todo!("media type encoding not empty: {:#?}", mt); + } + + if let Some(s) = &mt.schema { + let schema = s.to_schema(); + let name = format!( + "{}Body", + sanitize( + operation.operation_id.as_ref().unwrap(), + Case::Pascal + ) + ); + let typ = self + .type_space + .add_type_details_with_name(&schema, Some(name))? + .parameter; + (Some(typ), Some(quote! { .json(body) })) + } else { + todo!("media type encoding, no schema: {:#?}", mt); + } + } + } else { + (None, None) + }; + if let Some(body) = body_param { + raw_params.push((ParamType::Body, "body".to_string(), body)); + } + let tmp = template::parse(path)?; + let names = tmp.names(); + let url_path = tmp.compile(); + raw_params.sort_by(|a, b| match (&a.0, &b.0) { + // Path params are first and are in positional order. + (ParamType::Path, ParamType::Path) => { + let aa = names.iter().position(|x| x == &a.1).unwrap(); + let bb = names.iter().position(|x| x == &b.1).unwrap(); + aa.cmp(&bb) + } + (ParamType::Path, ParamType::Query) => Ordering::Less, + (ParamType::Path, ParamType::Body) => Ordering::Less, + + // Query params are in lexicographic order. + (ParamType::Query, ParamType::Body) => Ordering::Less, + (ParamType::Query, ParamType::Query) => a.1.cmp(&b.1), + (ParamType::Query, ParamType::Path) => Ordering::Greater, + + // Body params are last and should be unique + (ParamType::Body, ParamType::Path) => Ordering::Greater, + (ParamType::Body, ParamType::Query) => Ordering::Greater, + (ParamType::Body, ParamType::Body) => { + panic!("should only be one body") + } + }); + + let (response_type, decode_response) = + // TODO let's consider how we handle multiple responses + if operation.responses.responses.len() >= 1 { + let only = + operation.responses.responses.first().unwrap(); + if !matches!( + only.0, + openapiv3::StatusCode::Code(200..=299) + ) { + todo!("code? {:#?}", only); + } + + let i = only.1.item(components)?; + // TODO handle response headers. + + // Look at the response content. For now, support a + // single JSON-formatted response. + match ( + i.content.len(), + i.content.get("application/json"), + ) { + (0, _) => (quote! { () }, quote! { res.json().await? }), + (1, Some(mt)) => { + if !mt.encoding.is_empty() { + todo!( + "media type encoding not empty: {:#?}", + mt + ); + } + + let typ = if let Some(schema) = &mt.schema { + let schema = schema.to_schema(); + let name = format!( + "{}Response", + sanitize( + operation + .operation_id + .as_ref() + .unwrap(), + Case::Pascal + ) + ); + self.type_space + .add_type_details_with_name( + &schema, + Some(name), + )? + .ident + } else { + todo!( + "media type encoding, no schema: {:#?}", + mt + ); + }; + (typ, quote! { res.json().await? }) + } + (1, None) => { + // Non-JSON response. + (quote! { reqwest::Response }, quote! { res }) + } + (_, _) => { + todo!( + "too many response contents: {:#?}", + i.content + ); + } + } + } else if operation.responses.responses.is_empty() { + (quote! { reqwest::Response }, quote! { res }) + } else { + todo!("responses? {:#?}", operation.responses); + }; + let operation_id = format_ident!( + "{}", + sanitize(operation.operation_id.as_deref().unwrap(), Case::Snake) + ); + let bounds = if bounds.is_empty() { + quote! {} + } else { + quote! { + < #(#bounds),* > + } + }; + let params = raw_params.into_iter().map(|(_, name, typ)| { + let name = format_ident!("{}", name); + quote! { + #name: #typ + } + }); + let (query_build, query_use) = if query.is_empty() { + (quote! {}, quote! {}) + } else { + let query_items = query.iter().map(|(qn, opt)| { + if *opt { + let qn_ident = format_ident!("{}", qn); + quote! { + if let Some(v) = & #qn_ident { + query.push((#qn, v.to_string())); + } + } + } else { + quote! { + query.push((#qn, #qn.to_string())); + } + } + }); + + let query_build = quote! { + let mut query = Vec::new(); + #(#query_items)* + }; + let query_use = quote! { + .query(&query) + }; + + (query_build, query_use) + }; + let doc_comment = format!( + "{}: {} {}", + operation.operation_id.as_deref().unwrap(), + method.to_ascii_uppercase(), + path + ); + + let pre_hook = self.pre_hook.as_ref().map(|hook| { + quote! { + (#hook)(&self.inner, &request); + } + }); + let post_hook = self.post_hook.as_ref().map(|hook| { + quote! { + (#hook)(&self.inner, &result); + } + }); + + // TODO validate that method is one of the expected methods. + let method_func = format_ident!("{}", method.to_lowercase()); + let method = quote! { + #[doc = #doc_comment] + pub async fn #operation_id #bounds ( + &self, + #(#params),* + ) -> Result<#response_type> { + #url_path + #query_build + + let request = self.client + . #method_func (url) + #body_func + #query_use + .build()?; + #pre_hook + let result = self.client + .execute(request) + .await; + #post_hook + + // TODO we should do a match here for result?.status().as_u16() + let res = result?.error_for_status()?; + + Ok(#decode_response) + } + }; + Ok(method) + } + pub fn generate_text(&mut self, spec: &OpenAPI) -> Result { let output = self.generate_tokens(spec)?; @@ -453,6 +539,44 @@ impl Generator { } } +/// Make the schema optional if it isn't already. +pub fn make_optional( + schema: schemars::schema::Schema, +) -> schemars::schema::Schema { + match &schema { + // If the instance_type already includes Null then this is already + // optional. + schemars::schema::Schema::Object(schemars::schema::SchemaObject { + instance_type: Some(schemars::schema::SingleOrVec::Vec(types)), + .. + }) if types.contains(&schemars::schema::InstanceType::Null) => schema, + + // Otherwise, create a oneOf where one of the branches is the null + // type. We could potentially check to see if the schema already + // conforms to this pattern as well, but it doesn't hurt as typify will + // already reduce nested Options to a single Option. + _ => { + let null_schema = schemars::schema::Schema::Object( + schemars::schema::SchemaObject { + instance_type: Some(schemars::schema::SingleOrVec::Single( + Box::new(schemars::schema::InstanceType::Null), + )), + ..Default::default() + }, + ); + schemars::schema::Schema::Object(schemars::schema::SchemaObject { + subschemas: Some(Box::new( + schemars::schema::SubschemaValidation { + one_of: Some(vec![schema, null_schema]), + ..Default::default() + }, + )), + ..Default::default() + }) + } + } +} + trait ParameterDataExt { fn schema(&self) -> Result<&openapiv3::ReferenceOr>; } @@ -469,7 +593,7 @@ impl ParameterDataExt for openapiv3::ParameterData { } trait ExtractJsonMediaType { - fn is_binary(&self) -> Result; + fn is_binary(&self, components: &Option) -> Result; fn content_json(&self) -> Result; } @@ -489,7 +613,7 @@ impl ExtractJsonMediaType for openapiv3::Response { } } - fn is_binary(&self) -> Result { + fn is_binary(&self, _components: &Option) -> Result { if self.content.is_empty() { /* * XXX If there are no content types, I guess it is not binary? @@ -512,7 +636,7 @@ impl ExtractJsonMediaType for openapiv3::Response { VariantOrUnknownOrEmpty::Item, }; - let s = s.item()?; + let s = s.item(&None)?; if s.schema_data.nullable { todo!("XXX nullable binary?"); } @@ -570,7 +694,7 @@ impl ExtractJsonMediaType for openapiv3::RequestBody { } } - fn is_binary(&self) -> Result { + fn is_binary(&self, components: &Option) -> Result { if self.content.is_empty() { /* * XXX If there are no content types, I guess it is not binary? @@ -593,7 +717,7 @@ impl ExtractJsonMediaType for openapiv3::RequestBody { VariantOrUnknownOrEmpty::Item, }; - let s = s.item()?; + let s = s.item(components)?; if s.schema_data.nullable { todo!("XXX nullable binary?"); } @@ -635,17 +759,62 @@ impl ExtractJsonMediaType for openapiv3::RequestBody { } } -trait ReferenceOrExt { - fn item(&self) -> Result<&T>; +trait ReferenceOrExt { + fn item<'a>(&'a self, components: &'a Option) -> Result<&'a T>; +} +trait ComponentLookup: Sized { + fn get_components( + components: &Components, + ) -> &IndexMap>; } -impl ReferenceOrExt for openapiv3::ReferenceOr { - fn item(&self) -> Result<&T> { +impl ReferenceOrExt for openapiv3::ReferenceOr { + fn item<'a>(&'a self, components: &'a Option) -> Result<&'a T> { match self { - ReferenceOr::Reference { .. } => { - Err(Error::BadConversion("unexpected reference".to_string())) - } ReferenceOr::Item(item) => Ok(item), + ReferenceOr::Reference { reference } => { + let idx = reference.rfind('/').unwrap(); + let key = &reference[idx + 1..]; + let parameters = + T::get_components(components.as_ref().unwrap()); + parameters.get(key).unwrap().item(components) + } } } } + +impl ComponentLookup for Parameter { + fn get_components( + components: &Components, + ) -> &IndexMap> { + &components.parameters + } +} + +impl ComponentLookup for RequestBody { + fn get_components( + components: &Components, + ) -> &IndexMap> { + &components.request_bodies + } +} + +impl ComponentLookup for Response { + fn get_components( + components: &Components, + ) -> &IndexMap> { + &components.responses + } +} + +impl ComponentLookup for Schema { + fn get_components( + components: &Components, + ) -> &IndexMap> { + &components.schemas + } +} + +fn sanitize(input: &str, case: Case) -> String { + input.replace('/', "-").to_case(case) +} diff --git a/progenitor-impl/tests/output/buildomat.out b/progenitor-impl/tests/output/buildomat.out index 75a3833..4471267 100644 --- a/progenitor-impl/tests/output/buildomat.out +++ b/progenitor-impl/tests/output/buildomat.out @@ -156,18 +156,18 @@ pub struct Client { } impl Client { - pub fn new(baseurl: &str) -> Client { + pub fn new(baseurl: &str) -> Self { let dur = std::time::Duration::from_secs(15); let client = reqwest::ClientBuilder::new() .connect_timeout(dur) .timeout(dur) .build() .unwrap(); - Client::new_with_client(baseurl, client) + Self::new_with_client(baseurl, client) } - pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Client { - Client { + pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Self { + Self { baseurl: baseurl.to_string(), client, } @@ -176,14 +176,18 @@ impl Client { #[doc = "control_hold: POST /v1/control/hold"] pub async fn control_hold(&self) -> Result<()> { let url = format!("{}/v1/control/hold", self.baseurl,); - let res = self.client.post(url).send().await?.error_for_status()?; + let request = self.client.post(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "control_resume: POST /v1/control/resume"] pub async fn control_resume(&self) -> Result<()> { let url = format!("{}/v1/control/resume", self.baseurl,); - let res = self.client.post(url).send().await?.error_for_status()?; + let request = self.client.post(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -194,27 +198,27 @@ impl Client { self.baseurl, progenitor_support::encode_path(&task.to_string()), ); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "tasks_get: GET /v1/tasks"] pub async fn tasks_get(&self) -> Result> { let url = format!("{}/v1/tasks", self.baseurl,); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "task_submit: POST /v1/tasks"] pub async fn task_submit(&self, body: &types::TaskSubmit) -> Result { let url = format!("{}/v1/tasks", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -234,13 +238,9 @@ impl Client { query.push(("minseq", v.to_string())); } - let res = self - .client - .get(url) - .query(&query) - .send() - .await? - .error_for_status()?; + let request = self.client.get(url).query(&query).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -251,7 +251,9 @@ impl Client { self.baseurl, progenitor_support::encode_path(&task.to_string()), ); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -267,27 +269,27 @@ impl Client { progenitor_support::encode_path(&task.to_string()), progenitor_support::encode_path(&output.to_string()), ); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res) } #[doc = "user_create: POST /v1/users"] pub async fn user_create(&self, body: &types::UserCreate) -> Result { let url = format!("{}/v1/users", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "whoami: GET /v1/whoami"] pub async fn whoami(&self) -> Result { let url = format!("{}/v1/whoami", self.baseurl,); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -297,20 +299,18 @@ impl Client { body: &types::WorkerBootstrap, ) -> Result { let url = format!("{}/v1/worker/bootstrap", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "worker_ping: GET /v1/worker/ping"] pub async fn worker_ping(&self) -> Result { let url = format!("{}/v1/worker/ping", self.baseurl,); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -325,13 +325,9 @@ impl Client { self.baseurl, progenitor_support::encode_path(&task.to_string()), ); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -346,13 +342,9 @@ impl Client { self.baseurl, progenitor_support::encode_path(&task.to_string()), ); - let res = self - .client - .post(url) - .body(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).body(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -367,13 +359,9 @@ impl Client { self.baseurl, progenitor_support::encode_path(&task.to_string()), ); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -388,27 +376,27 @@ impl Client { self.baseurl, progenitor_support::encode_path(&task.to_string()), ); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "workers_list: GET /v1/workers"] pub async fn workers_list(&self) -> Result { let url = format!("{}/v1/workers", self.baseurl,); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "workers_recycle: POST /v1/workers/recycle"] pub async fn workers_recycle(&self) -> Result<()> { let url = format!("{}/v1/workers/recycle", self.baseurl,); - let res = self.client.post(url).send().await?.error_for_status()?; + let request = self.client.post(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } } diff --git a/progenitor-impl/tests/output/keeper.out b/progenitor-impl/tests/output/keeper.out index 0a11687..621710a 100644 --- a/progenitor-impl/tests/output/keeper.out +++ b/progenitor-impl/tests/output/keeper.out @@ -97,18 +97,18 @@ pub struct Client { } impl Client { - pub fn new(baseurl: &str) -> Client { + pub fn new(baseurl: &str) -> Self { let dur = std::time::Duration::from_secs(15); let client = reqwest::ClientBuilder::new() .connect_timeout(dur) .timeout(dur) .build() .unwrap(); - Client::new_with_client(baseurl, client) + Self::new_with_client(baseurl, client) } - pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Client { - Client { + pub fn new_with_client(baseurl: &str, client: reqwest::Client) -> Self { + Self { baseurl: baseurl.to_string(), client, } @@ -117,27 +117,27 @@ impl Client { #[doc = "enrol: POST /enrol"] pub async fn enrol(&self, body: &types::EnrolBody) -> Result<()> { let url = format!("{}/enrol", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "global_jobs: GET /global/jobs"] pub async fn global_jobs(&self) -> Result { let url = format!("{}/global/jobs", self.baseurl,); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "ping: GET /ping"] pub async fn ping(&self) -> Result { let url = format!("{}/ping", self.baseurl,); - let res = self.client.get(url).send().await?.error_for_status()?; + let request = self.client.get(url).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -147,13 +147,9 @@ impl Client { body: &types::ReportFinishBody, ) -> Result { let url = format!("{}/report/finish", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } @@ -163,26 +159,18 @@ impl Client { body: &types::ReportOutputBody, ) -> Result { let url = format!("{}/report/output", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } #[doc = "report_start: POST /report/start"] pub async fn report_start(&self, body: &types::ReportStartBody) -> Result { let url = format!("{}/report/start", self.baseurl,); - let res = self - .client - .post(url) - .json(body) - .send() - .await? - .error_for_status()?; + let request = self.client.post(url).json(body).build()?; + let result = self.client.execute(request).await; + let res = result?.error_for_status()?; Ok(res.json().await?) } } diff --git a/progenitor-macro/Cargo.toml b/progenitor-macro/Cargo.toml index c17f671..a87e86a 100644 --- a/progenitor-macro/Cargo.toml +++ b/progenitor-macro/Cargo.toml @@ -9,9 +9,10 @@ description = "An OpenAPI client generator - macros" [dependencies] openapiv3 = "1.0.0-beta.2" progenitor-impl = { path = "../progenitor-impl" } -quote = "1.0.10" -serde_json = "1.0.68" -syn = "1.0.80" +quote = "1.0" +proc-macro2 = "1.0" +serde_json = "1.0" +syn = "1.0" [lib] -proc-macro = true \ No newline at end of file +proc-macro = true diff --git a/progenitor-macro/src/lib.rs b/progenitor-macro/src/lib.rs index 9a1ce47..ff09445 100644 --- a/progenitor-macro/src/lib.rs +++ b/progenitor-macro/src/lib.rs @@ -5,7 +5,11 @@ use std::path::Path; use openapiv3::OpenAPI; use proc_macro::TokenStream; use progenitor_impl::Generator; -use syn::LitStr; +use quote::ToTokens; +use syn::{ + parse::{Parse, ParseStream}, + ExprClosure, LitStr, Token, +}; #[proc_macro] pub fn generate_api(item: TokenStream) -> TokenStream { @@ -15,34 +19,100 @@ pub fn generate_api(item: TokenStream) -> TokenStream { } } +struct Settings { + file: LitStr, + inner: Option, + pre: Option, + post: Option, +} + +impl Parse for Settings { + fn parse(input: ParseStream) -> Result { + let file = input.parse::()?; + let inner = parse_inner(input)?; + let pre = parse_hook(input)?; + let post = parse_hook(input)?; + + // Optional trailing comma. + if input.peek(Token!(,)) { + let _ = input.parse::(); + } + + Ok(Settings { + file, + inner, + pre, + post, + }) + } +} + +fn parse_inner( + input: ParseStream, +) -> Result, syn::Error> { + if input.is_empty() { + return Ok(None); + } + let _: Token!(,) = input.parse()?; + if input.is_empty() { + return Ok(None); + } + Ok(Some(input.parse::()?.to_token_stream())) +} + +fn parse_hook( + input: ParseStream, +) -> Result, syn::Error> { + if input.is_empty() { + return Ok(None); + } + let _: Token!(,) = input.parse()?; + if input.is_empty() { + return Ok(None); + } + if let Ok(closure) = input.parse::() { + Ok(Some(closure.to_token_stream())) + } else { + Ok(Some(input.parse::()?.to_token_stream())) + } +} + fn do_generate_api(item: TokenStream) -> Result { - let arg = syn::parse::(item)?; + let Settings { + file, + inner, + pre, + post, + } = syn::parse::(item)?; let dir = std::env::var("CARGO_MANIFEST_DIR").map_or_else( |_| std::env::current_dir().unwrap(), |s| Path::new(&s).to_path_buf(), ); - let path = dir.join(arg.value()); + let path = dir.join(file.value()); let content = std::fs::read_to_string(&path).map_err(|e| { syn::Error::new( - arg.span(), - format!("couldn't read file {}: {}", arg.value(), e.to_string()), + file.span(), + format!("couldn't read file {}: {}", file.value(), e.to_string()), ) })?; let spec = serde_json::from_str::(&content).map_err(|e| { syn::Error::new( - arg.span(), - format!("failed to parse {}: {}", arg.value(), e.to_string()), + file.span(), + format!("failed to parse {}: {}", file.value(), e.to_string()), ) })?; let mut builder = Generator::new(); + inner.map(|inner_type| builder.with_inner_type(inner_type)); + pre.map(|pre_hook| builder.with_pre_hook(pre_hook)); + post.map(|post_hook| builder.with_post_hook(post_hook)); let ret = builder.generate_tokens(&spec).map_err(|e| { syn::Error::new( - arg.span(), - format!("generation error for {}: {}", arg.value(), e.to_string()), + file.span(), + format!("generation error for {}: {}", file.value(), e.to_string()), ) })?; diff --git a/progenitor/Cargo.toml b/progenitor/Cargo.toml index a517f00..087887d 100644 --- a/progenitor/Cargo.toml +++ b/progenitor/Cargo.toml @@ -9,15 +9,8 @@ description = "An OpenAPI client generator" [dependencies] progenitor-macro = { path = "../progenitor-macro" } progenitor-impl = { path = "../progenitor-impl" } -rustfmt-wrapper = "0.1.0" -anyhow = "1" +anyhow = "1.0" getopts = "0.2" -indexmap = "1.7.0" openapiv3 = "1.0.0-beta.2" -#proc-macro2 = "1.0.29" -#quote = "1.0.9" -regex = "1.5.4" -#schemars = "0.8.5" -serde = { version = "1", features = [ "derive" ] } -serde_json = "1.0.68" -#typify = { git = "https://github.com/oxidecomputer/typify" } +serde = { version = "1.0", features = [ "derive" ] } +serde_json = "1.0"