surf_disco/
request.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the surf-disco library.
3
4// You should have received a copy of the MIT License
5// along with the surf-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    http::headers::{HeaderName, ToHeaderValues},
9    Error, StatusCode,
10};
11use serde::{de::DeserializeOwned, Serialize};
12use std::{error::Error as _, fmt::Display};
13use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
14
15#[must_use]
16#[derive(Debug)]
17pub struct Request<T, E, VER: StaticVersionType> {
18    inner: reqwest::RequestBuilder,
19    marker: std::marker::PhantomData<fn(T, E, VER) -> ()>,
20}
21
22impl<T, E, VER: StaticVersionType> From<reqwest::RequestBuilder> for Request<T, E, VER> {
23    fn from(inner: reqwest::RequestBuilder) -> Self {
24        Self {
25            inner,
26            marker: Default::default(),
27        }
28    }
29}
30
31impl<T: DeserializeOwned, E: Error, VER: StaticVersionType> Request<T, E, VER> {
32    /// Set a header on the request.
33    pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
34        let key = reqwest::header::HeaderName::from_bytes(key.into().as_str().as_bytes()).unwrap();
35        for value in values.to_header_values().unwrap() {
36            self = self.inner.header(key.clone(), value.as_str()).into()
37        }
38        self
39    }
40
41    /// Set the request body using JSON.
42    ///
43    /// Body is serialized using [serde_json] and the `Content-Type` header is set to
44    /// `application/json`.
45    pub fn body_json<B: Serialize>(self, body: &B) -> Result<Self, E> {
46        Ok(self
47            .header("Content-Type", "application/json")
48            .inner
49            .body(serde_json::to_string(body).map_err(request_error)?)
50            .into())
51    }
52
53    /// Set the request body using [bincode].
54    ///
55    /// Body is serialized using [bincode] and the `Content-Type` header is set to
56    /// `application/octet-stream`.
57    ///
58    /// # Errors
59    ///
60    /// Fails if `body` does not serialize successfully.
61    pub fn body_binary<B: Serialize>(self, body: &B) -> Result<Self, E> {
62        Ok(self
63            .header("Content-Type", "application/octet-stream")
64            .inner
65            .body(Serializer::<VER>::serialize(body).map_err(request_error)?)
66            .into())
67    }
68
69    /// Send the request and await a response from the server.
70    ///
71    /// If the request succeeds (receives a response with [StatusCode::OK]) the response body is
72    /// converted to a `T`, using a format determined by the `Content-Type` header of the request.
73    ///
74    /// # Errors
75    ///
76    /// If the client is unable to reach the server, or if the response body cannot be interpreted
77    /// as a `T`, an error message is created using [catch_all](Error::catch_all) and returned.
78    ///
79    /// If the request completes but the response status code is not [StatusCode::OK], an error
80    /// message is constructed using the body of the response. If there is a body and it can be
81    /// converted to an `E` using the content type specified in the response's `Content-Type`
82    /// header, that `E` will be returned directly. Otherwise, an error message is synthesized using
83    /// [catch_all](Error::catch_all) that includes human-readable information about the response.
84    pub async fn send(self) -> Result<T, E> {
85        let res = self.inner.send().await.map_err(reqwest_error)?;
86        let status = res.status();
87        let content_type = res.headers().get("Content-Type").cloned();
88        if res.status() == StatusCode::OK {
89            // If the response indicates success, deserialize the body using a format determined by
90            // the Content-Type header.
91            if let Some(content_type) = content_type {
92                match content_type.to_str() {
93                    Ok("application/json") => res.json().await.map_err(reqwest_error),
94                    Ok("application/octet-stream") => {
95                        Serializer::<VER>::deserialize(&res.bytes().await.map_err(reqwest_error)?)
96                            .map_err(request_error)
97                    }
98                    content_type => {
99                        // For help in debugging, include the body with the unexpected content type
100                        // in the error message.
101                        let msg = match res.bytes().await {
102                            Ok(bytes) => match std::str::from_utf8(&bytes) {
103                                Ok(s) => format!("body: {}", s),
104                                Err(_) => format!("body: {}", hex::encode(&bytes)),
105                            },
106                            Err(_) => String::default(),
107                        };
108                        Err(E::catch_all(
109                            StatusCode::UNSUPPORTED_MEDIA_TYPE,
110                            format!("unsupported content type {content_type:?} {msg}"),
111                        ))
112                    }
113                }
114            } else {
115                Err(E::catch_all(
116                    StatusCode::UNSUPPORTED_MEDIA_TYPE,
117                    "unspecified content type in response".into(),
118                ))
119            }
120        } else {
121            // To add context to the error, try to interpret the response body as a serialized
122            // error. Since `body_json`, `body_string`, etc. consume the response body, we will
123            // extract the body as raw bytes and then try various potential decodings based on the
124            // response headers and the contents of the body.
125            let bytes = match res.bytes().await {
126                Ok(bytes) => bytes,
127                Err(err) => {
128                    // If we are unable to even read the body, just return a generic error message
129                    // based on the status code.
130                    return Err(E::catch_all(
131                        status.into(),
132                        format!(
133                            "Request terminated with error {status}. Failed to read request body due to {err}",
134                        ),
135                    ));
136                }
137            };
138            if let Some(content_type) = &content_type {
139                // If the response specifies a content type, check if it is one of the types we know
140                // how to deserialize, and if it is, we can then see if it deserializes to an `E`.
141                match content_type.to_str() {
142                    Ok("application/json") => {
143                        if let Ok(err) = serde_json::from_slice(&bytes) {
144                            return Err(err);
145                        }
146                    }
147                    Ok("application/octet-stream") => {
148                        if let Ok(err) = Serializer::<VER>::deserialize(&bytes) {
149                            return Err(err);
150                        }
151                    }
152                    _ => {}
153                }
154            }
155            // If we get here, then we were not able to interpret the response body as an `E`
156            // directly. This can be because:
157            //  * the content type is not supported for deserialization
158            //  * the content type was unspecified
159            //  * the body did not deserialize to an `E` We have one thing left we can try: if the
160            //    body is a string, we can use the `catch_all` variant of `E` to include the
161            //    contents of the string in the error message.
162            if let Ok(msg) = std::str::from_utf8(&bytes) {
163                return Err(E::catch_all(status.into(), msg.to_string()));
164            }
165
166            // The response body was not an `E` or a string. Return the most helpful error message
167            // we can, including the status code, content type, and raw body.
168            Err(E::catch_all(
169                status.into(),
170                format!(
171                    "Request terminated with error {status}. Content-Type: {}. Body: 0x{}",
172                    match content_type {
173                        Some(content_type) =>
174                            content_type.to_str().unwrap_or("unspecified").to_owned(),
175                        None => "unspecified".to_owned(),
176                    },
177                    hex::encode(&bytes)
178                ),
179            ))
180        }
181    }
182
183    /// Sends the request and returns the full response body as raw bytes,
184    pub async fn bytes(self) -> Result<Vec<u8>, E> {
185        let response = self.inner.send().await.map_err(reqwest_error)?;
186        let status = response.status();
187        let content_type = response.headers().get("Content-Type").cloned();
188
189        let bytes_result = response.bytes().await.map(|b| b.to_vec()).map_err(|err| E::catch_all(
190                status.into(),
191                format!(
192                    "Request terminated with error {status}. Failed to read request body due to {err}",
193                ),
194            )
195        );
196
197        if status.is_success() {
198            return bytes_result;
199        }
200
201        let bytes = bytes_result?;
202
203        if let Ok(msg) = std::str::from_utf8(&bytes) {
204            return Err(E::catch_all(status.into(), msg.to_string()));
205        }
206
207        Err(E::catch_all(
208            status.into(),
209            format!(
210                "Request failed with status {status}. Content-Type: {}. Body: 0x{}",
211                content_type
212                    .as_ref()
213                    .and_then(|v| v.to_str().ok())
214                    .unwrap_or("unspecified"),
215                hex::encode(&bytes),
216            ),
217        ))
218    }
219}
220
221fn request_error<E: Error>(source: impl Display) -> E {
222    E::catch_all(StatusCode::BAD_REQUEST, source.to_string())
223}
224
225fn reqwest_error<E: Error>(source: reqwest::Error) -> E {
226    E::catch_all(
227        source
228            .status()
229            .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR)
230            .into(),
231        reqwest_error_msg(source),
232    )
233}
234
235pub(crate) fn reqwest_error_msg(err: reqwest::Error) -> String {
236    match err.source() {
237        Some(inner) => format!("{err}: {inner}"),
238        None => err.to_string(),
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245    use crate::{Client, ContentType};
246    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
247    use async_std::task::spawn;
248    use futures::FutureExt;
249    use portpicker::pick_unused_port;
250    use tide_disco::{error::ServerError, App};
251    use toml::toml;
252    use vbs::version::StaticVersion;
253
254    type Ver01 = StaticVersion<0, 1>;
255    const VER_0_1: Ver01 = StaticVersion {};
256
257    #[async_std::test]
258    async fn test_request_accept() {
259        setup_logging();
260        setup_backtrace();
261
262        // Set up a simple Tide Disco app.
263        let mut app: App<(), ServerError> = App::with_state(());
264        let api = toml! {
265            [route.get]
266            PATH = ["/get"]
267        };
268        app.module::<ServerError, Ver01>("mod", api)
269            .unwrap()
270            .get("get", |_req, _state| async move { Ok("response") }.boxed())
271            .unwrap();
272        let port = pick_unused_port().unwrap();
273        spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
274
275        // Connect one client with each supported content type.
276        let json_client = Client::<ServerError, Ver01>::builder(
277            format!("http://localhost:{port}").parse().unwrap(),
278        )
279        .content_type(ContentType::Json)
280        .build();
281        assert!(json_client.connect(None).await);
282
283        let bin_client = Client::<ServerError, Ver01>::builder(
284            format!("http://localhost:{port}").parse().unwrap(),
285        )
286        .content_type(ContentType::Binary)
287        .build();
288        assert!(bin_client.connect(None).await);
289
290        // Check that requests built with each client get a response in the desired content type.
291        let res = json_client
292            .get::<String>("mod/get")
293            .inner
294            .send()
295            .await
296            .unwrap();
297        assert_eq!(res.status(), StatusCode::OK);
298        assert_eq!(res.headers()["Content-Type"], "application/json");
299        assert_eq!(res.json::<String>().await.unwrap(), "response");
300
301        let res = bin_client
302            .get::<String>("mod/get")
303            .inner
304            .send()
305            .await
306            .unwrap();
307        assert_eq!(res.status(), StatusCode::OK);
308        assert_eq!(res.headers()["Content-Type"], "application/octet-stream");
309        assert_eq!(
310            Serializer::<Ver01>::deserialize::<String>(&res.bytes().await.unwrap()).unwrap(),
311            "response"
312        );
313    }
314
315    #[async_std::test]
316    async fn test_bad_response_bytes() {
317        setup_logging();
318        setup_backtrace();
319
320        // Set up a simple Tide Disco app.
321        let mut app: App<(), ServerError> = App::with_state(());
322        let api = toml! {
323            [route.integer]
324            PATH = ["/integer"]
325        };
326
327        app.module::<ServerError, Ver01>("app", api)
328            .unwrap()
329            .get::<_, u64>("integer", |_req, _state| {
330                async move {
331                    Err(ServerError::catch_all(
332                        StatusCode::NOT_FOUND,
333                        "not found".to_string(),
334                    ))
335                }
336                .boxed()
337            })
338            .unwrap();
339
340        let port = pick_unused_port().unwrap();
341        spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
342
343        // Connect client
344        let client = Client::<ServerError, Ver01>::builder(
345            format!("http://localhost:{port}").parse().unwrap(),
346        )
347        .build();
348        assert!(client.connect(None).await);
349
350        // Make a request and expect the .bytes() call to fail
351        let result = client.get::<()>("app/integer").bytes().await;
352
353        assert!(
354            result.is_err(),
355            "Expected error from .bytes() due to server-side failure"
356        );
357    }
358}