1use crate::{
8 http, request::reqwest_error_msg, Error, Method, Request, SocketRequest, StatusCode, Url,
9};
10use async_std::task::sleep;
11use derivative::Derivative;
12use serde::de::DeserializeOwned;
13use std::time::{Duration, Instant};
14use vbs::version::StaticVersionType;
15
16pub use tide_disco::healthcheck::{HealthCheck, HealthStatus};
17
18#[derive(Clone, Copy, Debug)]
20pub enum ContentType {
21 Json,
22 Binary,
23}
24
25impl From<ContentType> for http::Mime {
26 fn from(c: ContentType) -> http::Mime {
27 match c {
28 ContentType::Json => http::mime::JSON,
29 ContentType::Binary => http::mime::BYTE_STREAM,
30 }
31 }
32}
33
34#[derive(Derivative)]
36#[derivative(Clone(bound = ""), Debug(bound = ""))]
37pub struct Client<E, VER: StaticVersionType> {
38 inner: reqwest::Client,
39 base_url: Url,
40 accept: ContentType,
41 _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
42}
43
44impl<E: Error, VER: StaticVersionType> Client<E, VER> {
45 pub fn new(base_url: Url) -> Self {
47 Self::builder(base_url).build()
48 }
49
50 pub fn builder(base_url: Url) -> ClientBuilder<E, VER> {
52 ClientBuilder::<E, VER>::new(base_url)
53 }
54
55 pub async fn connect(&self, timeout: Option<Duration>) -> bool {
68 let timeout = timeout.map(|d| Instant::now() + d);
69 while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
70 match self
71 .inner
72 .get(self.base_url.join("/healthcheck").unwrap())
73 .send()
74 .await
75 {
76 Ok(res) if res.status() == StatusCode::OK => return true,
77 Ok(res) => {
78 tracing::info!(
79 url = %self.base_url,
80 status = %res.status(),
81 "waiting for server to become ready",
82 );
83 }
84 Err(err) => {
85 tracing::info!(
86 url = %self.base_url,
87 err = reqwest_error_msg(err),
88 "waiting for server to become ready",
89 );
90 }
91 }
92 sleep(Duration::from_secs(10)).await;
93 }
94 false
95 }
96
97 pub async fn wait_for_health<H: DeserializeOwned + HealthCheck>(
105 &self,
106 healthy: impl Fn(&H) -> bool,
107 timeout: Option<Duration>,
108 ) -> Option<H> {
109 let timeout = timeout.map(|d| Instant::now() + d);
110 while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
111 match self.healthcheck::<H>().await {
112 Ok(health) if healthy(&health) => return Some(health),
113 _ => sleep(Duration::from_secs(10)).await,
114 }
115 }
116 None
117 }
118
119 pub fn get<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
121 self.request(Method::Get, route)
122 }
123
124 pub fn post<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
126 self.request(Method::Post, route)
127 }
128
129 pub async fn healthcheck<H: DeserializeOwned + HealthCheck>(&self) -> Result<H, E> {
131 self.get("healthcheck").send().await
132 }
133
134 pub fn request<T: DeserializeOwned>(&self, method: Method, route: &str) -> Request<T, E, VER> {
136 Request::from(self.inner.request(
137 method.to_string().parse().unwrap(),
138 self.base_url.join(route).unwrap(),
139 ))
140 .header("Accept", http::Mime::from(self.accept).to_string())
141 }
142
143 pub fn socket(&self, route: &str) -> SocketRequest<E, VER> {
149 SocketRequest::new(self.base_url.join(route).unwrap(), self.accept)
150 .header("Accept", http::Mime::from(self.accept).to_string())
151 }
152
153 pub fn module<ModError: Error>(
155 &self,
156 prefix: &str,
157 ) -> Result<Client<ModError, VER>, http::url::ParseError> {
158 Ok(Client {
159 inner: self.inner.clone(),
160 base_url: self.base_url.join(prefix)?,
161 accept: self.accept,
162 _marker: Default::default(),
163 })
164 }
165}
166
167pub struct ClientBuilder<E: Error, VER: StaticVersionType> {
169 inner: reqwest::ClientBuilder,
170 accept: ContentType,
171 base_url: Url,
172 timeout: Option<Duration>,
173 _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
174}
175
176impl<E: Error, VER: StaticVersionType> ClientBuilder<E, VER> {
177 fn new(mut base_url: Url) -> Self {
178 if !base_url.path().ends_with('/') {
183 base_url.set_path(&format!("{}/", base_url.path()));
184 }
185 Self {
186 inner: reqwest::Client::builder(),
187 accept: ContentType::Binary,
188 base_url,
189 timeout: Some(Duration::from_secs(60)),
190 _marker: Default::default(),
191 }
192 }
193
194 pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
200 self.timeout = timeout;
201 self
202 }
203
204 pub fn content_type(mut self, content_type: ContentType) -> Self {
206 self.accept = content_type;
207 self
208 }
209
210 pub fn build(self) -> Client<E, VER> {
212 let mut builder = self.inner;
213
214 if let Some(timeout) = self.timeout {
215 builder = builder.timeout(timeout);
216 }
217
218 Client {
219 inner: builder.build().unwrap(),
220 base_url: self.base_url,
221 accept: self.accept,
222 _marker: Default::default(),
223 }
224 }
225}
226
227impl<E: Error, VER: StaticVersionType> From<ClientBuilder<E, VER>> for Client<E, VER> {
228 fn from(builder: ClientBuilder<E, VER>) -> Self {
229 builder.build()
230 }
231}
232
233#[cfg(test)]
234mod test {
235 use crate::socket::Connection;
236
237 use super::*;
238 use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
239 use async_std::{sync::RwLock, task::spawn};
240 use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
241 use portpicker::pick_unused_port;
242 use serde::{Deserialize, Serialize};
243 use tide_disco::{error::ServerError, App};
244 use toml::toml;
245 use vbs::version::StaticVersion;
246 type Ver01 = StaticVersion<0, 1>;
247 const VER_0_1: Ver01 = StaticVersion {};
248
249 async fn test_basic_http_client(accept: ContentType) {
250 setup_logging();
251 setup_backtrace();
252
253 let mut app: App<(), ServerError> = App::with_state(());
255 let api = toml! {
256 [route.get]
257 PATH = ["/get"]
258 METHOD = "GET"
259
260 [route.post]
261 PATH = ["/post"]
262 METHOD = "POST"
263 };
264 app.module::<ServerError, Ver01>("mod", api)
265 .unwrap()
266 .get("get", |_req, _state| async move { Ok("response") }.boxed())
267 .unwrap()
268 .post("post", |req, _state| {
269 async move {
270 if req.body_auto::<String, _>(VER_0_1).unwrap() == "body" {
271 Ok("response")
272 } else {
273 Err(ServerError::catch_all(
274 StatusCode::BAD_REQUEST,
275 "invalid body".into(),
276 ))
277 }
278 }
279 .boxed()
280 })
281 .unwrap();
282 let port = pick_unused_port().unwrap();
283 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
284
285 let client = Client::<ServerError, Ver01>::builder(
287 format!("http://localhost:{}", port).parse().unwrap(),
288 )
289 .content_type(accept)
290 .build();
291 assert!(client.connect(None).await);
292
293 assert_eq!(
295 client.get::<String>("mod/get").send().await.unwrap(),
296 "response"
297 );
298 assert_eq!(
299 client
300 .post::<String>("mod/post")
301 .body_json(&"body".to_string())
302 .unwrap()
303 .send()
304 .await
305 .unwrap(),
306 "response"
307 );
308
309 let err = client
311 .post::<String>("mod/post")
312 .body_json(&"bad".to_string())
313 .unwrap()
314 .send()
315 .await
316 .unwrap_err();
317 if err.status != StatusCode::BAD_REQUEST || err.message != "invalid body" {
318 panic!("unexpected error {}", err);
319 }
320 }
321
322 #[async_std::test]
323 async fn test_basic_http_client_json() {
324 test_basic_http_client(ContentType::Json).await
325 }
326
327 #[async_std::test]
328 async fn test_basic_http_client_binary() {
329 test_basic_http_client(ContentType::Binary).await
330 }
331
332 async fn test_streaming_client(accept: ContentType) {
333 setup_logging();
334 setup_backtrace();
335
336 let mut app: App<(), ServerError> = App::with_state(());
338 let api = toml! {
339 [route.echo]
340 PATH = ["/echo"]
341 METHOD = "SOCKET"
342
343 [route.naturals]
344 PATH = ["/naturals/:max"]
345 METHOD = "SOCKET"
346 ":max" = "Integer"
347 };
348 app.module::<ServerError, Ver01>("mod", api)
349 .unwrap()
350 .socket::<_, String, String>("echo", |_req, mut conn, _state| {
351 async move {
352 while let Some(Ok(msg)) = conn.next().await {
353 conn.send(&msg).await.unwrap();
354 }
355 Ok(())
356 }
357 .boxed()
358 })
359 .unwrap()
360 .stream("naturals", |req, _state| {
361 iter(0u64..req.integer_param("max").unwrap())
362 .map(Ok)
363 .boxed()
364 })
365 .unwrap();
366 let port = pick_unused_port().unwrap();
367 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
368
369 let client: Client<ServerError, _> =
371 Client::builder(format!("http://localhost:{}", port).parse().unwrap())
372 .content_type(accept)
373 .build();
374 assert!(client.connect(None).await);
375
376 let mut conn: Connection<_, _, _, Ver01> = client
378 .socket("mod/echo")
379 .connect::<String, String>()
380 .await
381 .unwrap();
382 conn.send(&"foo".into()).await.unwrap();
383 assert_eq!(conn.next().await.unwrap().unwrap(), "foo");
384 conn.send(&"bar".into()).await.unwrap();
385 assert_eq!(conn.next().await.unwrap().unwrap(), "bar");
386
387 assert_eq!(
389 client
390 .socket("mod/naturals/10")
391 .subscribe::<u64>()
392 .await
393 .unwrap()
394 .collect::<Vec<_>>()
395 .await,
396 (0..10).map(Ok).collect::<Vec<_>>()
397 );
398 }
399
400 #[async_std::test]
401 async fn test_streaming_client_json() {
402 test_streaming_client(ContentType::Json).await
403 }
404
405 #[async_std::test]
406 async fn test_streaming_client_binary() {
407 test_streaming_client(ContentType::Binary).await
408 }
409
410 #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
411 enum HealthCheck {
412 Ready,
413 Initializing,
414 }
415
416 impl super::HealthCheck for HealthCheck {
417 fn status(&self) -> StatusCode {
418 StatusCode::OK
419 }
420 }
421
422 #[async_std::test]
423 async fn test_healthcheck() {
424 setup_logging();
425 setup_backtrace();
426
427 let mut app: App<_, ServerError> = App::with_state(RwLock::new(HealthCheck::Initializing));
429 let api = toml! {
430 [route.init]
431 PATH = ["/init"]
432 METHOD = "POST"
433 };
434 app.module::<ServerError, Ver01>("mod", api)
435 .unwrap()
436 .with_health_check(|state| async move { *state.read().await }.boxed())
437 .post("init", |_, state| {
438 async move {
439 *state = HealthCheck::Ready;
440 Ok(())
441 }
442 .boxed()
443 })
444 .unwrap();
445 let port = pick_unused_port().unwrap();
446 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
447
448 let client = Client::<ServerError, Ver01>::new(
450 format!("http://localhost:{}/mod", port).parse().unwrap(),
451 );
452 assert!(client.connect(None).await);
453 assert_eq!(
454 HealthCheck::Initializing,
455 client.healthcheck().await.unwrap()
456 );
457
458 assert_eq!(
460 client
461 .wait_for_health::<HealthCheck>(
462 |h| *h == HealthCheck::Ready,
463 Some(Duration::from_secs(1))
464 )
465 .await,
466 None
467 );
468
469 client.post::<()>("init").send().await.unwrap();
471
472 assert_eq!(
474 client
475 .wait_for_health::<HealthCheck>(|h| *h == HealthCheck::Ready, None)
476 .await,
477 Some(HealthCheck::Ready)
478 );
479 assert_eq!(HealthCheck::Ready, client.healthcheck().await.unwrap());
480 }
481
482 #[test]
483 fn test_builder() {
484 let client =
485 Client::<ServerError, Ver01>::builder("http://www.example.com".parse().unwrap())
486 .set_timeout(None)
487 .build();
488 assert_eq!(client.base_url, "http://www.example.com".parse().unwrap());
489 }
490}