1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
// Copyright (c) 2022 Espresso Systems (espressosys.com)
// This file is part of the surf-disco library.

// You should have received a copy of the MIT License
// along with the surf-disco library. If not, see <https://mit-license.org/>.

use crate::{
    http::headers::{HeaderName, ToHeaderValues},
    Error, StatusCode,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{error::Error as _, fmt::Display};
use vbs::{version::StaticVersionType, BinarySerializer, Serializer};

#[must_use]
#[derive(Debug)]
pub struct Request<T, E, VER: StaticVersionType> {
    inner: reqwest::RequestBuilder,
    marker: std::marker::PhantomData<fn(T, E, VER) -> ()>,
}

impl<T, E, VER: StaticVersionType> From<reqwest::RequestBuilder> for Request<T, E, VER> {
    fn from(inner: reqwest::RequestBuilder) -> Self {
        Self {
            inner,
            marker: Default::default(),
        }
    }
}

impl<T: DeserializeOwned, E: Error, VER: StaticVersionType> Request<T, E, VER> {
    /// Set a header on the request.
    pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
        let key = reqwest::header::HeaderName::from_bytes(key.into().as_str().as_bytes()).unwrap();
        for value in values.to_header_values().unwrap() {
            self = self.inner.header(key.clone(), value.as_str()).into()
        }
        self
    }

    /// Set the request body using JSON.
    ///
    /// Body is serialized using [serde_json] and the `Content-Type` header is set to
    /// `application/json`.
    pub fn body_json<B: Serialize>(self, body: &B) -> Result<Self, E> {
        Ok(self
            .header("Content-Type", "application/json")
            .inner
            .body(serde_json::to_string(body).map_err(request_error)?)
            .into())
    }

    /// Set the request body using [bincode].
    ///
    /// Body is serialized using [bincode] and the `Content-Type` header is set to
    /// `application/octet-stream`.
    ///
    /// # Errors
    ///
    /// Fails if `body` does not serialize successfully.
    pub fn body_binary<B: Serialize>(self, body: &B) -> Result<Self, E> {
        Ok(self
            .header("Content-Type", "application/octet-stream")
            .inner
            .body(Serializer::<VER>::serialize(body).map_err(request_error)?)
            .into())
    }

    /// Send the request and await a response from the server.
    ///
    /// If the request succeeds (receives a response with [StatusCode::OK]) the response body is
    /// converted to a `T`, using a format determined by the `Content-Type` header of the request.
    ///
    /// # Errors
    ///
    /// If the client is unable to reach the server, or if the response body cannot be interpreted
    /// as a `T`, an error message is created using [catch_all](Error::catch_all) and returned.
    ///
    /// If the request completes but the response status code is not [StatusCode::OK], an error
    /// message is constructed using the body of the response. If there is a body and it can be
    /// converted to an `E` using the content type specified in the response's `Content-Type`
    /// header, that `E` will be returned directly. Otherwise, an error message is synthesized using
    /// [catch_all](Error::catch_all) that includes human-readable information about the response.
    pub async fn send(self) -> Result<T, E> {
        let res = self.inner.send().await.map_err(reqwest_error)?;
        let status = res.status();
        let content_type = res.headers().get("Content-Type").cloned();
        if res.status() == StatusCode::OK {
            // If the response indicates success, deserialize the body using a format determined by
            // the Content-Type header.
            if let Some(content_type) = content_type {
                match content_type.to_str() {
                    Ok("application/json") => res.json().await.map_err(reqwest_error),
                    Ok("application/octet-stream") => {
                        Serializer::<VER>::deserialize(&res.bytes().await.map_err(reqwest_error)?)
                            .map_err(request_error)
                    }
                    content_type => {
                        // For help in debugging, include the body with the unexpected content type
                        // in the error message.
                        let msg = match res.bytes().await {
                            Ok(bytes) => match std::str::from_utf8(&bytes) {
                                Ok(s) => format!("body: {}", s),
                                Err(_) => format!("body: {}", hex::encode(&bytes)),
                            },
                            Err(_) => String::default(),
                        };
                        Err(E::catch_all(
                            StatusCode::UNSUPPORTED_MEDIA_TYPE,
                            format!("unsupported content type {content_type:?} {msg}"),
                        ))
                    }
                }
            } else {
                Err(E::catch_all(
                    StatusCode::UNSUPPORTED_MEDIA_TYPE,
                    "unspecified content type in response".into(),
                ))
            }
        } else {
            // To add context to the error, try to interpret the response body as a serialized
            // error. Since `body_json`, `body_string`, etc. consume the response body, we will
            // extract the body as raw bytes and then try various potential decodings based on the
            // response headers and the contents of the body.
            let bytes = match res.bytes().await {
                Ok(bytes) => bytes,
                Err(err) => {
                    // If we are unable to even read the body, just return a generic error message
                    // based on the status code.
                    return Err(E::catch_all(
                        status.into(),
                        format!(
                            "Request terminated with error {status}. Failed to read request body due to {err}",
                        ),
                    ));
                }
            };
            if let Some(content_type) = &content_type {
                // If the response specifies a content type, check if it is one of the types we know
                // how to deserialize, and if it is, we can then see if it deserializes to an `E`.
                match content_type.to_str() {
                    Ok("application/json") => {
                        if let Ok(err) = serde_json::from_slice(&bytes) {
                            return Err(err);
                        }
                    }
                    Ok("application/octet-stream") => {
                        if let Ok(err) = Serializer::<VER>::deserialize(&bytes) {
                            return Err(err);
                        }
                    }
                    _ => {}
                }
            }
            // If we get here, then we were not able to interpret the response body as an `E`
            // directly. This can be because:
            //  * the content type is not supported for deserialization
            //  * the content type was unspecified
            //  * the body did not deserialize to an `E` We have one thing left we can try: if the
            //    body is a string, we can use the `catch_all` variant of `E` to include the
            //    contents of the string in the error message.
            if let Ok(msg) = std::str::from_utf8(&bytes) {
                return Err(E::catch_all(status.into(), msg.to_string()));
            }

            // The response body was not an `E` or a string. Return the most helpful error message
            // we can, including the status code, content type, and raw body.
            Err(E::catch_all(
                status.into(),
                format!(
                    "Request terminated with error {status}. Content-Type: {}. Body: 0x{}",
                    match content_type {
                        Some(content_type) =>
                            content_type.to_str().unwrap_or("unspecified").to_owned(),
                        None => "unspecified".to_owned(),
                    },
                    hex::encode(&bytes)
                ),
            ))
        }
    }
}

fn request_error<E: Error>(source: impl Display) -> E {
    E::catch_all(StatusCode::BAD_REQUEST, source.to_string())
}

fn reqwest_error<E: Error>(source: reqwest::Error) -> E {
    E::catch_all(
        source
            .status()
            .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR)
            .into(),
        reqwest_error_msg(source),
    )
}

pub(crate) fn reqwest_error_msg(err: reqwest::Error) -> String {
    match err.source() {
        Some(inner) => format!("{err}: {inner}"),
        None => err.to_string(),
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::{Client, ContentType};
    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
    use async_std::task::spawn;
    use futures::FutureExt;
    use portpicker::pick_unused_port;
    use tide_disco::{error::ServerError, App};
    use toml::toml;
    use vbs::version::StaticVersion;

    type Ver01 = StaticVersion<0, 1>;
    const VER_0_1: Ver01 = StaticVersion {};

    #[async_std::test]
    async fn test_request_accept() {
        setup_logging();
        setup_backtrace();

        // Set up a simple Tide Disco app.
        let mut app: App<(), ServerError> = App::with_state(());
        let api = toml! {
            [route.get]
            PATH = ["/get"]
        };
        app.module::<ServerError, Ver01>("mod", api)
            .unwrap()
            .get("get", |_req, _state| async move { Ok("response") }.boxed())
            .unwrap();
        let port = pick_unused_port().unwrap();
        spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));

        // Connect one client with each supported content type.
        let json_client = Client::<ServerError, Ver01>::builder(
            format!("http://localhost:{port}").parse().unwrap(),
        )
        .content_type(ContentType::Json)
        .build();
        assert!(json_client.connect(None).await);

        let bin_client = Client::<ServerError, Ver01>::builder(
            format!("http://localhost:{port}").parse().unwrap(),
        )
        .content_type(ContentType::Binary)
        .build();
        assert!(bin_client.connect(None).await);

        // Check that requests built with each client get a response in the desired content type.
        let res = json_client
            .get::<String>("mod/get")
            .inner
            .send()
            .await
            .unwrap();
        assert_eq!(res.status(), StatusCode::OK);
        assert_eq!(res.headers()["Content-Type"], "application/json");
        assert_eq!(res.json::<String>().await.unwrap(), "response");

        let res = bin_client
            .get::<String>("mod/get")
            .inner
            .send()
            .await
            .unwrap();
        assert_eq!(res.status(), StatusCode::OK);
        assert_eq!(res.headers()["Content-Type"], "application/octet-stream");
        assert_eq!(
            Serializer::<Ver01>::deserialize::<String>(&res.bytes().await.unwrap()).unwrap(),
            "response"
        );
    }
}