1use crate::{
8 http::headers::{HeaderName, ToHeaderValues},
9 ContentType, Error, StatusCode, Url,
10};
11use async_tungstenite::{
12 async_std::{connect_async_with_config, ConnectStream},
13 tungstenite::{
14 http::request::Builder as RequestBuilder, protocol::WebSocketConfig, Error as WsError,
15 Message,
16 },
17 WebSocketStream,
18};
19use futures::{
20 task::{Context, Poll},
21 Sink, Stream,
22};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use std::{collections::HashMap, pin::Pin};
25use vbs::{version::StaticVersionType, BinarySerializer, Serializer};
26
27#[must_use]
28#[derive(Debug)]
29pub struct SocketRequest<E, VER: StaticVersionType> {
30 url: Url,
31 content_type: ContentType,
32 headers: HashMap<String, Vec<String>>,
33 config: Option<WebSocketConfig>,
34 marker: std::marker::PhantomData<fn(E, VER) -> ()>,
35}
36
37impl<E: Error, VER: StaticVersionType> SocketRequest<E, VER> {
38 pub(crate) fn new(
39 mut url: Url,
40 content_type: ContentType,
41 config: Option<WebSocketConfig>,
42 ) -> Self {
43 url.set_scheme(&socket_scheme(url.scheme())).unwrap();
44 Self {
45 url,
46 content_type,
47 headers: Default::default(),
48 config,
49 marker: Default::default(),
50 }
51 }
52
53 pub fn header(mut self, key: impl Into<HeaderName>, values: impl ToHeaderValues) -> Self {
55 let name = key.into().to_string();
56 for value in values.to_header_values().unwrap() {
57 self.headers
58 .entry(name.clone())
59 .or_default()
60 .push(value.to_string());
61 }
62 self
63 }
64
65 pub async fn connect<FromServer: DeserializeOwned, ToServer: Serialize + ?Sized>(
67 mut self,
68 ) -> Result<Connection<FromServer, ToServer, E, VER>, E> {
69 loop {
71 let mut req = RequestBuilder::new().uri(self.url.to_string());
72 for (key, values) in &self.headers {
73 for value in values {
74 req = req.header(key, value);
75 }
76 }
77 let req = req
78 .body(())
79 .map_err(|err| E::catch_all(StatusCode::BAD_REQUEST, err.to_string()))?;
80
81 let err = match connect_async_with_config(req, self.config).await {
82 Ok((conn, _)) => return Ok(Connection::new(conn, self.content_type)),
83 Err(err) => err,
84 };
85 if let WsError::Http(res) = &err {
86 if (301..=308).contains(&u16::from(res.status())) {
87 if let Some(location) = res
88 .headers()
89 .get("location")
90 .and_then(|header| header.to_str().ok())
91 {
92 tracing::info!(from = %self.url, to = %location, "WS handshake following redirect");
93 self.url.set_path(location);
94 continue;
95 }
96 }
97 }
98 return Err(E::catch_all(StatusCode::BAD_REQUEST, err.to_string()));
99 }
100 }
101
102 pub async fn subscribe<FromServer: DeserializeOwned>(
107 self,
108 ) -> Result<Connection<FromServer, Unsupported, E, VER>, E> {
109 self.connect().await
110 }
111}
112
113pub struct Connection<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> {
115 inner: WebSocketStream<ConnectStream>,
116 content_type: ContentType,
117 #[allow(clippy::type_complexity)]
118 marker: std::marker::PhantomData<fn(FromServer, ToServer, E, VER) -> ()>,
119}
120
121impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
122 Connection<FromServer, ToServer, E, VER>
123{
124 fn new(inner: WebSocketStream<ConnectStream>, content_type: ContentType) -> Self {
125 Self {
126 inner,
127 content_type,
128 marker: Default::default(),
129 }
130 }
131}
132
133impl<FromServer: DeserializeOwned, ToServer: ?Sized, E: Error, VER: StaticVersionType> Stream
134 for Connection<FromServer, ToServer, E, VER>
135{
136 type Item = Result<FromServer, E>;
137
138 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139 match self.pinned_inner().poll_next(cx) {
142 Poll::Ready(None) => Poll::Ready(None),
143 Poll::Ready(Some(Err(err))) => match err {
144 WsError::ConnectionClosed | WsError::AlreadyClosed => Poll::Ready(None),
145 err => Poll::Ready(Some(Err(E::catch_all(
146 StatusCode::INTERNAL_SERVER_ERROR,
147 err.to_string(),
148 )))),
149 },
150 Poll::Ready(Some(Ok(msg))) => Poll::Ready(match msg {
151 Message::Binary(bytes) => {
152 Some(Serializer::<VER>::deserialize(&bytes).map_err(|err| {
153 E::catch_all(
154 StatusCode::INTERNAL_SERVER_ERROR,
155 format!("invalid binary: {}\n{bytes:?}", err),
156 )
157 }))
158 }
159 Message::Text(s) => Some(serde_json::from_str(&s).map_err(|err| {
160 E::catch_all(
161 StatusCode::INTERNAL_SERVER_ERROR,
162 format!("invalid JSON: {}\n{s}", err),
163 )
164 })),
165 Message::Close(_) => None,
166 _ => Some(Err(E::catch_all(
167 StatusCode::UNSUPPORTED_MEDIA_TYPE,
168 "unsupported WebSocket message".into(),
169 ))),
170 }),
171 Poll::Pending => Poll::Pending,
172 }
173 }
174}
175
176impl<FromServer, ToServer: Serialize + ?Sized, E: Error, VER: StaticVersionType> Sink<&ToServer>
177 for Connection<FromServer, ToServer, E, VER>
178{
179 type Error = E;
180
181 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182 self.pinned_inner().poll_ready(cx).map_err(|err| {
183 E::catch_all(
184 StatusCode::INTERNAL_SERVER_ERROR,
185 format!("error in WebSocket connection: {}", err),
186 )
187 })
188 }
189
190 fn start_send(self: Pin<&mut Self>, item: &ToServer) -> Result<(), Self::Error> {
191 let msg = match self.content_type {
192 ContentType::Binary => {
193 Message::Binary(Serializer::<VER>::serialize(item).map_err(|err| {
194 E::catch_all(
195 StatusCode::BAD_REQUEST,
196 format!("invalid binary serialization: {}", err),
197 )
198 })?)
199 }
200 ContentType::Json => Message::Text(serde_json::to_string(item).map_err(|err| {
201 E::catch_all(
202 StatusCode::BAD_REQUEST,
203 format!("invalid JSON serialization: {}", err),
204 )
205 })?),
206 };
207 self.pinned_inner().start_send(msg).map_err(|err| {
208 E::catch_all(
209 StatusCode::INTERNAL_SERVER_ERROR,
210 format!("error sending WebSocket message: {}", err),
211 )
212 })
213 }
214
215 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 self.pinned_inner().poll_flush(cx).map_err(|err| {
217 E::catch_all(
218 StatusCode::INTERNAL_SERVER_ERROR,
219 format!("error in WebSocket connection: {}", err),
220 )
221 })
222 }
223
224 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225 self.pinned_inner().poll_close(cx).map_err(|err| {
226 E::catch_all(
227 StatusCode::INTERNAL_SERVER_ERROR,
228 format!("error in WebSocket connection: {}", err),
229 )
230 })
231 }
232}
233
234impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>
235 Connection<FromServer, ToServer, E, VER>
236{
237 fn pinned_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<ConnectStream>> {
239 unsafe { self.map_unchecked_mut(|s| &mut s.inner) }
257 }
258}
259
260impl<FromServer, ToServer: ?Sized, E, VER: StaticVersionType> Drop
261 for Connection<FromServer, ToServer, E, VER>
262{
263 fn drop(&mut self) {
264 inner_drop(unsafe { Pin::new_unchecked(self) });
278 fn inner_drop<FromServer, ToServer: ?Sized, E, VER: StaticVersionType>(
279 _this: Pin<&mut Connection<FromServer, ToServer, E, VER>>,
280 ) {
281 }
283 }
284}
285
286#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
288pub enum Unsupported {}
289
290fn socket_scheme(scheme: &str) -> String {
296 match scheme {
297 "http" => "ws",
298 "https" => "wss",
299 _ => scheme,
300 }
301 .to_string()
302}
303
304#[cfg(test)]
305mod test {
306 use super::*;
307 use crate::{Client, ContentType};
308 use async_compatibility_layer::logging::{setup_backtrace, setup_logging};
309 use async_std::task::spawn;
310 use futures::stream::{repeat, StreamExt};
311 use portpicker::pick_unused_port;
312 use tide_disco::{error::ServerError, App};
313 use toml::toml;
314 use vbs::version::StaticVersion;
315
316 type Ver01 = StaticVersion<0, 1>;
317 const VER_0_1: Ver01 = StaticVersion {};
318
319 #[async_std::test]
320 async fn test_socket_accept() {
321 setup_logging();
322 setup_backtrace();
323
324 let mut app: App<(), ServerError> = App::with_state(());
326 let api = toml! {
327 [route.subscribe]
328 PATH = ["/subscribe"]
329 METHOD = "SOCKET"
330 };
331 app.module::<ServerError, Ver01>("mod", api)
332 .unwrap()
333 .stream("subscribe", |_req, _state| {
334 repeat("response").map(Ok).boxed()
335 })
336 .unwrap();
337 let port = pick_unused_port().unwrap();
338 spawn(app.serve(format!("0.0.0.0:{port}"), VER_0_1));
339
340 let json_client = Client::<ServerError, Ver01>::builder(
342 format!("http://localhost:{port}").parse().unwrap(),
343 )
344 .content_type(ContentType::Json)
345 .build();
346 assert!(json_client.connect(None).await);
347
348 let bin_client = Client::<ServerError, Ver01>::builder(
349 format!("http://localhost:{port}").parse().unwrap(),
350 )
351 .content_type(ContentType::Binary)
352 .build();
353 assert!(bin_client.connect(None).await);
354
355 let mut conn = json_client
357 .socket("mod/subscribe")
358 .subscribe::<String>()
359 .await
360 .unwrap();
361 let Message::Text(msg) = conn.inner.next().await.unwrap().unwrap() else {
362 panic!("unexpected content type");
363 };
364 assert_eq!(serde_json::from_str::<String>(&msg).unwrap(), "response");
365
366 let mut conn = bin_client
367 .socket("mod/subscribe")
368 .subscribe::<String>()
369 .await
370 .unwrap();
371 let Message::Binary(msg) = conn.inner.next().await.unwrap().unwrap() else {
372 panic!("unexpected content type");
373 };
374 assert_eq!(
375 Serializer::<Ver01>::deserialize::<String>(&msg).unwrap(),
376 "response"
377 );
378 }
379}