1use crate::{
8 http, request::reqwest_error_msg, Error, Method, Request, SocketRequest, StatusCode, Url,
9};
10use async_std::task::sleep;
11use derivative::Derivative;
12use serde::de::DeserializeOwned;
13use std::time::{Duration, Instant};
14use vbs::version::StaticVersionType;
15
16pub use tide_disco::healthcheck::{HealthCheck, HealthStatus};
17
18#[derive(Clone, Copy, Debug)]
20pub enum ContentType {
21 Json,
22 Binary,
23}
24
25impl From<ContentType> for http::Mime {
26 fn from(c: ContentType) -> http::Mime {
27 match c {
28 ContentType::Json => http::mime::JSON,
29 ContentType::Binary => http::mime::BYTE_STREAM,
30 }
31 }
32}
33
34#[derive(Derivative)]
36#[derivative(Clone(bound = ""), Debug(bound = ""))]
37pub struct Client<E, VER: StaticVersionType> {
38 inner: reqwest::Client,
39 base_url: Url,
40 accept: ContentType,
41 _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
42}
43
44impl<E: Error, VER: StaticVersionType> Client<E, VER> {
45 pub fn new(base_url: Url) -> Self {
47 Self::builder(base_url).build()
48 }
49
50 pub fn builder(base_url: Url) -> ClientBuilder<E, VER> {
52 ClientBuilder::<E, VER>::new(base_url)
53 }
54
55 pub async fn connect(&self, timeout: Option<Duration>) -> bool {
68 let timeout = timeout.map(|d| Instant::now() + d);
69 while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
70 match self
71 .inner
72 .get(self.base_url.join("/healthcheck").unwrap())
73 .send()
74 .await
75 {
76 Ok(res) if res.status() == StatusCode::OK => return true,
77 Ok(res) => {
78 tracing::info!(
79 url = %self.base_url,
80 status = %res.status(),
81 "waiting for server to become ready",
82 );
83 }
84 Err(err) => {
85 tracing::info!(
86 url = %self.base_url,
87 err = reqwest_error_msg(err),
88 "waiting for server to become ready",
89 );
90 }
91 }
92 sleep(Duration::from_secs(10)).await;
93 }
94 false
95 }
96
97 pub async fn wait_for_health<H: DeserializeOwned + HealthCheck>(
105 &self,
106 healthy: impl Fn(&H) -> bool,
107 timeout: Option<Duration>,
108 ) -> Option<H> {
109 let timeout = timeout.map(|d| Instant::now() + d);
110 while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
111 match self.healthcheck::<H>().await {
112 Ok(health) if healthy(&health) => return Some(health),
113 _ => sleep(Duration::from_secs(10)).await,
114 }
115 }
116 None
117 }
118
119 pub fn get<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
121 self.request(Method::Get, route)
122 }
123
124 pub fn post<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
126 self.request(Method::Post, route)
127 }
128
129 pub async fn healthcheck<H: DeserializeOwned + HealthCheck>(&self) -> Result<H, E> {
131 self.get("healthcheck").send().await
132 }
133
134 pub fn request<T: DeserializeOwned>(&self, method: Method, route: &str) -> Request<T, E, VER> {
136 Request::from(self.inner.request(
137 method.to_string().parse().unwrap(),
138 self.base_url.join(route).unwrap(),
139 ))
140 .header("Accept", http::Mime::from(self.accept).to_string())
141 }
142
143 pub fn socket(&self, route: &str) -> SocketRequest<E, VER> {
149 SocketRequest::new(self.base_url.join(route).unwrap(), self.accept)
150 .header("Accept", http::Mime::from(self.accept).to_string())
151 }
152
153 pub fn module<ModError: Error>(
155 &self,
156 prefix: &str,
157 ) -> Result<Client<ModError, VER>, http::url::ParseError> {
158 Ok(Client {
159 inner: self.inner.clone(),
160 base_url: self.base_url.join(prefix)?,
161 accept: self.accept,
162 _marker: Default::default(),
163 })
164 }
165
166 pub fn base_url(&self) -> Url {
167 self.base_url.clone()
168 }
169}
170
171pub struct ClientBuilder<E: Error, VER: StaticVersionType> {
173 inner: reqwest::ClientBuilder,
174 accept: ContentType,
175 base_url: Url,
176 timeout: Option<Duration>,
177 _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
178}
179
180impl<E: Error, VER: StaticVersionType> ClientBuilder<E, VER> {
181 fn new(mut base_url: Url) -> Self {
182 if !base_url.path().ends_with('/') {
187 base_url.set_path(&format!("{}/", base_url.path()));
188 }
189 Self {
190 inner: reqwest::Client::builder(),
191 accept: ContentType::Binary,
192 base_url,
193 timeout: Some(Duration::from_secs(60)),
194 _marker: Default::default(),
195 }
196 }
197
198 pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
204 self.timeout = timeout;
205 self
206 }
207
208 pub fn content_type(mut self, content_type: ContentType) -> Self {
210 self.accept = content_type;
211 self
212 }
213
214 pub fn build(self) -> Client<E, VER> {
216 let mut builder = self.inner;
217
218 if let Some(timeout) = self.timeout {
219 builder = builder.timeout(timeout);
220 }
221
222 Client {
223 inner: builder.build().unwrap(),
224 base_url: self.base_url,
225 accept: self.accept,
226 _marker: Default::default(),
227 }
228 }
229}
230
231impl<E: Error, VER: StaticVersionType> From<ClientBuilder<E, VER>> for Client<E, VER> {
232 fn from(builder: ClientBuilder<E, VER>) -> Self {
233 builder.build()
234 }
235}
236
237#[cfg(test)]
238mod test {
239 use crate::socket::Connection;
240
241 use super::*;
242 use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
243 use async_std::{sync::RwLock, task::spawn};
244 use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
245 use portpicker::pick_unused_port;
246 use serde::{Deserialize, Serialize};
247 use tide_disco::{error::ServerError, App};
248 use toml::toml;
249 use vbs::version::StaticVersion;
250 type Ver01 = StaticVersion<0, 1>;
251 const VER_0_1: Ver01 = StaticVersion {};
252
253 async fn test_basic_http_client(accept: ContentType) {
254 setup_logging();
255 setup_backtrace();
256
257 let mut app: App<(), ServerError> = App::with_state(());
259 let api = toml! {
260 [route.get]
261 PATH = ["/get"]
262 METHOD = "GET"
263
264 [route.post]
265 PATH = ["/post"]
266 METHOD = "POST"
267 };
268 app.module::<ServerError, Ver01>("mod", api)
269 .unwrap()
270 .get("get", |_req, _state| async move { Ok("response") }.boxed())
271 .unwrap()
272 .post("post", |req, _state| {
273 async move {
274 if req.body_auto::<String, _>(VER_0_1).unwrap() == "body" {
275 Ok("response")
276 } else {
277 Err(ServerError::catch_all(
278 StatusCode::BAD_REQUEST,
279 "invalid body".into(),
280 ))
281 }
282 }
283 .boxed()
284 })
285 .unwrap();
286 let port = pick_unused_port().unwrap();
287 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
288
289 let client = Client::<ServerError, Ver01>::builder(
291 format!("http://localhost:{}", port).parse().unwrap(),
292 )
293 .content_type(accept)
294 .build();
295 assert!(client.connect(None).await);
296
297 assert_eq!(
299 client.get::<String>("mod/get").send().await.unwrap(),
300 "response"
301 );
302 assert_eq!(
303 client
304 .post::<String>("mod/post")
305 .body_json(&"body".to_string())
306 .unwrap()
307 .send()
308 .await
309 .unwrap(),
310 "response"
311 );
312
313 let err = client
315 .post::<String>("mod/post")
316 .body_json(&"bad".to_string())
317 .unwrap()
318 .send()
319 .await
320 .unwrap_err();
321 if err.status != StatusCode::BAD_REQUEST || err.message != "invalid body" {
322 panic!("unexpected error {}", err);
323 }
324 }
325
326 #[async_std::test]
327 async fn test_basic_http_client_json() {
328 test_basic_http_client(ContentType::Json).await
329 }
330
331 #[async_std::test]
332 async fn test_basic_http_client_binary() {
333 test_basic_http_client(ContentType::Binary).await
334 }
335
336 async fn test_streaming_client(accept: ContentType) {
337 setup_logging();
338 setup_backtrace();
339
340 let mut app: App<(), ServerError> = App::with_state(());
342 let api = toml! {
343 [route.echo]
344 PATH = ["/echo"]
345 METHOD = "SOCKET"
346
347 [route.naturals]
348 PATH = ["/naturals/:max"]
349 METHOD = "SOCKET"
350 ":max" = "Integer"
351 };
352 app.module::<ServerError, Ver01>("mod", api)
353 .unwrap()
354 .socket::<_, String, String>("echo", |_req, mut conn, _state| {
355 async move {
356 while let Some(Ok(msg)) = conn.next().await {
357 conn.send(&msg).await.unwrap();
358 }
359 Ok(())
360 }
361 .boxed()
362 })
363 .unwrap()
364 .stream("naturals", |req, _state| {
365 iter(0u64..req.integer_param("max").unwrap())
366 .map(Ok)
367 .boxed()
368 })
369 .unwrap();
370 let port = pick_unused_port().unwrap();
371 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
372
373 let client: Client<ServerError, _> =
375 Client::builder(format!("http://localhost:{}", port).parse().unwrap())
376 .content_type(accept)
377 .build();
378 assert!(client.connect(None).await);
379
380 let mut conn: Connection<_, _, _, Ver01> = client
382 .socket("mod/echo")
383 .connect::<String, String>()
384 .await
385 .unwrap();
386 conn.send(&"foo".into()).await.unwrap();
387 assert_eq!(conn.next().await.unwrap().unwrap(), "foo");
388 conn.send(&"bar".into()).await.unwrap();
389 assert_eq!(conn.next().await.unwrap().unwrap(), "bar");
390
391 assert_eq!(
393 client
394 .socket("mod/naturals/10")
395 .subscribe::<u64>()
396 .await
397 .unwrap()
398 .collect::<Vec<_>>()
399 .await,
400 (0..10).map(Ok).collect::<Vec<_>>()
401 );
402 }
403
404 #[async_std::test]
405 async fn test_streaming_client_json() {
406 test_streaming_client(ContentType::Json).await
407 }
408
409 #[async_std::test]
410 async fn test_streaming_client_binary() {
411 test_streaming_client(ContentType::Binary).await
412 }
413
414 #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
415 enum HealthCheck {
416 Ready,
417 Initializing,
418 }
419
420 impl super::HealthCheck for HealthCheck {
421 fn status(&self) -> StatusCode {
422 StatusCode::OK
423 }
424 }
425
426 #[async_std::test]
427 async fn test_healthcheck() {
428 setup_logging();
429 setup_backtrace();
430
431 let mut app: App<_, ServerError> = App::with_state(RwLock::new(HealthCheck::Initializing));
433 let api = toml! {
434 [route.init]
435 PATH = ["/init"]
436 METHOD = "POST"
437 };
438 app.module::<ServerError, Ver01>("mod", api)
439 .unwrap()
440 .with_health_check(|state| async move { *state.read().await }.boxed())
441 .post("init", |_, state| {
442 async move {
443 *state = HealthCheck::Ready;
444 Ok(())
445 }
446 .boxed()
447 })
448 .unwrap();
449 let port = pick_unused_port().unwrap();
450 spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
451
452 let client = Client::<ServerError, Ver01>::new(
454 format!("http://localhost:{}/mod", port).parse().unwrap(),
455 );
456 assert!(client.connect(None).await);
457 assert_eq!(
458 HealthCheck::Initializing,
459 client.healthcheck().await.unwrap()
460 );
461
462 assert_eq!(
464 client
465 .wait_for_health::<HealthCheck>(
466 |h| *h == HealthCheck::Ready,
467 Some(Duration::from_secs(1))
468 )
469 .await,
470 None
471 );
472
473 client.post::<()>("init").send().await.unwrap();
475
476 assert_eq!(
478 client
479 .wait_for_health::<HealthCheck>(|h| *h == HealthCheck::Ready, None)
480 .await,
481 Some(HealthCheck::Ready)
482 );
483 assert_eq!(HealthCheck::Ready, client.healthcheck().await.unwrap());
484 }
485
486 #[test]
487 fn test_builder() {
488 let client =
489 Client::<ServerError, Ver01>::builder("http://www.example.com".parse().unwrap())
490 .set_timeout(None)
491 .build();
492 assert_eq!(client.base_url, "http://www.example.com".parse().unwrap());
493 }
494}