From 14f38b28b8305ebd9d7303098229fe4606d14a46 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 4 Nov 2024 21:33:40 -0800 Subject: [PATCH] remove appendingwebsocketread, specialcase data/close --- wisp/src/mux/client.rs | 7 +++-- wisp/src/mux/inner.rs | 67 +++++++++++++++++++++++++++++------------- wisp/src/mux/mod.rs | 6 ++-- wisp/src/mux/server.rs | 5 ++-- wisp/src/ws.rs | 30 ------------------- 5 files changed, 56 insertions(+), 59 deletions(-) diff --git a/wisp/src/mux/client.rs b/wisp/src/mux/client.rs index 750a4c5..bf00b93 100644 --- a/wisp/src/mux/client.rs +++ b/wisp/src/mux/client.rs @@ -12,7 +12,7 @@ use futures::channel::oneshot; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, mux::send_info_packet, - ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, WispError, }; @@ -110,10 +110,11 @@ impl ClientMux { let tx = LockedWebSocketWrite::new(Box::new(tx)); let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?; - let (extensions, frame) = handshake_result.kind.into_parts(); + let (extensions, extra_packet) = handshake_result.kind.into_parts(); let mux_inner = MuxInner::new_client( - AppendingWebSocketRead(frame, rx), + rx, + extra_packet, tx.clone(), extensions.clone(), buffer_size, diff --git a/wisp/src/mux/inner.rs b/wisp/src/mux/inner.rs index 265bc48..1b41386 100644 --- a/wisp/src/mux/inner.rs +++ b/wisp/src/mux/inner.rs @@ -46,6 +46,9 @@ struct MuxMapValue { pub struct MuxInner { // gets taken by the mux task rx: Option, + // gets taken by the mux task + maybe_downgrade_packet: Option>, + tx: LockedWebSocketWrite, extensions: Vec, tcp_extensions: Vec, @@ -82,6 +85,7 @@ impl MuxInner { pub fn new_server( rx: R, + maybe_downgrade_packet: Option>, tx: LockedWebSocketWrite, extensions: Vec, buffer_size: u32, @@ -98,6 +102,7 @@ impl MuxInner { MuxInnerResult { mux: Self { rx: Some(rx), + maybe_downgrade_packet, tx, actor_rx: Some(fut_rx), @@ -124,6 +129,7 @@ impl MuxInner { pub fn new_client( rx: R, + maybe_downgrade_packet: Option>, tx: LockedWebSocketWrite, extensions: Vec, buffer_size: u32, @@ -136,6 +142,7 @@ impl MuxInner { MuxInnerResult { mux: Self { rx: Some(rx), + maybe_downgrade_packet, tx, actor_rx: Some(fut_rx), @@ -265,9 +272,17 @@ impl MuxInner { let mut next_free_stream_id: u32 = 1; let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; + let maybe_downgrade_packet = self.maybe_downgrade_packet.take(); + let tx = self.tx.clone(); let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?; + if let Some(downgrade_packet) = maybe_downgrade_packet { + if self.handle_packet(downgrade_packet, None).await? { + return Ok(()); + } + } + let mut recv_fut = fut_rx.recv_async().fuse(); let mut read_fut = rx.wisp_read_split(&tx).fuse(); while let Some(msg) = select! { @@ -342,14 +357,7 @@ impl MuxInner { } WsEvent::WispMessage(packet, optional_frame) => { if let Some(packet) = packet { - let should_break = match self.role { - Role::Server => { - self.server_handle_packet(packet, optional_frame).await? - } - Role::Client => { - self.client_handle_packet(packet, optional_frame).await? - } - }; + let should_break = self.handle_packet(packet, optional_frame).await?; if should_break { break; } @@ -409,18 +417,31 @@ impl MuxInner { Ok(false) } - async fn server_handle_packet( + async fn handle_packet( &mut self, packet: Packet<'static>, optional_frame: Option>, ) -> Result { - use PacketType::*; + use PacketType as P; match packet.packet_type { - Continue(_) | Info(_) => Err(WispError::InvalidPacketType), - Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data), - Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet), + P::Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data), + P::Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet), - Connect(inner_packet) => { + _ => match self.role { + Role::Server => self.server_handle_packet(packet, optional_frame).await, + Role::Client => self.client_handle_packet(packet, optional_frame).await, + }, + } + } + + async fn server_handle_packet( + &mut self, + packet: Packet<'static>, + _optional_frame: Option>, + ) -> Result { + use PacketType as P; + match packet.packet_type { + P::Connect(inner_packet) => { let (map_value, stream) = self .create_new_stream(packet.stream_id, inner_packet.stream_type) .await?; @@ -431,21 +452,21 @@ impl MuxInner { self.stream_map.insert(packet.stream_id, map_value); Ok(false) } + + // Continue | Info => invalid packet type + // Data | Close => specialcased + _ => Err(WispError::InvalidPacketType), } } async fn client_handle_packet( &mut self, packet: Packet<'static>, - optional_frame: Option>, + _optional_frame: Option>, ) -> Result { - use PacketType::*; + use PacketType as P; match packet.packet_type { - Connect(_) | Info(_) => Err(WispError::InvalidPacketType), - Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data), - Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet), - - Continue(inner_packet) => { + P::Continue(inner_packet) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { if stream.stream_type == StreamType::Tcp { stream @@ -456,6 +477,10 @@ impl MuxInner { } Ok(false) } + + // Connect | Info => invalid packet type + // Data | Close => specialcased + _ => Err(WispError::InvalidPacketType), } } } diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs index e947908..b784704 100644 --- a/wisp/src/mux/mod.rs +++ b/wisp/src/mux/mod.rs @@ -8,7 +8,7 @@ pub use server::ServerMux; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}, - ws::{Frame, LockedWebSocketWrite}, + ws::LockedWebSocketWrite, CloseReason, Packet, PacketType, Role, WispError, }; @@ -22,12 +22,12 @@ enum WispHandshakeResultKind { extensions: Vec, }, V1 { - frame: Option>, + frame: Option>, }, } impl WispHandshakeResultKind { - pub fn into_parts(self) -> (Vec, Option>) { + pub fn into_parts(self) -> (Vec, Option>) { match self { Self::V2 { extensions } => (extensions, None), Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame), diff --git a/wisp/src/mux/server.rs b/wisp/src/mux/server.rs index 974c87e..cd628ab 100644 --- a/wisp/src/mux/server.rs +++ b/wisp/src/mux/server.rs @@ -11,7 +11,7 @@ use futures::channel::oneshot; use crate::{ extensions::AnyProtocolExtension, - ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, WispError, }; @@ -109,7 +109,8 @@ impl ServerMux { let (extensions, extra_packet) = handshake_result.kind.into_parts(); let (mux_result, muxstream_recv) = MuxInner::new_server( - AppendingWebSocketRead(extra_packet, rx), + rx, + extra_packet, tx.clone(), extensions.clone(), buffer_size, diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index d75b694..1d1ca78 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -261,33 +261,3 @@ impl LockedWebSocketWrite { self.0.lock().await.wisp_close().await } } - -pub(crate) struct AppendingWebSocketRead(pub Option>, pub R) -where - R: WebSocketRead + Send; - -#[async_trait] -impl WebSocketRead for AppendingWebSocketRead -where - R: WebSocketRead + Send, -{ - async fn wisp_read_frame( - &mut self, - tx: &LockedWebSocketWrite, - ) -> Result, WispError> { - if let Some(x) = self.0.take() { - return Ok(x); - } - self.1.wisp_read_frame(tx).await - } - - async fn wisp_read_split( - &mut self, - tx: &LockedWebSocketWrite, - ) -> Result<(Frame<'static>, Option>), WispError> { - if let Some(x) = self.0.take() { - return Ok((x, None)); - } - self.1.wisp_read_split(tx).await - } -}