remove the mutex<hashmap> in wisp_mux, other improvements

This commit is contained in:
Toshit Chawda 2024-03-26 18:55:54 -07:00
parent ff2a1ad269
commit 7001ee8fa5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
16 changed files with 346 additions and 309 deletions

22
Cargo.lock generated
View file

@ -153,6 +153,12 @@ dependencies = [
"tokio",
]
[[package]]
name = "atomic-counter"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f447d68cfa5a9ab0c1c862a703da2a65b5ed1b7ce1153c9eb0169506d56019"
[[package]]
name = "autocfg"
version = "1.1.0"
@ -516,7 +522,7 @@ checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
[[package]]
name = "epoxy-client"
version = "1.4.2"
version = "1.5.0"
dependencies = [
"async-compression",
"async_io_stream",
@ -1727,18 +1733,29 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
name = "simple-wisp-client"
version = "1.0.0"
dependencies = [
"atomic-counter",
"bytes",
"console-subscriber",
"fastwebsockets 0.7.1",
"futures",
"http-body-util",
"hyper 1.2.0",
"simple_moving_average",
"tokio",
"tokio-native-tls",
"tokio-util",
"wisp-mux",
]
[[package]]
name = "simple_moving_average"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a4b144ad185430cd033299e2c93e465d5a7e65fbb858593dc57181fa13cd310"
dependencies = [
"num-traits",
]
[[package]]
name = "slab"
version = "0.4.9"
@ -2500,10 +2517,11 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
[[package]]
name = "wisp-mux"
version = "2.0.2"
version = "3.0.0"
dependencies = [
"async_io_stream",
"bytes",
"dashmap",
"event-listener",
"fastwebsockets 0.7.1",
"futures",

View file

@ -8,8 +8,10 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
[profile.release]
lto = true
opt-level = 'z'
strip = true
debug = true
panic = "abort"
codegen-units = 1
opt-level = 3
[profile.release.package.epoxy-client]
opt-level = 'z'

View file

@ -1,6 +1,6 @@
[package]
name = "epoxy-client"
version = "1.4.2"
version = "1.5.0"
edition = "2021"
license = "LGPL-3.0-only"

View file

@ -1,6 +1,6 @@
{
"name": "@mercuryworkshop/epoxy-tls",
"version": "1.4.2",
"version": "1.5.0",
"description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser",
"scripts": {
"build": "./build.sh"

View file

@ -47,8 +47,8 @@ enum EpxCompression {
Gzip,
}
type EpxIoTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
type EpxIoTlsStream = TlsStream<EpxIoUnencryptedStream>;
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>;
#[wasm_bindgen(start)]

View file

@ -231,7 +231,7 @@ pub async fn replace_mux(
) -> Result<(), WispError> {
let (mux_replace, fut) = make_mux(url).await?;
let mut mux_write = mux.write().await;
mux_write.close().await;
mux_write.close().await?;
*mux_write = mux_replace;
drop(mux_write);
spawn_mux_fut(mux, fut, url.into());

View file

@ -56,7 +56,7 @@ impl Stream for IncomingBody {
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux<WebSocketWrapper>>>, pub String);
impl tower_service::Service<hyper::Uri> for ServiceWrapper {
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>;
type Response = TokioIo<EpxIoUnencryptedStream>;
type Error = WispError;
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;

View file

@ -239,7 +239,7 @@ async fn accept_ws(
println!("{:?}: connected", addr);
let (mut mux, fut) = ServerMux::new(rx, tx, 128);
let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX);
tokio::spawn(async move {
if let Err(e) = fut.await {
@ -247,7 +247,7 @@ async fn accept_ws(
}
});
while let Some((packet, stream)) = mux.server_new_stream().await {
while let Some((packet, mut stream)) = mux.server_new_stream().await {
tokio::spawn(async move {
if block_local {
match lookup_host(format!(
@ -272,8 +272,8 @@ async fn accept_ws(
}
}
}
let close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle();
let mut close_err = stream.get_close_handle();
let mut close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream)
.or_else(|err| async move {
let _ = close_err.close(CloseReason::Unexpected).await;

View file

@ -4,12 +4,14 @@ version = "1.0.0"
edition = "2021"
[dependencies]
atomic-counter = "1.0.1"
bytes = "1.5.0"
console-subscriber = { version = "0.2.0", optional = true }
fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] }
futures = "0.3.30"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["http1", "client"] }
simple_moving_average = "1.0.2"
tokio = { version = "1.36.0", features = ["full"] }
tokio-native-tls = "0.3.1"
tokio-util = "0.7.10"

View file

@ -1,16 +1,25 @@
use atomic_counter::{AtomicCounter, RelaxedCounter};
use bytes::Bytes;
use fastwebsockets::{handshake, FragmentCollectorRead};
use futures::io::AsyncWriteExt;
use futures::future::select_all;
use http_body_util::Empty;
use hyper::{
header::{CONNECTION, UPGRADE},
Request,
};
use std::{error::Error, future::Future};
use tokio::net::TcpStream;
use simple_moving_average::{SingleSumSMA, SMA};
use std::{
error::Error,
future::Future,
io::{stdout, IsTerminal, Write},
sync::Arc,
time::Duration,
usize,
};
use tokio::{net::TcpStream, time::interval};
use tokio_native_tls::{native_tls, TlsConnector};
use wisp_mux::{ClientMux, StreamType};
use tokio_util::either::Either;
use wisp_mux::{ClientMux, StreamType, WispError};
#[derive(Debug)]
struct StrError(String);
@ -70,6 +79,18 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.nth(6)
.ok_or(StrError::new("no should tls"))?
.parse()?;
let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?;
println!(
"connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}",
if should_tls { "wss" } else { "ws" },
addr,
addr_port,
addr_path,
addr_dest,
addr_dest_port,
thread_cnt
);
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
let socket = if should_tls {
@ -98,23 +119,59 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx);
let (mux, fut) = ClientMux::new(rx, tx).await?;
let mut threads = Vec::with_capacity(thread_cnt + 1);
tokio::task::spawn(async move { println!("err: {:?}", fut.await); });
threads.push(tokio::spawn(fut));
let mut hi: u64 = 0;
loop {
let payload = Bytes::from_static(&[0; 1024]);
let cnt = Arc::new(RelaxedCounter::new(0));
for _ in 0..thread_cnt {
let mut channel = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await?
.into_io()
.into_asyncrw();
for _ in 0..256 {
channel.write_all(b"hiiiiiiii").await?;
hi += 1;
println!("said hi {}", hi);
.await?;
let cnt = cnt.clone();
let payload = payload.clone();
threads.push(tokio::spawn(async move {
loop {
channel.write(payload.clone()).await?;
channel.read().await;
cnt.inc();
}
#[allow(unreachable_code)]
Ok::<(), WispError>(())
}));
}
#[allow(unreachable_code)]
threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
let mut last_time = 0;
let is_term = stdout().is_terminal();
loop {
interval.tick().await;
let now = cnt.get();
let stat = format!(
"sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}",
now,
now - last_time,
avg.get_average()
);
if is_term {
print!("\x1b[2K{}\r", stat);
} else {
println!("{}", stat);
}
stdout().flush().unwrap();
avg.add_sample(now - last_time);
last_time = now;
}
}));
let out = select_all(threads.into_iter()).await;
println!("\n\nout: {:?}", out.0);
Ok(())
}

View file

@ -1,6 +1,6 @@
[package]
name = "wisp-mux"
version = "2.0.2"
version = "3.0.0"
license = "LGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
@ -11,6 +11,7 @@ edition = "2021"
[dependencies]
async_io_stream = "0.3.3"
bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0"
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
futures = "0.3.30"

View file

@ -8,7 +8,9 @@ impl From<OpCode> for crate::ws::OpCode {
fn from(opcode: OpCode) -> Self {
use OpCode::*;
match opcode {
Continuation => unreachable!("continuation should never be recieved when using a fragmentcollector"),
Continuation => {
unreachable!("continuation should never be recieved when using a fragmentcollector")
}
Text => Self::Text,
Binary => Self::Binary,
Close => Self::Close,
@ -70,8 +72,6 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.into())
.await
.map_err(|e| e.into())
self.write_frame(frame.into()).await.map_err(|e| e.into())
}
}

View file

@ -16,14 +16,13 @@ pub use crate::packet::*;
pub use crate::stream::*;
use bytes::Bytes;
use dashmap::DashMap;
use event_listener::Event;
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
use std::{
collections::HashMap,
sync::{
use futures::SinkExt;
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
/// The role of the multiplexor.
@ -72,6 +71,8 @@ pub enum WispError {
Utf8Error(std::str::Utf8Error),
/// Other error.
Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task.
MuxMessageFailedToSend,
}
impl From<std::str::Utf8Error> for WispError {
@ -82,25 +83,29 @@ impl From<std::str::Utf8Error> for WispError {
impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use WispError::*;
match self {
PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"),
InvalidStreamId => write!(f, "Invalid stream id"),
InvalidCloseReason => write!(f, "Invalid close reason"),
InvalidUri => write!(f, "Invalid URI"),
UriHasNoHost => write!(f, "URI has no host"),
UriHasNoPort => write!(f, "URI has no port"),
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Other(err) => write!(f, "Other error: {}", err),
Self::PacketTooSmall => write!(f, "Packet too small"),
Self::InvalidPacketType => write!(f, "Invalid packet type"),
Self::InvalidStreamType => write!(f, "Invalid stream type"),
Self::InvalidStreamId => write!(f, "Invalid stream id"),
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
Self::InvalidUri => write!(f, "Invalid URI"),
Self::UriHasNoHost => write!(f, "URI has no host"),
Self::UriHasNoPort => write!(f, "URI has no port"),
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
Self::WsImplSocketClosed => {
write!(f, "Websocket implementation error: websocket closed")
}
Self::WsImplNotSupported => {
write!(f, "Websocket implementation error: unsupported feature")
}
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err),
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
}
}
}
@ -115,61 +120,74 @@ struct MuxMapValue {
is_closed: Arc<AtomicBool>,
}
struct ServerMuxInner<W>
struct MuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
stream_map: Arc<DashMap<u32, MuxMapValue>>,
}
impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
pub async fn into_future<R>(
impl<W: ws::WebSocketWrite> MuxInner<W> {
pub async fn server_into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>,
close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.into_future(
close_rx,
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
)
.await
}
pub async fn client_into_future<R>(
self,
rx: R,
close_rx: mpsc::Receiver<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.into_future(close_rx, self.client_loop(rx)).await
}
async fn into_future(
&self,
close_rx: mpsc::Receiver<WsEvent>,
wisp_fut: impl Future<Output = Result<(), WispError>>,
) -> Result<(), WispError> {
let ret = futures::select! {
x = self.server_bg_loop(close_rx).fuse() => x,
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
_ = self.stream_loop(close_rx).fuse() => Ok(()),
x = wisp_fut.fuse() => x,
};
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
self.stream_map.iter_mut().for_each(|mut x| {
x.is_closed.store(true, Ordering::Release);
x.stream.disconnect();
x.stream.close_channel();
});
self.stream_map.clear();
ret
}
async fn server_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
while let Some(msg) = stream_rx.next().await {
match msg {
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
if self.stream_map.get(&packet.stream_id).is_some() {
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
@ -180,20 +198,20 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
WsEvent::EndFut => break,
}
}
Ok(())
}
async fn server_msg_loop<R>(
async fn server_loop<R>(
&self,
mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
// will send continues once flow_control is at 10% of max
let target_buffer_size = buffer_size * 90 / 100;
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
self.tx
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
@ -214,7 +232,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
let flow_control_event: Arc<Event> = Event::new().into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.stream_map.lock().await.insert(
self.stream_map.insert(
packet.stream_id,
MuxMapValue {
stream: ch_tx,
@ -232,7 +250,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
Role::Server,
stream_type,
ch_rx,
self.close_tx.clone(),
close_tx.clone(),
is_closed,
flow_control,
flow_control_event,
@ -242,7 +260,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
.map_err(|x| WispError::Other(Box::new(x)))?;
}
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
@ -257,9 +275,47 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
}
Continue(_) => break Err(WispError::InvalidPacketType),
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
}
}
}
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
let packet = Packet::try_from(frame)?;
use PacketType::*;
match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
@ -290,8 +346,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
/// }
/// ```
pub struct ServerMux {
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
}
@ -305,22 +360,19 @@ impl ServerMux {
where
R: ws::WebSocketRead,
{
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
(
Self {
muxstream_recv: rx,
close_tx: close_tx.clone(),
stream_map: map.clone(),
},
ServerMuxInner {
MuxInner {
tx: write,
close_tx,
stream_map: map.clone(),
stream_map: DashMap::new().into(),
}
.into_future(read, close_rx, tx, buffer_size),
.server_into_future(read, close_rx, tx, buffer_size, close_tx),
)
}
@ -333,124 +385,13 @@ impl ServerMux {
///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called.
pub async fn close(&self) {
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
}
}
struct ClientMuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
}
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
pub(crate) async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.client_bg_loop(close_rx).fuse() => x,
x = self.client_loop(rx).fuse() => x
};
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
ret
}
async fn client_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
.await
.get(&packet.stream_id)
.is_some()
{
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
let packet = Packet::try_from(frame)?;
use PacketType::*;
match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
}
}
}
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}
/// Client side multiplexor.
///
/// # Example
@ -470,9 +411,9 @@ where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
stream_map: Arc<DashMap<u32, MuxMapValue>>,
next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
buf_size: u32,
target_buf_size: u32,
}
@ -492,23 +433,23 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (tx, rx) = mpsc::unbounded::<WsEvent>();
let map = Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = mpsc::channel::<WsEvent>(256);
let map = Arc::new(DashMap::new());
Ok((
Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx,
close_tx: tx.clone(),
buf_size: packet.buffer_remaining,
// server-only
target_buf_size: 0,
},
ClientMuxInner {
MuxInner {
tx: write.clone(),
stream_map: map.clone(),
}
.into_future(read, rx),
.client_into_future(read, rx),
))
} else {
Err(WispError::InvalidPacketType)
@ -540,7 +481,7 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
self.next_free_stream_id
.store(next_stream_id, Ordering::Release);
self.stream_map.lock().await.insert(
self.stream_map.insert(
stream_id,
MuxMapValue {
stream: ch_tx,
@ -568,12 +509,10 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function.
pub async fn close(&self) {
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
pub async fn close(&mut self) -> Result<(), WispError> {
self.close_tx
.send(WsEvent::EndFut)
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
}

View file

@ -1,5 +1,5 @@
use crate::{ws, WispError};
use bytes::{Buf, BufMut, Bytes};
use bytes::{Buf, BufMut, Bytes, BytesMut};
/// Wisp stream type.
#[derive(Debug, PartialEq, Copy, Clone)]
@ -115,13 +115,13 @@ impl TryFrom<Bytes> for ConnectPacket {
}
}
impl From<ConnectPacket> for Vec<u8> {
impl From<ConnectPacket> for Bytes {
fn from(packet: ConnectPacket) -> Self {
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
encoded.put_u8(packet.stream_type as u8);
encoded.put_u16_le(packet.destination_port);
encoded.extend(packet.destination_hostname.bytes());
encoded
encoded.freeze()
}
}
@ -153,11 +153,11 @@ impl TryFrom<Bytes> for ContinuePacket {
}
}
impl From<ContinuePacket> for Vec<u8> {
impl From<ContinuePacket> for Bytes {
fn from(packet: ContinuePacket) -> Self {
let mut encoded = Self::with_capacity(4);
let mut encoded = BytesMut::with_capacity(4);
encoded.put_u32_le(packet.buffer_remaining);
encoded
encoded.freeze()
}
}
@ -190,11 +190,11 @@ impl TryFrom<Bytes> for ClosePacket {
}
}
impl From<ClosePacket> for Vec<u8> {
impl From<ClosePacket> for Bytes {
fn from(packet: ClosePacket) -> Self {
let mut encoded = Self::with_capacity(1);
let mut encoded = BytesMut::with_capacity(1);
encoded.put_u8(packet.reason as u8);
encoded
encoded.freeze()
}
}
@ -224,12 +224,12 @@ impl PacketType {
}
}
impl From<PacketType> for Vec<u8> {
impl From<PacketType> for Bytes {
fn from(packet: PacketType) -> Self {
use PacketType::*;
match packet {
Connect(x) => x.into(),
Data(x) => x.to_vec(),
Data(x) => x,
Continue(x) => x.into(),
Close(x) => x.into(),
}
@ -250,7 +250,10 @@ impl Packet {
///
/// The helper functions should be used for most use cases.
pub fn new(stream_id: u32, packet: PacketType) -> Self {
Self { stream_id, packet_type: packet }
Self {
stream_id,
packet_type: packet,
}
}
/// Create a new connect packet.
@ -316,13 +319,15 @@ impl TryFrom<Bytes> for Packet {
}
}
impl From<Packet> for Vec<u8> {
impl From<Packet> for Bytes {
fn from(packet: Packet) -> Self {
let mut encoded = Self::with_capacity(1 + 4);
encoded.push(packet.packet_type.as_u8());
let inner_u8 = packet.packet_type.as_u8();
let inner = Bytes::from(packet.packet_type);
let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len());
encoded.put_u8(inner_u8);
encoded.put_u32_le(packet.stream_id);
encoded.extend(Vec::<u8>::from(packet.packet_type));
encoded
encoded.extend(inner);
encoded.freeze()
}
}
@ -341,6 +346,6 @@ impl TryFrom<ws::Frame> for Packet {
impl From<Packet> for ws::Frame {
fn from(packet: Packet) -> Self {
Self::binary(Vec::<u8>::from(packet).into())
Self::binary(packet.into())
}
}

View file

@ -45,28 +45,42 @@ pin_project! {
/// Sink for the [`unfold`] function.
#[derive(Debug)]
#[must_use = "sinks do nothing unless polled"]
pub struct Unfold<T, F, FC, R> {
pub struct Unfold<T, F, R, CT, CF, CR> {
function: F,
close_function: FC,
close_function: CF,
#[pin]
state: UnfoldState<T, R>,
#[pin]
close_state: UnfoldState<CT, CR>
}
}
pub(crate) fn unfold<T, F, FC, R, Item, E>(init: T, function: F, close_function: FC) -> Unfold<T, F, FC, R>
pub(crate) fn unfold<T, F, R, CT, CF, CR, Item, E>(
init: T,
function: F,
close_init: CT,
close_function: CF,
) -> Unfold<T, F, R, CT, CF, CR>
where
F: FnMut(T, Item) -> R,
R: Future<Output = Result<T, E>>,
FC: Fn() -> Result<(), E>,
CF: FnMut(CT) -> CR,
CR: Future<Output = Result<CT, E>>,
{
Unfold { function, close_function, state: UnfoldState::Value { value: init } }
Unfold {
function,
close_function,
state: UnfoldState::Value { value: init },
close_state: UnfoldState::Value { value: close_init },
}
}
impl<T, F, FC, R, Item, E> Sink<Item> for Unfold<T, F, FC, R>
impl<T, F, R, CT, CF, CR, Item, E> Sink<Item> for Unfold<T, F, R, CT, CF, CR>
where
F: FnMut(T, Item) -> R,
R: Future<Output = Result<T, E>>,
FC: Fn() -> Result<(), E>,
CF: FnMut(CT) -> CR,
CR: Future<Output = Result<CT, E>>,
{
type Error = E;
@ -104,6 +118,27 @@ where
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
Poll::Ready((self.close_function)())
let mut this = self.project();
Poll::Ready(
if let Some(future) = this.close_state.as_mut().project_future() {
match ready!(future.poll(cx)) {
Ok(state) => {
this.close_state.set(UnfoldState::Value { value: state });
Ok(())
}
Err(err) => {
this.close_state.set(UnfoldState::Empty);
Err(err)
}
}
} else {
let future = match this.close_state.as_mut().take_value() {
Some(value) => (this.close_function)(value),
None => panic!("start_send called without poll_ready being called first"),
};
this.close_state.set(UnfoldState::Future { future });
return Poll::Pending;
},
)
}
}

View file

@ -7,7 +7,7 @@ use futures::{
channel::{mpsc, oneshot},
stream,
task::{Context, Poll},
Sink, Stream, StreamExt,
Sink, SinkExt, Stream, StreamExt,
};
use pin_project_lite::pin_project;
use std::{
@ -31,7 +31,7 @@ pub struct MuxStreamRead {
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
tx: mpsc::UnboundedSender<WsEvent>,
tx: mpsc::Sender<WsEvent>,
rx: mpsc::UnboundedReceiver<Bytes>,
is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>,
@ -51,13 +51,14 @@ impl MuxStreamRead {
if val > self.target_flow_control {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.unbounded_send(WsEvent::SendPacket(
.send(WsEvent::SendPacket(
Packet::new_continue(
self.stream_id,
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
),
tx,
))
.await
.ok()?;
rx.await.ok()?.ok()?;
self.flow_control_read.store(0, Ordering::Release);
@ -80,7 +81,7 @@ pub struct MuxStreamWrite {
/// Type of the stream.
pub stream_type: StreamType,
role: Role,
tx: mpsc::UnboundedSender<WsEvent>,
tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>,
@ -88,7 +89,7 @@ pub struct MuxStreamWrite {
impl MuxStreamWrite {
/// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
@ -100,10 +101,11 @@ impl MuxStreamWrite {
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.unbounded_send(WsEvent::SendPacket(
.send(WsEvent::SendPacket(
Packet::new_data(self.stream_id, data),
tx,
))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
@ -135,7 +137,7 @@ impl MuxStreamWrite {
}
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
@ -143,10 +145,11 @@ impl MuxStreamWrite {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx
.unbounded_send(WsEvent::Close(
.send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
@ -157,25 +160,19 @@ impl MuxStreamWrite {
let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold(
self,
|tx, data| async move {
|mut tx, data| async move {
tx.write(data).await?;
Ok(tx)
},
move || handle.close_sync(CloseReason::Unknown),
handle,
move |mut handle| async {
handle.close(CloseReason::Unknown).await?;
Ok(handle)
},
))
}
}
impl Drop for MuxStreamWrite {
fn drop(&mut self) {
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
let _ = self.tx.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, CloseReason::Unknown),
tx,
));
}
}
/// Multiplexor stream.
pub struct MuxStream {
/// ID of the stream.
@ -191,7 +188,7 @@ impl MuxStream {
role: Role,
stream_type: StreamType,
rx: mpsc::UnboundedReceiver<Bytes>,
tx: mpsc::UnboundedSender<WsEvent>,
tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>,
@ -228,7 +225,7 @@ impl MuxStream {
}
/// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
self.tx.write(data).await
}
@ -248,7 +245,7 @@ impl MuxStream {
}
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await
}
@ -271,13 +268,13 @@ impl MuxStream {
pub struct MuxStreamCloser {
/// ID of the stream.
pub stream_id: u32,
close_channel: mpsc::UnboundedSender<WsEvent>,
close_channel: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>,
}
impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
@ -285,32 +282,16 @@ impl MuxStreamCloser {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.unbounded_send(WsEvent::Close(
.send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
Ok(())
}
pub(crate) fn close_sync(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.is_closed.store(true, Ordering::Release);
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.map_err(|x| WispError::Other(Box::new(x)))?;
Ok(())
}
}
pin_project! {
@ -336,10 +317,7 @@ impl MuxStreamIo {
impl Stream for MuxStreamIo {
type Item = Result<Vec<u8>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project()
.rx
.poll_next(cx)
.map(|x| x.map(|x| Ok(x.to_vec())))
self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec())))
}
}