// Copyright 2021 Oxide Computer Company use std::cmp::Ordering; use convert_case::{Case, Casing}; use indexmap::IndexMap; use openapiv3::{ Components, OpenAPI, Parameter, ReferenceOr, RequestBody, Response, Schema, }; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use thiserror::Error; use typify::TypeSpace; use crate::to_schema::ToSchema; mod template; mod to_schema; #[derive(Error, Debug)] pub enum Error { #[error("unexpected value type")] BadValue(String, serde_json::Value), #[error("type error")] TypeError(#[from] typify::Error), #[error("XXX")] BadConversion(String), #[error("invalid operation path")] InvalidPath(String), //#[error("unknown")] //Unknown, } pub type Result = std::result::Result; #[derive(Default)] pub struct Generator { type_space: TypeSpace, inner_type: Option, pre_hook: Option, post_hook: Option, } impl Generator { pub fn new() -> Self { Self::default() } pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self { self.inner_type = Some(inner_type); self } pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self { self.pre_hook = Some(pre_hook); self } pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self { self.post_hook = Some(post_hook); self } pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result { // Convert our components dictionary to schemars let schemas = spec .components .iter() .flat_map(|components| { components.schemas.iter().map(|(name, ref_or_schema)| { (name.clone(), ref_or_schema.to_schema()) }) }) .collect::>(); self.type_space.set_type_mod("types"); self.type_space.add_ref_types(schemas)?; let methods = spec .paths .iter() .flat_map(|(path, ref_or_item)| { let item = ref_or_item.as_item().unwrap(); assert!(item.parameters.is_empty()); item.iter().map(move |(method, operation)| { (path.as_str(), method, operation) }) }) .map(|(path, method, operation)| { self.process_operation( operation, &spec.components, path, method, ) }) .collect::>>()?; let mut types = self .type_space .iter_types() .map(|type_entry| { ( type_entry.type_name(&self.type_space), type_entry.output(&self.type_space), ) }) .collect::>(); types.sort_by(|(a_name, _), (b_name, _)| a_name.cmp(b_name)); let types = types.into_iter().map(|(_, def)| def); let inner_property = self.inner_type.as_ref().map(|inner| { quote! { inner: #inner, } }); let inner_value = self.inner_type.as_ref().map(|_| { quote! { inner } }); let file = quote! { use anyhow::Result; mod progenitor_support { use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; #[allow(dead_code)] const PATH_SET: &AsciiSet = &CONTROLS .add(b' ') .add(b'"') .add(b'#') .add(b'<') .add(b'>') .add(b'?') .add(b'`') .add(b'{') .add(b'}'); #[allow(dead_code)] pub(crate) fn encode_path(pc: &str) -> String { utf8_percent_encode(pc, PATH_SET).to_string() } } pub mod types { use serde::{Deserialize, Serialize}; #(#types)* } #[derive(Clone)] pub struct Client { baseurl: String, client: reqwest::Client, #inner_property } impl Client { pub fn new( baseurl: &str, #inner_property ) -> Self { let dur = std::time::Duration::from_secs(15); let client = reqwest::ClientBuilder::new() .connect_timeout(dur) .timeout(dur) .build() .unwrap(); Self::new_with_client(baseurl, client, #inner_value) } pub fn new_with_client( baseurl: &str, client: reqwest::Client, #inner_property ) -> Self { Self { baseurl: baseurl.to_string(), client, #inner_value } } #(#methods)* } }; Ok(file) } fn process_operation( &mut self, operation: &openapiv3::Operation, components: &Option, path: &str, method: &str, ) -> Result { enum ParamType { Path, Query, Body, } let mut query: Vec<(String, bool)> = Vec::new(); let mut raw_params = operation .parameters .iter() .map(|parameter| { match parameter.item(components)? { openapiv3::Parameter::Path { parameter_data, style: openapiv3::PathStyle::Simple, } => { // Path parameters MUST be required. assert!(parameter_data.required); let nam = parameter_data.name.clone(); let schema = parameter_data.schema()?.to_schema(); let name = format!( "{}{}", sanitize( operation.operation_id.as_ref().unwrap(), Case::Pascal ), sanitize(&nam, Case::Pascal), ); let typ = self .type_space .add_type_details_with_name(&schema, Some(name))? .parameter; Ok((ParamType::Path, nam, typ)) } openapiv3::Parameter::Query { parameter_data, allow_reserved: _, style: openapiv3::QueryStyle::Form, allow_empty_value, } => { if let Some(true) = allow_empty_value { todo!("allow empty value is a no go"); } let nam = parameter_data.name.clone(); let mut schema = parameter_data.schema()?.to_schema(); let name = format!( "{}{}", sanitize( operation.operation_id.as_ref().unwrap(), Case::Pascal ), sanitize(&nam, Case::Pascal), ); if !parameter_data.required { schema = make_optional(schema); } let typ = self .type_space .add_type_details_with_name(&schema, Some(name))? .parameter; query.push((nam.to_string(), !parameter_data.required)); Ok((ParamType::Query, nam, typ)) } x => todo!("unhandled parameter type: {:#?}", x), } }) .collect::>>()?; let mut bounds = Vec::new(); let (body_param, body_func) = if let Some(b) = &operation.request_body { let b = b.item(components)?; if b.is_binary(components)? { bounds.push(quote! {B: Into}); (Some(quote! {B}), Some(quote! { .body(body) })) } else { let mt = b.content_json()?; if !mt.encoding.is_empty() { todo!("media type encoding not empty: {:#?}", mt); } if let Some(s) = &mt.schema { let schema = s.to_schema(); let name = format!( "{}Body", sanitize( operation.operation_id.as_ref().unwrap(), Case::Pascal ) ); let typ = self .type_space .add_type_details_with_name(&schema, Some(name))? .parameter; (Some(typ), Some(quote! { .json(body) })) } else { todo!("media type encoding, no schema: {:#?}", mt); } } } else { (None, None) }; if let Some(body) = body_param { raw_params.push((ParamType::Body, "body".to_string(), body)); } let tmp = template::parse(path)?; let names = tmp.names(); let url_path = tmp.compile(); raw_params.sort_by(|a, b| match (&a.0, &b.0) { // Path params are first and are in positional order. (ParamType::Path, ParamType::Path) => { let aa = names.iter().position(|x| x == &a.1).unwrap(); let bb = names.iter().position(|x| x == &b.1).unwrap(); aa.cmp(&bb) } (ParamType::Path, ParamType::Query) => Ordering::Less, (ParamType::Path, ParamType::Body) => Ordering::Less, // Query params are in lexicographic order. (ParamType::Query, ParamType::Body) => Ordering::Less, (ParamType::Query, ParamType::Query) => a.1.cmp(&b.1), (ParamType::Query, ParamType::Path) => Ordering::Greater, // Body params are last and should be unique (ParamType::Body, ParamType::Path) => Ordering::Greater, (ParamType::Body, ParamType::Query) => Ordering::Greater, (ParamType::Body, ParamType::Body) => { panic!("should only be one body") } }); let (response_type, decode_response) = // TODO let's consider how we handle multiple responses if operation.responses.responses.len() >= 1 { let only = operation.responses.responses.first().unwrap(); if !matches!( only.0, openapiv3::StatusCode::Code(200..=299) ) { todo!("code? {:#?}", only); } let i = only.1.item(components)?; // TODO handle response headers. // Look at the response content. For now, support a // single JSON-formatted response. match ( i.content.len(), i.content.get("application/json"), ) { (0, _) => (quote! { () }, quote! { res.json().await? }), (1, Some(mt)) => { if !mt.encoding.is_empty() { todo!( "media type encoding not empty: {:#?}", mt ); } let typ = if let Some(schema) = &mt.schema { let schema = schema.to_schema(); let name = format!( "{}Response", sanitize( operation .operation_id .as_ref() .unwrap(), Case::Pascal ) ); self.type_space .add_type_details_with_name( &schema, Some(name), )? .ident } else { todo!( "media type encoding, no schema: {:#?}", mt ); }; (typ, quote! { res.json().await? }) } (1, None) => { // Non-JSON response. (quote! { reqwest::Response }, quote! { res }) } (_, _) => { todo!( "too many response contents: {:#?}", i.content ); } } } else if operation.responses.responses.is_empty() { (quote! { reqwest::Response }, quote! { res }) } else { todo!("responses? {:#?}", operation.responses); }; let operation_id = format_ident!( "{}", sanitize(operation.operation_id.as_deref().unwrap(), Case::Snake) ); let bounds = if bounds.is_empty() { quote! {} } else { quote! { < #(#bounds),* > } }; let params = raw_params.into_iter().map(|(_, name, typ)| { let name = format_ident!("{}", name); quote! { #name: #typ } }); let (query_build, query_use) = if query.is_empty() { (quote! {}, quote! {}) } else { let query_items = query.iter().map(|(qn, opt)| { if *opt { let qn_ident = format_ident!("{}", qn); quote! { if let Some(v) = & #qn_ident { query.push((#qn, v.to_string())); } } } else { quote! { query.push((#qn, #qn.to_string())); } } }); let query_build = quote! { let mut query = Vec::new(); #(#query_items)* }; let query_use = quote! { .query(&query) }; (query_build, query_use) }; let doc_comment = format!( "{}: {} {}", operation.operation_id.as_deref().unwrap(), method.to_ascii_uppercase(), path ); let pre_hook = self.pre_hook.as_ref().map(|hook| { quote! { (#hook)(&self.inner, &request); } }); let post_hook = self.post_hook.as_ref().map(|hook| { quote! { (#hook)(&self.inner, &result); } }); // TODO validate that method is one of the expected methods. let method_func = format_ident!("{}", method.to_lowercase()); let method = quote! { #[doc = #doc_comment] pub async fn #operation_id #bounds ( &self, #(#params),* ) -> Result<#response_type> { #url_path #query_build let request = self.client . #method_func (url) #body_func #query_use .build()?; #pre_hook let result = self.client .execute(request) .await; #post_hook // TODO we should do a match here for result?.status().as_u16() let res = result?.error_for_status()?; Ok(#decode_response) } }; Ok(method) } pub fn generate_text(&mut self, spec: &OpenAPI) -> Result { let output = self.generate_tokens(spec)?; // Format the file with rustfmt and some whitespace niceties. let content = rustfmt_wrapper::rustfmt(output).unwrap(); Ok(if cfg!(not(windows)) { let regex = regex::Regex::new(r#"(})(\n\s*[^} ])"#).unwrap(); regex.replace_all(&content, "$1\n$2").to_string() } else { let regex = regex::Regex::new(r#"(})(\r\n\s*[^} ])"#).unwrap(); regex.replace_all(&content, "$1\r\n$2").to_string() }) } pub fn dependencies(&self) -> Vec { let mut deps = vec![ "anyhow = \"1.0.44\"", "percent-encoding = \"2.1.0\"", "serde = { version = \"1.0.130\", features = [\"derive\"] }", "reqwest = { version = \"0.11.5\", features = [\"json\", \"stream\"] }", ]; if self.type_space.uses_uuid() { deps.push( "uuid = { version = \"0.8.2\", features = [\"serde\", \"v4\"] }", ) } if self.type_space.uses_chrono() { deps.push( "chrono = { version = \"0.4.19\", features = [\"serde\"] }", ) } deps.sort_unstable(); deps.iter().map(ToString::to_string).collect() } pub fn get_type_space(&self) -> &TypeSpace { &self.type_space } } /// Make the schema optional if it isn't already. pub fn make_optional( schema: schemars::schema::Schema, ) -> schemars::schema::Schema { match &schema { // If the instance_type already includes Null then this is already // optional. schemars::schema::Schema::Object(schemars::schema::SchemaObject { instance_type: Some(schemars::schema::SingleOrVec::Vec(types)), .. }) if types.contains(&schemars::schema::InstanceType::Null) => schema, // Otherwise, create a oneOf where one of the branches is the null // type. We could potentially check to see if the schema already // conforms to this pattern as well, but it doesn't hurt as typify will // already reduce nested Options to a single Option. _ => { let null_schema = schemars::schema::Schema::Object( schemars::schema::SchemaObject { instance_type: Some(schemars::schema::SingleOrVec::Single( Box::new(schemars::schema::InstanceType::Null), )), ..Default::default() }, ); schemars::schema::Schema::Object(schemars::schema::SchemaObject { subschemas: Some(Box::new( schemars::schema::SubschemaValidation { one_of: Some(vec![schema, null_schema]), ..Default::default() }, )), ..Default::default() }) } } } trait ParameterDataExt { fn schema(&self) -> Result<&openapiv3::ReferenceOr>; } impl ParameterDataExt for openapiv3::ParameterData { fn schema(&self) -> Result<&openapiv3::ReferenceOr> { match &self.format { openapiv3::ParameterSchemaOrContent::Schema(s) => Ok(s), x => { Err(Error::BadConversion(format!("XXX param format {:#?}", x))) } } } } trait ExtractJsonMediaType { fn is_binary(&self, components: &Option) -> Result; fn content_json(&self) -> Result; } impl ExtractJsonMediaType for openapiv3::Response { fn content_json(&self) -> Result { if self.content.len() != 1 { todo!("expected one content entry, found {}", self.content.len()); } if let Some(mt) = self.content.get("application/json") { Ok(mt.clone()) } else { todo!( "could not find application/json, only found {}", self.content.keys().next().unwrap() ); } } fn is_binary(&self, _components: &Option) -> Result { if self.content.is_empty() { /* * XXX If there are no content types, I guess it is not binary? */ return Ok(false); } if self.content.len() != 1 { todo!("expected one content entry, found {}", self.content.len()); } if let Some(mt) = self.content.get("application/octet-stream") { if !mt.encoding.is_empty() { todo!("XXX encoding"); } if let Some(s) = &mt.schema { use openapiv3::{ SchemaKind, StringFormat, Type, VariantOrUnknownOrEmpty::Item, }; let s = s.item(&None)?; if s.schema_data.nullable { todo!("XXX nullable binary?"); } if s.schema_data.default.is_some() { todo!("XXX default binary?"); } if s.schema_data.discriminator.is_some() { todo!("XXX binary discriminator?"); } match &s.schema_kind { SchemaKind::Type(Type::String(st)) => { if st.min_length.is_some() || st.max_length.is_some() { todo!("binary min/max length"); } if !matches!(st.format, Item(StringFormat::Binary)) { todo!( "expected binary format string, got {:?}", st.format ); } if st.pattern.is_some() { todo!("XXX pattern"); } if !st.enumeration.is_empty() { todo!("XXX enumeration"); } return Ok(true); } x => { todo!("XXX schemakind type {:?}", x); } } } else { todo!("binary thing had no schema?"); } } Ok(false) } } impl ExtractJsonMediaType for openapiv3::RequestBody { fn content_json(&self) -> Result { if self.content.len() != 1 { todo!("expected one content entry, found {}", self.content.len()); } if let Some(mt) = self.content.get("application/json") { Ok(mt.clone()) } else { todo!( "could not find application/json, only found {}", self.content.keys().next().unwrap() ); } } fn is_binary(&self, components: &Option) -> Result { if self.content.is_empty() { /* * XXX If there are no content types, I guess it is not binary? */ return Ok(false); } if self.content.len() != 1 { todo!("expected one content entry, found {}", self.content.len()); } if let Some(mt) = self.content.get("application/octet-stream") { if !mt.encoding.is_empty() { todo!("XXX encoding"); } if let Some(s) = &mt.schema { use openapiv3::{ SchemaKind, StringFormat, Type, VariantOrUnknownOrEmpty::Item, }; let s = s.item(components)?; if s.schema_data.nullable { todo!("XXX nullable binary?"); } if s.schema_data.default.is_some() { todo!("XXX default binary?"); } if s.schema_data.discriminator.is_some() { todo!("XXX binary discriminator?"); } match &s.schema_kind { SchemaKind::Type(Type::String(st)) => { if st.min_length.is_some() || st.max_length.is_some() { todo!("binary min/max length"); } if !matches!(st.format, Item(StringFormat::Binary)) { todo!( "expected binary format string, got {:?}", st.format ); } if st.pattern.is_some() { todo!("XXX pattern"); } if !st.enumeration.is_empty() { todo!("XXX enumeration"); } return Ok(true); } x => { todo!("XXX schemakind type {:?}", x); } } } else { todo!("binary thing had no schema?"); } } Ok(false) } } trait ReferenceOrExt { fn item<'a>(&'a self, components: &'a Option) -> Result<&'a T>; } trait ComponentLookup: Sized { fn get_components( components: &Components, ) -> &IndexMap>; } impl ReferenceOrExt for openapiv3::ReferenceOr { fn item<'a>(&'a self, components: &'a Option) -> Result<&'a T> { match self { ReferenceOr::Item(item) => Ok(item), ReferenceOr::Reference { reference } => { let idx = reference.rfind('/').unwrap(); let key = &reference[idx + 1..]; let parameters = T::get_components(components.as_ref().unwrap()); parameters.get(key).unwrap().item(components) } } } } impl ComponentLookup for Parameter { fn get_components( components: &Components, ) -> &IndexMap> { &components.parameters } } impl ComponentLookup for RequestBody { fn get_components( components: &Components, ) -> &IndexMap> { &components.request_bodies } } impl ComponentLookup for Response { fn get_components( components: &Components, ) -> &IndexMap> { &components.responses } } impl ComponentLookup for Schema { fn get_components( components: &Components, ) -> &IndexMap> { &components.schemas } } fn sanitize(input: &str, case: Case) -> String { input.replace('/', "-").to_case(case) }