1use crate::{
8 http::headers::{HeaderName, ToHeaderValues},
9 Error, StatusCode,
10};
11use serde::{de::DeserializeOwned, Serialize};
12use std::{error::Error as _, fmt::Display};
13use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
14
15#[must_use]
16#[derive(Debug)]
17pub struct Request<T, E, VER: StaticVersionType> {
18 inner: reqwest::RequestBuilder,
19 marker: std::marker::PhantomData<fn(T, E, VER) -> ()>,
20}
21
22impl<T, E, VER: StaticVersionType> From<reqwest::RequestBuilder> for Request<T, E, VER> {
23 fn from(inner: reqwest::RequestBuilder) -> Self {
24 Self {
25 inner,
26 marker: Default::default(),
27 }
28 }
29}
30
31impl<T: DeserializeOwned, E: Error, VER: StaticVersionType> Request<T, E, VER> {
32 pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
34 let key = reqwest::header::HeaderName::from_bytes(key.into().as_str().as_bytes()).unwrap();
35 for value in values.to_header_values().unwrap() {
36 self = self.inner.header(key.clone(), value.as_str()).into()
37 }
38 self
39 }
40
41 pub fn body_json<B: Serialize>(self, body: &B) -> Result<Self, E> {
46 Ok(self
47 .header("Content-Type", "application/json")
48 .inner
49 .body(serde_json::to_string(body).map_err(request_error)?)
50 .into())
51 }
52
53 pub fn body_binary<B: Serialize>(self, body: &B) -> Result<Self, E> {
62 Ok(self
63 .header("Content-Type", "application/octet-stream")
64 .inner
65 .body(Serializer::<VER>::serialize(body).map_err(request_error)?)
66 .into())
67 }
68
69 pub async fn send(self) -> Result<T, E> {
85 let res = self.inner.send().await.map_err(reqwest_error)?;
86 let status = res.status();
87 let content_type = res.headers().get("Content-Type").cloned();
88 if res.status() == StatusCode::OK {
89 if let Some(content_type) = content_type {
92 match content_type.to_str() {
93 Ok("application/json") => res.json().await.map_err(reqwest_error),
94 Ok("application/octet-stream") => {
95 Serializer::<VER>::deserialize(&res.bytes().await.map_err(reqwest_error)?)
96 .map_err(request_error)
97 }
98 content_type => {
99 let msg = match res.bytes().await {
102 Ok(bytes) => match std::str::from_utf8(&bytes) {
103 Ok(s) => format!("body: {}", s),
104 Err(_) => format!("body: {}", hex::encode(&bytes)),
105 },
106 Err(_) => String::default(),
107 };
108 Err(E::catch_all(
109 StatusCode::UNSUPPORTED_MEDIA_TYPE,
110 format!("unsupported content type {content_type:?} {msg}"),
111 ))
112 }
113 }
114 } else {
115 Err(E::catch_all(
116 StatusCode::UNSUPPORTED_MEDIA_TYPE,
117 "unspecified content type in response".into(),
118 ))
119 }
120 } else {
121 let bytes = match res.bytes().await {
126 Ok(bytes) => bytes,
127 Err(err) => {
128 return Err(E::catch_all(
131 status.into(),
132 format!(
133 "Request terminated with error {status}. Failed to read request body due to {err}",
134 ),
135 ));
136 }
137 };
138 if let Some(content_type) = &content_type {
139 match content_type.to_str() {
142 Ok("application/json") => {
143 if let Ok(err) = serde_json::from_slice(&bytes) {
144 return Err(err);
145 }
146 }
147 Ok("application/octet-stream") => {
148 if let Ok(err) = Serializer::<VER>::deserialize(&bytes) {
149 return Err(err);
150 }
151 }
152 _ => {}
153 }
154 }
155 if let Ok(msg) = std::str::from_utf8(&bytes) {
163 return Err(E::catch_all(status.into(), msg.to_string()));
164 }
165
166 Err(E::catch_all(
169 status.into(),
170 format!(
171 "Request terminated with error {status}. Content-Type: {}. Body: 0x{}",
172 match content_type {
173 Some(content_type) =>
174 content_type.to_str().unwrap_or("unspecified").to_owned(),
175 None => "unspecified".to_owned(),
176 },
177 hex::encode(&bytes)
178 ),
179 ))
180 }
181 }
182
183 pub async fn bytes(self) -> Result<Vec<u8>, E> {
185 let response = self.inner.send().await.map_err(reqwest_error)?;
186 let status = response.status();
187 let content_type = response.headers().get("Content-Type").cloned();
188
189 let bytes_result = response.bytes().await.map(|b| b.to_vec()).map_err(|err| E::catch_all(
190 status.into(),
191 format!(
192 "Request terminated with error {status}. Failed to read request body due to {err}",
193 ),
194 )
195 );
196
197 if status.is_success() {
198 return bytes_result;
199 }
200
201 let bytes = bytes_result?;
202
203 if let Ok(msg) = std::str::from_utf8(&bytes) {
204 return Err(E::catch_all(status.into(), msg.to_string()));
205 }
206
207 Err(E::catch_all(
208 status.into(),
209 format!(
210 "Request failed with status {status}. Content-Type: {}. Body: 0x{}",
211 content_type
212 .as_ref()
213 .and_then(|v| v.to_str().ok())
214 .unwrap_or("unspecified"),
215 hex::encode(&bytes),
216 ),
217 ))
218 }
219}
220
221fn request_error<E: Error>(source: impl Display) -> E {
222 E::catch_all(StatusCode::BAD_REQUEST, source.to_string())
223}
224
225fn reqwest_error<E: Error>(source: reqwest::Error) -> E {
226 E::catch_all(
227 source
228 .status()
229 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR)
230 .into(),
231 reqwest_error_msg(source),
232 )
233}
234
235pub(crate) fn reqwest_error_msg(err: reqwest::Error) -> String {
236 match err.source() {
237 Some(inner) => format!("{err}: {inner}"),
238 None => err.to_string(),
239 }
240}
241
242#[cfg(test)]
243mod test {
244 use super::*;
245 use crate::{Client, ContentType};
246 use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
247 use async_std::task::spawn;
248 use futures::FutureExt;
249 use portpicker::pick_unused_port;
250 use tide_disco::{error::ServerError, App};
251 use toml::toml;
252 use vbs::version::StaticVersion;
253
254 type Ver01 = StaticVersion<0, 1>;
255 const VER_0_1: Ver01 = StaticVersion {};
256
257 #[async_std::test]
258 async fn test_request_accept() {
259 setup_logging();
260 setup_backtrace();
261
262 let mut app: App<(), ServerError> = App::with_state(());
264 let api = toml! {
265 [route.get]
266 PATH = ["/get"]
267 };
268 app.module::<ServerError, Ver01>("mod", api)
269 .unwrap()
270 .get("get", |_req, _state| async move { Ok("response") }.boxed())
271 .unwrap();
272 let port = pick_unused_port().unwrap();
273 spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
274
275 let json_client = Client::<ServerError, Ver01>::builder(
277 format!("http://localhost:{port}").parse().unwrap(),
278 )
279 .content_type(ContentType::Json)
280 .build();
281 assert!(json_client.connect(None).await);
282
283 let bin_client = Client::<ServerError, Ver01>::builder(
284 format!("http://localhost:{port}").parse().unwrap(),
285 )
286 .content_type(ContentType::Binary)
287 .build();
288 assert!(bin_client.connect(None).await);
289
290 let res = json_client
292 .get::<String>("mod/get")
293 .inner
294 .send()
295 .await
296 .unwrap();
297 assert_eq!(res.status(), StatusCode::OK);
298 assert_eq!(res.headers()["Content-Type"], "application/json");
299 assert_eq!(res.json::<String>().await.unwrap(), "response");
300
301 let res = bin_client
302 .get::<String>("mod/get")
303 .inner
304 .send()
305 .await
306 .unwrap();
307 assert_eq!(res.status(), StatusCode::OK);
308 assert_eq!(res.headers()["Content-Type"], "application/octet-stream");
309 assert_eq!(
310 Serializer::<Ver01>::deserialize::<String>(&res.bytes().await.unwrap()).unwrap(),
311 "response"
312 );
313 }
314
315 #[async_std::test]
316 async fn test_bad_response_bytes() {
317 setup_logging();
318 setup_backtrace();
319
320 let mut app: App<(), ServerError> = App::with_state(());
322 let api = toml! {
323 [route.integer]
324 PATH = ["/integer"]
325 };
326
327 app.module::<ServerError, Ver01>("app", api)
328 .unwrap()
329 .get::<_, u64>("integer", |_req, _state| {
330 async move {
331 Err(ServerError::catch_all(
332 StatusCode::NOT_FOUND,
333 "not found".to_string(),
334 ))
335 }
336 .boxed()
337 })
338 .unwrap();
339
340 let port = pick_unused_port().unwrap();
341 spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
342
343 let client = Client::<ServerError, Ver01>::builder(
345 format!("http://localhost:{port}").parse().unwrap(),
346 )
347 .build();
348 assert!(client.connect(None).await);
349
350 let result = client.get::<()>("app/integer").bytes().await;
352
353 assert!(
354 result.is_err(),
355 "Expected error from .bytes() due to server-side failure"
356 );
357 }
358}