mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
remove appendingwebsocketread, specialcase data/close
This commit is contained in:
parent
0d12bff084
commit
14f38b28b8
5 changed files with 56 additions and 59 deletions
|
@ -12,7 +12,7 @@ use futures::channel::oneshot;
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
||||||
mux::send_info_packet,
|
mux::send_info_packet,
|
||||||
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
||||||
WispError,
|
WispError,
|
||||||
};
|
};
|
||||||
|
@ -110,10 +110,11 @@ impl ClientMux {
|
||||||
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||||
|
|
||||||
let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
|
let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
|
||||||
let (extensions, frame) = handshake_result.kind.into_parts();
|
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
||||||
|
|
||||||
let mux_inner = MuxInner::new_client(
|
let mux_inner = MuxInner::new_client(
|
||||||
AppendingWebSocketRead(frame, rx),
|
rx,
|
||||||
|
extra_packet,
|
||||||
tx.clone(),
|
tx.clone(),
|
||||||
extensions.clone(),
|
extensions.clone(),
|
||||||
buffer_size,
|
buffer_size,
|
||||||
|
|
|
@ -46,6 +46,9 @@ struct MuxMapValue {
|
||||||
pub struct MuxInner<R: WebSocketRead + Send> {
|
pub struct MuxInner<R: WebSocketRead + Send> {
|
||||||
// gets taken by the mux task
|
// gets taken by the mux task
|
||||||
rx: Option<R>,
|
rx: Option<R>,
|
||||||
|
// gets taken by the mux task
|
||||||
|
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||||
|
|
||||||
tx: LockedWebSocketWrite,
|
tx: LockedWebSocketWrite,
|
||||||
extensions: Vec<AnyProtocolExtension>,
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
tcp_extensions: Vec<u8>,
|
tcp_extensions: Vec<u8>,
|
||||||
|
@ -82,6 +85,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
|
|
||||||
pub fn new_server(
|
pub fn new_server(
|
||||||
rx: R,
|
rx: R,
|
||||||
|
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||||
tx: LockedWebSocketWrite,
|
tx: LockedWebSocketWrite,
|
||||||
extensions: Vec<AnyProtocolExtension>,
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
|
@ -98,6 +102,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
MuxInnerResult {
|
MuxInnerResult {
|
||||||
mux: Self {
|
mux: Self {
|
||||||
rx: Some(rx),
|
rx: Some(rx),
|
||||||
|
maybe_downgrade_packet,
|
||||||
tx,
|
tx,
|
||||||
|
|
||||||
actor_rx: Some(fut_rx),
|
actor_rx: Some(fut_rx),
|
||||||
|
@ -124,6 +129,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
|
|
||||||
pub fn new_client(
|
pub fn new_client(
|
||||||
rx: R,
|
rx: R,
|
||||||
|
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||||
tx: LockedWebSocketWrite,
|
tx: LockedWebSocketWrite,
|
||||||
extensions: Vec<AnyProtocolExtension>,
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
|
@ -136,6 +142,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
MuxInnerResult {
|
MuxInnerResult {
|
||||||
mux: Self {
|
mux: Self {
|
||||||
rx: Some(rx),
|
rx: Some(rx),
|
||||||
|
maybe_downgrade_packet,
|
||||||
tx,
|
tx,
|
||||||
|
|
||||||
actor_rx: Some(fut_rx),
|
actor_rx: Some(fut_rx),
|
||||||
|
@ -265,9 +272,17 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
let mut next_free_stream_id: u32 = 1;
|
let mut next_free_stream_id: u32 = 1;
|
||||||
|
|
||||||
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
|
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||||
|
let maybe_downgrade_packet = self.maybe_downgrade_packet.take();
|
||||||
|
|
||||||
let tx = self.tx.clone();
|
let tx = self.tx.clone();
|
||||||
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
|
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||||
|
|
||||||
|
if let Some(downgrade_packet) = maybe_downgrade_packet {
|
||||||
|
if self.handle_packet(downgrade_packet, None).await? {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut recv_fut = fut_rx.recv_async().fuse();
|
let mut recv_fut = fut_rx.recv_async().fuse();
|
||||||
let mut read_fut = rx.wisp_read_split(&tx).fuse();
|
let mut read_fut = rx.wisp_read_split(&tx).fuse();
|
||||||
while let Some(msg) = select! {
|
while let Some(msg) = select! {
|
||||||
|
@ -342,14 +357,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
}
|
}
|
||||||
WsEvent::WispMessage(packet, optional_frame) => {
|
WsEvent::WispMessage(packet, optional_frame) => {
|
||||||
if let Some(packet) = packet {
|
if let Some(packet) = packet {
|
||||||
let should_break = match self.role {
|
let should_break = self.handle_packet(packet, optional_frame).await?;
|
||||||
Role::Server => {
|
|
||||||
self.server_handle_packet(packet, optional_frame).await?
|
|
||||||
}
|
|
||||||
Role::Client => {
|
|
||||||
self.client_handle_packet(packet, optional_frame).await?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if should_break {
|
if should_break {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -409,18 +417,31 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn server_handle_packet(
|
async fn handle_packet(
|
||||||
&mut self,
|
&mut self,
|
||||||
packet: Packet<'static>,
|
packet: Packet<'static>,
|
||||||
optional_frame: Option<Frame<'static>>,
|
optional_frame: Option<Frame<'static>>,
|
||||||
) -> Result<bool, WispError> {
|
) -> Result<bool, WispError> {
|
||||||
use PacketType::*;
|
use PacketType as P;
|
||||||
match packet.packet_type {
|
match packet.packet_type {
|
||||||
Continue(_) | Info(_) => Err(WispError::InvalidPacketType),
|
P::Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
|
||||||
Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
|
P::Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
|
||||||
Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
|
|
||||||
|
|
||||||
Connect(inner_packet) => {
|
_ => match self.role {
|
||||||
|
Role::Server => self.server_handle_packet(packet, optional_frame).await,
|
||||||
|
Role::Client => self.client_handle_packet(packet, optional_frame).await,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn server_handle_packet(
|
||||||
|
&mut self,
|
||||||
|
packet: Packet<'static>,
|
||||||
|
_optional_frame: Option<Frame<'static>>,
|
||||||
|
) -> Result<bool, WispError> {
|
||||||
|
use PacketType as P;
|
||||||
|
match packet.packet_type {
|
||||||
|
P::Connect(inner_packet) => {
|
||||||
let (map_value, stream) = self
|
let (map_value, stream) = self
|
||||||
.create_new_stream(packet.stream_id, inner_packet.stream_type)
|
.create_new_stream(packet.stream_id, inner_packet.stream_type)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -431,21 +452,21 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
self.stream_map.insert(packet.stream_id, map_value);
|
self.stream_map.insert(packet.stream_id, map_value);
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Continue | Info => invalid packet type
|
||||||
|
// Data | Close => specialcased
|
||||||
|
_ => Err(WispError::InvalidPacketType),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_handle_packet(
|
async fn client_handle_packet(
|
||||||
&mut self,
|
&mut self,
|
||||||
packet: Packet<'static>,
|
packet: Packet<'static>,
|
||||||
optional_frame: Option<Frame<'static>>,
|
_optional_frame: Option<Frame<'static>>,
|
||||||
) -> Result<bool, WispError> {
|
) -> Result<bool, WispError> {
|
||||||
use PacketType::*;
|
use PacketType as P;
|
||||||
match packet.packet_type {
|
match packet.packet_type {
|
||||||
Connect(_) | Info(_) => Err(WispError::InvalidPacketType),
|
P::Continue(inner_packet) => {
|
||||||
Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
|
|
||||||
Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
|
|
||||||
|
|
||||||
Continue(inner_packet) => {
|
|
||||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
if stream.stream_type == StreamType::Tcp {
|
if stream.stream_type == StreamType::Tcp {
|
||||||
stream
|
stream
|
||||||
|
@ -456,6 +477,10 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
}
|
}
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Connect | Info => invalid packet type
|
||||||
|
// Data | Close => specialcased
|
||||||
|
_ => Err(WispError::InvalidPacketType),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ pub use server::ServerMux;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||||
ws::{Frame, LockedWebSocketWrite},
|
ws::LockedWebSocketWrite,
|
||||||
CloseReason, Packet, PacketType, Role, WispError,
|
CloseReason, Packet, PacketType, Role, WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -22,12 +22,12 @@ enum WispHandshakeResultKind {
|
||||||
extensions: Vec<AnyProtocolExtension>,
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
},
|
},
|
||||||
V1 {
|
V1 {
|
||||||
frame: Option<Frame<'static>>,
|
frame: Option<Packet<'static>>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WispHandshakeResultKind {
|
impl WispHandshakeResultKind {
|
||||||
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Frame<'static>>) {
|
pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Packet<'static>>) {
|
||||||
match self {
|
match self {
|
||||||
Self::V2 { extensions } => (extensions, None),
|
Self::V2 { extensions } => (extensions, None),
|
||||||
Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),
|
Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),
|
||||||
|
|
|
@ -11,7 +11,7 @@ use futures::channel::oneshot;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::AnyProtocolExtension,
|
extensions::AnyProtocolExtension,
|
||||||
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
ws::{LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
|
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
|
||||||
WispError,
|
WispError,
|
||||||
};
|
};
|
||||||
|
@ -109,7 +109,8 @@ impl ServerMux {
|
||||||
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
||||||
|
|
||||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
rx,
|
||||||
|
extra_packet,
|
||||||
tx.clone(),
|
tx.clone(),
|
||||||
extensions.clone(),
|
extensions.clone(),
|
||||||
buffer_size,
|
buffer_size,
|
||||||
|
|
|
@ -261,33 +261,3 @@ impl LockedWebSocketWrite {
|
||||||
self.0.lock().await.wisp_close().await
|
self.0.lock().await.wisp_close().await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
|
|
||||||
where
|
|
||||||
R: WebSocketRead + Send;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<R> WebSocketRead for AppendingWebSocketRead<R>
|
|
||||||
where
|
|
||||||
R: WebSocketRead + Send,
|
|
||||||
{
|
|
||||||
async fn wisp_read_frame(
|
|
||||||
&mut self,
|
|
||||||
tx: &LockedWebSocketWrite,
|
|
||||||
) -> Result<Frame<'static>, WispError> {
|
|
||||||
if let Some(x) = self.0.take() {
|
|
||||||
return Ok(x);
|
|
||||||
}
|
|
||||||
self.1.wisp_read_frame(tx).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn wisp_read_split(
|
|
||||||
&mut self,
|
|
||||||
tx: &LockedWebSocketWrite,
|
|
||||||
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
|
|
||||||
if let Some(x) = self.0.take() {
|
|
||||||
return Ok((x, None));
|
|
||||||
}
|
|
||||||
self.1.wisp_read_split(tx).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue