1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
// Copyright (c) 2022 Espresso Systems (espressosys.com)
// This file is part of the surf-disco library.
// You should have received a copy of the MIT License
// along with the surf-disco library. If not, see <https://mit-license.org/>.
use crate::{
http::headers::{HeaderName, ToHeaderValues},
Error, StatusCode,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{error::Error as _, fmt::Display};
use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
#[must_use]
#[derive(Debug)]
pub struct Request<T, E, VER: StaticVersionType> {
inner: reqwest::RequestBuilder,
marker: std::marker::PhantomData<fn(T, E, VER) -> ()>,
}
impl<T, E, VER: StaticVersionType> From<reqwest::RequestBuilder> for Request<T, E, VER> {
fn from(inner: reqwest::RequestBuilder) -> Self {
Self {
inner,
marker: Default::default(),
}
}
}
impl<T: DeserializeOwned, E: Error, VER: StaticVersionType> Request<T, E, VER> {
/// Set a header on the request.
pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
let key = reqwest::header::HeaderName::from_bytes(key.into().as_str().as_bytes()).unwrap();
for value in values.to_header_values().unwrap() {
self = self.inner.header(key.clone(), value.as_str()).into()
}
self
}
/// Set the request body using JSON.
///
/// Body is serialized using [serde_json] and the `Content-Type` header is set to
/// `application/json`.
pub fn body_json<B: Serialize>(self, body: &B) -> Result<Self, E> {
Ok(self
.header("Content-Type", "application/json")
.inner
.body(serde_json::to_string(body).map_err(request_error)?)
.into())
}
/// Set the request body using [bincode].
///
/// Body is serialized using [bincode] and the `Content-Type` header is set to
/// `application/octet-stream`.
///
/// # Errors
///
/// Fails if `body` does not serialize successfully.
pub fn body_binary<B: Serialize>(self, body: &B) -> Result<Self, E> {
Ok(self
.header("Content-Type", "application/octet-stream")
.inner
.body(Serializer::<VER>::serialize(body).map_err(request_error)?)
.into())
}
/// Send the request and await a response from the server.
///
/// If the request succeeds (receives a response with [StatusCode::OK]) the response body is
/// converted to a `T`, using a format determined by the `Content-Type` header of the request.
///
/// # Errors
///
/// If the client is unable to reach the server, or if the response body cannot be interpreted
/// as a `T`, an error message is created using [catch_all](Error::catch_all) and returned.
///
/// If the request completes but the response status code is not [StatusCode::OK], an error
/// message is constructed using the body of the response. If there is a body and it can be
/// converted to an `E` using the content type specified in the response's `Content-Type`
/// header, that `E` will be returned directly. Otherwise, an error message is synthesized using
/// [catch_all](Error::catch_all) that includes human-readable information about the response.
pub async fn send(self) -> Result<T, E> {
let res = self.inner.send().await.map_err(reqwest_error)?;
let status = res.status();
let content_type = res.headers().get("Content-Type").cloned();
if res.status() == StatusCode::OK {
// If the response indicates success, deserialize the body using a format determined by
// the Content-Type header.
if let Some(content_type) = content_type {
match content_type.to_str() {
Ok("application/json") => res.json().await.map_err(reqwest_error),
Ok("application/octet-stream") => {
Serializer::<VER>::deserialize(&res.bytes().await.map_err(reqwest_error)?)
.map_err(request_error)
}
content_type => {
// For help in debugging, include the body with the unexpected content type
// in the error message.
let msg = match res.bytes().await {
Ok(bytes) => match std::str::from_utf8(&bytes) {
Ok(s) => format!("body: {}", s),
Err(_) => format!("body: {}", hex::encode(&bytes)),
},
Err(_) => String::default(),
};
Err(E::catch_all(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
format!("unsupported content type {content_type:?} {msg}"),
))
}
}
} else {
Err(E::catch_all(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"unspecified content type in response".into(),
))
}
} else {
// To add context to the error, try to interpret the response body as a serialized
// error. Since `body_json`, `body_string`, etc. consume the response body, we will
// extract the body as raw bytes and then try various potential decodings based on the
// response headers and the contents of the body.
let bytes = match res.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
// If we are unable to even read the body, just return a generic error message
// based on the status code.
return Err(E::catch_all(
status.into(),
format!(
"Request terminated with error {status}. Failed to read request body due to {err}",
),
));
}
};
if let Some(content_type) = &content_type {
// If the response specifies a content type, check if it is one of the types we know
// how to deserialize, and if it is, we can then see if it deserializes to an `E`.
match content_type.to_str() {
Ok("application/json") => {
if let Ok(err) = serde_json::from_slice(&bytes) {
return Err(err);
}
}
Ok("application/octet-stream") => {
if let Ok(err) = Serializer::<VER>::deserialize(&bytes) {
return Err(err);
}
}
_ => {}
}
}
// If we get here, then we were not able to interpret the response body as an `E`
// directly. This can be because:
// * the content type is not supported for deserialization
// * the content type was unspecified
// * the body did not deserialize to an `E` We have one thing left we can try: if the
// body is a string, we can use the `catch_all` variant of `E` to include the
// contents of the string in the error message.
if let Ok(msg) = std::str::from_utf8(&bytes) {
return Err(E::catch_all(status.into(), msg.to_string()));
}
// The response body was not an `E` or a string. Return the most helpful error message
// we can, including the status code, content type, and raw body.
Err(E::catch_all(
status.into(),
format!(
"Request terminated with error {status}. Content-Type: {}. Body: 0x{}",
match content_type {
Some(content_type) =>
content_type.to_str().unwrap_or("unspecified").to_owned(),
None => "unspecified".to_owned(),
},
hex::encode(&bytes)
),
))
}
}
}
fn request_error<E: Error>(source: impl Display) -> E {
E::catch_all(StatusCode::BAD_REQUEST, source.to_string())
}
fn reqwest_error<E: Error>(source: reqwest::Error) -> E {
E::catch_all(
source
.status()
.unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR)
.into(),
reqwest_error_msg(source),
)
}
pub(crate) fn reqwest_error_msg(err: reqwest::Error) -> String {
match err.source() {
Some(inner) => format!("{err}: {inner}"),
None => err.to_string(),
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{Client, ContentType};
use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
use async_std::task::spawn;
use futures::FutureExt;
use portpicker::pick_unused_port;
use tide_disco::{error::ServerError, App};
use toml::toml;
use vbs::version::StaticVersion;
type Ver01 = StaticVersion<0, 1>;
const VER_0_1: Ver01 = StaticVersion {};
#[async_std::test]
async fn test_request_accept() {
setup_logging();
setup_backtrace();
// Set up a simple Tide Disco app.
let mut app: App<(), ServerError> = App::with_state(());
let api = toml! {
[route.get]
PATH = ["/get"]
};
app.module::<ServerError, Ver01>("mod", api)
.unwrap()
.get("get", |_req, _state| async move { Ok("response") }.boxed())
.unwrap();
let port = pick_unused_port().unwrap();
spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
// Connect one client with each supported content type.
let json_client = Client::<ServerError, Ver01>::builder(
format!("http://localhost:{port}").parse().unwrap(),
)
.content_type(ContentType::Json)
.build();
assert!(json_client.connect(None).await);
let bin_client = Client::<ServerError, Ver01>::builder(
format!("http://localhost:{port}").parse().unwrap(),
)
.content_type(ContentType::Binary)
.build();
assert!(bin_client.connect(None).await);
// Check that requests built with each client get a response in the desired content type.
let res = json_client
.get::<String>("mod/get")
.inner
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["Content-Type"], "application/json");
assert_eq!(res.json::<String>().await.unwrap(), "response");
let res = bin_client
.get::<String>("mod/get")
.inner
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["Content-Type"], "application/octet-stream");
assert_eq!(
Serializer::<Ver01>::deserialize::<String>(&res.bytes().await.unwrap()).unwrap(),
"response"
);
}
}