From b3f35b232f3d655027384879556a4128cd89cfe7 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 27 Apr 2024 17:36:06 -0700 Subject: [PATCH] some optimizations and muxprotocolextensionstream for stream id 0 --- client/src/wrappers.rs | 7 +- simple-wisp-client/src/main.rs | 6 +- wisp/src/fastwebsockets.rs | 4 +- wisp/src/lib.rs | 245 ++++++++++++++++++++------------- wisp/src/packet.rs | 52 ++++--- wisp/src/stream.rs | 72 +++++----- wisp/src/ws.rs | 21 +-- 7 files changed, 237 insertions(+), 170 deletions(-) diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 47ff2c3..02361ff 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -1,10 +1,9 @@ use crate::*; use std::{ - pin::Pin, - sync::atomic::{AtomicBool, Ordering}, - task::{Context, Poll}, + ops::Deref, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::{Context, Poll} }; +use bytes::BytesMut; use event_listener::Event; use futures_util::{FutureExt, Stream}; use hyper::body::Body; @@ -207,7 +206,7 @@ impl WebSocketRead for WebSocketReader { _ = self.close_event.listen().fuse() => Some(Closed), }; match res.ok_or(WispError::WsImplSocketClosed)? { - Message(bin) => Ok(Frame::binary(bin.into())), + Message(bin) => Ok(Frame::binary(BytesMut::from(bin.deref()))), Error => Err(WebSocketError::Unknown.into()), Closed => Err(WispError::WsImplSocketClosed), } diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index a478dc6..0452d65 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -225,14 +225,14 @@ async fn main() -> Result<(), Box> { interval.tick().await; let now = cnt_avg.get(); let stat = format!( - "sent &[0; 1024 * {}] cnt: {:?} ({} KiB), +{:?} ({} KiB / 100ms), moving average (10 s): {:?} ({} KiB / 10 s)", + "sent &[0; 1024 * {}] cnt: {:?} ({} KiB), +{:?} / 100ms ({} KiB / 1s), moving average (10 s): {:?} / 100ms ({} KiB / 1s)", opts.packet_size, now, now * opts.packet_size, now - last_time, - (now - last_time) * opts.packet_size, + (now - last_time) * opts.packet_size * 10, avg.get_average(), - avg.get_average() * opts.packet_size, + avg.get_average() * opts.packet_size * 10, ); if is_term { println!("\x1b[1A\x1b[2K{}\r", stat); diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index e05de91..f21cf75 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -30,7 +30,7 @@ impl From> for crate::ws::Frame { Self { finished: frame.fin, opcode: frame.opcode.into(), - payload: BytesMut::from(frame.payload.deref()).freeze(), + payload: BytesMut::from(frame.payload.deref()), } } } @@ -38,7 +38,7 @@ impl From> for crate::ws::Frame { impl<'a> From for Frame<'a> { fn from(frame: crate::ws::Frame) -> Self { use crate::ws::OpCode::*; - let payload = Payload::Owned(frame.payload.into()); + let payload = Payload::Bytes(frame.payload); match frame.opcode { Text => Self::text(payload), Binary => Self::binary(payload), diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 1b88da8..49ee3fc 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -29,7 +29,7 @@ use std::{ }, time::Duration, }; -use ws::AppendingWebSocketRead; +use ws::{AppendingWebSocketRead, LockedWebSocketWrite}; /// Wisp version supported by this crate. pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; @@ -92,6 +92,8 @@ pub enum WispError { MuxMessageFailedToSend, /// Failed to receive message from multiplexor task. MuxMessageFailedToRecv, + /// Multiplexor task ended. + MuxTaskEnded, } impl From for WispError { @@ -145,6 +147,7 @@ impl std::fmt::Display for WispError { Self::Other(err) => write!(f, "Other error: {}", err), Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), + Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), } } } @@ -164,6 +167,7 @@ struct MuxInner { tx: ws::LockedWebSocketWrite, stream_map: DashMap, buffer_size: u32, + fut_exited: Arc } impl MuxInner { @@ -210,6 +214,7 @@ impl MuxInner { _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), x = wisp_fut.fuse() => x, }; + self.fut_exited.store(true, Ordering::Release); for x in self.stream_map.iter_mut() { x.is_closed.store(true, Ordering::Release); x.is_closed_event.notify(usize::MAX); @@ -225,6 +230,7 @@ impl MuxInner { stream_type: StreamType, role: Role, stream_tx: mpsc::Sender, + tx: LockedWebSocketWrite, target_buffer_size: u32, ) -> Result<(MuxMapValue, MuxStream), WispError> { let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); @@ -249,7 +255,8 @@ impl MuxInner { role, stream_type, ch_rx, - stream_tx.clone(), + stream_tx, + tx, is_closed, is_closed_event, flow_control, @@ -267,16 +274,6 @@ impl MuxInner { let mut next_free_stream_id: u32 = 1; while let Ok(msg) = stream_rx.recv_async().await { match msg { - WsEvent::SendPacket(packet, channel) => { - if self.stream_map.get(&packet.stream_id).is_some() { - let _ = channel.send(self.tx.write_frame(packet.into()).await); - } else { - let _ = channel.send(Err(WispError::InvalidStreamId)); - } - } - WsEvent::SendBytes(packet, channel) => { - let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await); - } WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { let stream_id = next_free_stream_id; @@ -290,6 +287,7 @@ impl MuxInner { stream_type, Role::Client, stream_tx.clone(), + self.tx.clone(), 0, ) .await?; @@ -330,6 +328,16 @@ impl MuxInner { } } + fn close_stream(&self, packet: Packet) { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + stream.is_closed.store(true, Ordering::Release); + stream.is_closed_event.notify(usize::MAX); + stream.flow_control.store(u32::MAX, Ordering::Release); + stream.flow_control_event.notify(usize::MAX); + drop(stream.stream) + } + } + async fn server_loop( &self, mut rx: R, @@ -353,6 +361,7 @@ impl MuxInner { { use PacketType::*; match packet.packet_type { + Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Connect(inner_packet) => { let (map_value, stream) = self .create_new_stream( @@ -360,6 +369,7 @@ impl MuxInner { inner_packet.stream_type, Role::Server, stream_tx.clone(), + self.tx.clone(), target_buffer_size, ) .await?; @@ -383,16 +393,11 @@ impl MuxInner { } } } - Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Close(_) => { if packet.stream_id == 0 { break Ok(()); } - if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { - stream.is_closed.store(true, Ordering::Release); - stream.is_closed_event.notify(usize::MAX); - drop(stream.stream) - } + self.close_stream(packet) } } } @@ -437,11 +442,7 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { - stream.is_closed.store(true, Ordering::Release); - stream.is_closed_event.notify(usize::MAX); - drop(stream.stream) - } + self.close_stream(packet) } } } @@ -449,6 +450,42 @@ impl MuxInner { } } +async fn maybe_wisp_v2( + read: &mut R, + write: &LockedWebSocketWrite, + builders: &[Box], +) -> Result<(Vec, Option, bool), WispError> +where + R: ws::WebSocketRead + Send, +{ + let mut supported_extensions = Vec::new(); + let mut extra_packet = None; + let mut downgraded = true; + + let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.replace(packet.into()); + } + } + + for extension in supported_extensions.iter_mut() { + extension.handle_handshake(read, write).await?; + } + Ok((supported_extensions, extra_packet, downgraded)) +} + /// Server-side multiplexor. /// /// # Example @@ -477,6 +514,8 @@ pub struct ServerMux { pub supported_extension_ids: Vec, close_tx: mpsc::Sender, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, + tx: ws::LockedWebSocketWrite, + fut_exited: Arc, } impl ServerMux { @@ -498,41 +537,29 @@ impl ServerMux { let (close_tx, close_rx) = mpsc::bounded::(256); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let write = ws::LockedWebSocketWrite::new(Box::new(write)); + let fut_exited = Arc::new(AtomicBool::new(false)); write .write_frame(Packet::new_continue(0, buffer_size).into()) .await?; - let mut supported_extensions = Vec::new(); - let mut extra_packet = Vec::with_capacity(1); - let mut downgraded = true; - - if let Some(builders) = extension_builders { - let extensions: Vec<_> = builders - .iter() - .map(|x| x.build_to_extension(Role::Server)) - .collect(); - let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); - write - .write_frame(Packet::new_info(extensions).into()) - .await?; - if let Some(frame) = select! { - x = read.wisp_read_frame(&write).fuse() => Some(x?), - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - downgraded = false; - } else { - extra_packet.push(packet.into()); - } - } - } + let (supported_extensions, extra_packet, downgraded) = + if let Some(builders) = extension_builders { + write + .write_frame( + Packet::new_info( + builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(), + ) + .into(), + ) + .await?; + maybe_wisp_v2(&mut read, &write, builders).await? + } else { + (Vec::new(), None, true) + }; Ok(ServerMuxResult( Self { @@ -540,11 +567,14 @@ impl ServerMux { close_tx: close_tx.clone(), downgraded, supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(), + tx: write.clone(), + fut_exited: fut_exited.clone(), }, MuxInner { tx: write, stream_map: DashMap::new(), buffer_size, + fut_exited } .server_into_future( AppendingWebSocketRead(extra_packet, read), @@ -558,10 +588,16 @@ impl ServerMux { /// Wait for a stream to be created. pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { + if self.fut_exited.load(Ordering::Acquire) { + return None; + } self.muxstream_recv.recv_async().await.ok() } async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + if self.fut_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } self.close_tx .send_async(WsEvent::EndFut(reason)) .await @@ -570,20 +606,27 @@ impl ServerMux { /// Close all streams. /// - /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after - /// this function is called. + /// Also terminates the multiplexor future. pub async fn close(&self) -> Result<(), WispError> { self.close_internal(None).await } /// Close all streams and send an extension incompatibility error to the client. /// - /// Also terminates the multiplexor future. Waiting for a new stream will never succed after - /// this function is called. + /// Also terminates the multiplexor future. pub async fn close_extension_incompat(&self) -> Result<(), WispError> { self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await } + + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.fut_exited.clone(), + } + } } impl Drop for ServerMux { @@ -656,6 +699,8 @@ pub struct ClientMux { /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, stream_tx: mpsc::Sender, + tx: ws::LockedWebSocketWrite, + fut_exited: Arc, } impl ClientMux { @@ -675,44 +720,30 @@ impl ClientMux { { let write = ws::LockedWebSocketWrite::new(Box::new(write)); let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; + let fut_exited = Arc::new(AtomicBool::new(false)); + if first_packet.stream_id != 0 { return Err(WispError::InvalidStreamId); } if let PacketType::Continue(packet) = first_packet.packet_type { - let mut supported_extensions = Vec::new(); - let mut extra_packet = Vec::with_capacity(1); - let mut downgraded = true; - - if let Some(builders) = extension_builders { - let extensions: Vec<_> = builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(); - let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); - if let Some(frame) = select! { - x = read.wisp_read_frame(&write).fuse() => Some(x?), - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - write - .write_frame(Packet::new_info(extensions).into()) - .await?; - downgraded = false; - } else { - extra_packet.push(packet.into()); - } - } - } - - for extension in supported_extensions.iter_mut() { - extension.handle_handshake(&mut read, &write).await?; - } + let (supported_extensions, extra_packet, downgraded) = + if let Some(builders) = extension_builders { + let x = maybe_wisp_v2(&mut read, &write, builders).await?; + write + .write_frame( + Packet::new_info( + builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(), + ) + .into(), + ) + .await?; + x + } else { + (Vec::new(), None, true) + }; let (tx, rx) = mpsc::bounded::(256); Ok(ClientMuxResult( @@ -723,11 +754,14 @@ impl ClientMux { .iter() .map(|x| x.get_id()) .collect(), + tx: write.clone(), + fut_exited: fut_exited.clone(), }, MuxInner { tx: write, stream_map: DashMap::new(), buffer_size: packet.buffer_remaining, + fut_exited } .client_into_future( AppendingWebSocketRead(extra_packet, read), @@ -748,6 +782,9 @@ impl ClientMux { host: String, port: u16, ) -> Result { + if self.fut_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } if stream_type == StreamType::Udp && !self .supported_extension_ids @@ -767,6 +804,9 @@ impl ClientMux { } async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + if self.fut_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } self.stream_tx .send_async(WsEvent::EndFut(reason)) .await @@ -775,20 +815,27 @@ impl ClientMux { /// Close all streams. /// - /// Also terminates the multiplexor future. Creating a stream is UB after calling this - /// function. + /// Also terminates the multiplexor future. pub async fn close(&self) -> Result<(), WispError> { self.close_internal(None).await } /// Close all streams and send an extension incompatibility error to the client. /// - /// Also terminates the multiplexor future. Creating a stream is UB after calling this - /// function. + /// Also terminates the multiplexor future. pub async fn close_extension_incompat(&self) -> Result<(), WispError> { self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await } + + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.fut_exited.clone(), + } + } } impl Drop for ClientMux { @@ -812,7 +859,10 @@ where } /// Require protocol extensions by their ID. - pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(ClientMux, F), WispError> { + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(ClientMux, F), WispError> { let mut unsupported_extensions = Vec::new(); for extension in extensions { if !self.0.supported_extension_ids.contains(extension) { @@ -830,6 +880,7 @@ where /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]).await + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index ec933bd..62ee9f1 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -362,12 +362,12 @@ impl Packet { } } - pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> Bytes { + pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> BytesMut { let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len()); encoded.put_u8(packet_type); encoded.put_u32_le(stream_id); encoded.extend(bytes); - encoded.freeze() + encoded } fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { @@ -396,7 +396,7 @@ impl Packet { if frame.opcode != OpCode::Binary { return Err(WispError::WsFrameInvalidType); } - let mut bytes = frame.payload; + let mut bytes = frame.payload.freeze(); if bytes.remaining() < 1 { return Err(WispError::PacketTooSmall); } @@ -420,22 +420,40 @@ impl Packet { if frame.opcode != OpCode::Binary { return Err(WispError::WsFrameInvalidType); } - let mut bytes = frame.payload; + let mut bytes = frame.payload.freeze(); if bytes.remaining() < 1 { return Err(WispError::PacketTooSmall); } let packet_type = bytes.get_u8(); - if let Some(extension) = extensions - .iter_mut() - .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) - { - extension.handle_packet(bytes, read, write).await?; - Ok(None) - } else if packet_type == 0x05 { - // Server may send a 0x05 in handshake since it's Wisp v2 but we may be Wisp v1 so we need to ignore 0x05 - Ok(None) - } else { - Ok(Some(Self::parse_packet(packet_type, bytes)?)) + match packet_type { + 0x01 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Connect(bytes.try_into()?), + })), + 0x02 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Data(bytes), + })), + 0x03 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Continue(bytes.try_into()?), + })), + 0x04 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Close(bytes.try_into()?), + })), + 0x05 => Ok(None), + packet_type => { + if let Some(extension) = extensions + .iter_mut() + .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) + { + extension.handle_packet(bytes, read, write).await?; + Ok(None) + } else { + Err(WispError::InvalidPacketType) + } + } } } @@ -500,7 +518,7 @@ impl TryFrom for Packet { } } -impl From for Bytes { +impl From for BytesMut { fn from(packet: Packet) -> Self { Packet::raw_encode( packet.packet_type.as_u8(), @@ -519,7 +537,7 @@ impl TryFrom for Packet { if frame.opcode != ws::OpCode::Binary { return Err(Self::Error::WsFrameInvalidType); } - frame.payload.try_into() + frame.payload.freeze().try_into() } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index a7099c9..3918edb 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,4 +1,8 @@ -use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError}; +use crate::{ + sink_unfold, + ws::{Frame, LockedWebSocketWrite}, + CloseReason, Packet, Role, StreamType, WispError, +}; pub use async_io_stream::IoStream; use bytes::Bytes; @@ -20,8 +24,6 @@ use std::{ }; pub(crate) enum WsEvent { - SendPacket(Packet, oneshot::Sender>), - SendBytes(Bytes, oneshot::Sender>), Close(Packet, oneshot::Sender>), CreateStream( StreamType, @@ -39,7 +41,7 @@ pub struct MuxStreamRead { /// Type of the stream. pub stream_type: StreamType, role: Role, - tx: mpsc::Sender, + tx: LockedWebSocketWrite, rx: mpsc::Receiver, is_closed: Arc, is_closed_event: Arc, @@ -60,19 +62,17 @@ impl MuxStreamRead { }; if self.role == Role::Server && self.stream_type == StreamType::Tcp { let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; - if val > self.target_flow_control { - let (tx, rx) = oneshot::channel::>(); + if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { self.tx - .send_async(WsEvent::SendPacket( + .write_frame( Packet::new_continue( self.stream_id, self.flow_control.fetch_add(val, Ordering::AcqRel) + val, - ), - tx, - )) + ) + .into(), + ) .await .ok()?; - rx.await.ok()?.ok()?; self.flow_control_read.store(0, Ordering::Release); } } @@ -93,7 +93,8 @@ pub struct MuxStreamWrite { /// Type of the stream. pub stream_type: StreamType, role: Role, - tx: mpsc::Sender, + mux_tx: mpsc::Sender, + tx: LockedWebSocketWrite, is_closed: Arc, continue_recieved: Arc, flow_control: Arc, @@ -102,24 +103,20 @@ pub struct MuxStreamWrite { impl MuxStreamWrite { /// Write data to the stream. pub async fn write(&self, data: Bytes) -> Result<(), WispError> { - if self.is_closed.load(Ordering::Acquire) { - return Err(WispError::StreamAlreadyClosed); - } if self.role == Role::Client && self.stream_type == StreamType::Tcp && self.flow_control.load(Ordering::Acquire) == 0 { self.continue_recieved.listen().await; } - let (tx, rx) = oneshot::channel::>(); + if self.is_closed.load(Ordering::Acquire) { + return Err(WispError::StreamAlreadyClosed); + } + self.tx - .send_async(WsEvent::SendPacket( - Packet::new_data(self.stream_id, data), - tx, - )) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; + .write_frame(Packet::new_data(self.stream_id, data).into()) + .await?; + if self.role == Role::Client && self.stream_type == StreamType::Tcp { self.flow_control.store( self.flow_control.load(Ordering::Acquire).saturating_sub(1), @@ -143,7 +140,7 @@ impl MuxStreamWrite { pub fn get_close_handle(&self) -> MuxStreamCloser { MuxStreamCloser { stream_id: self.stream_id, - close_channel: self.tx.clone(), + close_channel: self.mux_tx.clone(), is_closed: self.is_closed.clone(), } } @@ -165,7 +162,7 @@ impl MuxStreamWrite { self.is_closed.store(true, Ordering::Release); let (tx, rx) = oneshot::channel::>(); - self.tx + self.mux_tx .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, @@ -199,7 +196,7 @@ impl Drop for MuxStreamWrite { if !self.is_closed.load(Ordering::Acquire) { self.is_closed.store(true, Ordering::Release); let (tx, _) = oneshot::channel(); - let _ = self.tx.send(WsEvent::Close( + let _ = self.mux_tx.send(WsEvent::Close( Packet::new_close(self.stream_id, CloseReason::Unknown), tx, )); @@ -222,7 +219,8 @@ impl MuxStream { role: Role, stream_type: StreamType, rx: mpsc::Receiver, - tx: mpsc::Sender, + mux_tx: mpsc::Sender, + tx: LockedWebSocketWrite, is_closed: Arc, is_closed_event: Arc, flow_control: Arc, @@ -247,6 +245,7 @@ impl MuxStream { stream_id, stream_type, role, + mux_tx, tx, is_closed: is_closed.clone(), flow_control: flow_control.clone(), @@ -339,26 +338,23 @@ impl MuxStreamCloser { pub struct MuxProtocolExtensionStream { /// ID of the stream. pub stream_id: u32, - tx: mpsc::Sender, - is_closed: Arc, + pub(crate) tx: LockedWebSocketWrite, + pub(crate) is_closed: Arc, } impl MuxProtocolExtensionStream { - /// Send a protocol extension packet. + /// Send a protocol extension packet with this stream's ID. pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> { if self.is_closed.load(Ordering::Acquire) { return Err(WispError::StreamAlreadyClosed); } - let (tx, rx) = oneshot::channel::>(); self.tx - .send_async(WsEvent::SendBytes( - Packet::raw_encode(packet_type, self.stream_id, data), - tx, - )) + .write_frame(Frame::binary(Packet::raw_encode( + packet_type, + self.stream_id, + data, + ))) .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; - Ok(()) } } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 258a5d1..06f55ad 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -4,9 +4,11 @@ //! for other WebSocket implementations. //! //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs +use std::sync::Arc; + use crate::WispError; use async_trait::async_trait; -use bytes::Bytes; +use bytes::BytesMut; use futures::lock::Mutex; /// Opcode of the WebSocket frame. @@ -32,12 +34,12 @@ pub struct Frame { /// Opcode of the WebSocket frame. pub opcode: OpCode, /// Payload of the WebSocket frame. - pub payload: Bytes, + pub payload: BytesMut, } impl Frame { /// Create a new text frame. - pub fn text(payload: Bytes) -> Self { + pub fn text(payload: BytesMut) -> Self { Self { finished: true, opcode: OpCode::Text, @@ -46,7 +48,7 @@ impl Frame { } /// Create a new binary frame. - pub fn binary(payload: Bytes) -> Self { + pub fn binary(payload: BytesMut) -> Self { Self { finished: true, opcode: OpCode::Binary, @@ -55,7 +57,7 @@ impl Frame { } /// Create a new close frame. - pub fn close(payload: Bytes) -> Self { + pub fn close(payload: BytesMut) -> Self { Self { finished: true, opcode: OpCode::Close, @@ -82,12 +84,13 @@ pub trait WebSocketWrite { } /// Locked WebSocket. -pub struct LockedWebSocketWrite(Mutex>); +#[derive(Clone)] +pub struct LockedWebSocketWrite(Arc>>); impl LockedWebSocketWrite { /// Create a new locked websocket. pub fn new(ws: Box) -> Self { - Self(Mutex::new(ws)) + Self(Mutex::new(ws).into()) } /// Write a frame to the websocket. @@ -101,7 +104,7 @@ impl LockedWebSocketWrite { } } -pub(crate) struct AppendingWebSocketRead(pub Vec, pub R) +pub(crate) struct AppendingWebSocketRead(pub Option, pub R) where R: WebSocketRead + Send; @@ -111,7 +114,7 @@ where R: WebSocketRead + Send, { async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result { - if let Some(x) = self.0.pop() { + if let Some(x) = self.0.take() { return Ok(x); } return self.1.wisp_read_frame(tx).await;