surf_disco/
client.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the surf-disco library.
3
4// You should have received a copy of the MIT License
5// along with the surf-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    http, request::reqwest_error_msg, Error, Method, Request, SocketRequest, StatusCode, Url,
9};
10use async_std::task::sleep;
11use derivative::Derivative;
12use serde::de::DeserializeOwned;
13use std::time::{Duration, Instant};
14use vbs::version::StaticVersionType;
15
16pub use tide_disco::healthcheck::{HealthCheck, HealthStatus};
17
18/// Content types supported by Tide Disco.
19#[derive(Clone, Copy, Debug)]
20pub enum ContentType {
21    Json,
22    Binary,
23}
24
25impl From<ContentType> for http::Mime {
26    fn from(c: ContentType) -> http::Mime {
27        match c {
28            ContentType::Json => http::mime::JSON,
29            ContentType::Binary => http::mime::BYTE_STREAM,
30        }
31    }
32}
33
34/// A client of a Tide Disco application.
35#[derive(Derivative)]
36#[derivative(Clone(bound = ""), Debug(bound = ""))]
37pub struct Client<E, VER: StaticVersionType> {
38    inner: reqwest::Client,
39    base_url: Url,
40    accept: ContentType,
41    _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
42}
43
44impl<E: Error, VER: StaticVersionType> Client<E, VER> {
45    /// Create a client and connect to the Tide Disco server at `base_url`.
46    pub fn new(base_url: Url) -> Self {
47        Self::builder(base_url).build()
48    }
49
50    /// Create a client with customization.
51    pub fn builder(base_url: Url) -> ClientBuilder<E, VER> {
52        ClientBuilder::<E, VER>::new(base_url)
53    }
54
55    /// Connect to the server, retrying if the server is not running.
56    ///
57    /// It is not necessary to call this function when creating a new client. The client will
58    /// automatically connect when a request is made, if the server is available. However, this can
59    /// be useful to wait for the server to come up, if the server may be offline when the client is
60    /// created.
61    ///
62    /// This function will make an HTTP `GET` request to the server's `/healthcheck` endpoint, to
63    /// test if the server is available. If this request succeeds, [connect](Self::connect) returns
64    /// `true`. Otherwise, the client will continue retrying `/healthcheck` requests until `timeout`
65    /// has elapsed (or forever, if `timeout` is `None`). If the timeout expires before a
66    /// `/healthcheck` request succeeds, [connect](Self::connect) will return `false`.
67    pub async fn connect(&self, timeout: Option<Duration>) -> bool {
68        let timeout = timeout.map(|d| Instant::now() + d);
69        while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
70            match self
71                .inner
72                .get(self.base_url.join("/healthcheck").unwrap())
73                .send()
74                .await
75            {
76                Ok(res) if res.status() == StatusCode::OK => return true,
77                Ok(res) => {
78                    tracing::info!(
79                        url = %self.base_url,
80                        status = %res.status(),
81                        "waiting for server to become ready",
82                    );
83                }
84                Err(err) => {
85                    tracing::info!(
86                        url = %self.base_url,
87                        err = reqwest_error_msg(err),
88                        "waiting for server to become ready",
89                    );
90                }
91            }
92            sleep(Duration::from_secs(10)).await;
93        }
94        false
95    }
96
97    /// Connect to the server, retrying until the server is `healthy`.
98    ///
99    /// This function is similar to [connect](Self::connect). It will make requests to the
100    /// `/healthcheck` endpoint until a request succeeds. However, it will then continue retrying
101    /// until the response from `/healthcheck` satisfies the `healthy` predicate.
102    ///
103    /// On success, returns the response from `/healthcheck`. On timeout, returns `None`.
104    pub async fn wait_for_health<H: DeserializeOwned + HealthCheck>(
105        &self,
106        healthy: impl Fn(&H) -> bool,
107        timeout: Option<Duration>,
108    ) -> Option<H> {
109        let timeout = timeout.map(|d| Instant::now() + d);
110        while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
111            match self.healthcheck::<H>().await {
112                Ok(health) if healthy(&health) => return Some(health),
113                _ => sleep(Duration::from_secs(10)).await,
114            }
115        }
116        None
117    }
118
119    /// Build an HTTP `GET` request.
120    pub fn get<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
121        self.request(Method::Get, route)
122    }
123
124    /// Build an HTTP `POST` request.
125    pub fn post<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
126        self.request(Method::Post, route)
127    }
128
129    /// Query the server's healthcheck endpoint.
130    pub async fn healthcheck<H: DeserializeOwned + HealthCheck>(&self) -> Result<H, E> {
131        self.get("healthcheck").send().await
132    }
133
134    /// Build an HTTP request with the specified method.
135    pub fn request<T: DeserializeOwned>(&self, method: Method, route: &str) -> Request<T, E, VER> {
136        Request::from(self.inner.request(
137            method.to_string().parse().unwrap(),
138            self.base_url.join(route).unwrap(),
139        ))
140        .header("Accept", http::Mime::from(self.accept).to_string())
141    }
142
143    /// Build a streaming connection request.
144    ///
145    /// # Panics
146    ///
147    /// This will panic if a malformed URL is passed.
148    pub fn socket(&self, route: &str) -> SocketRequest<E, VER> {
149        SocketRequest::new(self.base_url.join(route).unwrap(), self.accept)
150            .header("Accept", http::Mime::from(self.accept).to_string())
151    }
152
153    /// Create a client for a sub-module of the connected application.
154    pub fn module<ModError: Error>(
155        &self,
156        prefix: &str,
157    ) -> Result<Client<ModError, VER>, http::url::ParseError> {
158        Ok(Client {
159            inner: self.inner.clone(),
160            base_url: self.base_url.join(prefix)?,
161            accept: self.accept,
162            _marker: Default::default(),
163        })
164    }
165
166    pub fn base_url(&self) -> Url {
167        self.base_url.clone()
168    }
169}
170
171/// Interface to specify optional configuration values before creating a [Client].
172pub struct ClientBuilder<E: Error, VER: StaticVersionType> {
173    inner: reqwest::ClientBuilder,
174    accept: ContentType,
175    base_url: Url,
176    timeout: Option<Duration>,
177    _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
178}
179
180impl<E: Error, VER: StaticVersionType> ClientBuilder<E, VER> {
181    fn new(mut base_url: Url) -> Self {
182        // If the path part of `base_url` does not end in `/`, `join` will treat it as a filename
183        // and remove it, which is never what we want: `base_url` is _always_ a directory-like path.
184        // To avoid the annoyance of having every caller add a trailing slash if necessary, we will
185        // add a trailing slash here if there isn't one already.
186        if !base_url.path().ends_with('/') {
187            base_url.set_path(&format!("{}/", base_url.path()));
188        }
189        Self {
190            inner: reqwest::Client::builder(),
191            accept: ContentType::Binary,
192            base_url,
193            timeout: Some(Duration::from_secs(60)),
194            _marker: Default::default(),
195        }
196    }
197
198    /// Set connection timeout duration.
199    ///
200    /// Passing `None` will remove the timeout.
201    ///
202    /// Default: `Some(Duration::from_secs(60))`.
203    pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
204        self.timeout = timeout;
205        self
206    }
207
208    /// Set the content type used for responses.
209    pub fn content_type(mut self, content_type: ContentType) -> Self {
210        self.accept = content_type;
211        self
212    }
213
214    /// Create a [Client] with the settings specified in this builder.
215    pub fn build(self) -> Client<E, VER> {
216        let mut builder = self.inner;
217
218        if let Some(timeout) = self.timeout {
219            builder = builder.timeout(timeout);
220        }
221
222        Client {
223            inner: builder.build().unwrap(),
224            base_url: self.base_url,
225            accept: self.accept,
226            _marker: Default::default(),
227        }
228    }
229}
230
231impl<E: Error, VER: StaticVersionType> From<ClientBuilder<E, VER>> for Client<E, VER> {
232    fn from(builder: ClientBuilder<E, VER>) -> Self {
233        builder.build()
234    }
235}
236
237#[cfg(test)]
238mod test {
239    use crate::socket::Connection;
240
241    use super::*;
242    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
243    use async_std::{sync::RwLock, task::spawn};
244    use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
245    use portpicker::pick_unused_port;
246    use serde::{Deserialize, Serialize};
247    use tide_disco::{error::ServerError, App};
248    use toml::toml;
249    use vbs::version::StaticVersion;
250    type Ver01 = StaticVersion<0, 1>;
251    const VER_0_1: Ver01 = StaticVersion {};
252
253    async fn test_basic_http_client(accept: ContentType) {
254        setup_logging();
255        setup_backtrace();
256
257        // Set up a simple Tide Disco app as an example.
258        let mut app: App<(), ServerError> = App::with_state(());
259        let api = toml! {
260            [route.get]
261            PATH = ["/get"]
262            METHOD = "GET"
263
264            [route.post]
265            PATH = ["/post"]
266            METHOD = "POST"
267        };
268        app.module::<ServerError, Ver01>("mod", api)
269            .unwrap()
270            .get("get", |_req, _state| async move { Ok("response") }.boxed())
271            .unwrap()
272            .post("post", |req, _state| {
273                async move {
274                    if req.body_auto::<String, _>(VER_0_1).unwrap() == "body" {
275                        Ok("response")
276                    } else {
277                        Err(ServerError::catch_all(
278                            StatusCode::BAD_REQUEST,
279                            "invalid body".into(),
280                        ))
281                    }
282                }
283                .boxed()
284            })
285            .unwrap();
286        let port = pick_unused_port().unwrap();
287        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
288
289        // Connect a client.
290        let client = Client::<ServerError, Ver01>::builder(
291            format!("http://localhost:{}", port).parse().unwrap(),
292        )
293        .content_type(accept)
294        .build();
295        assert!(client.connect(None).await);
296
297        // Test a couple of basic requests.
298        assert_eq!(
299            client.get::<String>("mod/get").send().await.unwrap(),
300            "response"
301        );
302        assert_eq!(
303            client
304                .post::<String>("mod/post")
305                .body_json(&"body".to_string())
306                .unwrap()
307                .send()
308                .await
309                .unwrap(),
310            "response"
311        );
312
313        // Test an error response.
314        let err = client
315            .post::<String>("mod/post")
316            .body_json(&"bad".to_string())
317            .unwrap()
318            .send()
319            .await
320            .unwrap_err();
321        if err.status != StatusCode::BAD_REQUEST || err.message != "invalid body" {
322            panic!("unexpected error {}", err);
323        }
324    }
325
326    #[async_std::test]
327    async fn test_basic_http_client_json() {
328        test_basic_http_client(ContentType::Json).await
329    }
330
331    #[async_std::test]
332    async fn test_basic_http_client_binary() {
333        test_basic_http_client(ContentType::Binary).await
334    }
335
336    async fn test_streaming_client(accept: ContentType) {
337        setup_logging();
338        setup_backtrace();
339
340        // Set up a simple Tide Disco app as an example.
341        let mut app: App<(), ServerError> = App::with_state(());
342        let api = toml! {
343            [route.echo]
344            PATH = ["/echo"]
345            METHOD = "SOCKET"
346
347            [route.naturals]
348            PATH = ["/naturals/:max"]
349            METHOD = "SOCKET"
350            ":max" = "Integer"
351        };
352        app.module::<ServerError, Ver01>("mod", api)
353            .unwrap()
354            .socket::<_, String, String>("echo", |_req, mut conn, _state| {
355                async move {
356                    while let Some(Ok(msg)) = conn.next().await {
357                        conn.send(&msg).await.unwrap();
358                    }
359                    Ok(())
360                }
361                .boxed()
362            })
363            .unwrap()
364            .stream("naturals", |req, _state| {
365                iter(0u64..req.integer_param("max").unwrap())
366                    .map(Ok)
367                    .boxed()
368            })
369            .unwrap();
370        let port = pick_unused_port().unwrap();
371        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
372
373        // Connect a client.
374        let client: Client<ServerError, _> =
375            Client::builder(format!("http://localhost:{}", port).parse().unwrap())
376                .content_type(accept)
377                .build();
378        assert!(client.connect(None).await);
379
380        // Test a bidirectional endpoint.
381        let mut conn: Connection<_, _, _, Ver01> = client
382            .socket("mod/echo")
383            .connect::<String, String>()
384            .await
385            .unwrap();
386        conn.send(&"foo".into()).await.unwrap();
387        assert_eq!(conn.next().await.unwrap().unwrap(), "foo");
388        conn.send(&"bar".into()).await.unwrap();
389        assert_eq!(conn.next().await.unwrap().unwrap(), "bar");
390
391        // Test a streaming endpoint.
392        assert_eq!(
393            client
394                .socket("mod/naturals/10")
395                .subscribe::<u64>()
396                .await
397                .unwrap()
398                .collect::<Vec<_>>()
399                .await,
400            (0..10).map(Ok).collect::<Vec<_>>()
401        );
402    }
403
404    #[async_std::test]
405    async fn test_streaming_client_json() {
406        test_streaming_client(ContentType::Json).await
407    }
408
409    #[async_std::test]
410    async fn test_streaming_client_binary() {
411        test_streaming_client(ContentType::Binary).await
412    }
413
414    #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
415    enum HealthCheck {
416        Ready,
417        Initializing,
418    }
419
420    impl super::HealthCheck for HealthCheck {
421        fn status(&self) -> StatusCode {
422            StatusCode::OK
423        }
424    }
425
426    #[async_std::test]
427    async fn test_healthcheck() {
428        setup_logging();
429        setup_backtrace();
430
431        // Set up a simple Tide Disco app as an example.
432        let mut app: App<_, ServerError> = App::with_state(RwLock::new(HealthCheck::Initializing));
433        let api = toml! {
434            [route.init]
435            PATH = ["/init"]
436            METHOD = "POST"
437        };
438        app.module::<ServerError, Ver01>("mod", api)
439            .unwrap()
440            .with_health_check(|state| async move { *state.read().await }.boxed())
441            .post("init", |_, state| {
442                async move {
443                    *state = HealthCheck::Ready;
444                    Ok(())
445                }
446                .boxed()
447            })
448            .unwrap();
449        let port = pick_unused_port().unwrap();
450        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
451
452        // Connect a client.
453        let client = Client::<ServerError, Ver01>::new(
454            format!("http://localhost:{}/mod", port).parse().unwrap(),
455        );
456        assert!(client.connect(None).await);
457        assert_eq!(
458            HealthCheck::Initializing,
459            client.healthcheck().await.unwrap()
460        );
461
462        // Waiting for [HealthCheck::Ready] should time out.
463        assert_eq!(
464            client
465                .wait_for_health::<HealthCheck>(
466                    |h| *h == HealthCheck::Ready,
467                    Some(Duration::from_secs(1))
468                )
469                .await,
470            None
471        );
472
473        // Initialize the service.
474        client.post::<()>("init").send().await.unwrap();
475
476        // Now waiting for [HealthCheck::Ready] should succeed.
477        assert_eq!(
478            client
479                .wait_for_health::<HealthCheck>(|h| *h == HealthCheck::Ready, None)
480                .await,
481            Some(HealthCheck::Ready)
482        );
483        assert_eq!(HealthCheck::Ready, client.healthcheck().await.unwrap());
484    }
485
486    #[test]
487    fn test_builder() {
488        let client =
489            Client::<ServerError, Ver01>::builder("http://www.example.com".parse().unwrap())
490                .set_timeout(None)
491                .build();
492        assert_eq!(client.base_url, "http://www.example.com".parse().unwrap());
493    }
494}