fix continue packets issue, remove requirement for Send on the websocket

This commit is contained in:
Toshit Chawda 2024-03-17 11:04:33 -07:00
parent bed942eb75
commit ce86e7b095
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
19 changed files with 872 additions and 235 deletions

3
.cargo/config.toml Normal file
View file

@ -0,0 +1,3 @@
[build]
rustflags = ["--cfg", "tokio_unstable"]

652
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
[workspace]
resolver = "2"
members = ["server", "client", "wisp", "simple-wisp-client"]
default-members = ["server"]
[patch.crates-io]
rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
@ -8,4 +9,7 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
[profile.release]
lto = true
opt-level = 'z'
strip = true
panic = "abort"
codegen-units = 1

View file

@ -2,5 +2,3 @@ build.sh
Cargo.toml
serve.py
src
pkg/epoxy.wasm

View file

@ -1,6 +1,6 @@
[package]
name = "epoxy-client"
version = "1.4.1"
version = "1.4.2"
edition = "2021"
license = "LGPL-3.0-only"
@ -38,4 +38,3 @@ wasmtimer = "0.2.0"
[dependencies.ring]
features = ["wasm32_unknown_unknown_js"]

View file

@ -19,6 +19,7 @@ echo "[epx] wasm-opt finished"
AUTOGENERATED_SOURCE=$(<"out/epoxy_client.js")
# patch for websocket sharedarraybuffer error
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//getObject(arg0).send(getArrayU8FromWasm0(arg1, arg2)/getObject(arg0).send(new Uint8Array(getArrayU8FromWasm0(arg1, arg2))}
echo "$AUTOGENERATED_SOURCE" > pkg/epoxy.js
WASM_BASE64=$(base64 -w0 out/epoxy_client_bg.wasm)
AUTOGENERATED_SOURCE=${AUTOGENERATED_SOURCE//__wbg_init(input, maybe_memory) \{/__wbg_init(maybe_memory) \{$'\n'let input=\'data:application/wasm;base64,$WASM_BASE64\'}
@ -37,6 +38,7 @@ echo "}\nexport default function epoxy(maybe_memory?: WebAssembly.Memory): Promi
echo "$AUTOGENERATED_TYPEDEFS" > pkg/epoxy-bundled.d.ts
echo "}\ndeclare function epoxy(maybe_memory?: WebAssembly.Memory): Promise<typeof wasm_bindgen>;" >> pkg/epoxy-bundled.d.ts
cp out/epoxy_client.d.ts pkg/epoxy.d.ts
cp out/epoxy_client_bg.wasm pkg/epoxy.wasm
rm -r out/

View file

@ -12,6 +12,7 @@ onmessage = async (msg) => {
should_tls_test,
should_udp_test,
should_reconnect_test,
should_perf2_test,
] = msg.data;
console.log(
"%cWASM is significantly slower with DevTools open!",
@ -217,6 +218,24 @@ onmessage = async (msg) => {
log("sent req");
await (new Promise((res, _) => setTimeout(res, 500)));
}
} else if (should_perf2_test) {
const num_outer_tests = 10;
const num_inner_tests = 50;
let total_mux_multi = 0;
for (const _ of Array(num_outer_tests).keys()) {
let total_mux = 0;
await Promise.all([...Array(num_inner_tests).keys()].map(async i => {
log(`running mux test ${i}`);
return await test_mux("https://httpbin.org/get");
})).then((vals) => { total_mux = vals.reduce((acc, x) => acc + x, 0) });
total_mux = total_mux / num_inner_tests;
log(`avg mux (${num_inner_tests}) took ${total_mux} ms or ${total_mux / 1000} s`);
total_mux_multi += total_mux;
}
total_mux_multi = total_mux_multi / num_outer_tests;
log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`);
} else {
let resp = await epoxy_client.fetch("https://httpbin.org/get");
console.log(resp, Object.fromEntries(resp.headers));

View file

@ -18,6 +18,7 @@
const should_tls_test = params.has("rawtls_test");
const should_udp_test = params.has("udp_test");
const should_reconnect_test = params.has("reconnect_test");
const should_perf2_test = params.has("perf2_test");
const worker = new Worker("demo.js", {type:'module'});
worker.onmessage = (msg) => {
let el = document.createElement("pre");
@ -35,6 +36,7 @@
should_tls_test,
should_udp_test,
should_reconnect_test,
should_perf2_test,
]);
</script>
</head>

View file

@ -1,6 +1,6 @@
{
"name": "@mercuryworkshop/epoxy-tls",
"version": "1.4.1",
"version": "1.4.2",
"description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser",
"scripts": {
"build": "./build.sh"

View file

@ -6,7 +6,7 @@ use wasm_bindgen_futures::JsFuture;
use hyper::rt::Executor;
use js_sys::ArrayBuffer;
use std::future::Future;
use wisp_mux::{CloseReason, WispError};
use wisp_mux::WispError;
#[wasm_bindgen]
extern "C" {
@ -22,11 +22,11 @@ macro_rules! debug {
($($t:tt)*) => (utils::console_debug(&format_args!($($t)*).to_string()))
}
#[allow(unused_macros)]
macro_rules! log {
($($t:tt)*) => (utils::console_log(&format_args!($($t)*).to_string()))
}
#[allow(unused_macros)]
macro_rules! error {
($($t:tt)*) => (utils::console_error(&format_args!($($t)*).to_string()))
}
@ -215,9 +215,9 @@ pub fn spawn_mux_fut(
) {
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = fut.await {
error!("epoxy: error in mux future, restarting: {:?}", e);
log!("epoxy: error in mux future, restarting: {:?}", e);
while let Err(e) = replace_mux(mux.clone(), &url).await {
error!("epoxy: failed to restart mux future: {:?}", e);
log!("epoxy: failed to restart mux future: {:?}", e);
wasmtimer::tokio::sleep(std::time::Duration::from_millis(500)).await;
}
}
@ -231,7 +231,7 @@ pub async fn replace_mux(
) -> Result<(), WispError> {
let (mux_replace, fut) = make_mux(url).await?;
let mut mux_write = mux.write().await;
mux_write.close(CloseReason::Unknown).await;
mux_write.close().await;
*mux_write = mux_replace;
drop(mux_write);
spawn_mux_fut(mux, fut, url.into());

View file

@ -7,6 +7,7 @@ edition = "2021"
bytes = "1.5.0"
clap = { version = "4.4.18", features = ["derive", "help", "usage", "color", "wrap_help", "cargo"] }
clio = { version = "0.3.5", features = ["clap-parse"] }
console-subscriber = { version = "0.2.0", optional = true }
dashmap = "5.5.3"
fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8", "unstable-split"] }
futures-util = { version = "0.3.30", features = ["sink"] }
@ -16,3 +17,6 @@ hyper-util = { version = "0.1.2", features = ["tokio"] }
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
tokio-util = { version = "0.7.10", features = ["codec"] }
wisp-mux = { path = "../wisp", features = ["fastwebsockets", "tokio_io"] }
[features]
tokio-console = ["tokio/tracing", "dep:console-subscriber"]

View file

@ -19,14 +19,12 @@ use tokio_util::codec::{BytesCodec, Framed};
#[cfg(unix)]
use tokio_util::either::Either;
use wisp_mux::{
ws, CloseReason, ConnectPacket, MuxEvent, MuxStream, ServerMux, StreamType, WispError,
};
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError};
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
#[derive(Parser)]
#[command(version = clap::crate_version!(), about = "Implementation of the Wisp protocol in Rust, made for epoxy.")]
#[command(version = clap::crate_version!(), about = "Server implementation of the Wisp protocol in Rust, made for epoxy.")]
struct Cli {
#[arg(long, default_value = "")]
prefix: String,
@ -96,6 +94,8 @@ async fn bind(addr: &str, unix: bool) -> Result<Listener, std::io::Error> {
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Error> {
#[cfg(feature = "tokio-console")]
console_subscriber::init();
let opt = Cli::parse();
let addr = if opt.unix_socket {
opt.bind_host
@ -137,8 +137,7 @@ async fn accept_http(
if uri.is_empty() || uri == "/" {
tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
} else {
let uri = uri.strip_prefix('/').unwrap_or(uri).to_string();
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await });
}
@ -155,10 +154,7 @@ async fn accept_http(
}
}
async fn handle_mux(
packet: ConnectPacket,
mut stream: MuxStream<impl ws::WebSocketWrite + Send + 'static>,
) -> Result<bool, WispError> {
async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool, WispError> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
@ -190,12 +186,9 @@ async fn handle_mux(
},
event = stream.read() => {
match event {
Some(event) => match event {
MuxEvent::Send(data) => {
udp_socket.send(&data).await.map_err(|x| WispError::Other(Box::new(x)))?;
}
MuxEvent::Close(_) => return Ok(false),
},
Some(event) => {
let _ = udp_socket.send(&event).await.map_err(|x| WispError::Other(Box::new(x)))?;
}
None => break,
}
}

View file

@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
bytes = "1.5.0"
console-subscriber = { version = "0.2.0", optional = true }
fastwebsockets = { version = "0.6.0", features = ["unstable-split", "upgrade"] }
futures = "0.3.30"
http-body-util = "0.1.0"
@ -14,3 +15,6 @@ tokio-native-tls = "0.3.1"
tokio-util = "0.7.10"
wisp-mux = { path = "../wisp", features = ["fastwebsockets"]}
[features]
tokio-console = ["tokio/tracing", "dep:console-subscriber"]

View file

@ -41,8 +41,10 @@ where
}
}
#[tokio::main]
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[cfg(feature = "tokio-console")]
console_subscriber::init();
let addr = std::env::args()
.nth(1)
.ok_or(StrError::new("no src addr"))?;
@ -106,7 +108,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.await?
.into_io()
.into_asyncrw();
for _ in 0..10 {
for _ in 0..256 {
channel.write_all(b"hiiiiiiii").await?;
hi += 1;
println!("said hi {}", hi);

View file

@ -1,6 +1,6 @@
[package]
name = "wisp-mux"
version = "1.2.2"
version = "2.0.0"
license = "LGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"

View file

@ -56,10 +56,10 @@ impl From<WebSocketError> for crate::WispError {
}
}
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
@ -68,7 +68,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
}
}
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.into())
.await

View file

@ -15,6 +15,7 @@ pub mod ws;
pub use crate::packet::*;
pub use crate::stream::*;
use bytes::Bytes;
use event_listener::Event;
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
use std::{
@ -95,11 +96,11 @@ impl std::fmt::Display for WispError {
StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
Other(err) => write!(f, "Other error: {:?}", err),
Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Other(err) => write!(f, "Other error: {}", err),
}
}
}
@ -107,27 +108,28 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {}
struct MuxMapValue {
stream: mpsc::UnboundedSender<MuxEvent>,
stream: mpsc::UnboundedSender<Bytes>,
stream_type: StreamType,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
is_closed: Arc<AtomicBool>,
}
struct ServerMuxInner<W>
where
W: ws::WebSocketWrite + Send + 'static,
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
pub async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
) -> Result<(), WispError>
where
@ -137,10 +139,10 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
x = self.server_bg_loop(close_rx).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
};
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(CloseReason::Unknown)));
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
ret
}
@ -151,13 +153,26 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
WsEvent::Close(stream_id, reason, channel) => {
if self.stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
@ -171,12 +186,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
async fn server_msg_loop<R>(
&self,
mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
// will send continues once flow_control is at 10% of max
let target_buffer_size = buffer_size * 90 / 100;
self.tx
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
@ -195,6 +212,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.stream_map.lock().await.insert(
packet.stream_id,
@ -203,6 +221,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
muxstream_sender
@ -213,18 +232,18 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
Role::Server,
stream_type,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
is_closed,
flow_control,
flow_control_event,
target_buffer_size,
),
))
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
let _ = stream.stream.unbounded_send(data);
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
@ -237,11 +256,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
}
}
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
@ -267,18 +289,15 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
/// });
/// }
/// ```
pub struct ServerMux<W>
where
W: ws::WebSocketWrite + Send + 'static,
{
pub struct ServerMux {
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream<W>)>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
}
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
impl ServerMux {
/// Create a new server-side multiplexor.
pub fn new<R>(
pub fn new<R, W: ws::WebSocketWrite>(
read: R,
write: W,
buffer_size: u32,
@ -287,7 +306,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
R: ws::WebSocketRead,
{
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
@ -306,7 +325,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
}
/// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> {
self.muxstream_recv.next().await
}
@ -314,11 +333,11 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&self, reason: CloseReason) {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
pub async fn close(&self) {
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}
@ -332,7 +351,7 @@ where
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
}
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
pub(crate) async fn into_future<R>(
self,
rx: R,
@ -341,10 +360,16 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
where
R: ws::WebSocketRead,
{
futures::select! {
let ret = futures::select! {
x = self.client_bg_loop(close_rx).fuse() => x,
x = self.client_loop(rx).fuse() => x
}
};
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
ret
}
async fn client_bg_loop(
@ -353,13 +378,26 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
WsEvent::Close(stream_id, reason, channel) => {
if self.stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
@ -386,7 +424,7 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
@ -399,11 +437,14 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
}
}
}
Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
self.stream_map.lock().await.remove(&packet.stream_id);
}
}
}
@ -433,9 +474,10 @@ where
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<WsEvent>,
buf_size: u32,
target_buf_size: u32,
}
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
impl<W: ws::WebSocketWrite> ClientMux<W> {
/// Create a new client side multiplexor.
pub async fn new<R>(
mut read: R,
@ -459,6 +501,8 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
buf_size: packet.buffer_remaining,
// server-only
target_buf_size: 0,
},
ClientMuxInner {
tx: write.clone(),
@ -477,39 +521,46 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream<W>, WispError> {
) -> Result<MuxStream, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded();
let evt: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
let next_stream_id = stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?;
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.tx
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
.await?;
self.next_free_stream_id.store(
stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?,
Ordering::Release,
);
self.next_free_stream_id
.store(next_stream_id, Ordering::Release);
self.stream_map.lock().await.insert(
stream_id,
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: evt.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
Ok(MuxStream::new(
stream_id,
Role::Client,
stream_type,
ch_rx,
self.tx.clone(),
self.close_tx.clone(),
AtomicBool::new(false).into(),
is_closed,
flow_control,
evt,
flow_control_event,
self.target_buf_size,
))
}
@ -517,11 +568,11 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&self, reason: CloseReason) {
self.stream_map.lock().await.drain().for_each(|x| {
let _ =
x.1.stream
.unbounded_send(MuxEvent::Close(ClosePacket::new(reason)));
pub async fn close(&self) {
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}

View file

@ -1,4 +1,4 @@
use crate::{sink_unfold, ws, ClosePacket, CloseReason, Packet, Role, StreamType, WispError};
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
use async_io_stream::IoStream;
use bytes::Bytes;
@ -18,91 +18,75 @@ use std::{
},
};
/// Multiplexor event recieved from a Wisp stream.
pub enum MuxEvent {
/// The other side has sent data.
Send(Bytes),
/// The other side has closed.
Close(ClosePacket),
}
pub(crate) enum WsEvent {
Close(u32, CloseReason, oneshot::Sender<Result<(), WispError>>),
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
Close(Packet, oneshot::Sender<Result<(), WispError>>),
EndFut,
}
/// Read side of a multiplexor stream.
pub struct MuxStreamRead<W>
where
W: ws::WebSocketWrite,
{
pub struct MuxStreamRead {
/// ID of the stream.
pub stream_id: u32,
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
tx: ws::LockedWebSocketWrite<W>,
rx: mpsc::UnboundedReceiver<MuxEvent>,
tx: mpsc::UnboundedSender<WsEvent>,
rx: mpsc::UnboundedReceiver<Bytes>,
is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32,
target_flow_control: u32,
}
impl<W: ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
impl MuxStreamRead {
/// Read an event from the stream.
pub async fn read(&mut self) -> Option<MuxEvent> {
pub async fn read(&mut self) -> Option<Bytes> {
if self.is_closed.load(Ordering::Acquire) {
return None;
}
match self.rx.next().await? {
MuxEvent::Send(bytes) => {
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
let old_val = self.flow_control.fetch_add(1, Ordering::AcqRel);
self.tx
.write_frame(Packet::new_continue(self.stream_id, old_val + 1).into())
.await
.ok()?;
}
Some(MuxEvent::Send(bytes))
}
MuxEvent::Close(packet) => {
self.is_closed.store(true, Ordering::Release);
Some(MuxEvent::Close(packet))
let bytes = self.rx.next().await?;
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
if val > self.target_flow_control {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.unbounded_send(WsEvent::SendPacket(
Packet::new_continue(
self.stream_id,
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
),
tx,
))
.ok()?;
rx.await.ok()?.ok()?;
self.flow_control_read.store(0, Ordering::Release);
}
}
Some(bytes)
}
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |mut rx| async move {
let evt = rx.read().await?;
Some((
match evt {
MuxEvent::Send(bytes) => bytes,
MuxEvent::Close(_) => return None,
},
rx,
))
Some((rx.read().await?, rx))
}))
}
}
/// Write side of a multiplexor stream.
pub struct MuxStreamWrite<W>
where
W: ws::WebSocketWrite,
{
pub struct MuxStreamWrite {
/// ID of the stream.
pub stream_id: u32,
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
tx: ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<WsEvent>,
tx: mpsc::UnboundedSender<WsEvent>,
is_closed: Arc<AtomicBool>,
continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>,
}
impl<W: ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
impl MuxStreamWrite {
/// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
@ -114,9 +98,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
{
self.continue_recieved.listen().await;
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.write_frame(Packet::new_data(self.stream_id, data).into())
.await?;
.unbounded_send(WsEvent::SendPacket(
Packet::new_data(self.stream_id, data),
tx,
))
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store(
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
@ -140,7 +129,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
pub fn get_close_handle(&self) -> MuxStreamCloser {
MuxStreamCloser {
stream_id: self.stream_id,
close_channel: self.close_channel.clone(),
close_channel: self.tx.clone(),
is_closed: self.is_closed.clone(),
}
}
@ -150,13 +139,17 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.is_closed.store(true, Ordering::Release);
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.unbounded_send(WsEvent::Close(self.stream_id, reason, tx))
self.tx
.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
self.is_closed.store(true, Ordering::Release);
Ok(())
}
@ -173,40 +166,36 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
}
}
impl<W: ws::WebSocketWrite> Drop for MuxStreamWrite<W> {
impl Drop for MuxStreamWrite {
fn drop(&mut self) {
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
let _ = self.close_channel.unbounded_send(WsEvent::Close(
self.stream_id,
CloseReason::Unknown,
let _ = self.tx.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, CloseReason::Unknown),
tx,
));
}
}
/// Multiplexor stream.
pub struct MuxStream<W>
where
W: ws::WebSocketWrite,
{
pub struct MuxStream {
/// ID of the stream.
pub stream_id: u32,
rx: MuxStreamRead<W>,
tx: MuxStreamWrite<W>,
rx: MuxStreamRead,
tx: MuxStreamWrite,
}
impl<W: ws::WebSocketWrite + Send + 'static> MuxStream<W> {
impl MuxStream {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
stream_id: u32,
role: Role,
stream_type: StreamType,
rx: mpsc::UnboundedReceiver<MuxEvent>,
tx: ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<WsEvent>,
rx: mpsc::UnboundedReceiver<Bytes>,
tx: mpsc::UnboundedSender<WsEvent>,
is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>,
target_flow_control: u32,
) -> Self {
Self {
stream_id,
@ -218,13 +207,14 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStream<W> {
rx,
is_closed: is_closed.clone(),
flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0),
target_flow_control,
},
tx: MuxStreamWrite {
stream_id,
stream_type,
role,
tx,
close_channel,
is_closed: is_closed.clone(),
flow_control: flow_control.clone(),
continue_recieved: continue_recieved.clone(),
@ -233,7 +223,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStream<W> {
}
/// Read an event from the stream.
pub async fn read(&mut self) -> Option<MuxEvent> {
pub async fn read(&mut self) -> Option<Bytes> {
self.rx.read().await
}
@ -263,7 +253,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> MuxStream<W> {
}
/// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) {
pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) {
(self.rx, self.tx)
}
@ -291,25 +281,34 @@ impl MuxStreamCloser {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.is_closed.store(true, Ordering::Release);
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.unbounded_send(WsEvent::Close(self.stream_id, reason, tx))
.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
self.is_closed.store(true, Ordering::Release);
Ok(())
}
/// Close the stream. This function does not check if it was actually closed.
pub(crate) fn close_sync(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.is_closed.store(true, Ordering::Release);
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.unbounded_send(WsEvent::Close(self.stream_id, reason, tx))
.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.map_err(|x| WispError::Other(Box::new(x)))?;
self.is_closed.store(true, Ordering::Release);
Ok(())
}
}

View file

@ -1,10 +1,9 @@
//! Abstraction over WebSocket implementations.
//!
//! Use the [`fastwebsockets`] and [`ws_stream_wasm`] implementations of these traits as an example
//! for implementing them for other WebSocket implementations.
//! Use the [`fastwebsockets`] implementation of these traits as an example for implementing them
//! for other WebSocket implementations.
//!
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
//! [`ws_stream_wasm`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/ws_stream_wasm.rs
use bytes::Bytes;
use futures::lock::Mutex;
use std::sync::Arc;
@ -68,8 +67,8 @@ pub trait WebSocketRead {
/// Read a frame from the socket.
fn wisp_read_frame(
&mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>> + Send;
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
}
/// Generic WebSocket write trait.
@ -78,13 +77,13 @@ pub trait WebSocketWrite {
fn wisp_write_frame(
&mut self,
frame: Frame,
) -> impl std::future::Future<Output = Result<(), crate::WispError>> + Send;
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
}
/// Locked WebSocket that can be shared between threads.
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
impl<S: WebSocketWrite + Send> LockedWebSocketWrite<S> {
impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
/// Create a new locked websocket.
pub fn new(ws: S) -> Self {
Self(Arc::new(Mutex::new(ws)))