mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
fix continue packets issue, remove requirement for Send on the websocket
This commit is contained in:
parent
bed942eb75
commit
ce86e7b095
19 changed files with 872 additions and 235 deletions
199
wisp/src/lib.rs
199
wisp/src/lib.rs
|
@ -15,6 +15,7 @@ pub mod ws;
|
|||
pub use crate::packet::*;
|
||||
pub use crate::stream::*;
|
||||
|
||||
use bytes::Bytes;
|
||||
use event_listener::Event;
|
||||
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
|
||||
use std::{
|
||||
|
@ -95,11 +96,11 @@ impl std::fmt::Display for WispError {
|
|||
StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
|
||||
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
|
||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
|
||||
Other(err) => write!(f, "Other error: {:?}", err),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Other(err) => write!(f, "Other error: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,27 +108,28 @@ impl std::fmt::Display for WispError {
|
|||
impl std::error::Error for WispError {}
|
||||
|
||||
struct MuxMapValue {
|
||||
stream: mpsc::UnboundedSender<MuxEvent>,
|
||||
stream: mpsc::UnboundedSender<Bytes>,
|
||||
stream_type: StreamType,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
flow_control_event: Arc<Event>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
struct ServerMuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
||||
impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
||||
pub async fn into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
|
@ -137,10 +139,10 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
x = self.server_bg_loop(close_rx).fuse() => x,
|
||||
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
|
||||
};
|
||||
self.stream_map.lock().await.drain().for_each(|x| {
|
||||
let _ =
|
||||
x.1.stream
|
||||
.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown)));
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
ret
|
||||
}
|
||||
|
@ -151,13 +153,26 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::Close(stream_id, reason, channel) => {
|
||||
if self.stream_map.lock().await.remove(&stream_id).is_some() {
|
||||
let _ = channel.send(
|
||||
self.tx
|
||||
.write_frame(Packet::new_close(stream_id, reason).into())
|
||||
.await,
|
||||
);
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self
|
||||
.stream_map
|
||||
.lock()
|
||||
.await
|
||||
.get(&packet.stream_id)
|
||||
.is_some()
|
||||
{
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
|
@ -171,12 +186,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
async fn server_msg_loop<R>(
|
||||
&self,
|
||||
mut rx: R,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
// will send continues once flow_control is at 10% of max
|
||||
let target_buffer_size = buffer_size * 90 / 100;
|
||||
self.tx
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
@ -195,6 +212,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
let stream_type = inner_packet.stream_type;
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.stream_map.lock().await.insert(
|
||||
packet.stream_id,
|
||||
|
@ -203,6 +221,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
},
|
||||
);
|
||||
muxstream_sender
|
||||
|
@ -213,18 +232,18 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
Role::Server,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
self.tx.clone(),
|
||||
self.close_tx.clone(),
|
||||
AtomicBool::new(false).into(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
target_buffer_size,
|
||||
),
|
||||
))
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream.flow_control.store(
|
||||
stream
|
||||
|
@ -237,11 +256,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
}
|
||||
}
|
||||
Continue(_) => unreachable!(),
|
||||
Close(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
|
||||
Close(_) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
}
|
||||
self.stream_map.lock().await.remove(&packet.stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -267,18 +289,15 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
|||
/// });
|
||||
/// }
|
||||
/// ```
|
||||
pub struct ServerMux<W>
|
||||
where
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
pub struct ServerMux {
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
||||
impl ServerMux {
|
||||
/// Create a new server-side multiplexor.
|
||||
pub fn new<R>(
|
||||
pub fn new<R, W: ws::WebSocketWrite>(
|
||||
read: R,
|
||||
write: W,
|
||||
buffer_size: u32,
|
||||
|
@ -287,7 +306,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
|||
R: ws::WebSocketRead,
|
||||
{
|
||||
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
(
|
||||
|
@ -306,7 +325,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
|||
}
|
||||
|
||||
/// Wait for a stream to be created.
|
||||
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
|
||||
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
|
||||
self.muxstream_recv.next().await
|
||||
}
|
||||
|
||||
|
@ -314,11 +333,11 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
|||
///
|
||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
||||
/// this function is called.
|
||||
pub async fn close(&self, reason: CloseReason) {
|
||||
self.stream_map.lock().await.drain().for_each(|x| {
|
||||
let _ =
|
||||
x.1.stream
|
||||
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
|
||||
pub async fn close(&self) {
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
|
||||
}
|
||||
|
@ -332,7 +351,7 @@ where
|
|||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
||||
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
|
||||
pub(crate) async fn into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
|
@ -341,10 +360,16 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
|||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
futures::select! {
|
||||
let ret = futures::select! {
|
||||
x = self.client_bg_loop(close_rx).fuse() => x,
|
||||
x = self.client_loop(rx).fuse() => x
|
||||
}
|
||||
};
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
ret
|
||||
}
|
||||
|
||||
async fn client_bg_loop(
|
||||
|
@ -353,13 +378,26 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
|||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::Close(stream_id, reason, channel) => {
|
||||
if self.stream_map.lock().await.remove(&stream_id).is_some() {
|
||||
let _ = channel.send(
|
||||
self.tx
|
||||
.write_frame(Packet::new_close(stream_id, reason).into())
|
||||
.await,
|
||||
);
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self
|
||||
.stream_map
|
||||
.lock()
|
||||
.await
|
||||
.get(&packet.stream_id)
|
||||
.is_some()
|
||||
{
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
|
@ -386,7 +424,7 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
|||
Connect(_) => unreachable!(),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
|
@ -399,11 +437,14 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
|||
}
|
||||
}
|
||||
}
|
||||
Close(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
|
||||
Close(_) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
}
|
||||
self.stream_map.lock().await.remove(&packet.stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -433,9 +474,10 @@ where
|
|||
next_free_stream_id: AtomicU32,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
buf_size: u32,
|
||||
target_buf_size: u32,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
||||
impl<W: ws::WebSocketWrite> ClientMux<W> {
|
||||
/// Create a new client side multiplexor.
|
||||
pub async fn new<R>(
|
||||
mut read: R,
|
||||
|
@ -459,6 +501,8 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
|||
next_free_stream_id: AtomicU32::new(1),
|
||||
close_tx: tx,
|
||||
buf_size: packet.buffer_remaining,
|
||||
// server-only
|
||||
target_buf_size: 0,
|
||||
},
|
||||
ClientMuxInner {
|
||||
tx: write.clone(),
|
||||
|
@ -477,39 +521,46 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
|||
stream_type: StreamType,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> Result<MuxStream<W>, WispError> {
|
||||
) -> Result<MuxStream, WispError> {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
let evt: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
|
||||
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
|
||||
let next_stream_id = stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?;
|
||||
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
|
||||
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.tx
|
||||
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
|
||||
.await?;
|
||||
self.next_free_stream_id.store(
|
||||
stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?,
|
||||
Ordering::Release,
|
||||
);
|
||||
|
||||
self.next_free_stream_id
|
||||
.store(next_stream_id, Ordering::Release);
|
||||
|
||||
self.stream_map.lock().await.insert(
|
||||
stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: evt.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(MuxStream::new(
|
||||
stream_id,
|
||||
Role::Client,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
self.tx.clone(),
|
||||
self.close_tx.clone(),
|
||||
AtomicBool::new(false).into(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
evt,
|
||||
flow_control_event,
|
||||
self.target_buf_size,
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -517,11 +568,11 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
|||
///
|
||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||
/// function.
|
||||
pub async fn close(&self, reason: CloseReason) {
|
||||
self.stream_map.lock().await.drain().for_each(|x| {
|
||||
let _ =
|
||||
x.1.stream
|
||||
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
|
||||
pub async fn close(&self) {
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue