surf_disco/
client.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the surf-disco library.
3
4// You should have received a copy of the MIT License
5// along with the surf-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    http, request::reqwest_error_msg, Error, Method, Request, SocketRequest, StatusCode, Url,
9};
10use async_std::task::sleep;
11use async_tungstenite::tungstenite::protocol::WebSocketConfig;
12use derivative::Derivative;
13use serde::de::DeserializeOwned;
14use std::time::{Duration, Instant};
15use vbs::version::StaticVersionType;
16
17pub use tide_disco::healthcheck::{HealthCheck, HealthStatus};
18
19/// Content types supported by Tide Disco.
20#[derive(Clone, Copy, Debug)]
21pub enum ContentType {
22    Json,
23    Binary,
24}
25
26impl From<ContentType> for http::Mime {
27    fn from(c: ContentType) -> http::Mime {
28        match c {
29            ContentType::Json => http::mime::JSON,
30            ContentType::Binary => http::mime::BYTE_STREAM,
31        }
32    }
33}
34
35/// A client of a Tide Disco application.
36#[derive(Derivative)]
37#[derivative(Clone(bound = ""), Debug(bound = ""))]
38pub struct Client<E, VER: StaticVersionType> {
39    inner: reqwest::Client,
40    base_url: Url,
41    accept: ContentType,
42    _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
43}
44
45impl<E: Error, VER: StaticVersionType> Client<E, VER> {
46    /// Create a client and connect to the Tide Disco server at `base_url`.
47    pub fn new(base_url: Url) -> Self {
48        Self::builder(base_url).build()
49    }
50
51    /// Create a client with customization.
52    pub fn builder(base_url: Url) -> ClientBuilder<E, VER> {
53        ClientBuilder::<E, VER>::new(base_url)
54    }
55
56    /// Connect to the server, retrying if the server is not running.
57    ///
58    /// It is not necessary to call this function when creating a new client. The client will
59    /// automatically connect when a request is made, if the server is available. However, this can
60    /// be useful to wait for the server to come up, if the server may be offline when the client is
61    /// created.
62    ///
63    /// This function will make an HTTP `GET` request to the server's `/healthcheck` endpoint, to
64    /// test if the server is available. If this request succeeds, [connect](Self::connect) returns
65    /// `true`. Otherwise, the client will continue retrying `/healthcheck` requests until `timeout`
66    /// has elapsed (or forever, if `timeout` is `None`). If the timeout expires before a
67    /// `/healthcheck` request succeeds, [connect](Self::connect) will return `false`.
68    pub async fn connect(&self, timeout: Option<Duration>) -> bool {
69        let timeout = timeout.map(|d| Instant::now() + d);
70        while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
71            match self
72                .inner
73                .get(self.base_url.join("/healthcheck").unwrap())
74                .send()
75                .await
76            {
77                Ok(res) if res.status() == StatusCode::OK => return true,
78                Ok(res) => {
79                    tracing::info!(
80                        url = %self.base_url,
81                        status = %res.status(),
82                        "waiting for server to become ready",
83                    );
84                }
85                Err(err) => {
86                    tracing::info!(
87                        url = %self.base_url,
88                        err = reqwest_error_msg(err),
89                        "waiting for server to become ready",
90                    );
91                }
92            }
93            sleep(Duration::from_secs(10)).await;
94        }
95        false
96    }
97
98    /// Connect to the server, retrying until the server is `healthy`.
99    ///
100    /// This function is similar to [connect](Self::connect). It will make requests to the
101    /// `/healthcheck` endpoint until a request succeeds. However, it will then continue retrying
102    /// until the response from `/healthcheck` satisfies the `healthy` predicate.
103    ///
104    /// On success, returns the response from `/healthcheck`. On timeout, returns `None`.
105    pub async fn wait_for_health<H: DeserializeOwned + HealthCheck>(
106        &self,
107        healthy: impl Fn(&H) -> bool,
108        timeout: Option<Duration>,
109    ) -> Option<H> {
110        let timeout = timeout.map(|d| Instant::now() + d);
111        while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
112            match self.healthcheck::<H>().await {
113                Ok(health) if healthy(&health) => return Some(health),
114                _ => sleep(Duration::from_secs(10)).await,
115            }
116        }
117        None
118    }
119
120    /// Build an HTTP `GET` request.
121    pub fn get<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
122        self.request(Method::Get, route)
123    }
124
125    /// Build an HTTP `POST` request.
126    pub fn post<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
127        self.request(Method::Post, route)
128    }
129
130    /// Query the server's healthcheck endpoint.
131    pub async fn healthcheck<H: DeserializeOwned + HealthCheck>(&self) -> Result<H, E> {
132        self.get("healthcheck").send().await
133    }
134
135    /// Build an HTTP request with the specified method.
136    pub fn request<T: DeserializeOwned>(&self, method: Method, route: &str) -> Request<T, E, VER> {
137        Request::from(self.inner.request(
138            method.to_string().parse().unwrap(),
139            self.base_url.join(route).unwrap(),
140        ))
141        .header("Accept", http::Mime::from(self.accept).to_string())
142    }
143
144    /// Build a streaming connection request.
145    ///
146    /// # Panics
147    ///
148    /// This will panic if a malformed URL is passed.
149    pub fn socket(&self, route: &str) -> SocketRequest<E, VER> {
150        SocketRequest::new(self.base_url.join(route).unwrap(), self.accept, None)
151            .header("Accept", http::Mime::from(self.accept).to_string())
152    }
153
154    /// Build a streaming connection request using a custom [`WebSocketConfig`].
155    ///
156    /// # Panics
157    ///
158    /// This will panic if a malformed URL is passed.
159    pub fn socket_with_config(
160        &self,
161        route: &str,
162        config: WebSocketConfig,
163    ) -> SocketRequest<E, VER> {
164        SocketRequest::new(
165            self.base_url.join(route).unwrap(),
166            self.accept,
167            Some(config),
168        )
169        .header("Accept", http::Mime::from(self.accept).to_string())
170    }
171
172    /// Create a client for a sub-module of the connected application.
173    pub fn module<ModError: Error>(
174        &self,
175        prefix: &str,
176    ) -> Result<Client<ModError, VER>, http::url::ParseError> {
177        Ok(Client {
178            inner: self.inner.clone(),
179            base_url: self.base_url.join(prefix)?,
180            accept: self.accept,
181            _marker: Default::default(),
182        })
183    }
184
185    pub fn base_url(&self) -> Url {
186        self.base_url.clone()
187    }
188}
189
190/// Interface to specify optional configuration values before creating a [Client].
191pub struct ClientBuilder<E: Error, VER: StaticVersionType> {
192    inner: reqwest::ClientBuilder,
193    accept: ContentType,
194    base_url: Url,
195    timeout: Option<Duration>,
196    _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
197}
198
199impl<E: Error, VER: StaticVersionType> ClientBuilder<E, VER> {
200    fn new(mut base_url: Url) -> Self {
201        // If the path part of `base_url` does not end in `/`, `join` will treat it as a filename
202        // and remove it, which is never what we want: `base_url` is _always_ a directory-like path.
203        // To avoid the annoyance of having every caller add a trailing slash if necessary, we will
204        // add a trailing slash here if there isn't one already.
205        if !base_url.path().ends_with('/') {
206            base_url.set_path(&format!("{}/", base_url.path()));
207        }
208        Self {
209            inner: reqwest::Client::builder(),
210            accept: ContentType::Binary,
211            base_url,
212            timeout: Some(Duration::from_secs(60)),
213            _marker: Default::default(),
214        }
215    }
216
217    /// Set connection timeout duration.
218    ///
219    /// Passing `None` will remove the timeout.
220    ///
221    /// Default: `Some(Duration::from_secs(60))`.
222    pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
223        self.timeout = timeout;
224        self
225    }
226
227    /// Set the content type used for responses.
228    pub fn content_type(mut self, content_type: ContentType) -> Self {
229        self.accept = content_type;
230        self
231    }
232
233    /// Create a [Client] with the settings specified in this builder.
234    pub fn build(self) -> Client<E, VER> {
235        let mut builder = self.inner;
236
237        if let Some(timeout) = self.timeout {
238            builder = builder.timeout(timeout);
239        }
240
241        Client {
242            inner: builder.build().unwrap(),
243            base_url: self.base_url,
244            accept: self.accept,
245            _marker: Default::default(),
246        }
247    }
248}
249
250impl<E: Error, VER: StaticVersionType> From<ClientBuilder<E, VER>> for Client<E, VER> {
251    fn from(builder: ClientBuilder<E, VER>) -> Self {
252        builder.build()
253    }
254}
255
256#[cfg(test)]
257mod test {
258    use crate::socket::Connection;
259
260    use super::*;
261    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
262    use async_std::{sync::RwLock, task::spawn};
263    use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
264    use portpicker::pick_unused_port;
265    use serde::{Deserialize, Serialize};
266    use tide_disco::{error::ServerError, App};
267    use toml::toml;
268    use vbs::version::StaticVersion;
269    type Ver01 = StaticVersion<0, 1>;
270    const VER_0_1: Ver01 = StaticVersion {};
271
272    async fn test_basic_http_client(accept: ContentType) {
273        setup_logging();
274        setup_backtrace();
275
276        // Set up a simple Tide Disco app as an example.
277        let mut app: App<(), ServerError> = App::with_state(());
278        let api = toml! {
279            [route.get]
280            PATH = ["/get"]
281            METHOD = "GET"
282
283            [route.post]
284            PATH = ["/post"]
285            METHOD = "POST"
286        };
287        app.module::<ServerError, Ver01>("mod", api)
288            .unwrap()
289            .get("get", |_req, _state| async move { Ok("response") }.boxed())
290            .unwrap()
291            .post("post", |req, _state| {
292                async move {
293                    if req.body_auto::<String, _>(VER_0_1).unwrap() == "body" {
294                        Ok("response")
295                    } else {
296                        Err(ServerError::catch_all(
297                            StatusCode::BAD_REQUEST,
298                            "invalid body".into(),
299                        ))
300                    }
301                }
302                .boxed()
303            })
304            .unwrap();
305        let port = pick_unused_port().unwrap();
306        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
307
308        // Connect a client.
309        let client = Client::<ServerError, Ver01>::builder(
310            format!("http://localhost:{}", port).parse().unwrap(),
311        )
312        .content_type(accept)
313        .build();
314        assert!(client.connect(None).await);
315
316        // Test a couple of basic requests.
317        assert_eq!(
318            client.get::<String>("mod/get").send().await.unwrap(),
319            "response"
320        );
321        assert_eq!(
322            client
323                .post::<String>("mod/post")
324                .body_json(&"body".to_string())
325                .unwrap()
326                .send()
327                .await
328                .unwrap(),
329            "response"
330        );
331
332        // Test an error response.
333        let err = client
334            .post::<String>("mod/post")
335            .body_json(&"bad".to_string())
336            .unwrap()
337            .send()
338            .await
339            .unwrap_err();
340        if err.status != StatusCode::BAD_REQUEST || err.message != "invalid body" {
341            panic!("unexpected error {}", err);
342        }
343    }
344
345    #[async_std::test]
346    async fn test_basic_http_client_json() {
347        test_basic_http_client(ContentType::Json).await
348    }
349
350    #[async_std::test]
351    async fn test_basic_http_client_binary() {
352        test_basic_http_client(ContentType::Binary).await
353    }
354
355    async fn test_streaming_client(accept: ContentType) {
356        setup_logging();
357        setup_backtrace();
358
359        // Set up a simple Tide Disco app as an example.
360        let mut app: App<(), ServerError> = App::with_state(());
361        let api = toml! {
362            [route.echo]
363            PATH = ["/echo"]
364            METHOD = "SOCKET"
365
366            [route.naturals]
367            PATH = ["/naturals/:max"]
368            METHOD = "SOCKET"
369            ":max" = "Integer"
370        };
371        app.module::<ServerError, Ver01>("mod", api)
372            .unwrap()
373            .socket::<_, String, String>("echo", |_req, mut conn, _state| {
374                async move {
375                    while let Some(Ok(msg)) = conn.next().await {
376                        conn.send(&msg).await.unwrap();
377                    }
378                    Ok(())
379                }
380                .boxed()
381            })
382            .unwrap()
383            .stream("naturals", |req, _state| {
384                iter(0u64..req.integer_param("max").unwrap())
385                    .map(Ok)
386                    .boxed()
387            })
388            .unwrap();
389        let port = pick_unused_port().unwrap();
390        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
391
392        // Connect a client.
393        let client: Client<ServerError, _> =
394            Client::builder(format!("http://localhost:{}", port).parse().unwrap())
395                .content_type(accept)
396                .build();
397        assert!(client.connect(None).await);
398
399        // Test a bidirectional endpoint.
400        let mut conn: Connection<_, _, _, Ver01> = client
401            .socket("mod/echo")
402            .connect::<String, String>()
403            .await
404            .unwrap();
405        conn.send(&"foo".into()).await.unwrap();
406        assert_eq!(conn.next().await.unwrap().unwrap(), "foo");
407        conn.send(&"bar".into()).await.unwrap();
408        assert_eq!(conn.next().await.unwrap().unwrap(), "bar");
409
410        // Test a streaming endpoint.
411        assert_eq!(
412            client
413                .socket("mod/naturals/10")
414                .subscribe::<u64>()
415                .await
416                .unwrap()
417                .collect::<Vec<_>>()
418                .await,
419            (0..10).map(Ok).collect::<Vec<_>>()
420        );
421    }
422
423    #[async_std::test]
424    async fn test_streaming_client_json() {
425        test_streaming_client(ContentType::Json).await
426    }
427
428    #[async_std::test]
429    async fn test_streaming_client_binary() {
430        test_streaming_client(ContentType::Binary).await
431    }
432
433    #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
434    enum HealthCheck {
435        Ready,
436        Initializing,
437    }
438
439    impl super::HealthCheck for HealthCheck {
440        fn status(&self) -> StatusCode {
441            StatusCode::OK
442        }
443    }
444
445    #[async_std::test]
446    async fn test_healthcheck() {
447        setup_logging();
448        setup_backtrace();
449
450        // Set up a simple Tide Disco app as an example.
451        let mut app: App<_, ServerError> = App::with_state(RwLock::new(HealthCheck::Initializing));
452        let api = toml! {
453            [route.init]
454            PATH = ["/init"]
455            METHOD = "POST"
456        };
457        app.module::<ServerError, Ver01>("mod", api)
458            .unwrap()
459            .with_health_check(|state| async move { *state.read().await }.boxed())
460            .post("init", |_, state| {
461                async move {
462                    *state = HealthCheck::Ready;
463                    Ok(())
464                }
465                .boxed()
466            })
467            .unwrap();
468        let port = pick_unused_port().unwrap();
469        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
470
471        // Connect a client.
472        let client = Client::<ServerError, Ver01>::new(
473            format!("http://localhost:{}/mod", port).parse().unwrap(),
474        );
475        assert!(client.connect(None).await);
476        assert_eq!(
477            HealthCheck::Initializing,
478            client.healthcheck().await.unwrap()
479        );
480
481        // Waiting for [HealthCheck::Ready] should time out.
482        assert_eq!(
483            client
484                .wait_for_health::<HealthCheck>(
485                    |h| *h == HealthCheck::Ready,
486                    Some(Duration::from_secs(1))
487                )
488                .await,
489            None
490        );
491
492        // Initialize the service.
493        client.post::<()>("init").send().await.unwrap();
494
495        // Now waiting for [HealthCheck::Ready] should succeed.
496        assert_eq!(
497            client
498                .wait_for_health::<HealthCheck>(|h| *h == HealthCheck::Ready, None)
499                .await,
500            Some(HealthCheck::Ready)
501        );
502        assert_eq!(HealthCheck::Ready, client.healthcheck().await.unwrap());
503    }
504
505    #[test]
506    fn test_builder() {
507        let client =
508            Client::<ServerError, Ver01>::builder("http://www.example.com".parse().unwrap())
509                .set_timeout(None)
510                .build();
511        assert_eq!(client.base_url, "http://www.example.com".parse().unwrap());
512    }
513}