diff --git a/progenitor-client/src/lib.rs b/progenitor-client/src/lib.rs index 6590ecd..8675045 100644 --- a/progenitor-client/src/lib.rs +++ b/progenitor-client/src/lib.rs @@ -4,6 +4,10 @@ mod progenitor_client; pub use crate::progenitor_client::*; +// For stand-alone crates, rather than adding a dependency on +// progenitor-client, we simply dump the code right in. This means we don't +// need to determine the provenance of progenitor (crates.io, github, etc.) +// when generating the stand-alone crate. #[doc(hidden)] pub fn code() -> &'static str { include_str!("progenitor_client.rs") diff --git a/progenitor-client/src/progenitor_client.rs b/progenitor-client/src/progenitor_client.rs index a4d3e76..1a1a6cb 100644 --- a/progenitor-client/src/progenitor_client.rs +++ b/progenitor-client/src/progenitor_client.rs @@ -15,8 +15,23 @@ use reqwest::RequestBuilder; use serde::{de::DeserializeOwned, Serialize}; /// Represents an untyped byte stream for both success and error responses. -pub type ByteStream = - Pin> + Send>>; +pub struct ByteStream( + Pin> + Send>>, +); + +impl Deref for ByteStream { + type Target = Pin> + Send>>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ByteStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} /// Success value returned by generated client methods. pub struct ResponseValue { @@ -52,7 +67,7 @@ impl ResponseValue { let status = response.status(); let headers = response.headers().clone(); Self { - inner: Box::pin(response.bytes_stream()), + inner: ByteStream(Box::pin(response.bytes_stream())), status, headers, } @@ -75,6 +90,19 @@ impl ResponseValue<()> { } impl ResponseValue { + /// Create an instance for testing + pub fn new( + inner: T, + status: reqwest::StatusCode, + headers: reqwest::header::HeaderMap, + ) -> Self { + Self { + inner, + status, + headers, + } + } + /// Consumes the ResponseValue, returning the wrapped value. pub fn into_inner(self) -> T { self.inner @@ -197,7 +225,10 @@ impl From for Error { } } -impl std::fmt::Display for Error { +impl std::fmt::Display for Error +where + ResponseValue: ErrorFormat, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::InvalidRequest(s) => { @@ -206,8 +237,9 @@ impl std::fmt::Display for Error { Error::CommunicationError(e) => { write!(f, "Communication Error: {}", e) } - Error::ErrorResponse(_) => { - write!(f, "Error Response") + Error::ErrorResponse(rve) => { + write!(f, "Error Response: ")?; + rve.fmt_info(f) } Error::InvalidResponsePayload(e) => { write!(f, "Invalid Response Payload: {}", e) @@ -218,12 +250,46 @@ impl std::fmt::Display for Error { } } } -impl std::fmt::Debug for Error { + +trait ErrorFormat { + fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; +} + +impl ErrorFormat for ResponseValue +where + E: std::fmt::Debug, +{ + fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "status: {}; headers: {:?}; value: {:?}", + self.status, self.headers, self.inner, + ) + } +} + +impl ErrorFormat for ResponseValue { + fn fmt_info(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "status: {}; headers: {:?}; value: ", + self.status, self.headers, + ) + } +} + +impl std::fmt::Debug for Error +where + ResponseValue: ErrorFormat, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } } -impl std::error::Error for Error { +impl std::error::Error for Error +where + ResponseValue: ErrorFormat, +{ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Error::CommunicationError(e) => Some(e), diff --git a/progenitor/tests/test_client.rs b/progenitor/tests/test_client.rs new file mode 100644 index 0000000..8539132 --- /dev/null +++ b/progenitor/tests/test_client.rs @@ -0,0 +1,28 @@ +// Copyright 2022 Oxide Computer Company + +// Validate that we get useful output from a user-typed error. +#[test] +#[should_panic = "Error Response: \ + status: 403 Forbidden; \ + headers: {}; \ + value: MyErr { msg: \"things went bad\" }"] +fn test_error() { + #[derive(Debug)] + struct MyErr { + #[allow(dead_code)] + msg: String, + } + + let mine = MyErr { + msg: "things went bad".to_string(), + }; + let e = progenitor_client::Error::ErrorResponse( + progenitor_client::ResponseValue::new( + mine, + reqwest::StatusCode::FORBIDDEN, + reqwest::header::HeaderMap::default(), + ), + ); + + (Err(e) as Result<(), progenitor_client::Error>).unwrap(); +}