mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
use blazingly fast flume channels 🚀
This commit is contained in:
parent
5af56fe582
commit
5e741d3808
11 changed files with 225 additions and 135 deletions
25
Cargo.lock
generated
25
Cargo.lock
generated
|
@ -861,6 +861,18 @@ dependencies = [
|
|||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"nanorand",
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -1487,6 +1499,15 @@ dependencies = [
|
|||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nanorand"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
|
@ -2273,6 +2294,9 @@ name = "spin"
|
|||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
|
@ -3203,6 +3227,7 @@ dependencies = [
|
|||
"dashmap",
|
||||
"event-listener",
|
||||
"fastwebsockets 0.7.1",
|
||||
"flume",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
|
|
|
@ -238,9 +238,9 @@ onmessage = async (msg) => {
|
|||
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
|
||||
|
||||
} else {
|
||||
let resp = await epoxy_client.fetch("https://httpbin.org/get");
|
||||
let resp = await epoxy_client.fetch("https://www.example.com/");
|
||||
console.log(resp, Object.fromEntries(resp.headers));
|
||||
plog(await resp.json());
|
||||
log(await resp.text());
|
||||
}
|
||||
log("done");
|
||||
};
|
||||
|
|
|
@ -200,13 +200,10 @@ pub async fn make_mux(
|
|||
),
|
||||
WispError,
|
||||
> {
|
||||
let (wtx, wrx) = WebSocketWrapper::connect(url, vec![])
|
||||
.await
|
||||
.map_err(|_| WispError::WsImplSocketClosed)?;
|
||||
let (wtx, wrx) =
|
||||
WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?;
|
||||
wtx.wait_for_open().await;
|
||||
let mux = ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?;
|
||||
|
||||
Ok(mux)
|
||||
ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await
|
||||
}
|
||||
|
||||
pub fn spawn_mux_fut(
|
||||
|
@ -215,6 +212,7 @@ pub fn spawn_mux_fut(
|
|||
url: String,
|
||||
) {
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
debug!("epoxy: mux future started");
|
||||
if let Err(e) = fut.await {
|
||||
log!("epoxy: error in mux future, restarting: {:?}", e);
|
||||
while let Err(e) = replace_mux(mux.clone(), &url).await {
|
||||
|
@ -229,7 +227,7 @@ pub fn spawn_mux_fut(
|
|||
pub async fn replace_mux(mux: Arc<RwLock<ClientMux>>, url: &str) -> Result<(), WispError> {
|
||||
let (mux_replace, fut) = make_mux(url).await?;
|
||||
let mut mux_write = mux.write().await;
|
||||
mux_write.close().await?;
|
||||
let _ = mux_write.close().await;
|
||||
*mux_write = mux_replace;
|
||||
drop(mux_write);
|
||||
spawn_mux_fut(mux, fut, url.into());
|
||||
|
|
|
@ -123,6 +123,7 @@ impl tower_service::Service<hyper::Uri> for TlsWispService {
|
|||
let stream = service.call(uri_parsed).await?.into_inner();
|
||||
if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? {
|
||||
let connector = TlsConnector::from(rustls_config);
|
||||
log!("got stream");
|
||||
Ok(TokioIo::new(Either::Left(
|
||||
connector
|
||||
.connect(
|
||||
|
@ -143,6 +144,7 @@ impl tower_service::Service<hyper::Uri> for TlsWispService {
|
|||
pub enum WebSocketError {
|
||||
Unknown,
|
||||
SendFailed,
|
||||
CloseFailed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebSocketError {
|
||||
|
@ -151,6 +153,7 @@ impl std::fmt::Display for WebSocketError {
|
|||
match self {
|
||||
Unknown => write!(f, "Unknown error"),
|
||||
SendFailed => write!(f, "Send failed"),
|
||||
CloseFailed => write!(f, "Close failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -213,7 +216,7 @@ impl WebSocketRead for WebSocketReader {
|
|||
}
|
||||
|
||||
impl WebSocketWrapper {
|
||||
pub async fn connect(
|
||||
pub fn connect(
|
||||
url: &str,
|
||||
protocols: Vec<String>,
|
||||
) -> Result<(Self, WebSocketReader), JsValue> {
|
||||
|
@ -327,6 +330,12 @@ impl WebSocketWrite for WebSocketWrapper {
|
|||
_ => Err(WispError::WsImplNotSupported),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
self.inner
|
||||
.close()
|
||||
.map_err(|_| WebSocketError::CloseFailed.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WebSocketWrapper {
|
||||
|
|
|
@ -12,9 +12,13 @@ use hyper::{
|
|||
body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket};
|
||||
#[cfg(unix)]
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio::{
|
||||
io::{copy_bidirectional, split, BufReader, BufWriter},
|
||||
net::{lookup_host, TcpListener, TcpStream, UdpSocket},
|
||||
select,
|
||||
};
|
||||
use tokio_util::codec::{BytesCodec, Framed};
|
||||
#[cfg(unix)]
|
||||
use tokio_util::either::Either;
|
||||
|
@ -22,9 +26,10 @@ use tokio_util::either::Either;
|
|||
use wisp_mux::{
|
||||
extensions::{
|
||||
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||
udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder,
|
||||
udp::UdpProtocolExtensionBuilder,
|
||||
ProtocolExtensionBuilder,
|
||||
},
|
||||
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
|
||||
CloseReason, ConnectPacket, IoStream, MuxStream, MuxStreamIo, ServerMux, StreamType, WispError,
|
||||
};
|
||||
|
||||
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
|
||||
|
@ -182,7 +187,10 @@ async fn main() -> Result<(), Error> {
|
|||
block_local: opt.block_local,
|
||||
block_non_http: opt.block_non_http,
|
||||
block_udp: opt.block_udp,
|
||||
auth: Arc::new(vec![Box::new(UdpProtocolExtensionBuilder()), Box::new(pw_ext)]),
|
||||
auth: Arc::new(vec![
|
||||
Box::new(UdpProtocolExtensionBuilder()),
|
||||
Box::new(pw_ext),
|
||||
]),
|
||||
enforce_auth,
|
||||
};
|
||||
|
||||
|
@ -257,7 +265,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
|
|||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
let mut mux_stream = stream.into_io().into_asyncrw();
|
||||
tokio::io::copy_bidirectional(&mut tcp_stream, &mut mux_stream)
|
||||
copy_bidirectional(&mut mux_stream, &mut tcp_stream)
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
|
@ -312,13 +320,7 @@ async fn accept_ws(
|
|||
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
|
||||
// size is set to 128
|
||||
let (mut mux, fut) = if mux_options.enforce_auth {
|
||||
let (mut mux, fut) = ServerMux::new(
|
||||
rx,
|
||||
tx,
|
||||
128,
|
||||
Some(mux_options.auth.as_slice()),
|
||||
)
|
||||
.await?;
|
||||
let (mut mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
|
||||
if !mux
|
||||
.supported_extension_ids
|
||||
.iter()
|
||||
|
@ -333,7 +335,13 @@ async fn accept_ws(
|
|||
}
|
||||
(mux, fut)
|
||||
} else {
|
||||
ServerMux::new(rx, tx, 128, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?
|
||||
ServerMux::new(
|
||||
rx,
|
||||
tx,
|
||||
128,
|
||||
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
println!(
|
||||
|
@ -388,10 +396,9 @@ async fn accept_ws(
|
|||
})
|
||||
.and_then(|should_send| async move {
|
||||
if should_send {
|
||||
close_ok.close(CloseReason::Voluntary).await
|
||||
} else {
|
||||
Ok(())
|
||||
let _ = close_ok.close(CloseReason::Voluntary).await;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await;
|
||||
});
|
||||
|
|
|
@ -253,7 +253,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
avg.get_average() * opts.packet_size,
|
||||
);
|
||||
if is_term {
|
||||
print!("\x1b[2K{}\r", stat);
|
||||
println!("\x1b[1A\x1b[2K{}\r", stat);
|
||||
} else {
|
||||
println!("{}", stat);
|
||||
}
|
||||
|
@ -284,6 +284,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
|
||||
let out = select_all(threads.into_iter()).await;
|
||||
|
||||
let duration_since = Instant::now().duration_since(start_time);
|
||||
|
||||
if let Err(err) = out.0? {
|
||||
println!("\n\nerr: {:?}", err);
|
||||
exit(1);
|
||||
|
@ -291,10 +293,10 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
|
||||
out.2.into_iter().for_each(|x| x.abort());
|
||||
|
||||
let duration_since = Instant::now().duration_since(start_time);
|
||||
mux.close().await?;
|
||||
|
||||
println!(
|
||||
"\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
|
||||
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
|
||||
cnt.get(),
|
||||
opts.packet_size,
|
||||
cnt.get() * opts.packet_size,
|
||||
|
|
|
@ -15,6 +15,7 @@ bytes = "1.5.0"
|
|||
dashmap = { version = "5.5.3", features = ["inline"] }
|
||||
event-listener = "5.0.0"
|
||||
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
|
||||
flume = "0.11.0"
|
||||
futures = "0.3.30"
|
||||
futures-timer = "3.0.3"
|
||||
futures-util = "0.3.30"
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::ops::Deref;
|
|||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use fastwebsockets::{
|
||||
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
|
@ -77,4 +77,8 @@ impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<
|
|||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
|
||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||
self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
|
193
wisp/src/lib.rs
193
wisp/src/lib.rs
|
@ -1,4 +1,4 @@
|
|||
#![deny(missing_docs)]
|
||||
#![deny(missing_docs, warnings)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
//! A library for easily creating [Wisp] clients and servers.
|
||||
//!
|
||||
|
@ -19,9 +19,8 @@ use bytes::Bytes;
|
|||
use dashmap::DashMap;
|
||||
use event_listener::Event;
|
||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt
|
||||
};
|
||||
use flume as mpsc;
|
||||
use futures::{channel::oneshot, select, Future, FutureExt};
|
||||
use futures_timer::Delay;
|
||||
use std::{
|
||||
sync::{
|
||||
|
@ -151,11 +150,12 @@ impl std::fmt::Display for WispError {
|
|||
impl std::error::Error for WispError {}
|
||||
|
||||
struct MuxMapValue {
|
||||
stream: Mutex<mpsc::Sender<Bytes>>,
|
||||
stream: mpsc::Sender<Bytes>,
|
||||
stream_type: StreamType,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
flow_control_event: Arc<Event>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
is_closed_event: Arc<Event>,
|
||||
}
|
||||
|
||||
struct MuxInner {
|
||||
|
@ -170,7 +170,7 @@ impl MuxInner {
|
|||
rx: R,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
|
@ -210,20 +210,60 @@ impl MuxInner {
|
|||
};
|
||||
for x in self.stream_map.iter_mut() {
|
||||
x.is_closed.store(true, Ordering::Release);
|
||||
x.stream.lock().await.disconnect();
|
||||
x.stream.lock().await.close_channel();
|
||||
x.is_closed_event.notify(usize::MAX);
|
||||
}
|
||||
self.stream_map.clear();
|
||||
let _ = self.tx.close().await;
|
||||
ret
|
||||
}
|
||||
|
||||
async fn create_new_stream(
|
||||
&self,
|
||||
stream_id: u32,
|
||||
stream_type: StreamType,
|
||||
role: Role,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
target_buffer_size: u32,
|
||||
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
||||
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
||||
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
let is_closed_event: Arc<Event> = Event::new().into();
|
||||
|
||||
Ok((
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
is_closed_event: is_closed_event.clone(),
|
||||
},
|
||||
MuxStream::new(
|
||||
stream_id,
|
||||
role,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
stream_tx.clone(),
|
||||
is_closed,
|
||||
is_closed_event,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
target_buffer_size,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
async fn stream_loop(
|
||||
&self,
|
||||
mut stream_rx: mpsc::Receiver<WsEvent>,
|
||||
stream_rx: mpsc::Receiver<WsEvent>,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
) {
|
||||
let mut next_free_stream_id: u32 = 1;
|
||||
while let Some(msg) = stream_rx.next().await {
|
||||
while let Ok(msg) = stream_rx.recv_async().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self.stream_map.get(&packet.stream_id).is_some() {
|
||||
|
@ -234,16 +274,20 @@ impl MuxInner {
|
|||
}
|
||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||
let ret: Result<MuxStream, WispError> = async {
|
||||
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
|
||||
let stream_id = next_free_stream_id;
|
||||
let next_stream_id = next_free_stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?;
|
||||
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
let (map_value, stream) = self
|
||||
.create_new_stream(
|
||||
stream_id,
|
||||
stream_type,
|
||||
Role::Client,
|
||||
stream_tx.clone(),
|
||||
0,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.tx
|
||||
.write_frame(
|
||||
|
@ -251,39 +295,19 @@ impl MuxInner {
|
|||
)
|
||||
.await?;
|
||||
|
||||
self.stream_map.insert(stream_id, map_value);
|
||||
|
||||
next_free_stream_id = next_stream_id;
|
||||
|
||||
self.stream_map.insert(
|
||||
stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx.into(),
|
||||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(MuxStream::new(
|
||||
stream_id,
|
||||
Role::Client,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
stream_tx.clone(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
0,
|
||||
))
|
||||
Ok(stream)
|
||||
}
|
||||
.await;
|
||||
let _ = channel.send(ret);
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.stream.lock().await.disconnect();
|
||||
stream.stream.lock().await.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
drop(stream.stream)
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
|
@ -305,8 +329,8 @@ impl MuxInner {
|
|||
&self,
|
||||
mut rx: R,
|
||||
mut extensions: Vec<AnyProtocolExtension>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
|
@ -325,42 +349,24 @@ impl MuxInner {
|
|||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Connect(inner_packet) => {
|
||||
let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize);
|
||||
let stream_type = inner_packet.stream_type;
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.stream_map.insert(
|
||||
packet.stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx.into(),
|
||||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
},
|
||||
);
|
||||
muxstream_sender
|
||||
.unbounded_send((
|
||||
inner_packet,
|
||||
MuxStream::new(
|
||||
let (map_value, stream) = self
|
||||
.create_new_stream(
|
||||
packet.stream_id,
|
||||
inner_packet.stream_type,
|
||||
Role::Server,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
close_tx.clone(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
stream_tx.clone(),
|
||||
target_buffer_size,
|
||||
),
|
||||
))
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
)
|
||||
.await?;
|
||||
muxstream_sender
|
||||
.send_async((inner_packet, stream))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
self.stream_map.insert(packet.stream_id, map_value);
|
||||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.lock().await.send(data).await;
|
||||
let _ = stream.stream.send_async(data).await;
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream.flow_control.store(
|
||||
stream
|
||||
|
@ -379,8 +385,8 @@ impl MuxInner {
|
|||
}
|
||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.lock().await.disconnect();
|
||||
stream.stream.lock().await.close_channel();
|
||||
stream.is_closed_event.notify(usize::MAX);
|
||||
drop(stream.stream)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -409,7 +415,7 @@ impl MuxInner {
|
|||
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.lock().await.send(data).await;
|
||||
let _ = stream.stream.send_async(data).await;
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
|
@ -428,8 +434,8 @@ impl MuxInner {
|
|||
}
|
||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.lock().await.disconnect();
|
||||
stream.stream.lock().await.close_channel();
|
||||
stream.is_closed_event.notify(usize::MAX);
|
||||
drop(stream.stream)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -465,7 +471,7 @@ pub struct ServerMux {
|
|||
/// Extensions that are supported by both sides.
|
||||
pub supported_extension_ids: Vec<u8>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
|
||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||
}
|
||||
|
||||
impl ServerMux {
|
||||
|
@ -484,7 +490,7 @@ impl ServerMux {
|
|||
R: ws::WebSocketRead + Send,
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
|
||||
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||
|
||||
|
@ -547,12 +553,12 @@ impl ServerMux {
|
|||
|
||||
/// Wait for a stream to be created.
|
||||
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
|
||||
self.muxstream_recv.next().await
|
||||
self.muxstream_recv.recv_async().await.ok()
|
||||
}
|
||||
|
||||
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||
self.close_tx
|
||||
.send(WsEvent::EndFut(reason))
|
||||
.send_async(WsEvent::EndFut(reason))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
|
@ -574,6 +580,13 @@ impl ServerMux {
|
|||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerMux {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.close_tx.send(WsEvent::EndFut(None));
|
||||
}
|
||||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -595,7 +608,7 @@ pub struct ClientMux {
|
|||
pub downgraded: bool,
|
||||
/// Extensions that are supported by both sides.
|
||||
pub supported_extension_ids: Vec<u8>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
}
|
||||
|
||||
impl ClientMux {
|
||||
|
@ -654,10 +667,10 @@ impl ClientMux {
|
|||
extension.handle_handshake(&mut read, &write).await?;
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel::<WsEvent>(256);
|
||||
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
||||
Ok((
|
||||
Self {
|
||||
close_tx: tx.clone(),
|
||||
stream_tx: tx.clone(),
|
||||
downgraded,
|
||||
supported_extension_ids: supported_extensions
|
||||
.iter()
|
||||
|
@ -697,16 +710,16 @@ impl ClientMux {
|
|||
return Err(WispError::UdpExtensionNotSupported);
|
||||
}
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.close_tx
|
||||
.send(WsEvent::CreateStream(stream_type, host, port, tx))
|
||||
self.stream_tx
|
||||
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||
}
|
||||
|
||||
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||
self.close_tx
|
||||
.send(WsEvent::EndFut(reason))
|
||||
self.stream_tx
|
||||
.send_async(WsEvent::EndFut(reason))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
|
@ -728,3 +741,9 @@ impl ClientMux {
|
|||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ClientMux {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.stream_tx.send(WsEvent::EndFut(None));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
|
||||
|
||||
use async_io_stream::IoStream;
|
||||
pub use async_io_stream::IoStream;
|
||||
use bytes::Bytes;
|
||||
use event_listener::Event;
|
||||
use flume as mpsc;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream,
|
||||
channel::oneshot,
|
||||
select, stream,
|
||||
task::{Context, Poll},
|
||||
Sink, SinkExt, Stream, StreamExt,
|
||||
FutureExt, Sink, Stream,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
|
@ -40,6 +41,7 @@ pub struct MuxStreamRead {
|
|||
tx: mpsc::Sender<WsEvent>,
|
||||
rx: mpsc::Receiver<Bytes>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
is_closed_event: Arc<Event>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
flow_control_read: AtomicU32,
|
||||
target_flow_control: u32,
|
||||
|
@ -51,13 +53,16 @@ impl MuxStreamRead {
|
|||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
let bytes = self.rx.next().await?;
|
||||
let bytes = select! {
|
||||
x = self.rx.recv_async() => x.ok()?,
|
||||
_ = self.is_closed_event.listen().fuse() => return None
|
||||
};
|
||||
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
|
||||
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
|
||||
if val > self.target_flow_control {
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.tx
|
||||
.send(WsEvent::SendPacket(
|
||||
.send_async(WsEvent::SendPacket(
|
||||
Packet::new_continue(
|
||||
self.stream_id,
|
||||
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
|
||||
|
@ -107,13 +112,13 @@ impl MuxStreamWrite {
|
|||
}
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.tx
|
||||
.send(WsEvent::SendPacket(
|
||||
.send_async(WsEvent::SendPacket(
|
||||
Packet::new_data(self.stream_id, data),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||
self.flow_control.store(
|
||||
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
|
||||
|
@ -151,13 +156,13 @@ impl MuxStreamWrite {
|
|||
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.tx
|
||||
.send(WsEvent::Close(
|
||||
.send_async(WsEvent::Close(
|
||||
Packet::new_close(self.stream_id, reason),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -179,6 +184,16 @@ impl MuxStreamWrite {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for MuxStreamWrite {
|
||||
fn drop(&mut self) {
|
||||
if !self.is_closed.load(Ordering::Acquire) {
|
||||
self.is_closed.store(true, Ordering::Release);
|
||||
let (tx, _) = oneshot::channel();
|
||||
let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplexor stream.
|
||||
pub struct MuxStream {
|
||||
/// ID of the stream.
|
||||
|
@ -196,6 +211,7 @@ impl MuxStream {
|
|||
rx: mpsc::Receiver<Bytes>,
|
||||
tx: mpsc::Sender<WsEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
is_closed_event: Arc<Event>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
continue_recieved: Arc<Event>,
|
||||
target_flow_control: u32,
|
||||
|
@ -209,6 +225,7 @@ impl MuxStream {
|
|||
tx: tx.clone(),
|
||||
rx,
|
||||
is_closed: is_closed.clone(),
|
||||
is_closed_event: is_closed_event.clone(),
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_read: AtomicU32::new(0),
|
||||
target_flow_control,
|
||||
|
@ -288,13 +305,13 @@ impl MuxStreamCloser {
|
|||
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.close_channel
|
||||
.send(WsEvent::Close(
|
||||
.send_async(WsEvent::Close(
|
||||
Packet::new_close(self.stream_id, reason),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -76,6 +76,9 @@ pub trait WebSocketRead {
|
|||
pub trait WebSocketWrite {
|
||||
/// Write a frame to the socket.
|
||||
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
|
||||
|
||||
/// Close the socket.
|
||||
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
||||
}
|
||||
|
||||
/// Locked WebSocket.
|
||||
|
@ -88,9 +91,14 @@ impl LockedWebSocketWrite {
|
|||
}
|
||||
|
||||
/// Write a frame to the websocket.
|
||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
|
||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_write_frame(frame).await
|
||||
}
|
||||
|
||||
/// Close the websocket.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.0.lock().await.wisp_close().await
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue