remove appendingwebsocketread, specialcase data/close

This commit is contained in:
Toshit Chawda 2024-11-04 21:33:40 -08:00
parent 0d12bff084
commit 14f38b28b8
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 56 additions and 59 deletions

View file

@ -12,7 +12,7 @@ use futures::channel::oneshot;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
mux::send_info_packet,
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
WispError,
};
@ -110,10 +110,11 @@ impl ClientMux {
let tx = LockedWebSocketWrite::new(Box::new(tx));
let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
let (extensions, frame) = handshake_result.kind.into_parts();
let (extensions, extra_packet) = handshake_result.kind.into_parts();
let mux_inner = MuxInner::new_client(
AppendingWebSocketRead(frame, rx),
rx,
extra_packet,
tx.clone(),
extensions.clone(),
buffer_size,

View file

@ -46,6 +46,9 @@ struct MuxMapValue {
pub struct MuxInner<R: WebSocketRead + Send> {
// gets taken by the mux task
rx: Option<R>,
// gets taken by the mux task
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>,
tcp_extensions: Vec<u8>,
@ -82,6 +85,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
pub fn new_server(
rx: R,
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>,
buffer_size: u32,
@ -98,6 +102,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
MuxInnerResult {
mux: Self {
rx: Some(rx),
maybe_downgrade_packet,
tx,
actor_rx: Some(fut_rx),
@ -124,6 +129,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
pub fn new_client(
rx: R,
maybe_downgrade_packet: Option<Packet<'static>>,
tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>,
buffer_size: u32,
@ -136,6 +142,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
MuxInnerResult {
mux: Self {
rx: Some(rx),
maybe_downgrade_packet,
tx,
actor_rx: Some(fut_rx),
@ -265,9 +272,17 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
let mut next_free_stream_id: u32 = 1;
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let maybe_downgrade_packet = self.maybe_downgrade_packet.take();
let tx = self.tx.clone();
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
if let Some(downgrade_packet) = maybe_downgrade_packet {
if self.handle_packet(downgrade_packet, None).await? {
return Ok(());
}
}
let mut recv_fut = fut_rx.recv_async().fuse();
let mut read_fut = rx.wisp_read_split(&tx).fuse();
while let Some(msg) = select! {
@ -342,14 +357,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
WsEvent::WispMessage(packet, optional_frame) => {
if let Some(packet) = packet {
let should_break = match self.role {
Role::Server => {
self.server_handle_packet(packet, optional_frame).await?
}
Role::Client => {
self.client_handle_packet(packet, optional_frame).await?
}
};
let should_break = self.handle_packet(packet, optional_frame).await?;
if should_break {
break;
}
@ -409,18 +417,31 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
Ok(false)
}
async fn server_handle_packet(
async fn handle_packet(
&mut self,
packet: Packet<'static>,
optional_frame: Option<Frame<'static>>,
) -> Result<bool, WispError> {
use PacketType::*;
use PacketType as P;
match packet.packet_type {
Continue(_) | Info(_) => Err(WispError::InvalidPacketType),
Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
P::Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
P::Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
Connect(inner_packet) => {
_ => match self.role {
Role::Server => self.server_handle_packet(packet, optional_frame).await,
Role::Client => self.client_handle_packet(packet, optional_frame).await,
},
}
}
async fn server_handle_packet(
&mut self,
packet: Packet<'static>,
_optional_frame: Option<Frame<'static>>,
) -> Result<bool, WispError> {
use PacketType as P;
match packet.packet_type {
P::Connect(inner_packet) => {
let (map_value, stream) = self
.create_new_stream(packet.stream_id, inner_packet.stream_type)
.await?;
@ -431,21 +452,21 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
self.stream_map.insert(packet.stream_id, map_value);
Ok(false)
}
// Continue | Info => invalid packet type
// Data | Close => specialcased
_ => Err(WispError::InvalidPacketType),
}
}
async fn client_handle_packet(
&mut self,
packet: Packet<'static>,
optional_frame: Option<Frame<'static>>,
_optional_frame: Option<Frame<'static>>,
) -> Result<bool, WispError> {
use PacketType::*;
use PacketType as P;
match packet.packet_type {
Connect(_) | Info(_) => Err(WispError::InvalidPacketType),
Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
Continue(inner_packet) => {
P::Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
@ -456,6 +477,10 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
}
Ok(false)
}
// Connect | Info => invalid packet type
// Data | Close => specialcased
_ => Err(WispError::InvalidPacketType),
}
}
}

View file

@ -8,7 +8,7 @@ pub use server::ServerMux;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{Frame, LockedWebSocketWrite},
ws::LockedWebSocketWrite,
CloseReason, Packet, PacketType, Role, WispError,
};
@ -22,12 +22,12 @@ enum WispHandshakeResultKind {
extensions: Vec<AnyProtocolExtension>,
},
V1 {
frame: Option<Frame<'static>>,
frame: Option<Packet<'static>>,
},
}
impl WispHandshakeResultKind {
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Frame<'static>>) {
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Packet<'static>>) {
match self {
Self::V2 { extensions } => (extensions, None),
Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),

View file

@ -11,7 +11,7 @@ use futures::channel::oneshot;
use crate::{
extensions::AnyProtocolExtension,
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
WispError,
};
@ -109,7 +109,8 @@ impl ServerMux {
let (extensions, extra_packet) = handshake_result.kind.into_parts();
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
rx,
extra_packet,
tx.clone(),
extensions.clone(),
buffer_size,

View file

@ -261,33 +261,3 @@ impl LockedWebSocketWrite {
self.0.lock().await.wisp_close().await
}
}
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
where
R: WebSocketRead + Send;
#[async_trait]
impl<R> WebSocketRead for AppendingWebSocketRead<R>
where
R: WebSocketRead + Send,
{
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
if let Some(x) = self.0.take() {
return Ok(x);
}
self.1.wisp_read_frame(tx).await
}
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
if let Some(x) = self.0.take() {
return Ok((x, None));
}
self.1.wisp_read_split(tx).await
}
}