From 9cd87b72433a98ebfe6c127b25ae046bb838b430 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 31 Aug 2024 16:20:56 -0700 Subject: [PATCH] rewrite actor --- Cargo.lock | 16 -- wisp/Cargo.toml | 1 - wisp/src/inner.rs | 401 +++++++++++++++++++++++++++++++++++ wisp/src/lib.rs | 509 ++++++--------------------------------------- wisp/src/stream.rs | 15 +- 5 files changed, 470 insertions(+), 472 deletions(-) create mode 100644 wisp/src/inner.rs diff --git a/Cargo.lock b/Cargo.lock index a3ba6a1..d38abf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1485,27 +1485,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" -[[package]] -name = "scc" -version = "2.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37" -dependencies = [ - "sdd", -] - [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "sdd" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f" - [[package]] name = "send_wrapper" version = "0.4.0" @@ -2274,7 +2259,6 @@ dependencies = [ "futures", "futures-timer", "pin-project-lite", - "scc", "tokio", ] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 28b8b76..dbea286 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -18,7 +18,6 @@ flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" pin-project-lite = "0.2.14" -scc = "2.1.16" tokio = { version = "1.39.3", optional = true, default-features = false } [features] diff --git a/wisp/src/inner.rs b/wisp/src/inner.rs new file mode 100644 index 0000000..55219f2 --- /dev/null +++ b/wisp/src/inner.rs @@ -0,0 +1,401 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, +}; + +use crate::{ + extensions::AnyProtocolExtension, + ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, + AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType, + Role, StreamType, WispError, +}; +use bytes::{Bytes, BytesMut}; +use event_listener::Event; +use flume as mpsc; +use futures::{channel::oneshot, FutureExt}; + +pub(crate) enum WsEvent { + Close(Packet<'static>, oneshot::Sender>), + CreateStream( + StreamType, + String, + u16, + oneshot::Sender>, + ), + WispMessage(Frame<'static>, Option>), + EndFut(Option), +} + +struct MuxMapValue { + stream: mpsc::Sender, + stream_type: StreamType, + + flow_control: Arc, + flow_control_event: Arc, + + is_closed: Arc, + close_reason: Arc, + is_closed_event: Arc, +} + +pub struct MuxInner { + rx: R, + tx: LockedWebSocketWrite, + extensions: Vec, + role: Role, + + fut_rx: mpsc::Receiver, + fut_tx: mpsc::Sender, + fut_exited: Arc, + + stream_map: HashMap, + + buffer_size: u32, + target_buffer_size: u32, + + server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>, +} + +impl MuxInner { + pub fn new_server( + rx: R, + tx: LockedWebSocketWrite, + extensions: Vec, + buffer_size: u32, + ) -> ( + Self, + Arc, + mpsc::Sender, + mpsc::Receiver<(ConnectPacket, MuxStream)>, + ) { + let (fut_tx, fut_rx) = mpsc::bounded::(256); + let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + let ret_fut_tx = fut_tx.clone(); + let fut_exited = Arc::new(AtomicBool::new(false)); + + ( + Self { + rx, + tx, + + fut_rx, + fut_tx, + fut_exited: fut_exited.clone(), + + extensions, + buffer_size, + target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32, + + role: Role::Server, + + stream_map: HashMap::new(), + + server_tx, + }, + fut_exited, + ret_fut_tx, + server_rx, + ) + } + + pub fn new_client( + rx: R, + tx: LockedWebSocketWrite, + extensions: Vec, + buffer_size: u32, + ) -> (Self, Arc, mpsc::Sender) { + let (fut_tx, fut_rx) = mpsc::bounded::(256); + let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + let ret_fut_tx = fut_tx.clone(); + let fut_exited = Arc::new(AtomicBool::new(false)); + + ( + Self { + rx, + tx, + + fut_rx, + fut_tx, + fut_exited: fut_exited.clone(), + + extensions, + buffer_size, + target_buffer_size: 0, + + role: Role::Client, + + stream_map: HashMap::new(), + + server_tx, + }, + fut_exited, + ret_fut_tx, + ) + } + + pub async fn into_future(mut self) -> Result<(), WispError> { + let ret = self.stream_loop().await; + + self.fut_exited.store(true, Ordering::Release); + + for (_, stream) in self.stream_map.iter() { + self.close_stream(stream, ClosePacket::new(CloseReason::Unknown)); + } + self.stream_map.clear(); + + let _ = self.tx.close().await; + ret + } + + async fn create_new_stream( + &mut self, + stream_id: u32, + stream_type: StreamType, + ) -> Result<(MuxMapValue, MuxStream), WispError> { + let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); + + let flow_control_event: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + + let is_closed: Arc = AtomicBool::new(false).into(); + let close_reason: Arc = + AtomicCloseReason::new(CloseReason::Unknown).into(); + let is_closed_event: Arc = Event::new().into(); + + Ok(( + MuxMapValue { + stream: ch_tx, + stream_type, + + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + + is_closed: is_closed.clone(), + close_reason: close_reason.clone(), + is_closed_event: is_closed_event.clone(), + }, + MuxStream::new( + stream_id, + self.role, + stream_type, + ch_rx, + self.fut_tx.clone(), + self.tx.clone(), + is_closed, + is_closed_event, + close_reason, + flow_control, + flow_control_event, + self.target_buffer_size, + ), + )) + } + + fn close_stream(&self, stream: &MuxMapValue, close_packet: ClosePacket) { + stream + .close_reason + .store(close_packet.reason, Ordering::Release); + stream.is_closed.store(true, Ordering::Release); + stream.is_closed_event.notify(usize::MAX); + stream.flow_control.store(u32::MAX, Ordering::Release); + stream.flow_control_event.notify(usize::MAX); + } + + async fn get_message(&mut self) -> Result, WispError> { + futures::select! { + x = self.fut_rx.recv_async().fuse() => Ok(x.ok()), + x = self.rx.wisp_read_split(&self.tx).fuse() => { + let (mut frame, optional_frame) = x?; + if frame.opcode == OpCode::Close { + return Ok(None); + } + + if let Some(ref extra_frame) = optional_frame { + if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { + let mut payload = BytesMut::from(frame.payload); + payload.extend_from_slice(&extra_frame.payload); + frame.payload = Payload::Bytes(payload); + } + } + + Ok(Some(WsEvent::WispMessage(frame, optional_frame))) + } + } + } + + async fn stream_loop(&mut self) -> Result<(), WispError> { + let mut next_free_stream_id: u32 = 1; + while let Some(msg) = self.get_message().await? { + match msg { + WsEvent::CreateStream(stream_type, host, port, channel) => { + let ret: Result = async { + let stream_id = next_free_stream_id; + let next_stream_id = next_free_stream_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?; + + let (map_value, stream) = + self.create_new_stream(stream_id, stream_type).await?; + + self.tx + .write_frame( + Packet::new_connect(stream_id, stream_type, port, host).into(), + ) + .await?; + + self.stream_map.insert(stream_id, map_value); + + next_free_stream_id = next_stream_id; + + Ok(stream) + } + .await; + let _ = channel.send(ret); + } + WsEvent::Close(packet, channel) => { + if let Some(stream) = self.stream_map.remove(&packet.stream_id) { + if let PacketType::Close(close) = packet.packet_type { + self.close_stream(&stream, close); + } + let _ = channel.send(self.tx.write_frame(packet.into()).await); + } else { + let _ = channel.send(Err(WispError::InvalidStreamId)); + } + } + WsEvent::EndFut(x) => { + if let Some(reason) = x { + let _ = self + .tx + .write_frame(Packet::new_close(0, reason).into()) + .await; + } + break; + } + WsEvent::WispMessage(frame, optional_frame) => { + if let Some(packet) = Packet::maybe_handle_extension( + frame, + &mut self.extensions, + &mut self.rx, + &mut self.tx, + ) + .await? + { + let should_break = match self.role { + Role::Server => { + self.server_handle_packet(packet, optional_frame).await? + } + Role::Client => { + self.client_handle_packet(packet, optional_frame).await? + } + }; + if should_break { + break; + } + } + } + } + } + + Ok(()) + } + + fn handle_close_packet( + &mut self, + stream_id: u32, + inner_packet: ClosePacket, + ) -> Result { + if stream_id == 0 { + return Ok(true); + } + + if let Some(stream) = self.stream_map.remove(&stream_id) { + self.close_stream(&stream, inner_packet); + } + + Ok(false) + } + + fn handle_data_packet( + &mut self, + stream_id: u32, + optional_frame: Option>, + data: Payload<'static>, + ) -> Result { + let mut data = BytesMut::from(data); + + if let Some(stream) = self.stream_map.get(&stream_id) { + if let Some(extra_frame) = optional_frame { + if data.is_empty() { + data = extra_frame.payload.into(); + } else { + data.extend_from_slice(&extra_frame.payload); + } + } + let _ = stream.stream.try_send(data.freeze()); + if self.role == Role::Server && stream.stream_type == StreamType::Tcp { + stream.flow_control.store( + stream + .flow_control + .load(Ordering::Acquire) + .saturating_sub(1), + Ordering::Release, + ); + } + } + + Ok(false) + } + + async fn server_handle_packet( + &mut self, + packet: Packet<'static>, + optional_frame: Option>, + ) -> Result { + use PacketType::*; + match packet.packet_type { + Continue(_) | Info(_) => Err(WispError::InvalidPacketType), + Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data), + Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet), + + Connect(inner_packet) => { + let (map_value, stream) = self + .create_new_stream(packet.stream_id, inner_packet.stream_type) + .await?; + self.server_tx + .send_async((inner_packet, stream)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + self.stream_map.insert(packet.stream_id, map_value); + Ok(false) + } + } + } + + async fn client_handle_packet( + &mut self, + packet: Packet<'static>, + optional_frame: Option>, + ) -> Result { + use PacketType::*; + match packet.packet_type { + Connect(_) | Info(_) => Err(WispError::InvalidPacketType), + Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data), + Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet), + + Continue(inner_packet) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if stream.stream_type == StreamType::Tcp { + stream + .flow_control + .store(inner_packet.buffer_remaining, Ordering::Release); + let _ = stream.flow_control_event.notify(u32::MAX); + } + } + Ok(false) + } + } + } +} diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 4f77e77..384d0a1 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -11,6 +11,7 @@ mod fastwebsockets; #[cfg(feature = "generic_stream")] #[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))] pub mod generic; +mod inner; mod packet; mod sink_unfold; mod stream; @@ -18,21 +19,19 @@ pub mod ws; pub use crate::{packet::*, stream::*}; -use bytes::{Bytes, BytesMut}; -use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use flume as mpsc; use futures::{channel::oneshot, select, Future, FutureExt}; use futures_timer::Delay; -use scc::HashMap; +use inner::{MuxInner, WsEvent}; use std::{ sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, + atomic::{AtomicBool, Ordering}, Arc, }, time::Duration, }; -use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload}; +use ws::{AppendingWebSocketRead, LockedWebSocketWrite}; /// Wisp version supported by this crate. pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; @@ -157,363 +156,6 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} -struct MuxMapValue { - stream: mpsc::Sender, - stream_type: StreamType, - - flow_control: Arc, - flow_control_event: Arc, - - is_closed: Arc, - close_reason: Arc, - is_closed_event: Arc, -} - -impl Drop for MuxMapValue { - fn drop(&mut self) { - self.is_closed.store(true, Ordering::Release); - self.is_closed_event.notify(usize::MAX); - } -} - -struct MuxInner { - tx: ws::LockedWebSocketWrite, - stream_map: HashMap, - buffer_size: u32, - fut_exited: Arc, -} - -impl MuxInner { - pub async fn server_into_future( - self, - rx: R, - extensions: Vec, - close_rx: mpsc::Receiver, - muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, - close_tx: mpsc::Sender, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - self.as_future( - close_rx, - close_tx.clone(), - self.server_loop(rx, extensions, muxstream_sender, close_tx), - ) - .await - } - - pub async fn client_into_future( - self, - rx: R, - extensions: Vec, - close_rx: mpsc::Receiver, - close_tx: mpsc::Sender, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - self.as_future(close_rx, close_tx, self.client_loop(rx, extensions)) - .await - } - - async fn as_future( - &self, - close_rx: mpsc::Receiver, - close_tx: mpsc::Sender, - wisp_fut: impl Future>, - ) -> Result<(), WispError> { - let ret = futures::select! { - _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), - x = wisp_fut.fuse() => x, - }; - self.fut_exited.store(true, Ordering::Release); - self.stream_map.clear_async().await; - let _ = self.tx.close().await; - ret - } - - async fn create_new_stream( - &self, - stream_id: u32, - stream_type: StreamType, - role: Role, - stream_tx: mpsc::Sender, - tx: LockedWebSocketWrite, - target_buffer_size: u32, - ) -> Result<(MuxMapValue, MuxStream), WispError> { - let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); - - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - - let is_closed: Arc = AtomicBool::new(false).into(); - let close_reason: Arc = - AtomicCloseReason::new(CloseReason::Unknown).into(); - let is_closed_event: Arc = Event::new().into(); - - Ok(( - MuxMapValue { - stream: ch_tx, - stream_type, - - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - - is_closed: is_closed.clone(), - close_reason: close_reason.clone(), - is_closed_event: is_closed_event.clone(), - }, - MuxStream::new( - stream_id, - role, - stream_type, - ch_rx, - stream_tx, - tx, - is_closed, - is_closed_event, - close_reason, - flow_control, - flow_control_event, - target_buffer_size, - ), - )) - } - - async fn stream_loop( - &self, - stream_rx: mpsc::Receiver, - stream_tx: mpsc::Sender, - ) { - let mut next_free_stream_id: u32 = 1; - while let Ok(msg) = stream_rx.recv_async().await { - match msg { - WsEvent::CreateStream(stream_type, host, port, channel) => { - let ret: Result = async { - let stream_id = next_free_stream_id; - let next_stream_id = next_free_stream_id - .checked_add(1) - .ok_or(WispError::MaxStreamCountReached)?; - - let (map_value, stream) = self - .create_new_stream( - stream_id, - stream_type, - Role::Client, - stream_tx.clone(), - self.tx.clone(), - 0, - ) - .await?; - - self.tx - .write_frame( - Packet::new_connect(stream_id, stream_type, port, host).into(), - ) - .await?; - - self.stream_map.upsert_async(stream_id, map_value).await; - - next_free_stream_id = next_stream_id; - - Ok(stream) - } - .await; - let _ = channel.send(ret); - } - WsEvent::Close(packet, channel) => { - if let Some((_, stream)) = self.stream_map.remove_async(&packet.stream_id).await - { - if let PacketType::Close(close) = packet.packet_type { - self.close_stream(stream, close); - } - let _ = channel.send(self.tx.write_frame(packet.into()).await); - } else { - let _ = channel.send(Err(WispError::InvalidStreamId)); - } - } - WsEvent::EndFut(x) => { - if let Some(reason) = x { - let _ = self - .tx - .write_frame(Packet::new_close(0, reason).into()) - .await; - } - break; - } - } - } - } - - fn close_stream(&self, stream: MuxMapValue, close_packet: ClosePacket) { - stream - .close_reason - .store(close_packet.reason, Ordering::Release); - stream.is_closed.store(true, Ordering::Release); - stream.is_closed_event.notify(usize::MAX); - stream.flow_control.store(u32::MAX, Ordering::Release); - stream.flow_control_event.notify(usize::MAX); - } - - async fn server_loop( - &self, - mut rx: R, - mut extensions: Vec, - muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, - stream_tx: mpsc::Sender, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - // will send continues once flow_control is at 10% of max - let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32; - - loop { - let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?; - if frame.opcode == ws::OpCode::Close { - break Ok(()); - } - - if let Some(ref extra_frame) = optional_frame { - if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { - let mut payload = BytesMut::from(frame.payload); - payload.extend_from_slice(&extra_frame.payload); - frame.payload = Payload::Bytes(payload); - } - } - - if let Some(packet) = - Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? - { - use PacketType::*; - match packet.packet_type { - Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), - Connect(inner_packet) => { - let (map_value, stream) = self - .create_new_stream( - packet.stream_id, - inner_packet.stream_type, - Role::Server, - stream_tx.clone(), - self.tx.clone(), - target_buffer_size, - ) - .await?; - muxstream_sender - .send_async((inner_packet, stream)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - self.stream_map - .upsert_async(packet.stream_id, map_value) - .await; - } - Data(data) => { - let mut data = BytesMut::from(data); - if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await { - if let Some(extra_frame) = optional_frame { - if data.is_empty() { - data = extra_frame.payload.into(); - } else { - data.extend_from_slice(&extra_frame.payload); - } - } - let _ = stream.stream.try_send(data.freeze()); - if stream.stream_type == StreamType::Tcp { - stream.flow_control.store( - stream - .flow_control - .load(Ordering::Acquire) - .saturating_sub(1), - Ordering::Release, - ); - } - } - } - Close(inner_packet) => { - if packet.stream_id == 0 { - break Ok(()); - } - - if let Some((_, stream)) = - self.stream_map.remove_async(&packet.stream_id).await - { - self.close_stream(stream, inner_packet) - } - } - } - } - } - } - - async fn client_loop( - &self, - mut rx: R, - mut extensions: Vec, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - loop { - let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?; - if frame.opcode == ws::OpCode::Close { - break Ok(()); - } - - if let Some(ref extra_frame) = optional_frame { - if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() { - let mut payload = BytesMut::from(frame.payload); - payload.extend_from_slice(&extra_frame.payload); - frame.payload = Payload::Bytes(payload); - } - } - - if let Some(packet) = - Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? - { - use PacketType::*; - match packet.packet_type { - Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), - Data(data) => { - let mut data = BytesMut::from(data); - if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await { - if let Some(extra_frame) = optional_frame { - if data.is_empty() { - data = extra_frame.payload.into(); - } else { - data.extend_from_slice(&extra_frame.payload); - } - } - let _ = stream.stream.send_async(data.freeze()).await; - } - } - Continue(inner_packet) => { - if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await { - if stream.stream_type == StreamType::Tcp { - stream - .flow_control - .store(inner_packet.buffer_remaining, Ordering::Release); - let _ = stream.flow_control_event.notify(u32::MAX); - } - } - } - Close(inner_packet) => { - if packet.stream_id == 0 { - break Ok(()); - } - - if let Some((_, stream)) = - self.stream_map.remove_async(&packet.stream_id).await - { - self.close_stream(stream, inner_packet) - } - } - } - } - } - } -} - async fn maybe_wisp_v2( read: &mut R, write: &LockedWebSocketWrite, @@ -576,7 +218,7 @@ pub struct ServerMux { pub downgraded: bool, /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, - close_tx: mpsc::Sender, + actor_tx: mpsc::Sender, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, tx: ws::LockedWebSocketWrite, fut_exited: Arc, @@ -589,8 +231,8 @@ impl ServerMux { /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. pub async fn create( - mut read: R, - write: W, + mut rx: R, + tx: W, buffer_size: u32, extension_builders: Option<&[Box]>, ) -> Result> + Send>, WispError> @@ -598,55 +240,47 @@ impl ServerMux { R: ws::WebSocketRead + Send, W: ws::WebSocketWrite + Send + 'static, { - let (close_tx, close_rx) = mpsc::bounded::(256); - let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); - let write = ws::LockedWebSocketWrite::new(Box::new(write)); - let fut_exited = Arc::new(AtomicBool::new(false)); + let tx = ws::LockedWebSocketWrite::new(Box::new(tx)); - write - .write_frame(Packet::new_continue(0, buffer_size).into()) + tx.write_frame(Packet::new_continue(0, buffer_size).into()) .await?; let (supported_extensions, extra_packet, downgraded) = if let Some(builders) = extension_builders { - write - .write_frame( - Packet::new_info( - builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(), - ) - .into(), + tx.write_frame( + Packet::new_info( + builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(), ) - .await?; - maybe_wisp_v2(&mut read, &write, builders).await? + .into(), + ) + .await?; + maybe_wisp_v2(&mut rx, &tx, builders).await? } else { (Vec::new(), None, true) }; + let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); + + let (mux_inner, fut_exited, actor_tx, muxstream_recv) = MuxInner::new_server( + AppendingWebSocketRead(extra_packet, rx), + tx.clone(), + supported_extensions, + buffer_size, + ); + Ok(ServerMuxResult( Self { - muxstream_recv: rx, - close_tx: close_tx.clone(), + muxstream_recv, + actor_tx, downgraded, - supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(), - tx: write.clone(), + supported_extension_ids, + tx, fut_exited: fut_exited.clone(), }, - MuxInner { - tx: write, - stream_map: HashMap::new(), - buffer_size, - fut_exited, - } - .server_into_future( - AppendingWebSocketRead(extra_packet, read), - supported_extensions, - close_rx, - tx, - close_tx, - ), + mux_inner.into_future(), )) } @@ -662,7 +296,7 @@ impl ServerMux { if self.fut_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); } - self.close_tx + self.actor_tx .send_async(WsEvent::EndFut(reason)) .await .map_err(|_| WispError::MuxMessageFailedToSend) @@ -695,7 +329,7 @@ impl ServerMux { impl Drop for ServerMux { fn drop(&mut self) { - let _ = self.close_tx.send(WsEvent::EndFut(None)); + let _ = self.actor_tx.send(WsEvent::EndFut(None)); } } @@ -762,7 +396,7 @@ pub struct ClientMux { pub downgraded: bool, /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, - stream_tx: mpsc::Sender, + actor_tx: mpsc::Sender, tx: ws::LockedWebSocketWrite, fut_exited: Arc, } @@ -774,68 +408,61 @@ impl ClientMux { /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. pub async fn create( - mut read: R, - write: W, + mut rx: R, + tx: W, extension_builders: Option<&[Box]>, ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, W: ws::WebSocketWrite + Send + 'static, { - let write = ws::LockedWebSocketWrite::new(Box::new(write)); - let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; - let fut_exited = Arc::new(AtomicBool::new(false)); + let tx = ws::LockedWebSocketWrite::new(Box::new(tx)); + let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?; if first_packet.stream_id != 0 { return Err(WispError::InvalidStreamId); } + if let PacketType::Continue(packet) = first_packet.packet_type { let (supported_extensions, extra_packet, downgraded) = if let Some(builders) = extension_builders { - let x = maybe_wisp_v2(&mut read, &write, builders).await?; + let x = maybe_wisp_v2(&mut rx, &tx, builders).await?; // if not downgraded if !x.2 { - write - .write_frame( - Packet::new_info( - builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(), - ) - .into(), + tx.write_frame( + Packet::new_info( + builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(), ) - .await?; + .into(), + ) + .await?; } x } else { (Vec::new(), None, true) }; - let (tx, rx) = mpsc::bounded::(256); + let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); + + let (mux_inner, fut_exited, actor_tx) = MuxInner::new_client( + AppendingWebSocketRead(extra_packet, rx), + tx.clone(), + supported_extensions, + packet.buffer_remaining, + ); + Ok(ClientMuxResult( Self { - stream_tx: tx.clone(), + actor_tx, downgraded, - supported_extension_ids: supported_extensions - .iter() - .map(|x| x.get_id()) - .collect(), - tx: write.clone(), - fut_exited: fut_exited.clone(), - }, - MuxInner { - tx: write, - stream_map: HashMap::new(), - buffer_size: packet.buffer_remaining, - fut_exited, - } - .client_into_future( - AppendingWebSocketRead(extra_packet, read), - supported_extensions, - rx, + supported_extension_ids, tx, - ), + fut_exited, + }, + mux_inner.into_future(), )) } else { Err(WispError::InvalidPacketType) @@ -863,7 +490,7 @@ impl ClientMux { ])); } let (tx, rx) = oneshot::channel(); - self.stream_tx + self.actor_tx .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; @@ -874,7 +501,7 @@ impl ClientMux { if self.fut_exited.load(Ordering::Acquire) { return Err(WispError::MuxTaskEnded); } - self.stream_tx + self.actor_tx .send_async(WsEvent::EndFut(reason)) .await .map_err(|_| WispError::MuxMessageFailedToSend) @@ -907,7 +534,7 @@ impl ClientMux { impl Drop for ClientMux { fn drop(&mut self) { - let _ = self.stream_tx.send(WsEvent::EndFut(None)); + let _ = self.actor_tx.send(WsEvent::EndFut(None)); } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index bd0982d..90e3ffc 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,7 +1,5 @@ use crate::{ - sink_unfold, - ws::{Frame, LockedWebSocketWrite, Payload}, - AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError, + inner::WsEvent, sink_unfold, ws::{Frame, LockedWebSocketWrite, Payload}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError }; use bytes::{BufMut, Bytes, BytesMut}; @@ -23,17 +21,6 @@ use std::{ }, }; -pub(crate) enum WsEvent { - Close(Packet<'static>, oneshot::Sender>), - CreateStream( - StreamType, - String, - u16, - oneshot::Sender>, - ), - EndFut(Option), -} - /// Read side of a multiplexor stream. pub struct MuxStreamRead { /// ID of the stream.