surf_disco/
client.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the surf-disco library.
3
4// You should have received a copy of the MIT License
5// along with the surf-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    http, request::reqwest_error_msg, Error, Method, Request, SocketRequest, StatusCode, Url,
9};
10use async_std::task::sleep;
11use derivative::Derivative;
12use serde::de::DeserializeOwned;
13use std::time::{Duration, Instant};
14use vbs::version::StaticVersionType;
15
16pub use tide_disco::healthcheck::{HealthCheck, HealthStatus};
17
18/// Content types supported by Tide Disco.
19#[derive(Clone, Copy, Debug)]
20pub enum ContentType {
21    Json,
22    Binary,
23}
24
25impl From<ContentType> for http::Mime {
26    fn from(c: ContentType) -> http::Mime {
27        match c {
28            ContentType::Json => http::mime::JSON,
29            ContentType::Binary => http::mime::BYTE_STREAM,
30        }
31    }
32}
33
34/// A client of a Tide Disco application.
35#[derive(Derivative)]
36#[derivative(Clone(bound = ""), Debug(bound = ""))]
37pub struct Client<E, VER: StaticVersionType> {
38    inner: reqwest::Client,
39    base_url: Url,
40    accept: ContentType,
41    _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
42}
43
44impl<E: Error, VER: StaticVersionType> Client<E, VER> {
45    /// Create a client and connect to the Tide Disco server at `base_url`.
46    pub fn new(base_url: Url) -> Self {
47        Self::builder(base_url).build()
48    }
49
50    /// Create a client with customization.
51    pub fn builder(base_url: Url) -> ClientBuilder<E, VER> {
52        ClientBuilder::<E, VER>::new(base_url)
53    }
54
55    /// Connect to the server, retrying if the server is not running.
56    ///
57    /// It is not necessary to call this function when creating a new client. The client will
58    /// automatically connect when a request is made, if the server is available. However, this can
59    /// be useful to wait for the server to come up, if the server may be offline when the client is
60    /// created.
61    ///
62    /// This function will make an HTTP `GET` request to the server's `/healthcheck` endpoint, to
63    /// test if the server is available. If this request succeeds, [connect](Self::connect) returns
64    /// `true`. Otherwise, the client will continue retrying `/healthcheck` requests until `timeout`
65    /// has elapsed (or forever, if `timeout` is `None`). If the timeout expires before a
66    /// `/healthcheck` request succeeds, [connect](Self::connect) will return `false`.
67    pub async fn connect(&self, timeout: Option<Duration>) -> bool {
68        let timeout = timeout.map(|d| Instant::now() + d);
69        while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
70            match self
71                .inner
72                .get(self.base_url.join("/healthcheck").unwrap())
73                .send()
74                .await
75            {
76                Ok(res) if res.status() == StatusCode::OK => return true,
77                Ok(res) => {
78                    tracing::info!(
79                        url = %self.base_url,
80                        status = %res.status(),
81                        "waiting for server to become ready",
82                    );
83                }
84                Err(err) => {
85                    tracing::info!(
86                        url = %self.base_url,
87                        err = reqwest_error_msg(err),
88                        "waiting for server to become ready",
89                    );
90                }
91            }
92            sleep(Duration::from_secs(10)).await;
93        }
94        false
95    }
96
97    /// Connect to the server, retrying until the server is `healthy`.
98    ///
99    /// This function is similar to [connect](Self::connect). It will make requests to the
100    /// `/healthcheck` endpoint until a request succeeds. However, it will then continue retrying
101    /// until the response from `/healthcheck` satisfies the `healthy` predicate.
102    ///
103    /// On success, returns the response from `/healthcheck`. On timeout, returns `None`.
104    pub async fn wait_for_health<H: DeserializeOwned + HealthCheck>(
105        &self,
106        healthy: impl Fn(&H) -> bool,
107        timeout: Option<Duration>,
108    ) -> Option<H> {
109        let timeout = timeout.map(|d| Instant::now() + d);
110        while timeout.map(|t| Instant::now() < t).unwrap_or(true) {
111            match self.healthcheck::<H>().await {
112                Ok(health) if healthy(&health) => return Some(health),
113                _ => sleep(Duration::from_secs(10)).await,
114            }
115        }
116        None
117    }
118
119    /// Build an HTTP `GET` request.
120    pub fn get<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
121        self.request(Method::Get, route)
122    }
123
124    /// Build an HTTP `POST` request.
125    pub fn post<T: DeserializeOwned>(&self, route: &str) -> Request<T, E, VER> {
126        self.request(Method::Post, route)
127    }
128
129    /// Query the server's healthcheck endpoint.
130    pub async fn healthcheck<H: DeserializeOwned + HealthCheck>(&self) -> Result<H, E> {
131        self.get("healthcheck").send().await
132    }
133
134    /// Build an HTTP request with the specified method.
135    pub fn request<T: DeserializeOwned>(&self, method: Method, route: &str) -> Request<T, E, VER> {
136        Request::from(self.inner.request(
137            method.to_string().parse().unwrap(),
138            self.base_url.join(route).unwrap(),
139        ))
140        .header("Accept", http::Mime::from(self.accept).to_string())
141    }
142
143    /// Build a streaming connection request.
144    ///
145    /// # Panics
146    ///
147    /// This will panic if a malformed URL is passed.
148    pub fn socket(&self, route: &str) -> SocketRequest<E, VER> {
149        SocketRequest::new(self.base_url.join(route).unwrap(), self.accept)
150            .header("Accept", http::Mime::from(self.accept).to_string())
151    }
152
153    /// Create a client for a sub-module of the connected application.
154    pub fn module<ModError: Error>(
155        &self,
156        prefix: &str,
157    ) -> Result<Client<ModError, VER>, http::url::ParseError> {
158        Ok(Client {
159            inner: self.inner.clone(),
160            base_url: self.base_url.join(prefix)?,
161            accept: self.accept,
162            _marker: Default::default(),
163        })
164    }
165}
166
167/// Interface to specify optional configuration values before creating a [Client].
168pub struct ClientBuilder<E: Error, VER: StaticVersionType> {
169    inner: reqwest::ClientBuilder,
170    accept: ContentType,
171    base_url: Url,
172    timeout: Option<Duration>,
173    _marker: std::marker::PhantomData<fn(E, VER) -> ()>,
174}
175
176impl<E: Error, VER: StaticVersionType> ClientBuilder<E, VER> {
177    fn new(mut base_url: Url) -> Self {
178        // If the path part of `base_url` does not end in `/`, `join` will treat it as a filename
179        // and remove it, which is never what we want: `base_url` is _always_ a directory-like path.
180        // To avoid the annoyance of having every caller add a trailing slash if necessary, we will
181        // add a trailing slash here if there isn't one already.
182        if !base_url.path().ends_with('/') {
183            base_url.set_path(&format!("{}/", base_url.path()));
184        }
185        Self {
186            inner: reqwest::Client::builder(),
187            accept: ContentType::Binary,
188            base_url,
189            timeout: Some(Duration::from_secs(60)),
190            _marker: Default::default(),
191        }
192    }
193
194    /// Set connection timeout duration.
195    ///
196    /// Passing `None` will remove the timeout.
197    ///
198    /// Default: `Some(Duration::from_secs(60))`.
199    pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
200        self.timeout = timeout;
201        self
202    }
203
204    /// Set the content type used for responses.
205    pub fn content_type(mut self, content_type: ContentType) -> Self {
206        self.accept = content_type;
207        self
208    }
209
210    /// Create a [Client] with the settings specified in this builder.
211    pub fn build(self) -> Client<E, VER> {
212        let mut builder = self.inner;
213
214        if let Some(timeout) = self.timeout {
215            builder = builder.timeout(timeout);
216        }
217
218        Client {
219            inner: builder.build().unwrap(),
220            base_url: self.base_url,
221            accept: self.accept,
222            _marker: Default::default(),
223        }
224    }
225}
226
227impl<E: Error, VER: StaticVersionType> From<ClientBuilder<E, VER>> for Client<E, VER> {
228    fn from(builder: ClientBuilder<E, VER>) -> Self {
229        builder.build()
230    }
231}
232
233#[cfg(test)]
234mod test {
235    use crate::socket::Connection;
236
237    use super::*;
238    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
239    use async_std::{sync::RwLock, task::spawn};
240    use futures::{stream::iter, FutureExt, SinkExt, StreamExt};
241    use portpicker::pick_unused_port;
242    use serde::{Deserialize, Serialize};
243    use tide_disco::{error::ServerError, App};
244    use toml::toml;
245    use vbs::version::StaticVersion;
246    type Ver01 = StaticVersion<0, 1>;
247    const VER_0_1: Ver01 = StaticVersion {};
248
249    async fn test_basic_http_client(accept: ContentType) {
250        setup_logging();
251        setup_backtrace();
252
253        // Set up a simple Tide Disco app as an example.
254        let mut app: App<(), ServerError> = App::with_state(());
255        let api = toml! {
256            [route.get]
257            PATH = ["/get"]
258            METHOD = "GET"
259
260            [route.post]
261            PATH = ["/post"]
262            METHOD = "POST"
263        };
264        app.module::<ServerError, Ver01>("mod", api)
265            .unwrap()
266            .get("get", |_req, _state| async move { Ok("response") }.boxed())
267            .unwrap()
268            .post("post", |req, _state| {
269                async move {
270                    if req.body_auto::<String, _>(VER_0_1).unwrap() == "body" {
271                        Ok("response")
272                    } else {
273                        Err(ServerError::catch_all(
274                            StatusCode::BAD_REQUEST,
275                            "invalid body".into(),
276                        ))
277                    }
278                }
279                .boxed()
280            })
281            .unwrap();
282        let port = pick_unused_port().unwrap();
283        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
284
285        // Connect a client.
286        let client = Client::<ServerError, Ver01>::builder(
287            format!("http://localhost:{}", port).parse().unwrap(),
288        )
289        .content_type(accept)
290        .build();
291        assert!(client.connect(None).await);
292
293        // Test a couple of basic requests.
294        assert_eq!(
295            client.get::<String>("mod/get").send().await.unwrap(),
296            "response"
297        );
298        assert_eq!(
299            client
300                .post::<String>("mod/post")
301                .body_json(&"body".to_string())
302                .unwrap()
303                .send()
304                .await
305                .unwrap(),
306            "response"
307        );
308
309        // Test an error response.
310        let err = client
311            .post::<String>("mod/post")
312            .body_json(&"bad".to_string())
313            .unwrap()
314            .send()
315            .await
316            .unwrap_err();
317        if err.status != StatusCode::BAD_REQUEST || err.message != "invalid body" {
318            panic!("unexpected error {}", err);
319        }
320    }
321
322    #[async_std::test]
323    async fn test_basic_http_client_json() {
324        test_basic_http_client(ContentType::Json).await
325    }
326
327    #[async_std::test]
328    async fn test_basic_http_client_binary() {
329        test_basic_http_client(ContentType::Binary).await
330    }
331
332    async fn test_streaming_client(accept: ContentType) {
333        setup_logging();
334        setup_backtrace();
335
336        // Set up a simple Tide Disco app as an example.
337        let mut app: App<(), ServerError> = App::with_state(());
338        let api = toml! {
339            [route.echo]
340            PATH = ["/echo"]
341            METHOD = "SOCKET"
342
343            [route.naturals]
344            PATH = ["/naturals/:max"]
345            METHOD = "SOCKET"
346            ":max" = "Integer"
347        };
348        app.module::<ServerError, Ver01>("mod", api)
349            .unwrap()
350            .socket::<_, String, String>("echo", |_req, mut conn, _state| {
351                async move {
352                    while let Some(Ok(msg)) = conn.next().await {
353                        conn.send(&msg).await.unwrap();
354                    }
355                    Ok(())
356                }
357                .boxed()
358            })
359            .unwrap()
360            .stream("naturals", |req, _state| {
361                iter(0u64..req.integer_param("max").unwrap())
362                    .map(Ok)
363                    .boxed()
364            })
365            .unwrap();
366        let port = pick_unused_port().unwrap();
367        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
368
369        // Connect a client.
370        let client: Client<ServerError, _> =
371            Client::builder(format!("http://localhost:{}", port).parse().unwrap())
372                .content_type(accept)
373                .build();
374        assert!(client.connect(None).await);
375
376        // Test a bidirectional endpoint.
377        let mut conn: Connection<_, _, _, Ver01> = client
378            .socket("mod/echo")
379            .connect::<String, String>()
380            .await
381            .unwrap();
382        conn.send(&"foo".into()).await.unwrap();
383        assert_eq!(conn.next().await.unwrap().unwrap(), "foo");
384        conn.send(&"bar".into()).await.unwrap();
385        assert_eq!(conn.next().await.unwrap().unwrap(), "bar");
386
387        // Test a streaming endpoint.
388        assert_eq!(
389            client
390                .socket("mod/naturals/10")
391                .subscribe::<u64>()
392                .await
393                .unwrap()
394                .collect::<Vec<_>>()
395                .await,
396            (0..10).map(Ok).collect::<Vec<_>>()
397        );
398    }
399
400    #[async_std::test]
401    async fn test_streaming_client_json() {
402        test_streaming_client(ContentType::Json).await
403    }
404
405    #[async_std::test]
406    async fn test_streaming_client_binary() {
407        test_streaming_client(ContentType::Binary).await
408    }
409
410    #[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize)]
411    enum HealthCheck {
412        Ready,
413        Initializing,
414    }
415
416    impl super::HealthCheck for HealthCheck {
417        fn status(&self) -> StatusCode {
418            StatusCode::OK
419        }
420    }
421
422    #[async_std::test]
423    async fn test_healthcheck() {
424        setup_logging();
425        setup_backtrace();
426
427        // Set up a simple Tide Disco app as an example.
428        let mut app: App<_, ServerError> = App::with_state(RwLock::new(HealthCheck::Initializing));
429        let api = toml! {
430            [route.init]
431            PATH = ["/init"]
432            METHOD = "POST"
433        };
434        app.module::<ServerError, Ver01>("mod", api)
435            .unwrap()
436            .with_health_check(|state| async move { *state.read().await }.boxed())
437            .post("init", |_, state| {
438                async move {
439                    *state = HealthCheck::Ready;
440                    Ok(())
441                }
442                .boxed()
443            })
444            .unwrap();
445        let port = pick_unused_port().unwrap();
446        spawn(app.serve(format!("0.0.0.0:{}", port), VER_0_1));
447
448        // Connect a client.
449        let client = Client::<ServerError, Ver01>::new(
450            format!("http://localhost:{}/mod", port).parse().unwrap(),
451        );
452        assert!(client.connect(None).await);
453        assert_eq!(
454            HealthCheck::Initializing,
455            client.healthcheck().await.unwrap()
456        );
457
458        // Waiting for [HealthCheck::Ready] should time out.
459        assert_eq!(
460            client
461                .wait_for_health::<HealthCheck>(
462                    |h| *h == HealthCheck::Ready,
463                    Some(Duration::from_secs(1))
464                )
465                .await,
466            None
467        );
468
469        // Initialize the service.
470        client.post::<()>("init").send().await.unwrap();
471
472        // Now waiting for [HealthCheck::Ready] should succeed.
473        assert_eq!(
474            client
475                .wait_for_health::<HealthCheck>(|h| *h == HealthCheck::Ready, None)
476                .await,
477            Some(HealthCheck::Ready)
478        );
479        assert_eq!(HealthCheck::Ready, client.healthcheck().await.unwrap());
480    }
481
482    #[test]
483    fn test_builder() {
484        let client =
485            Client::<ServerError, Ver01>::builder("http://www.example.com".parse().unwrap())
486                .set_timeout(None)
487                .build();
488        assert_eq!(client.base_url, "http://www.example.com".parse().unwrap());
489    }
490}