1use crate::{
8 http::headers::{HeaderName, ToHeaderValues},
9 ContentType, Error, StatusCode, Url,
10};
11use async_tungstenite::{
12 async_std::{connect_async, ConnectStream},
13 tungstenite::{http::request::Builder as RequestBuilder, Error as WsError, Message},
14 WebSocketStream,
15};
16use futures::{
17 task::{Context, Poll},
18 Sink, Stream,
19};
20use serde::{de::DeserializeOwned, Deserialize, Serialize};
21use std::{collections::HashMap, pin::Pin};
22use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
23
24#[must_use]
25#[derive(Debug)]
26pub struct SocketRequest<E, VER: StaticVersionType> {
27 url: Url,
28 content_type: ContentType,
29 headers: HashMap<String, Vec<String>>,
30 marker: std::marker::PhantomData<fn(E, VER) -> ()>,
31}
32
33impl<E: Error, VER: StaticVersionType> SocketRequest<E, VER> {
34 pub(crate) fn new(mut url: Url, content_type: ContentType) -> Self {
35 url.set_scheme(&socket_scheme(url.scheme())).unwrap();
36 Self {
37 url,
38 content_type,
39 headers: Default::default(),
40 marker: Default::default(),
41 }
42 }
43
44 pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
46 let name = key.into().to_string();
47 for value in values.to_header_values().unwrap() {
48 self.headers
49 .entry(name.clone())
50 .or_default()
51 .push(value.to_string());
52 }
53 self
54 }
55
56 pub async fn connect<FromServer: DeserializeOwned, ToServer: Serialize + ?Sized>(
58 mut self,
59 ) -> Result<Connection<FromServer, ToServer, E, VER>, E> {
60 loop {
62 let mut req = RequestBuilder::new().uri(self.url.to_string());
63 for (key, values) in &self.headers {
64 for value in values {
65 req = req.header(key, value);
66 }
67 }
68 let req = req
69 .body(())
70 .map_err(|err| E::catch_all(StatusCode::BAD_REQUEST, err.to_string()))?;
71
72 let err = match connect_async(req).await {
73 Ok((conn, _)) => return Ok(Connection::new(conn, self.content_type)),
74 Err(err) => err,
75 };
76 if let WsError::Http(res) = &err {
77 if (301..=308).contains(&u16::from(res.status())) {
78 if let Some(location) = res
79 .headers()
80 .get("location")
81 .and_then(|header| header.to_str().ok())
82 {
83 tracing::info!(from = %self.url, to = %location, "WS handshake following redirect");
84 self.url.set_path(location);
85 continue;
86 }
87 }
88 }
89 return Err(E::catch_all(StatusCode::BAD_REQUEST, err.to_string()));
90 }
91 }
92
93 pub async fn subscribe<FromServer: DeserializeOwned>(
98 self,
99 ) -> Result<Connection<FromServer, Unsupported, E, VER>, E> {
100 self.connect().await
101 }
102}
103
104pub struct Connection<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> {
106 inner: WebSocketStream<ConnectStream>,
107 content_type: ContentType,
108 #[allow(clippy::type_complexity)]
109 marker: std::marker::PhantomData<fn(FromServer, ToServer, E, VER) -> ()>,
110}
111
112impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
113 Connection<FromServer, ToServer, E, VER>
114{
115 fn new(inner: WebSocketStream<ConnectStream>, content_type: ContentType) -> Self {
116 Self {
117 inner,
118 content_type,
119 marker: Default::default(),
120 }
121 }
122}
123
124impl<FromServer: DeserializeOwned, ToServer: ?Sized, E: Error, VER: StaticVersionType> Stream
125 for Connection<FromServer, ToServer, E, VER>
126{
127 type Item = Result<FromServer, E>;
128
129 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 match self.pinned_inner().poll_next(cx) {
133 Poll::Ready(None) => Poll::Ready(None),
134 Poll::Ready(Some(Err(err))) => match err {
135 WsError::ConnectionClosed | WsError::AlreadyClosed => Poll::Ready(None),
136 err => Poll::Ready(Some(Err(E::catch_all(
137 StatusCode::INTERNAL_SERVER_ERROR,
138 err.to_string(),
139 )))),
140 },
141 Poll::Ready(Some(Ok(msg))) => Poll::Ready(match msg {
142 Message::Binary(bytes) => {
143 Some(Serializer::<VER>::deserialize(&bytes).map_err(|err| {
144 E::catch_all(
145 StatusCode::INTERNAL_SERVER_ERROR,
146 format!("invalid binary: {}\n{bytes:?}", err),
147 )
148 }))
149 }
150 Message::Text(s) => Some(serde_json::from_str(&s).map_err(|err| {
151 E::catch_all(
152 StatusCode::INTERNAL_SERVER_ERROR,
153 format!("invalid JSON: {}\n{s}", err),
154 )
155 })),
156 Message::Close(_) => None,
157 _ => Some(Err(E::catch_all(
158 StatusCode::UNSUPPORTED_MEDIA_TYPE,
159 "unsupported WebSocket message".into(),
160 ))),
161 }),
162 Poll::Pending => Poll::Pending,
163 }
164 }
165}
166
167impl<FromServer, ToServer: Serialize + ?Sized, E: Error, VER: StaticVersionType> Sink<&ToServer>
168 for Connection<FromServer, ToServer, E, VER>
169{
170 type Error = E;
171
172 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173 self.pinned_inner().poll_ready(cx).map_err(|err| {
174 E::catch_all(
175 StatusCode::INTERNAL_SERVER_ERROR,
176 format!("error in WebSocket connection: {}", err),
177 )
178 })
179 }
180
181 fn start_send(self: Pin<&mut Self>, item: &ToServer) -> Result<(), Self::Error> {
182 let msg = match self.content_type {
183 ContentType::Binary => {
184 Message::Binary(Serializer::<VER>::serialize(item).map_err(|err| {
185 E::catch_all(
186 StatusCode::BAD_REQUEST,
187 format!("invalid binary serialization: {}", err),
188 )
189 })?)
190 }
191 ContentType::Json => Message::Text(serde_json::to_string(item).map_err(|err| {
192 E::catch_all(
193 StatusCode::BAD_REQUEST,
194 format!("invalid JSON serialization: {}", err),
195 )
196 })?),
197 };
198 self.pinned_inner().start_send(msg).map_err(|err| {
199 E::catch_all(
200 StatusCode::INTERNAL_SERVER_ERROR,
201 format!("error sending WebSocket message: {}", err),
202 )
203 })
204 }
205
206 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207 self.pinned_inner().poll_flush(cx).map_err(|err| {
208 E::catch_all(
209 StatusCode::INTERNAL_SERVER_ERROR,
210 format!("error in WebSocket connection: {}", err),
211 )
212 })
213 }
214
215 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 self.pinned_inner().poll_close(cx).map_err(|err| {
217 E::catch_all(
218 StatusCode::INTERNAL_SERVER_ERROR,
219 format!("error in WebSocket connection: {}", err),
220 )
221 })
222 }
223}
224
225impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
226 Connection<FromServer, ToServer, E, VER>
227{
228 fn pinned_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<ConnectStream>> {
230 unsafe { self.map_unchecked_mut(|s| &mut s.inner) }
248 }
249}
250
251impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> Drop
252 for Connection<FromServer, ToServer, E, VER>
253{
254 fn drop(&mut self) {
255 inner_drop(unsafe { Pin::new_unchecked(self) });
269 fn inner_drop<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>(
270 _this: Pin<&mut Connection<FromServer, ToServer, E, VER>>,
271 ) {
272 }
274 }
275}
276
277#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
279pub enum Unsupported {}
280
281fn socket_scheme(scheme: &str) -> String {
287 match scheme {
288 "http" => "ws",
289 "https" => "wss",
290 _ => scheme,
291 }
292 .to_string()
293}
294
295#[cfg(test)]
296mod test {
297 use super::*;
298 use crate::{Client, ContentType};
299 use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
300 use async_std::task::spawn;
301 use futures::stream::{repeat, StreamExt};
302 use portpicker::pick_unused_port;
303 use tide_disco::{error::ServerError, App};
304 use toml::toml;
305 use vbs::version::StaticVersion;
306
307 type Ver01 = StaticVersion<0, 1>;
308 const VER_0_1: Ver01 = StaticVersion {};
309
310 #[async_std::test]
311 async fn test_socket_accept() {
312 setup_logging();
313 setup_backtrace();
314
315 let mut app: App<(), ServerError> = App::with_state(());
317 let api = toml! {
318 [route.subscribe]
319 PATH = ["/subscribe"]
320 METHOD = "SOCKET"
321 };
322 app.module::<ServerError, Ver01>("mod", api)
323 .unwrap()
324 .stream("subscribe", |_req, _state| {
325 repeat("response").map(Ok).boxed()
326 })
327 .unwrap();
328 let port = pick_unused_port().unwrap();
329 spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
330
331 let json_client = Client::<ServerError, Ver01>::builder(
333 format!("http://localhost:{port}").parse().unwrap(),
334 )
335 .content_type(ContentType::Json)
336 .build();
337 assert!(json_client.connect(None).await);
338
339 let bin_client = Client::<ServerError, Ver01>::builder(
340 format!("http://localhost:{port}").parse().unwrap(),
341 )
342 .content_type(ContentType::Binary)
343 .build();
344 assert!(bin_client.connect(None).await);
345
346 let mut conn = json_client
348 .socket("mod/subscribe")
349 .subscribe::<String>()
350 .await
351 .unwrap();
352 let Message::Text(msg) = conn.inner.next().await.unwrap().unwrap() else {
353 panic!("unexpected content type");
354 };
355 assert_eq!(serde_json::from_str::<String>(&msg).unwrap(), "response");
356
357 let mut conn = bin_client
358 .socket("mod/subscribe")
359 .subscribe::<String>()
360 .await
361 .unwrap();
362 let Message::Binary(msg) = conn.inner.next().await.unwrap().unwrap() else {
363 panic!("unexpected content type");
364 };
365 assert_eq!(
366 Serializer::<Ver01>::deserialize::<String>(&msg).unwrap(),
367 "response"
368 );
369 }
370}