surf_disco/
socket.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the surf-disco library.
3
4// You should have received a copy of the MIT License
5// along with the surf-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    http::headers::{HeaderName, ToHeaderValues},
9    ContentType, Error, StatusCode, Url,
10};
11use async_tungstenite::{
12    async_std::{connect_async_with_config, ConnectStream},
13    tungstenite::{
14        http::request::Builder as RequestBuilder, protocol::WebSocketConfig, Error as WsError,
15        Message,
16    },
17    WebSocketStream,
18};
19use futures::{
20    task::{Context, Poll},
21    Sink, Stream,
22};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use std::{collections::HashMap, pin::Pin};
25use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
26
27#[must_use]
28#[derive(Debug)]
29pub struct SocketRequest<E, VER: StaticVersionType> {
30    url: Url,
31    content_type: ContentType,
32    headers: HashMap<String, Vec<String>>,
33    config: Option<WebSocketConfig>,
34    marker: std::marker::PhantomData<fn(E, VER) -> ()>,
35}
36
37impl<E: Error, VER: StaticVersionType> SocketRequest<E, VER> {
38    pub(crate) fn new(
39        mut url: Url,
40        content_type: ContentType,
41        config: Option<WebSocketConfig>,
42    ) -> Self {
43        url.set_scheme(&socket_scheme(url.scheme())).unwrap();
44        Self {
45            url,
46            content_type,
47            headers: Default::default(),
48            config,
49            marker: Default::default(),
50        }
51    }
52
53    /// Set a header on the request.
54    pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
55        let name = key.into().to_string();
56        for value in values.to_header_values().unwrap() {
57            self.headers
58                .entry(name.clone())
59                .or_default()
60                .push(value.to_string());
61        }
62        self
63    }
64
65    /// Start the WebSocket handshake and initiate a connection to the server.
66    pub async fn connect<FromServer: DeserializeOwned, ToServer: Serialize + ?Sized>(
67        mut self,
68    ) -> Result<Connection<FromServer, ToServer, E, VER>, E> {
69        // Follow redirects.
70        loop {
71            let mut req = RequestBuilder::new().uri(self.url.to_string());
72            for (key, values) in &self.headers {
73                for value in values {
74                    req = req.header(key, value);
75                }
76            }
77            let req = req
78                .body(())
79                .map_err(|err| E::catch_all(StatusCode::BAD_REQUEST, err.to_string()))?;
80
81            let err = match connect_async_with_config(req, self.config).await {
82                Ok((conn, _)) => return Ok(Connection::new(conn, self.content_type)),
83                Err(err) => err,
84            };
85            if let WsError::Http(res) = &err {
86                if (301..=308).contains(&u16::from(res.status())) {
87                    if let Some(location) = res
88                        .headers()
89                        .get("location")
90                        .and_then(|header| header.to_str().ok())
91                    {
92                        tracing::info!(from = %self.url, to = %location, "WS handshake following redirect");
93                        self.url.set_path(location);
94                        continue;
95                    }
96                }
97            }
98            return Err(E::catch_all(StatusCode::BAD_REQUEST, err.to_string()));
99        }
100    }
101
102    /// Initiate a unidirectional connection to the server.
103    ///
104    /// This is equivalent to `self.connect()` with the `ToServer` message type replaced by
105    /// [Unsupported], so that you don't have to specify the type parameter if it isn't used.
106    pub async fn subscribe<FromServer: DeserializeOwned>(
107        self,
108    ) -> Result<Connection<FromServer, Unsupported, E, VER>, E> {
109        self.connect().await
110    }
111}
112
113/// A bi-directional connection to a WebSocket server.
114pub struct Connection<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> {
115    inner: WebSocketStream<ConnectStream>,
116    content_type: ContentType,
117    #[allow(clippy::type_complexity)]
118    marker: std::marker::PhantomData<fn(FromServer, ToServer, E, VER) -> ()>,
119}
120
121impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
122    Connection<FromServer, ToServer, E, VER>
123{
124    fn new(inner: WebSocketStream<ConnectStream>, content_type: ContentType) -> Self {
125        Self {
126            inner,
127            content_type,
128            marker: Default::default(),
129        }
130    }
131}
132
133impl<FromServer: DeserializeOwned, ToServer: ?Sized, E: Error, VER: StaticVersionType> Stream
134    for Connection<FromServer, ToServer, E, VER>
135{
136    type Item = Result<FromServer, E>;
137
138    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        // Get a `Pin<&mut WebSocketStream>` for the underlying connection, so we can use the
140        // `Stream` implementation of that field.
141        match self.pinned_inner().poll_next(cx) {
142            Poll::Ready(None) => Poll::Ready(None),
143            Poll::Ready(Some(Err(err))) => match err {
144                WsError::ConnectionClosed | WsError::AlreadyClosed => Poll::Ready(None),
145                err => Poll::Ready(Some(Err(E::catch_all(
146                    StatusCode::INTERNAL_SERVER_ERROR,
147                    err.to_string(),
148                )))),
149            },
150            Poll::Ready(Some(Ok(msg))) => Poll::Ready(match msg {
151                Message::Binary(bytes) => {
152                    Some(Serializer::<VER>::deserialize(&bytes).map_err(|err| {
153                        E::catch_all(
154                            StatusCode::INTERNAL_SERVER_ERROR,
155                            format!("invalid binary: {}\n{bytes:?}", err),
156                        )
157                    }))
158                }
159                Message::Text(s) => Some(serde_json::from_str(&s).map_err(|err| {
160                    E::catch_all(
161                        StatusCode::INTERNAL_SERVER_ERROR,
162                        format!("invalid JSON: {}\n{s}", err),
163                    )
164                })),
165                Message::Close(_) => None,
166                _ => Some(Err(E::catch_all(
167                    StatusCode::UNSUPPORTED_MEDIA_TYPE,
168                    "unsupported WebSocket message".into(),
169                ))),
170            }),
171            Poll::Pending => Poll::Pending,
172        }
173    }
174}
175
176impl<FromServer, ToServer: Serialize + ?Sized, E: Error, VER: StaticVersionType> Sink<&ToServer>
177    for Connection<FromServer, ToServer, E, VER>
178{
179    type Error = E;
180
181    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182        self.pinned_inner().poll_ready(cx).map_err(|err| {
183            E::catch_all(
184                StatusCode::INTERNAL_SERVER_ERROR,
185                format!("error in WebSocket connection: {}", err),
186            )
187        })
188    }
189
190    fn start_send(self: Pin<&mut Self>, item: &ToServer) -> Result<(), Self::Error> {
191        let msg = match self.content_type {
192            ContentType::Binary => {
193                Message::Binary(Serializer::<VER>::serialize(item).map_err(|err| {
194                    E::catch_all(
195                        StatusCode::BAD_REQUEST,
196                        format!("invalid binary serialization: {}", err),
197                    )
198                })?)
199            }
200            ContentType::Json => Message::Text(serde_json::to_string(item).map_err(|err| {
201                E::catch_all(
202                    StatusCode::BAD_REQUEST,
203                    format!("invalid JSON serialization: {}", err),
204                )
205            })?),
206        };
207        self.pinned_inner().start_send(msg).map_err(|err| {
208            E::catch_all(
209                StatusCode::INTERNAL_SERVER_ERROR,
210                format!("error sending WebSocket message: {}", err),
211            )
212        })
213    }
214
215    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        self.pinned_inner().poll_flush(cx).map_err(|err| {
217            E::catch_all(
218                StatusCode::INTERNAL_SERVER_ERROR,
219                format!("error in WebSocket connection: {}", err),
220            )
221        })
222    }
223
224    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225        self.pinned_inner().poll_close(cx).map_err(|err| {
226            E::catch_all(
227                StatusCode::INTERNAL_SERVER_ERROR,
228                format!("error in WebSocket connection: {}", err),
229            )
230        })
231    }
232}
233
234impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
235    Connection<FromServer, ToServer, E, VER>
236{
237    /// Project a `Pin<&mut Self>` to a pinned reference to the underlying connection.
238    fn pinned_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<ConnectStream>> {
239        // # Soundness
240        //
241        // This implements _structural pinning_ for [Connection]. This comes with some requirements
242        // to maintain safety, as described at
243        // https://doc.rust-lang.org/std/pin/index.html#pinning-is-structural-for-field:
244        //
245        // 1. The struct must only be [Unpin] if all the structural fields are [Unpin]. This is the
246        //    default, and we don't explicitly implement [Unpin] for [Connection].
247        // 2. The destructor of the struct must not move structural fields out of its argument. This
248        //    is enforced by the compiler in our [Drop] implementation, which follows the idiom for
249        //    safe [Drop] implementations for pinned structs.
250        // 3. You must make sure that you uphold the [Drop] guarantee: once your struct is pinned,
251        //    the memory that contains the content is not overwritten or deallocated without calling
252        //    the content’s destructors. This is also enforced by our [Drop] implementation.
253        // 4. You must not offer any other operations that could lead to data being moved out of the
254        //    structural fields when your type is pinned. There are no operations on this type that
255        //    move out of `inner`.
256        unsafe { self.map_unchecked_mut(|s| &mut s.inner) }
257    }
258}
259
260impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> Drop
261    for Connection<FromServer, ToServer, E, VER>
262{
263    fn drop(&mut self) {
264        // This is the idiomatic way to implement [drop] for a type that uses pinning. Since [drop]
265        // is implicitly called with `&mut self` even on types that were pinned, we place any
266        // implementation inside [inner_drop], which takes `Pin<&mut Self>`, when the commpiler will
267        // be able to check that we do not do anything that we couldn't have done on a
268        // `Pin<&mut Self>`.
269        //
270        // The [drop] implementation for this type is trivial, and it would be safe to use the
271        // automatically generated [drop] implementation, but we nonetheless implement [drop]
272        // explicitly in the idiomatic fashion so that it is impossible to accidentally implement an
273        // unsafe version of [drop] for this type in the future.
274
275        // `new_unchecked` is okay because we know this value is never used again after being
276        // dropped.
277        inner_drop(unsafe { Pin::new_unchecked(self) });
278        fn inner_drop<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>(
279            _this: Pin<&mut Connection<FromServer, ToServer, E, VER>>,
280        ) {
281            // Any logic goes here.
282        }
283    }
284}
285
286/// Unconstructable enum used to disable the [Sink] functionality of [Connection].
287#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
288pub enum Unsupported {}
289
290/// Get the scheme for a WebSockets connection upgraded from an existing stateless connection.
291///
292/// `scheme` is the scheme of the stateless connection, e.g. HTTP or HTTPS. If the scheme has a
293/// known WebSockets counterpart, e.g. WS or WSS, we return it. Otherwise we trust the user knows
294/// what they're doing and return `scheme` unmodified.
295fn socket_scheme(scheme: &str) -> String {
296    match scheme {
297        "http" => "ws",
298        "https" => "wss",
299        _ => scheme,
300    }
301    .to_string()
302}
303
304#[cfg(test)]
305mod test {
306    use super::*;
307    use crate::{Client, ContentType};
308    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
309    use async_std::task::spawn;
310    use futures::stream::{repeat, StreamExt};
311    use portpicker::pick_unused_port;
312    use tide_disco::{error::ServerError, App};
313    use toml::toml;
314    use vbs::version::StaticVersion;
315
316    type Ver01 = StaticVersion<0, 1>;
317    const VER_0_1: Ver01 = StaticVersion {};
318
319    #[async_std::test]
320    async fn test_socket_accept() {
321        setup_logging();
322        setup_backtrace();
323
324        // Set up a simple Tide Disco app.
325        let mut app: App<(), ServerError> = App::with_state(());
326        let api = toml! {
327            [route.subscribe]
328            PATH = ["/subscribe"]
329            METHOD = "SOCKET"
330        };
331        app.module::<ServerError, Ver01>("mod", api)
332            .unwrap()
333            .stream("subscribe", |_req, _state| {
334                repeat("response").map(Ok).boxed()
335            })
336            .unwrap();
337        let port = pick_unused_port().unwrap();
338        spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
339
340        // Connect one client with each supported content type.
341        let json_client = Client::<ServerError, Ver01>::builder(
342            format!("http://localhost:{port}").parse().unwrap(),
343        )
344        .content_type(ContentType::Json)
345        .build();
346        assert!(json_client.connect(None).await);
347
348        let bin_client = Client::<ServerError, Ver01>::builder(
349            format!("http://localhost:{port}").parse().unwrap(),
350        )
351        .content_type(ContentType::Binary)
352        .build();
353        assert!(bin_client.connect(None).await);
354
355        // Check that connections built with each client get messages in the desired content type.
356        let mut conn = json_client
357            .socket("mod/subscribe")
358            .subscribe::<String>()
359            .await
360            .unwrap();
361        let Message::Text(msg) = conn.inner.next().await.unwrap().unwrap() else {
362            panic!("unexpected content type");
363        };
364        assert_eq!(serde_json::from_str::<String>(&msg).unwrap(), "response");
365
366        let mut conn = bin_client
367            .socket("mod/subscribe")
368            .subscribe::<String>()
369            .await
370            .unwrap();
371        let Message::Binary(msg) = conn.inner.next().await.unwrap().unwrap() else {
372            panic!("unexpected content type");
373        };
374        assert_eq!(
375            Serializer::<Ver01>::deserialize::<String>(&msg).unwrap(),
376            "response"
377        );
378    }
379}