surf_disco/
socket.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the surf-disco library.
3
4// You should have received a copy of the MIT License
5// along with the surf-disco library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    http::headers::{HeaderName, ToHeaderValues},
9    ContentType, Error, StatusCode, Url,
10};
11use async_tungstenite::{
12    async_std::{connect_async, ConnectStream},
13    tungstenite::{http::request::Builder as RequestBuilder, Error as WsError, Message},
14    WebSocketStream,
15};
16use futures::{
17    task::{Context, Poll},
18    Sink, Stream,
19};
20use serde::{de::DeserializeOwned, Deserialize, Serialize};
21use std::{collections::HashMap, pin::Pin};
22use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
23
24#[must_use]
25#[derive(Debug)]
26pub struct SocketRequest<E, VER: StaticVersionType> {
27    url: Url,
28    content_type: ContentType,
29    headers: HashMap<String, Vec<String>>,
30    marker: std::marker::PhantomData<fn(E, VER) -> ()>,
31}
32
33impl<E: Error, VER: StaticVersionType> SocketRequest<E, VER> {
34    pub(crate) fn new(mut url: Url, content_type: ContentType) -> Self {
35        url.set_scheme(&socket_scheme(url.scheme())).unwrap();
36        Self {
37            url,
38            content_type,
39            headers: Default::default(),
40            marker: Default::default(),
41        }
42    }
43
44    /// Set a header on the request.
45    pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
46        let name = key.into().to_string();
47        for value in values.to_header_values().unwrap() {
48            self.headers
49                .entry(name.clone())
50                .or_default()
51                .push(value.to_string());
52        }
53        self
54    }
55
56    /// Start the WebSocket handshake and initiate a connection to the server.
57    pub async fn connect<FromServer: DeserializeOwned, ToServer: Serialize + ?Sized>(
58        mut self,
59    ) -> Result<Connection<FromServer, ToServer, E, VER>, E> {
60        // Follow redirects.
61        loop {
62            let mut req = RequestBuilder::new().uri(self.url.to_string());
63            for (key, values) in &self.headers {
64                for value in values {
65                    req = req.header(key, value);
66                }
67            }
68            let req = req
69                .body(())
70                .map_err(|err| E::catch_all(StatusCode::BAD_REQUEST, err.to_string()))?;
71
72            let err = match connect_async(req).await {
73                Ok((conn, _)) => return Ok(Connection::new(conn, self.content_type)),
74                Err(err) => err,
75            };
76            if let WsError::Http(res) = &err {
77                if (301..=308).contains(&u16::from(res.status())) {
78                    if let Some(location) = res
79                        .headers()
80                        .get("location")
81                        .and_then(|header| header.to_str().ok())
82                    {
83                        tracing::info!(from = %self.url, to = %location, "WS handshake following redirect");
84                        self.url.set_path(location);
85                        continue;
86                    }
87                }
88            }
89            return Err(E::catch_all(StatusCode::BAD_REQUEST, err.to_string()));
90        }
91    }
92
93    /// Initiate a unidirectional connection to the server.
94    ///
95    /// This is equivalent to `self.connect()` with the `ToServer` message type replaced by
96    /// [Unsupported], so that you don't have to specify the type parameter if it isn't used.
97    pub async fn subscribe<FromServer: DeserializeOwned>(
98        self,
99    ) -> Result<Connection<FromServer, Unsupported, E, VER>, E> {
100        self.connect().await
101    }
102}
103
104/// A bi-directional connection to a WebSocket server.
105pub struct Connection<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> {
106    inner: WebSocketStream<ConnectStream>,
107    content_type: ContentType,
108    #[allow(clippy::type_complexity)]
109    marker: std::marker::PhantomData<fn(FromServer, ToServer, E, VER) -> ()>,
110}
111
112impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
113    Connection<FromServer, ToServer, E, VER>
114{
115    fn new(inner: WebSocketStream<ConnectStream>, content_type: ContentType) -> Self {
116        Self {
117            inner,
118            content_type,
119            marker: Default::default(),
120        }
121    }
122}
123
124impl<FromServer: DeserializeOwned, ToServer: ?Sized, E: Error, VER: StaticVersionType> Stream
125    for Connection<FromServer, ToServer, E, VER>
126{
127    type Item = Result<FromServer, E>;
128
129    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130        // Get a `Pin<&mut WebSocketStream>` for the underlying connection, so we can use the
131        // `Stream` implementation of that field.
132        match self.pinned_inner().poll_next(cx) {
133            Poll::Ready(None) => Poll::Ready(None),
134            Poll::Ready(Some(Err(err))) => match err {
135                WsError::ConnectionClosed | WsError::AlreadyClosed => Poll::Ready(None),
136                err => Poll::Ready(Some(Err(E::catch_all(
137                    StatusCode::INTERNAL_SERVER_ERROR,
138                    err.to_string(),
139                )))),
140            },
141            Poll::Ready(Some(Ok(msg))) => Poll::Ready(match msg {
142                Message::Binary(bytes) => {
143                    Some(Serializer::<VER>::deserialize(&bytes).map_err(|err| {
144                        E::catch_all(
145                            StatusCode::INTERNAL_SERVER_ERROR,
146                            format!("invalid binary: {}\n{bytes:?}", err),
147                        )
148                    }))
149                }
150                Message::Text(s) => Some(serde_json::from_str(&s).map_err(|err| {
151                    E::catch_all(
152                        StatusCode::INTERNAL_SERVER_ERROR,
153                        format!("invalid JSON: {}\n{s}", err),
154                    )
155                })),
156                Message::Close(_) => None,
157                _ => Some(Err(E::catch_all(
158                    StatusCode::UNSUPPORTED_MEDIA_TYPE,
159                    "unsupported WebSocket message".into(),
160                ))),
161            }),
162            Poll::Pending => Poll::Pending,
163        }
164    }
165}
166
167impl<FromServer, ToServer: Serialize + ?Sized, E: Error, VER: StaticVersionType> Sink<&ToServer>
168    for Connection<FromServer, ToServer, E, VER>
169{
170    type Error = E;
171
172    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173        self.pinned_inner().poll_ready(cx).map_err(|err| {
174            E::catch_all(
175                StatusCode::INTERNAL_SERVER_ERROR,
176                format!("error in WebSocket connection: {}", err),
177            )
178        })
179    }
180
181    fn start_send(self: Pin<&mut Self>, item: &ToServer) -> Result<(), Self::Error> {
182        let msg = match self.content_type {
183            ContentType::Binary => {
184                Message::Binary(Serializer::<VER>::serialize(item).map_err(|err| {
185                    E::catch_all(
186                        StatusCode::BAD_REQUEST,
187                        format!("invalid binary serialization: {}", err),
188                    )
189                })?)
190            }
191            ContentType::Json => Message::Text(serde_json::to_string(item).map_err(|err| {
192                E::catch_all(
193                    StatusCode::BAD_REQUEST,
194                    format!("invalid JSON serialization: {}", err),
195                )
196            })?),
197        };
198        self.pinned_inner().start_send(msg).map_err(|err| {
199            E::catch_all(
200                StatusCode::INTERNAL_SERVER_ERROR,
201                format!("error sending WebSocket message: {}", err),
202            )
203        })
204    }
205
206    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207        self.pinned_inner().poll_flush(cx).map_err(|err| {
208            E::catch_all(
209                StatusCode::INTERNAL_SERVER_ERROR,
210                format!("error in WebSocket connection: {}", err),
211            )
212        })
213    }
214
215    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        self.pinned_inner().poll_close(cx).map_err(|err| {
217            E::catch_all(
218                StatusCode::INTERNAL_SERVER_ERROR,
219                format!("error in WebSocket connection: {}", err),
220            )
221        })
222    }
223}
224
225impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
226    Connection<FromServer, ToServer, E, VER>
227{
228    /// Project a `Pin<&mut Self>` to a pinned reference to the underlying connection.
229    fn pinned_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<ConnectStream>> {
230        // # Soundness
231        //
232        // This implements _structural pinning_ for [Connection]. This comes with some requirements
233        // to maintain safety, as described at
234        // https://doc.rust-lang.org/std/pin/index.html#pinning-is-structural-for-field:
235        //
236        // 1. The struct must only be [Unpin] if all the structural fields are [Unpin]. This is the
237        //    default, and we don't explicitly implement [Unpin] for [Connection].
238        // 2. The destructor of the struct must not move structural fields out of its argument. This
239        //    is enforced by the compiler in our [Drop] implementation, which follows the idiom for
240        //    safe [Drop] implementations for pinned structs.
241        // 3. You must make sure that you uphold the [Drop] guarantee: once your struct is pinned,
242        //    the memory that contains the content is not overwritten or deallocated without calling
243        //    the content’s destructors. This is also enforced by our [Drop] implementation.
244        // 4. You must not offer any other operations that could lead to data being moved out of the
245        //    structural fields when your type is pinned. There are no operations on this type that
246        //    move out of `inner`.
247        unsafe { self.map_unchecked_mut(|s| &mut s.inner) }
248    }
249}
250
251impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> Drop
252    for Connection<FromServer, ToServer, E, VER>
253{
254    fn drop(&mut self) {
255        // This is the idiomatic way to implement [drop] for a type that uses pinning. Since [drop]
256        // is implicitly called with `&mut self` even on types that were pinned, we place any
257        // implementation inside [inner_drop], which takes `Pin<&mut Self>`, when the commpiler will
258        // be able to check that we do not do anything that we couldn't have done on a
259        // `Pin<&mut Self>`.
260        //
261        // The [drop] implementation for this type is trivial, and it would be safe to use the
262        // automatically generated [drop] implementation, but we nonetheless implement [drop]
263        // explicitly in the idiomatic fashion so that it is impossible to accidentally implement an
264        // unsafe version of [drop] for this type in the future.
265
266        // `new_unchecked` is okay because we know this value is never used again after being
267        // dropped.
268        inner_drop(unsafe { Pin::new_unchecked(self) });
269        fn inner_drop<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>(
270            _this: Pin<&mut Connection<FromServer, ToServer, E, VER>>,
271        ) {
272            // Any logic goes here.
273        }
274    }
275}
276
277/// Unconstructable enum used to disable the [Sink] functionality of [Connection].
278#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
279pub enum Unsupported {}
280
281/// Get the scheme for a WebSockets connection upgraded from an existing stateless connection.
282///
283/// `scheme` is the scheme of the stateless connection, e.g. HTTP or HTTPS. If the scheme has a
284/// known WebSockets counterpart, e.g. WS or WSS, we return it. Otherwise we trust the user knows
285/// what they're doing and return `scheme` unmodified.
286fn socket_scheme(scheme: &str) -> String {
287    match scheme {
288        "http" => "ws",
289        "https" => "wss",
290        _ => scheme,
291    }
292    .to_string()
293}
294
295#[cfg(test)]
296mod test {
297    use super::*;
298    use crate::{Client, ContentType};
299    use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
300    use async_std::task::spawn;
301    use futures::stream::{repeat, StreamExt};
302    use portpicker::pick_unused_port;
303    use tide_disco::{error::ServerError, App};
304    use toml::toml;
305    use vbs::version::StaticVersion;
306
307    type Ver01 = StaticVersion<0, 1>;
308    const VER_0_1: Ver01 = StaticVersion {};
309
310    #[async_std::test]
311    async fn test_socket_accept() {
312        setup_logging();
313        setup_backtrace();
314
315        // Set up a simple Tide Disco app.
316        let mut app: App<(), ServerError> = App::with_state(());
317        let api = toml! {
318            [route.subscribe]
319            PATH = ["/subscribe"]
320            METHOD = "SOCKET"
321        };
322        app.module::<ServerError, Ver01>("mod", api)
323            .unwrap()
324            .stream("subscribe", |_req, _state| {
325                repeat("response").map(Ok).boxed()
326            })
327            .unwrap();
328        let port = pick_unused_port().unwrap();
329        spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
330
331        // Connect one client with each supported content type.
332        let json_client = Client::<ServerError, Ver01>::builder(
333            format!("http://localhost:{port}").parse().unwrap(),
334        )
335        .content_type(ContentType::Json)
336        .build();
337        assert!(json_client.connect(None).await);
338
339        let bin_client = Client::<ServerError, Ver01>::builder(
340            format!("http://localhost:{port}").parse().unwrap(),
341        )
342        .content_type(ContentType::Binary)
343        .build();
344        assert!(bin_client.connect(None).await);
345
346        // Check that connections built with each client get messages in the desired content type.
347        let mut conn = json_client
348            .socket("mod/subscribe")
349            .subscribe::<String>()
350            .await
351            .unwrap();
352        let Message::Text(msg) = conn.inner.next().await.unwrap().unwrap() else {
353            panic!("unexpected content type");
354        };
355        assert_eq!(serde_json::from_str::<String>(&msg).unwrap(), "response");
356
357        let mut conn = bin_client
358            .socket("mod/subscribe")
359            .subscribe::<String>()
360            .await
361            .unwrap();
362        let Message::Binary(msg) = conn.inner.next().await.unwrap().unwrap() else {
363            panic!("unexpected content type");
364        };
365        assert_eq!(
366            Serializer::<Ver01>::deserialize::<String>(&msg).unwrap(),
367            "response"
368        );
369    }
370}