mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
make server mux more like client mux and fix deadlock
This commit is contained in:
parent
fa2b84d646
commit
9f1561fa76
4 changed files with 101 additions and 77 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1449,7 +1449,6 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async_io_stream",
|
"async_io_stream",
|
||||||
"bytes",
|
"bytes",
|
||||||
"dashmap",
|
|
||||||
"fastwebsockets",
|
"fastwebsockets",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
|
|
@ -20,7 +20,7 @@ use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, W
|
||||||
|
|
||||||
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
|
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Error> {
|
async fn main() -> Result<(), Error> {
|
||||||
let pem = include_bytes!("./pem.pem");
|
let pem = include_bytes!("./pem.pem");
|
||||||
let key = include_bytes!("./key.pem");
|
let key = include_bytes!("./key.pem");
|
||||||
|
@ -117,7 +117,6 @@ async fn handle_mux(
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
event = stream.read() => {
|
event = stream.read() => {
|
||||||
println!("ws rx");
|
|
||||||
match event {
|
match event {
|
||||||
Some(event) => match event {
|
Some(event) => match event {
|
||||||
WsEvent::Send(data) => {
|
WsEvent::Send(data) => {
|
||||||
|
@ -129,10 +128,9 @@ async fn handle_mux(
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
event = tcp_stream_framed.next() => {
|
event = tcp_stream_framed.next() => {
|
||||||
println!("tcp rx");
|
|
||||||
match event.and_then(|x| x.ok()) {
|
match event.and_then(|x| x.ok()) {
|
||||||
Some(event) => stream.write(event.into()).await?,
|
Some(event) => stream.write(event.into()).await?,
|
||||||
None => return Ok(true),
|
None => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -176,10 +174,18 @@ async fn accept_ws(
|
||||||
|
|
||||||
println!("{:?}: connected", addr);
|
println!("{:?}: connected", addr);
|
||||||
|
|
||||||
ServerMux::handle(rx, tx, &mut |packet, stream| async move {
|
let (mut mux, fut) = ServerMux::new(rx, tx);
|
||||||
let close_err = stream.get_close_handle();
|
|
||||||
let close_ok = stream.get_close_handle();
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = fut.await {
|
||||||
|
println!("err in mux: {:?}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
while let Some((packet, stream)) = mux.server_new_stream().await {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
let close_err = stream.get_close_handle();
|
||||||
|
let close_ok = stream.get_close_handle();
|
||||||
let _ = handle_mux(packet, stream)
|
let _ = handle_mux(packet, stream)
|
||||||
.or_else(|err| async move {
|
.or_else(|err| async move {
|
||||||
let _ = close_err.close(0x03).await;
|
let _ = close_err.close(0x03).await;
|
||||||
|
@ -194,9 +200,7 @@ async fn accept_ws(
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
Ok(())
|
}
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
println!("{:?}: disconnected", addr);
|
println!("{:?}: disconnected", addr);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -6,7 +6,6 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async_io_stream = "0.3.3"
|
async_io_stream = "0.3.3"
|
||||||
bytes = "1.5.0"
|
bytes = "1.5.0"
|
||||||
dashmap = "5.5.3"
|
|
||||||
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
|
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
|
|
152
wisp/src/lib.rs
152
wisp/src/lib.rs
|
@ -9,11 +9,13 @@ mod ws_stream_wasm;
|
||||||
pub use crate::packet::*;
|
pub use crate::packet::*;
|
||||||
pub use crate::stream::*;
|
pub use crate::stream::*;
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
|
||||||
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
|
use std::{
|
||||||
use std::sync::{
|
collections::HashMap,
|
||||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
sync::{
|
||||||
Arc,
|
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
|
@ -68,63 +70,45 @@ impl std::fmt::Display for WispError {
|
||||||
|
|
||||||
impl std::error::Error for WispError {}
|
impl std::error::Error for WispError {}
|
||||||
|
|
||||||
pub struct ServerMux<W>
|
struct ServerMuxInner<W>
|
||||||
where
|
where
|
||||||
W: ws::WebSocketWrite,
|
W: ws::WebSocketWrite + Send + 'static,
|
||||||
{
|
{
|
||||||
tx: ws::LockedWebSocketWrite<W>,
|
tx: ws::LockedWebSocketWrite<W>,
|
||||||
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
|
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
|
||||||
close_tx: mpsc::UnboundedSender<MuxEvent>,
|
close_tx: mpsc::UnboundedSender<MuxEvent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
|
||||||
pub fn handle<'a, FR, R>(
|
pub async fn into_future<R>(
|
||||||
read: R,
|
|
||||||
write: W,
|
|
||||||
handler_fn: &'a mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
|
|
||||||
) -> impl Future<Output = Result<(), WispError>> + 'a
|
|
||||||
where
|
|
||||||
FR: std::future::Future<Output = Result<(), WispError>> + 'a,
|
|
||||||
R: ws::WebSocketRead + 'a,
|
|
||||||
W: ws::WebSocketWrite + 'a,
|
|
||||||
{
|
|
||||||
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
|
|
||||||
let write = ws::LockedWebSocketWrite::new(write);
|
|
||||||
let map = Arc::new(DashMap::new());
|
|
||||||
let inner = ServerMux {
|
|
||||||
stream_map: map.clone(),
|
|
||||||
tx: write.clone(),
|
|
||||||
close_tx: tx,
|
|
||||||
};
|
|
||||||
inner.into_future(read, rx, handler_fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn into_future<R, FR>(
|
|
||||||
self,
|
self,
|
||||||
rx: R,
|
rx: R,
|
||||||
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
||||||
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
|
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
|
||||||
) -> Result<(), WispError>
|
) -> Result<(), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead,
|
||||||
FR: std::future::Future<Output = Result<(), WispError>>,
|
|
||||||
{
|
{
|
||||||
futures::select! {
|
let ret = futures::select! {
|
||||||
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x,
|
x = self.server_close_loop(close_rx, self.stream_map.clone(), self.tx.clone()).fuse() => x,
|
||||||
x = self.server_msg_loop(rx, handler_fn).fuse() => x
|
x = self.server_msg_loop(rx, muxstream_sender).fuse() => x
|
||||||
}
|
};
|
||||||
|
self.stream_map.lock().await.iter().for_each(|x| {
|
||||||
|
let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01)));
|
||||||
|
});
|
||||||
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn server_close_loop(
|
async fn server_close_loop(
|
||||||
&self,
|
&self,
|
||||||
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>,
|
||||||
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
|
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
|
||||||
tx: ws::LockedWebSocketWrite<W>,
|
tx: ws::LockedWebSocketWrite<W>,
|
||||||
) -> Result<(), WispError> {
|
) -> Result<(), WispError> {
|
||||||
while let Some(msg) = close_rx.next().await {
|
while let Some(msg) = close_rx.next().await {
|
||||||
match msg {
|
match msg {
|
||||||
MuxEvent::Close(stream_id, reason, channel) => {
|
MuxEvent::Close(stream_id, reason, channel) => {
|
||||||
if stream_map.clone().remove(&stream_id).is_some() {
|
if stream_map.lock().await.remove(&stream_id).is_some() {
|
||||||
let _ = channel.send(
|
let _ = channel.send(
|
||||||
tx.write_frame(Packet::new_close(stream_id, reason).into())
|
tx.write_frame(Packet::new_close(stream_id, reason).into())
|
||||||
.await,
|
.await,
|
||||||
|
@ -138,14 +122,13 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn server_msg_loop<R, FR>(
|
async fn server_msg_loop<R>(
|
||||||
&self,
|
&self,
|
||||||
mut rx: R,
|
mut rx: R,
|
||||||
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
|
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
|
||||||
) -> Result<(), WispError>
|
) -> Result<(), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead,
|
||||||
FR: std::future::Future<Output = Result<(), WispError>>,
|
|
||||||
{
|
{
|
||||||
self.tx
|
self.tx
|
||||||
.write_frame(Packet::new_continue(0, u32::MAX).into())
|
.write_frame(Packet::new_continue(0, u32::MAX).into())
|
||||||
|
@ -157,21 +140,22 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
||||||
match packet.packet {
|
match packet.packet {
|
||||||
Connect(inner_packet) => {
|
Connect(inner_packet) => {
|
||||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||||
self.stream_map.clone().insert(packet.stream_id, ch_tx);
|
self.stream_map.lock().await.insert(packet.stream_id, ch_tx);
|
||||||
let _ = handler_fn(
|
muxstream_sender
|
||||||
inner_packet,
|
.unbounded_send((
|
||||||
MuxStream::new(
|
inner_packet,
|
||||||
packet.stream_id,
|
MuxStream::new(
|
||||||
ch_rx,
|
packet.stream_id,
|
||||||
self.tx.clone(),
|
ch_rx,
|
||||||
self.close_tx.clone(),
|
self.tx.clone(),
|
||||||
AtomicBool::new(false).into(),
|
self.close_tx.clone(),
|
||||||
),
|
AtomicBool::new(false).into(),
|
||||||
)
|
),
|
||||||
.await;
|
))
|
||||||
|
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||||
}
|
}
|
||||||
Data(data) => {
|
Data(data) => {
|
||||||
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||||
let _ = stream.unbounded_send(WsEvent::Send(data));
|
let _ = stream.unbounded_send(WsEvent::Send(data));
|
||||||
self.tx
|
self.tx
|
||||||
.write_frame(
|
.write_frame(
|
||||||
|
@ -182,24 +166,59 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
||||||
}
|
}
|
||||||
Continue(_) => unreachable!(),
|
Continue(_) => unreachable!(),
|
||||||
Close(inner_packet) => {
|
Close(inner_packet) => {
|
||||||
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||||
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
||||||
self.stream_map.clone().remove(&packet.stream_id);
|
self.stream_map.lock().await.remove(&packet.stream_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
drop(muxstream_sender);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ServerMux<W>
|
||||||
|
where
|
||||||
|
W: ws::WebSocketWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
|
||||||
|
pub fn new<R>(read: R, write: W) -> (Self, impl Future<Output = Result<(), WispError>>)
|
||||||
|
where
|
||||||
|
R: ws::WebSocketRead,
|
||||||
|
{
|
||||||
|
let (close_tx, close_rx) = mpsc::unbounded::<MuxEvent>();
|
||||||
|
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||||
|
let write = ws::LockedWebSocketWrite::new(write);
|
||||||
|
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
(
|
||||||
|
Self { muxstream_recv: rx },
|
||||||
|
ServerMuxInner {
|
||||||
|
tx: write,
|
||||||
|
close_tx,
|
||||||
|
stream_map: map.clone(),
|
||||||
|
}
|
||||||
|
.into_future(read, close_rx, tx),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
|
||||||
|
self.muxstream_recv.next().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ClientMuxInner<W>
|
pub struct ClientMuxInner<W>
|
||||||
where
|
where
|
||||||
W: ws::WebSocketWrite,
|
W: ws::WebSocketWrite,
|
||||||
{
|
{
|
||||||
tx: ws::LockedWebSocketWrite<W>,
|
tx: ws::LockedWebSocketWrite<W>,
|
||||||
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
|
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
||||||
|
@ -211,7 +230,10 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead,
|
||||||
{
|
{
|
||||||
futures::try_join!(self.client_bg_loop(close_rx), self.client_loop(rx)).map(|_| ())
|
futures::select! {
|
||||||
|
x = self.client_bg_loop(close_rx).fuse() => x,
|
||||||
|
x = self.client_loop(rx).fuse() => x
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_bg_loop(
|
async fn client_bg_loop(
|
||||||
|
@ -221,7 +243,7 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
||||||
while let Some(msg) = close_rx.next().await {
|
while let Some(msg) = close_rx.next().await {
|
||||||
match msg {
|
match msg {
|
||||||
MuxEvent::Close(stream_id, reason, channel) => {
|
MuxEvent::Close(stream_id, reason, channel) => {
|
||||||
if self.stream_map.clone().remove(&stream_id).is_some() {
|
if self.stream_map.lock().await.remove(&stream_id).is_some() {
|
||||||
let _ = channel.send(
|
let _ = channel.send(
|
||||||
self.tx
|
self.tx
|
||||||
.write_frame(Packet::new_close(stream_id, reason).into())
|
.write_frame(Packet::new_close(stream_id, reason).into())
|
||||||
|
@ -246,15 +268,15 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
|
||||||
match packet.packet {
|
match packet.packet {
|
||||||
Connect(_) => unreachable!(),
|
Connect(_) => unreachable!(),
|
||||||
Data(data) => {
|
Data(data) => {
|
||||||
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||||
let _ = stream.unbounded_send(WsEvent::Send(data));
|
let _ = stream.unbounded_send(WsEvent::Send(data));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Continue(_) => {}
|
Continue(_) => {}
|
||||||
Close(inner_packet) => {
|
Close(inner_packet) => {
|
||||||
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||||
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
||||||
self.stream_map.clone().remove(&packet.stream_id);
|
self.stream_map.lock().await.remove(&packet.stream_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -269,7 +291,7 @@ where
|
||||||
W: ws::WebSocketWrite,
|
W: ws::WebSocketWrite,
|
||||||
{
|
{
|
||||||
tx: ws::LockedWebSocketWrite<W>,
|
tx: ws::LockedWebSocketWrite<W>,
|
||||||
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
|
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>,
|
||||||
next_free_stream_id: AtomicU32,
|
next_free_stream_id: AtomicU32,
|
||||||
close_tx: mpsc::UnboundedSender<MuxEvent>,
|
close_tx: mpsc::UnboundedSender<MuxEvent>,
|
||||||
}
|
}
|
||||||
|
@ -280,7 +302,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
||||||
R: ws::WebSocketRead,
|
R: ws::WebSocketRead,
|
||||||
{
|
{
|
||||||
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
|
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
|
||||||
let map = Arc::new(DashMap::new());
|
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||||
let write = ws::LockedWebSocketWrite::new(write);
|
let write = ws::LockedWebSocketWrite::new(write);
|
||||||
(
|
(
|
||||||
Self {
|
Self {
|
||||||
|
@ -314,7 +336,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
|
||||||
.ok_or(WispError::MaxStreamCountReached)?,
|
.ok_or(WispError::MaxStreamCountReached)?,
|
||||||
Ordering::Release,
|
Ordering::Release,
|
||||||
);
|
);
|
||||||
self.stream_map.clone().insert(stream_id, ch_tx);
|
self.stream_map.lock().await.insert(stream_id, ch_tx);
|
||||||
Ok(MuxStream::new(
|
Ok(MuxStream::new(
|
||||||
stream_id,
|
stream_id,
|
||||||
ch_rx,
|
ch_rx,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue