1use crate::{
8 http, request::reqwest_error_msg, Error, Method, Request, SocketRequest, StatusCode, Url,
9};
10use async_std::task::sleep;
11use async_tungstenite::tungstenite::protocol::WebSocketConfig;
12use derivative::Derivative;
13use serde::de::DeserializeOwned;
14use std::time::{Duration, Instant};
15use vbs::version::StaticVersionType;
16
17pub use tide_disco::healthcheck::{HealthCheck, HealthStatus};
18
19#[derive(Clone, Copy, Debug)]
21pub enum ContentType {
22 Json,
23 Binary,
24}
25
26impl From<ContentType> for http::Mime {
27 fn from(c: ContentType) -> http::Mime {
28 match c {
29 ContentType::Json => http::mime::JSON,
30 ContentType::Binary => http::mime::BYTE_STREAM,
31 }
32 }
33}
34
35#[derive(Derivative)]
37#[derivative(Clone(bound = ""), Debug(bound = ""))]
38pub struct Client<E, VER: StaticVersionType> {
39 inner: reqwest::Client,
40 base_url: Url,
41 accept: ContentType,
42 _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
43}
44
45impl<E: Error, VER: StaticVersionType> Client<E, VER> {
46 pub fn new(base_url: Url) -> Self {
48 Self::builder(base_url).build()
49 }
50
51 pub fn builder(base_url: Url) -> ClientBuilder<E, VER> {
53 ClientBuilder::<E, VER>::new(base_url)
54 }
55
56 pub async fn connect(&self, timeout: Option<Duration>) -> bool {
69 let timeout = timeout.map(|d| Instant::now() + d);
70 while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
71 match self
72 .inner
73 .get(self.base_url.join("/healthcheck").unwrap())
74 .send()
75 .await
76 {
77 Ok(res) if res.status() == StatusCode::OK => return true,
78 Ok(res) => {
79 tracing::info!(
80 url = %self.base_url,
81 status = %res.status(),
82 "waiting for server to become ready",
83 );
84 }
85 Err(err) => {
86 tracing::info!(
87 url = %self.base_url,
88 err = reqwest_error_msg(err),
89 "waiting for server to become ready",
90 );
91 }
92 }
93 sleep(Duration::from_secs(10)).await;
94 }
95 false
96 }
97
98 pub async fn wait_for_health<H: DeserializeOwned + HealthCheck>(
106 &self,
107 healthy: impl Fn(&H) -> bool,
108 timeout: Option<Duration>,
109 ) -> Option<H> {
110 let timeout = timeout.map(|d| Instant::now() + d);
111 while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
112 match self.healthcheck::<H>().await {
113 Ok(health) if healthy(&health) => return Some(health),
114 _ => sleep(Duration::from_secs(10)).await,
115 }
116 }
117 None
118 }
119
120 pub fn get<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
122 self.request(Method::Get, route)
123 }
124
125 pub fn post<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
127 self.request(Method::Post, route)
128 }
129
130 pub async fn healthcheck<H: DeserializeOwned + HealthCheck>(&self) -> Result<H, E> {
132 self.get("healthcheck").send().await
133 }
134
135 pub fn request<T: DeserializeOwned>(&self, method: Method, route: &str) -> Request<T, E, VER> {
137 Request::from(self.inner.request(
138 method.to_string().parse().unwrap(),
139 self.base_url.join(route).unwrap(),
140 ))
141 .header("Accept", http::Mime::from(self.accept).to_string())
142 }
143
144 pub fn socket(&self, route: &str) -> SocketRequest<E, VER> {
150 SocketRequest::new(self.base_url.join(route).unwrap(), self.accept, None)
151 .header("Accept", http::Mime::from(self.accept).to_string())
152 }
153
154 pub fn socket_with_config(
160 &self,
161 route: &str,
162 config: WebSocketConfig,
163 ) -> SocketRequest<E, VER> {
164 SocketRequest::new(
165 self.base_url.join(route).unwrap(),
166 self.accept,
167 Some(config),
168 )
169 .header("Accept", http::Mime::from(self.accept).to_string())
170 }
171
172 pub fn module<ModError: Error>(
174 &self,
175 prefix: &str,
176 ) -> Result<Client<ModError, VER>, http::url::ParseError> {
177 Ok(Client {
178 inner: self.inner.clone(),
179 base_url: self.base_url.join(prefix)?,
180 accept: self.accept,
181 _marker: Default::default(),
182 })
183 }
184
185 pub fn base_url(&self) -> Url {
186 self.base_url.clone()
187 }
188}
189
190pub struct ClientBuilder<E: Error, VER: StaticVersionType> {
192 inner: reqwest::ClientBuilder,
193 accept: ContentType,
194 base_url: Url,
195 timeout: Option<Duration>,
196 _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
197}
198
199impl<E: Error, VER: StaticVersionType> ClientBuilder<E, VER> {
200 fn new(mut base_url: Url) -> Self {
201 if !base_url.path().ends_with('/') {
206 base_url.set_path(&format!("{}/", base_url.path()));
207 }
208 Self {
209 inner: reqwest::Client::builder(),
210 accept: ContentType::Binary,
211 base_url,
212 timeout: Some(Duration::from_secs(60)),
213 _marker: Default::default(),
214 }
215 }
216
217 pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
223 self.timeout = timeout;
224 self
225 }
226
227 pub fn content_type(mut self, content_type: ContentType) -> Self {
229 self.accept = content_type;
230 self
231 }
232
233 pub fn build(self) -> Client<E, VER> {
235 let mut builder = self.inner;
236
237 if let Some(timeout) = self.timeout {
238 builder = builder.timeout(timeout);
239 }
240
241 Client {
242 inner: builder.build().unwrap(),
243 base_url: self.base_url,
244 accept: self.accept,
245 _marker: Default::default(),
246 }
247 }
248}
249
250impl<E: Error, VER: StaticVersionType> From<ClientBuilder<E, VER>> for Client<E, VER> {
251 fn from(builder: ClientBuilder<E, VER>) -> Self {
252 builder.build()
253 }
254}
255
256#[cfg(test)]
257mod test {
258 use crate::socket::Connection;
259
260 use super::*;
261 use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
262 use async_std::{sync::RwLock, task::spawn};
263 use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
264 use portpicker::pick_unused_port;
265 use serde::{Deserialize, Serialize};
266 use tide_disco::{error::ServerError, App};
267 use toml::toml;
268 use vbs::version::StaticVersion;
269 type Ver01 = StaticVersion<0, 1>;
270 const VER_0_1: Ver01 = StaticVersion {};
271
272 async fn test_basic_http_client(accept: ContentType) {
273 setup_logging();
274 setup_backtrace();
275
276 let mut app: App<(), ServerError> = App::with_state(());
278 let api = toml! {
279 [route.get]
280 PATH = ["/get"]
281 METHOD = "GET"
282
283 [route.post]
284 PATH = ["/post"]
285 METHOD = "POST"
286 };
287 app.module::<ServerError, Ver01>("mod", api)
288 .unwrap()
289 .get("get", |_req, _state| async move { Ok("response") }.boxed())
290 .unwrap()
291 .post("post", |req, _state| {
292 async move {
293 if req.body_auto::<String, _>(VER_0_1).unwrap() == "body" {
294 Ok("response")
295 } else {
296 Err(ServerError::catch_all(
297 StatusCode::BAD_REQUEST,
298 "invalid body".into(),
299 ))
300 }
301 }
302 .boxed()
303 })
304 .unwrap();
305 let port = pick_unused_port().unwrap();
306 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
307
308 let client = Client::<ServerError, Ver01>::builder(
310 format!("http://localhost:{}", port).parse().unwrap(),
311 )
312 .content_type(accept)
313 .build();
314 assert!(client.connect(None).await);
315
316 assert_eq!(
318 client.get::<String>("mod/get").send().await.unwrap(),
319 "response"
320 );
321 assert_eq!(
322 client
323 .post::<String>("mod/post")
324 .body_json(&"body".to_string())
325 .unwrap()
326 .send()
327 .await
328 .unwrap(),
329 "response"
330 );
331
332 let err = client
334 .post::<String>("mod/post")
335 .body_json(&"bad".to_string())
336 .unwrap()
337 .send()
338 .await
339 .unwrap_err();
340 if err.status != StatusCode::BAD_REQUEST || err.message != "invalid body" {
341 panic!("unexpected error {}", err);
342 }
343 }
344
345 #[async_std::test]
346 async fn test_basic_http_client_json() {
347 test_basic_http_client(ContentType::Json).await
348 }
349
350 #[async_std::test]
351 async fn test_basic_http_client_binary() {
352 test_basic_http_client(ContentType::Binary).await
353 }
354
355 async fn test_streaming_client(accept: ContentType) {
356 setup_logging();
357 setup_backtrace();
358
359 let mut app: App<(), ServerError> = App::with_state(());
361 let api = toml! {
362 [route.echo]
363 PATH = ["/echo"]
364 METHOD = "SOCKET"
365
366 [route.naturals]
367 PATH = ["/naturals/:max"]
368 METHOD = "SOCKET"
369 ":max" = "Integer"
370 };
371 app.module::<ServerError, Ver01>("mod", api)
372 .unwrap()
373 .socket::<_, String, String>("echo", |_req, mut conn, _state| {
374 async move {
375 while let Some(Ok(msg)) = conn.next().await {
376 conn.send(&msg).await.unwrap();
377 }
378 Ok(())
379 }
380 .boxed()
381 })
382 .unwrap()
383 .stream("naturals", |req, _state| {
384 iter(0u64..req.integer_param("max").unwrap())
385 .map(Ok)
386 .boxed()
387 })
388 .unwrap();
389 let port = pick_unused_port().unwrap();
390 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
391
392 let client: Client<ServerError, _> =
394 Client::builder(format!("http://localhost:{}", port).parse().unwrap())
395 .content_type(accept)
396 .build();
397 assert!(client.connect(None).await);
398
399 let mut conn: Connection<_, _, _, Ver01> = client
401 .socket("mod/echo")
402 .connect::<String, String>()
403 .await
404 .unwrap();
405 conn.send(&"foo".into()).await.unwrap();
406 assert_eq!(conn.next().await.unwrap().unwrap(), "foo");
407 conn.send(&"bar".into()).await.unwrap();
408 assert_eq!(conn.next().await.unwrap().unwrap(), "bar");
409
410 assert_eq!(
412 client
413 .socket("mod/naturals/10")
414 .subscribe::<u64>()
415 .await
416 .unwrap()
417 .collect::<Vec<_>>()
418 .await,
419 (0..10).map(Ok).collect::<Vec<_>>()
420 );
421 }
422
423 #[async_std::test]
424 async fn test_streaming_client_json() {
425 test_streaming_client(ContentType::Json).await
426 }
427
428 #[async_std::test]
429 async fn test_streaming_client_binary() {
430 test_streaming_client(ContentType::Binary).await
431 }
432
433 #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
434 enum HealthCheck {
435 Ready,
436 Initializing,
437 }
438
439 impl super::HealthCheck for HealthCheck {
440 fn status(&self) -> StatusCode {
441 StatusCode::OK
442 }
443 }
444
445 #[async_std::test]
446 async fn test_healthcheck() {
447 setup_logging();
448 setup_backtrace();
449
450 let mut app: App<_, ServerError> = App::with_state(RwLock::new(HealthCheck::Initializing));
452 let api = toml! {
453 [route.init]
454 PATH = ["/init"]
455 METHOD = "POST"
456 };
457 app.module::<ServerError, Ver01>("mod", api)
458 .unwrap()
459 .with_health_check(|state| async move { *state.read().await }.boxed())
460 .post("init", |_, state| {
461 async move {
462 *state = HealthCheck::Ready;
463 Ok(())
464 }
465 .boxed()
466 })
467 .unwrap();
468 let port = pick_unused_port().unwrap();
469 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
470
471 let client = Client::<ServerError, Ver01>::new(
473 format!("http://localhost:{}/mod", port).parse().unwrap(),
474 );
475 assert!(client.connect(None).await);
476 assert_eq!(
477 HealthCheck::Initializing,
478 client.healthcheck().await.unwrap()
479 );
480
481 assert_eq!(
483 client
484 .wait_for_health::<HealthCheck>(
485 |h| *h == HealthCheck::Ready,
486 Some(Duration::from_secs(1))
487 )
488 .await,
489 None
490 );
491
492 client.post::<()>("init").send().await.unwrap();
494
495 assert_eq!(
497 client
498 .wait_for_health::<HealthCheck>(|h| *h == HealthCheck::Ready, None)
499 .await,
500 Some(HealthCheck::Ready)
501 );
502 assert_eq!(HealthCheck::Ready, client.healthcheck().await.unwrap());
503 }
504
505 #[test]
506 fn test_builder() {
507 let client =
508 Client::<ServerError, Ver01>::builder("http://www.example.com".parse().unwrap())
509 .set_timeout(None)
510 .build();
511 assert_eq!(client.base_url, "http://www.example.com".parse().unwrap());
512 }
513}