diff --git a/progenitor-impl/src/lib.rs b/progenitor-impl/src/lib.rs index f20a699..21231ef 100644 --- a/progenitor-impl/src/lib.rs +++ b/progenitor-impl/src/lib.rs @@ -6,10 +6,12 @@ use convert_case::{Case, Casing}; use indexmap::IndexMap; use openapiv3::{ Components, OpenAPI, Parameter, ReferenceOr, RequestBody, Response, Schema, + StatusCode, }; use proc_macro2::TokenStream; use quote::{format_ident, quote}; +use template::PathTemplate; use thiserror::Error; use typify::TypeSpace; @@ -42,6 +44,44 @@ pub struct Generator { post_hook: Option, } +struct OperationMethod { + operation_id: String, + method: String, + path: PathTemplate, + doc_comment: Option, + params: Vec, + responses: Vec, +} + +#[derive(Debug, PartialEq, Eq)] +enum OperationParameterKind { + Path, + Query(bool), + Body, +} +struct OperationParameter { + name: String, + typ: OperationParameterType, + kind: OperationParameterKind, +} + +enum OperationParameterType { + TokenStream(TokenStream), + RawBody, +} +#[derive(Debug)] +struct OperationResponse { + status_code: StatusCode, + typ: OperationResponseType, +} + +#[derive(Debug)] +enum OperationResponseType { + TokenStream(TokenStream), + None, + Raw, +} + impl Generator { pub fn new() -> Self { Self::default() @@ -77,11 +117,13 @@ impl Generator { self.type_space.set_type_mod("types"); self.type_space.add_ref_types(schemas)?; - let methods = spec + let raw_methods = spec .paths .iter() .flat_map(|(path, ref_or_item)| { + // Exclude externally defined path items. let item = ref_or_item.as_item().unwrap(); + // TODO punt on paramters that apply to all path items for now. assert!(item.parameters.is_empty()); item.iter().map(move |(method, operation)| { (path.as_str(), method, operation) @@ -97,6 +139,11 @@ impl Generator { }) .collect::>>()?; + let methods = raw_methods + .iter() + .map(|method| self.process_method(method)) + .collect::>>()?; + let mut types = self .type_space .iter_types() @@ -199,12 +246,8 @@ impl Generator { components: &Option, path: &str, method: &str, - ) -> Result { - enum ParamType { - Path, - Query, - Body, - } + ) -> Result { + let operation_id = operation.operation_id.as_ref().unwrap(); let mut query: Vec<(String, bool)> = Vec::new(); let mut raw_params = operation @@ -223,10 +266,7 @@ impl Generator { let schema = parameter_data.schema()?.to_schema(); let name = format!( "{}{}", - sanitize( - operation.operation_id.as_ref().unwrap(), - Case::Pascal - ), + sanitize(operation_id, Case::Pascal), sanitize(&nam, Case::Pascal), ); let typ = self @@ -234,7 +274,11 @@ impl Generator { .add_type_with_name(&schema, Some(name))? .parameter_ident(); - Ok((ParamType::Path, nam, typ)) + Ok(OperationParameter { + name: sanitize(¶meter_data.name, Case::Snake), + typ: OperationParameterType::TokenStream(typ), + kind: OperationParameterKind::Path, + }) } openapiv3::Parameter::Query { parameter_data, @@ -266,19 +310,23 @@ impl Generator { .add_type_with_name(&schema, Some(name))? .parameter_ident(); - query.push((nam.to_string(), !parameter_data.required)); - Ok((ParamType::Query, nam, typ)) + query.push((nam, !parameter_data.required)); + Ok(OperationParameter { + name: sanitize(¶meter_data.name, Case::Snake), + typ: OperationParameterType::TokenStream(typ), + kind: OperationParameterKind::Query( + parameter_data.required, + ), + }) } x => todo!("unhandled parameter type: {:#?}", x), } }) .collect::>>()?; - let mut bounds = Vec::new(); - let (body_param, body_func) = if let Some(b) = &operation.request_body { + if let Some(b) = &operation.request_body { let b = b.item(components)?; - if b.is_binary(components)? { - bounds.push(quote! {B: Into}); - (Some(quote! {B}), Some(quote! { .body(body) })) + let typ = if b.is_binary(components)? { + OperationParameterType::RawBody } else { let mt = b.content_json()?; if !mt.encoding.is_empty() { @@ -298,154 +346,208 @@ impl Generator { .type_space .add_type_with_name(&schema, Some(name))? .parameter_ident(); - (Some(typ), Some(quote! { .json(body) })) + OperationParameterType::TokenStream(typ) } else { todo!("media type encoding, no schema: {:#?}", mt); } - } - } else { - (None, None) - }; - if let Some(body) = body_param { - raw_params.push((ParamType::Body, "body".to_string(), body)); + }; + + raw_params.push(OperationParameter { + name: "body".to_string(), + typ, + kind: OperationParameterKind::Body, + }); } let tmp = template::parse(path)?; let names = tmp.names(); - let url_path = tmp.compile(); - raw_params.sort_by(|a, b| match (&a.0, &b.0) { - // Path params are first and are in positional order. - (ParamType::Path, ParamType::Path) => { - let aa = names.iter().position(|x| x == &a.1).unwrap(); - let bb = names.iter().position(|x| x == &b.1).unwrap(); - aa.cmp(&bb) - } - (ParamType::Path, ParamType::Query) => Ordering::Less, - (ParamType::Path, ParamType::Body) => Ordering::Less, - - // Query params are in lexicographic order. - (ParamType::Query, ParamType::Body) => Ordering::Less, - (ParamType::Query, ParamType::Query) => a.1.cmp(&b.1), - (ParamType::Query, ParamType::Path) => Ordering::Greater, - - // Body params are last and should be unique - (ParamType::Body, ParamType::Path) => Ordering::Greater, - (ParamType::Body, ParamType::Query) => Ordering::Greater, - (ParamType::Body, ParamType::Body) => { - panic!("should only be one body") - } - }); - - let (response_type, decode_response) = - // TODO let's consider how we handle multiple responses - if operation.responses.responses.len() >= 1 { - let only = - operation.responses.responses.first().unwrap(); - if !matches!( - only.0, - openapiv3::StatusCode::Code(200..=299) - ) { - todo!("code? {:#?}", only); - } - - let i = only.1.item(components)?; - // TODO handle response headers. - - // Look at the response content. For now, support a - // single JSON-formatted response. - match ( - i.content.len(), - i.content.get("application/json"), - ) { - // TODO we should verify that the content length of the - // response is zero in this case; if it's not we'll want to - // do the same thing as if there were a serialization - // error. - (0, _) => (quote! { () }, quote! { () }), - (1, Some(mt)) => { - if !mt.encoding.is_empty() { - todo!( - "media type encoding not empty: {:#?}", - mt - ); - } - - let typ = if let Some(schema) = &mt.schema { - let schema = schema.to_schema(); - let name = format!( - "{}Response", - sanitize( - operation - .operation_id - .as_ref() - .unwrap(), - Case::Pascal - ) - ); - self.type_space - .add_type_with_name( - &schema, - Some(name), - )? - .ident() - } else { - todo!( - "media type encoding, no schema: {:#?}", - mt - ); - }; - (typ, quote! { res.json().await? }) + raw_params.sort_by( + |OperationParameter { + kind: a_kind, + name: a_name, + .. + }, + OperationParameter { + kind: b_kind, + name: b_name, + .. + }| { + match (a_kind, b_kind) { + // Path params are first and are in positional order. + ( + OperationParameterKind::Path, + OperationParameterKind::Path, + ) => { + let a_index = + names.iter().position(|x| x == a_name).unwrap(); + let b_index = + names.iter().position(|x| x == b_name).unwrap(); + a_index.cmp(&b_index) } - (1, None) => { - // Non-JSON response. - (quote! { reqwest::Response }, quote! { res }) - } - (_, _) => { - todo!( - "too many response contents: {:#?}", - i.content - ); + ( + OperationParameterKind::Path, + OperationParameterKind::Query(_), + ) => Ordering::Less, + ( + OperationParameterKind::Path, + OperationParameterKind::Body, + ) => Ordering::Less, + + // Query params are in lexicographic order. + ( + OperationParameterKind::Query(_), + OperationParameterKind::Body, + ) => Ordering::Less, + ( + OperationParameterKind::Query(_), + OperationParameterKind::Query(_), + ) => a_name.cmp(b_name), + ( + OperationParameterKind::Query(_), + OperationParameterKind::Path, + ) => Ordering::Greater, + + // Body params are last and should be unique + ( + OperationParameterKind::Body, + OperationParameterKind::Path, + ) => Ordering::Greater, + ( + OperationParameterKind::Body, + OperationParameterKind::Query(_), + ) => Ordering::Greater, + ( + OperationParameterKind::Body, + OperationParameterKind::Body, + ) => { + panic!("should only be one body") } } - } else if operation.responses.responses.is_empty() { - (quote! { reqwest::Response }, quote! { res }) - } else { - todo!("responses? {:#?}", operation.responses); - }; - let operation_id = format_ident!( - "{}", - sanitize(operation.operation_id.as_deref().unwrap(), Case::Snake) + }, ); - let bounds = if bounds.is_empty() { + + let mut success = false; + + let mut responses = operation + .responses + .responses + .iter() + .map(|(status_code, response_or_ref)| { + let response = response_or_ref.item(components)?; + + let typ = if let Some(mt) = + response.content.get("application/json") + { + assert!(mt.encoding.is_empty()); + + let typ = if let Some(schema) = &mt.schema { + let schema = schema.to_schema(); + let name = format!( + "{}Response", + sanitize( + operation.operation_id.as_ref().unwrap(), + Case::Pascal + ) + ); + self.type_space + .add_type_with_name(&schema, Some(name))? + .ident() + } else { + todo!("media type encoding, no schema: {:#?}", mt); + }; + + OperationResponseType::TokenStream(typ) + } else if response.content.first().is_some() { + OperationResponseType::Raw + } else { + OperationResponseType::None + }; + + if matches!( + status_code, + StatusCode::Code(200..=299) | StatusCode::Range(2) + ) { + success = true; + } + + Ok(OperationResponse { + status_code: status_code.clone(), + typ, + }) + }) + .collect::>>()?; + + // If the API has declined to specify the characteristics of a + // successful response, we cons up a generic one. + if !success { + responses.push(OperationResponse { + status_code: StatusCode::Range(2), + typ: OperationResponseType::Raw, + }); + } + + Ok(OperationMethod { + operation_id: sanitize(operation_id, Case::Snake), + method: method.to_string(), + path: tmp, + doc_comment: operation.description.clone(), + params: raw_params, + responses, + }) + } + + fn process_method(&self, method: &OperationMethod) -> Result { + let operation_id = format_ident!("{}", method.operation_id,); + let mut bounds_items: Vec = Vec::new(); + let params = method + .params + .iter() + .map(|param| { + let name = format_ident!("{}", param.name); + let typ = match ¶m.typ { + OperationParameterType::TokenStream(t) => t.clone(), + OperationParameterType::RawBody => { + bounds_items.push(quote! { B: Into}); + quote! {B} + } + }; + quote! { + #name: #typ + } + }) + .collect::>(); + let bounds = if bounds_items.is_empty() { quote! {} } else { quote! { - < #(#bounds),* > + < #(#bounds_items),* > } }; - let params = raw_params.into_iter().map(|(_, name, typ)| { - let name = format_ident!("{}", name); - quote! { - #name: #typ - } - }); - let (query_build, query_use) = if query.is_empty() { + + let query_items = method + .params + .iter() + .filter_map(|param| match ¶m.kind { + OperationParameterKind::Query(required) => { + let qn = ¶m.name; + Some(if *required { + quote! { + query.push((#qn, #qn.to_string())); + } + } else { + let qn_ident = format_ident!("{}", qn); + quote! { + if let Some(v) = & #qn_ident { + query.push((#qn, v.to_string())); + } + } + }) + } + _ => None, + }) + .collect::>(); + let (query_build, query_use) = if query_items.is_empty() { (quote! {}, quote! {}) } else { - let query_items = query.iter().map(|(qn, opt)| { - if *opt { - let qn_ident = format_ident!("{}", qn); - quote! { - if let Some(v) = & #qn_ident { - query.push((#qn, v.to_string())); - } - } - } else { - quote! { - query.push((#qn, #qn.to_string())); - } - } - }); - let query_build = quote! { let mut query = Vec::new(); #(#query_items)* @@ -456,11 +558,61 @@ impl Generator { (query_build, query_use) }; + + let url_path = method.path.compile(); + + let body_func = + method.params.iter().filter_map(|param| match ¶m.kind { + OperationParameterKind::Body => match ¶m.typ { + OperationParameterType::TokenStream(_) => { + Some(quote! { .json(body) }) + } + OperationParameterType::RawBody => { + Some(quote! { .body(body )}) + } + }, + _ => None, + }); + + assert!(body_func.clone().count() <= 1); + + let mut success_response_items = + method.responses.iter().filter(|response| { + matches!( + response.status_code, + StatusCode::Code(200..=299) | StatusCode::Range(2) + ) + }); + + assert_eq!(success_response_items.clone().count(), 1); + + let (response_type, decode_response) = success_response_items + .next() + .map(|response| match &response.typ { + OperationResponseType::TokenStream(typ) => { + (typ.clone(), quote! {res.json().await?}) + } + OperationResponseType::None => { + // TODO this doesn't seem quite right; I think we still want to return the raw response structure here. + (quote! { () }, quote! { () }) + } + OperationResponseType::Raw => { + (quote! { reqwest::Response }, quote! { res }) + } + }) + .unwrap(); + + // TODO document parameters let doc_comment = format!( - "{}: {} {}", - operation.operation_id.as_deref().unwrap(), - method.to_ascii_uppercase(), - path + "{}{}: {} {}", + method + .doc_comment + .as_ref() + .map(|s| format!("{}\n\n", s)) + .unwrap_or_else(String::new), + method.operation_id, + method.method.to_ascii_uppercase(), + method.path.to_string(), ); let pre_hook = self.pre_hook.as_ref().map(|hook| { @@ -475,8 +627,9 @@ impl Generator { }); // TODO validate that method is one of the expected methods. - let method_func = format_ident!("{}", method.to_lowercase()); - let method = quote! { + let method_func = format_ident!("{}", method.method.to_lowercase()); + + let method_impl = quote! { #[doc = #doc_comment] pub async fn #operation_id #bounds ( &self, @@ -487,7 +640,7 @@ impl Generator { let request = self.client . #method_func (url) - #body_func + #(#body_func)* #query_use .build()?; #pre_hook @@ -502,7 +655,7 @@ impl Generator { Ok(#decode_response) } }; - Ok(method) + Ok(method_impl) } pub fn generate_text(&mut self, spec: &OpenAPI) -> Result { diff --git a/progenitor-impl/src/template.rs b/progenitor-impl/src/template.rs index d55c561..a551065 100644 --- a/progenitor-impl/src/template.rs +++ b/progenitor-impl/src/template.rs @@ -12,11 +12,11 @@ enum Component { } #[derive(Eq, PartialEq, Clone, Debug)] -pub struct Template { +pub struct PathTemplate { components: Vec, } -impl Template { +impl PathTemplate { pub fn compile(&self) -> TokenStream { let mut fmt = String::new(); fmt.push_str("{}"); @@ -55,7 +55,7 @@ impl Template { } } -pub fn parse(t: &str) -> Result