remove the mutex<hashmap> in wisp_mux, other improvements

This commit is contained in:
Toshit Chawda 2024-03-26 18:55:54 -07:00
parent ff2a1ad269
commit 7001ee8fa5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
16 changed files with 346 additions and 309 deletions

22
Cargo.lock generated
View file

@ -153,6 +153,12 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "atomic-counter"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f447d68cfa5a9ab0c1c862a703da2a65b5ed1b7ce1153c9eb0169506d56019"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.1.0" version = "1.1.0"
@ -516,7 +522,7 @@ checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
[[package]] [[package]]
name = "epoxy-client" name = "epoxy-client"
version = "1.4.2" version = "1.5.0"
dependencies = [ dependencies = [
"async-compression", "async-compression",
"async_io_stream", "async_io_stream",
@ -1727,18 +1733,29 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
name = "simple-wisp-client" name = "simple-wisp-client"
version = "1.0.0" version = "1.0.0"
dependencies = [ dependencies = [
"atomic-counter",
"bytes", "bytes",
"console-subscriber", "console-subscriber",
"fastwebsockets 0.7.1", "fastwebsockets 0.7.1",
"futures", "futures",
"http-body-util", "http-body-util",
"hyper 1.2.0", "hyper 1.2.0",
"simple_moving_average",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tokio-util", "tokio-util",
"wisp-mux", "wisp-mux",
] ]
[[package]]
name = "simple_moving_average"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a4b144ad185430cd033299e2c93e465d5a7e65fbb858593dc57181fa13cd310"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"
@ -2500,10 +2517,11 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
[[package]] [[package]]
name = "wisp-mux" name = "wisp-mux"
version = "2.0.2" version = "3.0.0"
dependencies = [ dependencies = [
"async_io_stream", "async_io_stream",
"bytes", "bytes",
"dashmap",
"event-listener", "event-listener",
"fastwebsockets 0.7.1", "fastwebsockets 0.7.1",
"futures", "futures",

View file

@ -8,8 +8,10 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
[profile.release] [profile.release]
lto = true lto = true
opt-level = 'z' debug = true
strip = true
panic = "abort" panic = "abort"
codegen-units = 1 codegen-units = 1
opt-level = 3
[profile.release.package.epoxy-client]
opt-level = 'z'

View file

@ -1,6 +1,6 @@
[package] [package]
name = "epoxy-client" name = "epoxy-client"
version = "1.4.2" version = "1.5.0"
edition = "2021" edition = "2021"
license = "LGPL-3.0-only" license = "LGPL-3.0-only"

View file

@ -1,6 +1,6 @@
{ {
"name": "@mercuryworkshop/epoxy-tls", "name": "@mercuryworkshop/epoxy-tls",
"version": "1.4.2", "version": "1.5.0",
"description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser", "description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser",
"scripts": { "scripts": {
"build": "./build.sh" "build": "./build.sh"

View file

@ -47,8 +47,8 @@ enum EpxCompression {
Gzip, Gzip,
} }
type EpxIoTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>; type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
type EpxIoTlsStream = TlsStream<EpxIoUnencryptedStream>;
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>; type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>;
#[wasm_bindgen(start)] #[wasm_bindgen(start)]

View file

@ -231,7 +231,7 @@ pub async fn replace_mux(
) -> Result<(), WispError> { ) -> Result<(), WispError> {
let (mux_replace, fut) = make_mux(url).await?; let (mux_replace, fut) = make_mux(url).await?;
let mut mux_write = mux.write().await; let mut mux_write = mux.write().await;
mux_write.close().await; mux_write.close().await?;
*mux_write = mux_replace; *mux_write = mux_replace;
drop(mux_write); drop(mux_write);
spawn_mux_fut(mux, fut, url.into()); spawn_mux_fut(mux, fut, url.into());

View file

@ -56,7 +56,7 @@ impl Stream for IncomingBody {
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux<WebSocketWrapper>>>, pub String); pub struct ServiceWrapper(pub Arc<RwLock<ClientMux<WebSocketWrapper>>>, pub String);
impl tower_service::Service<hyper::Uri> for ServiceWrapper { impl tower_service::Service<hyper::Uri> for ServiceWrapper {
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>; type Response = TokioIo<EpxIoUnencryptedStream>;
type Error = WispError; type Error = WispError;
type Future = impl Future<Output = Result<Self::Response, Self::Error>>; type Future = impl Future<Output = Result<Self::Response, Self::Error>>;

View file

@ -239,7 +239,7 @@ async fn accept_ws(
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
let (mut mux, fut) = ServerMux::new(rx, tx, 128); let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = fut.await { if let Err(e) = fut.await {
@ -247,7 +247,7 @@ async fn accept_ws(
} }
}); });
while let Some((packet, stream)) = mux.server_new_stream().await { while let Some((packet, mut stream)) = mux.server_new_stream().await {
tokio::spawn(async move { tokio::spawn(async move {
if block_local { if block_local {
match lookup_host(format!( match lookup_host(format!(
@ -272,8 +272,8 @@ async fn accept_ws(
} }
} }
} }
let close_err = stream.get_close_handle(); let mut close_err = stream.get_close_handle();
let close_ok = stream.get_close_handle(); let mut close_ok = stream.get_close_handle();
let _ = handle_mux(packet, stream) let _ = handle_mux(packet, stream)
.or_else(|err| async move { .or_else(|err| async move {
let _ = close_err.close(CloseReason::Unexpected).await; let _ = close_err.close(CloseReason::Unexpected).await;

View file

@ -4,12 +4,14 @@ version = "1.0.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
atomic-counter = "1.0.1"
bytes = "1.5.0" bytes = "1.5.0"
console-subscriber = { version = "0.2.0", optional = true } console-subscriber = { version = "0.2.0", optional = true }
fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] } fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] }
futures = "0.3.30" futures = "0.3.30"
http-body-util = "0.1.0" http-body-util = "0.1.0"
hyper = { version = "1.1.0", features = ["http1", "client"] } hyper = { version = "1.1.0", features = ["http1", "client"] }
simple_moving_average = "1.0.2"
tokio = { version = "1.36.0", features = ["full"] } tokio = { version = "1.36.0", features = ["full"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
tokio-util = "0.7.10" tokio-util = "0.7.10"

View file

@ -1,16 +1,25 @@
use atomic_counter::{AtomicCounter, RelaxedCounter};
use bytes::Bytes; use bytes::Bytes;
use fastwebsockets::{handshake, FragmentCollectorRead}; use fastwebsockets::{handshake, FragmentCollectorRead};
use futures::io::AsyncWriteExt; use futures::future::select_all;
use http_body_util::Empty; use http_body_util::Empty;
use hyper::{ use hyper::{
header::{CONNECTION, UPGRADE}, header::{CONNECTION, UPGRADE},
Request, Request,
}; };
use std::{error::Error, future::Future}; use simple_moving_average::{SingleSumSMA, SMA};
use tokio::net::TcpStream; use std::{
error::Error,
future::Future,
io::{stdout, IsTerminal, Write},
sync::Arc,
time::Duration,
usize,
};
use tokio::{net::TcpStream, time::interval};
use tokio_native_tls::{native_tls, TlsConnector}; use tokio_native_tls::{native_tls, TlsConnector};
use wisp_mux::{ClientMux, StreamType};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{ClientMux, StreamType, WispError};
#[derive(Debug)] #[derive(Debug)]
struct StrError(String); struct StrError(String);
@ -70,6 +79,18 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
.nth(6) .nth(6)
.ok_or(StrError::new("no should tls"))? .ok_or(StrError::new("no should tls"))?
.parse()?; .parse()?;
let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?;
println!(
"connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}",
if should_tls { "wss" } else { "ws" },
addr,
addr_port,
addr_path,
addr_dest,
addr_dest_port,
thread_cnt
);
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
let socket = if should_tls { let socket = if should_tls {
@ -98,23 +119,59 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let (mux, fut) = ClientMux::new(rx, tx).await?; let (mux, fut) = ClientMux::new(rx, tx).await?;
let mut threads = Vec::with_capacity(thread_cnt + 1);
tokio::task::spawn(async move { println!("err: {:?}", fut.await); }); threads.push(tokio::spawn(fut));
let mut hi: u64 = 0; let payload = Bytes::from_static(&[0; 1024]);
loop {
let cnt = Arc::new(RelaxedCounter::new(0));
for _ in 0..thread_cnt {
let mut channel = mux let mut channel = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await? .await?;
.into_io() let cnt = cnt.clone();
.into_asyncrw(); let payload = payload.clone();
for _ in 0..256 { threads.push(tokio::spawn(async move {
channel.write_all(b"hiiiiiiii").await?; loop {
hi += 1; channel.write(payload.clone()).await?;
println!("said hi {}", hi); channel.read().await;
} cnt.inc();
}
#[allow(unreachable_code)]
Ok::<(), WispError>(())
}));
} }
#[allow(unreachable_code)] threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
let mut last_time = 0;
let is_term = stdout().is_terminal();
loop {
interval.tick().await;
let now = cnt.get();
let stat = format!(
"sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}",
now,
now - last_time,
avg.get_average()
);
if is_term {
print!("\x1b[2K{}\r", stat);
} else {
println!("{}", stat);
}
stdout().flush().unwrap();
avg.add_sample(now - last_time);
last_time = now;
}
}));
let out = select_all(threads.into_iter()).await;
println!("\n\nout: {:?}", out.0);
Ok(()) Ok(())
} }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "wisp-mux" name = "wisp-mux"
version = "2.0.2" version = "3.0.0"
license = "LGPL-3.0-only" license = "LGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients." description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
@ -11,6 +11,7 @@ edition = "2021"
[dependencies] [dependencies]
async_io_stream = "0.3.3" async_io_stream = "0.3.3"
bytes = "1.5.0" bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0" event-listener = "5.0.0"
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
futures = "0.3.30" futures = "0.3.30"

View file

@ -8,7 +8,9 @@ impl From<OpCode> for crate::ws::OpCode {
fn from(opcode: OpCode) -> Self { fn from(opcode: OpCode) -> Self {
use OpCode::*; use OpCode::*;
match opcode { match opcode {
Continuation => unreachable!("continuation should never be recieved when using a fragmentcollector"), Continuation => {
unreachable!("continuation should never be recieved when using a fragmentcollector")
}
Text => Self::Text, Text => Self::Text,
Binary => Self::Binary, Binary => Self::Binary,
Close => Self::Close, Close => Self::Close,
@ -70,8 +72,6 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> { impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.into()) self.write_frame(frame.into()).await.map_err(|e| e.into())
.await
.map_err(|e| e.into())
} }
} }

View file

@ -16,14 +16,13 @@ pub use crate::packet::*;
pub use crate::stream::*; pub use crate::stream::*;
use bytes::Bytes; use bytes::Bytes;
use dashmap::DashMap;
use event_listener::Event; use event_listener::Event;
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt}; use futures::SinkExt;
use std::{ use futures::{channel::mpsc, Future, FutureExt, StreamExt};
collections::HashMap, use std::sync::{
sync::{ atomic::{AtomicBool, AtomicU32, Ordering},
atomic::{AtomicBool, AtomicU32, Ordering}, Arc,
Arc,
},
}; };
/// The role of the multiplexor. /// The role of the multiplexor.
@ -72,6 +71,8 @@ pub enum WispError {
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
/// Other error. /// Other error.
Other(Box<dyn std::error::Error + Sync + Send>), Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task.
MuxMessageFailedToSend,
} }
impl From<std::str::Utf8Error> for WispError { impl From<std::str::Utf8Error> for WispError {
@ -82,25 +83,29 @@ impl From<std::str::Utf8Error> for WispError {
impl std::fmt::Display for WispError { impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use WispError::*;
match self { match self {
PacketTooSmall => write!(f, "Packet too small"), Self::PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"), Self::InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"), Self::InvalidStreamType => write!(f, "Invalid stream type"),
InvalidStreamId => write!(f, "Invalid stream id"), Self::InvalidStreamId => write!(f, "Invalid stream id"),
InvalidCloseReason => write!(f, "Invalid close reason"), Self::InvalidCloseReason => write!(f, "Invalid close reason"),
InvalidUri => write!(f, "Invalid URI"), Self::InvalidUri => write!(f, "Invalid URI"),
UriHasNoHost => write!(f, "URI has no host"), Self::UriHasNoHost => write!(f, "URI has no host"),
UriHasNoPort => write!(f, "URI has no port"), Self::UriHasNoPort => write!(f, "URI has no port"),
MaxStreamCountReached => write!(f, "Maximum stream count reached"), Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
StreamAlreadyClosed => write!(f, "Stream already closed"), Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"), Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"), Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {}", err), Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"), Self::WsImplSocketClosed => {
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), write!(f, "Websocket implementation error: websocket closed")
Utf8Error(err) => write!(f, "UTF-8 error: {}", err), }
Other(err) => write!(f, "Other error: {}", err), Self::WsImplNotSupported => {
write!(f, "Websocket implementation error: unsupported feature")
}
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err),
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
} }
} }
} }
@ -115,61 +120,74 @@ struct MuxMapValue {
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
} }
struct ServerMuxInner<W> struct MuxInner<W>
where where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>, stream_map: Arc<DashMap<u32, MuxMapValue>>,
close_tx: mpsc::UnboundedSender<WsEvent>,
} }
impl<W: ws::WebSocketWrite> ServerMuxInner<W> { impl<W: ws::WebSocketWrite> MuxInner<W> {
pub async fn into_future<R>( pub async fn server_into_future<R>(
self, self,
rx: R, rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>, close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32, buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
self.into_future(
close_rx,
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
)
.await
}
pub async fn client_into_future<R>(
self,
rx: R,
close_rx: mpsc::Receiver<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
self.into_future(close_rx, self.client_loop(rx)).await
}
async fn into_future(
&self,
close_rx: mpsc::Receiver<WsEvent>,
wisp_fut: impl Future<Output = Result<(), WispError>>,
) -> Result<(), WispError> {
let ret = futures::select! { let ret = futures::select! {
x = self.server_bg_loop(close_rx).fuse() => x, _ = self.stream_loop(close_rx).fuse() => Ok(()),
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x x = wisp_fut.fuse() => x,
}; };
self.stream_map.lock().await.drain().for_each(|mut x| { self.stream_map.iter_mut().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release); x.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect(); x.stream.disconnect();
x.1.stream.close_channel(); x.stream.close_channel();
}); });
self.stream_map.clear();
ret ret
} }
async fn server_bg_loop( async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
&self, while let Some(msg) = stream_rx.next().await {
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg { match msg {
WsEvent::SendPacket(packet, channel) => { WsEvent::SendPacket(packet, channel) => {
if self if self.stream_map.get(&packet.stream_id).is_some() {
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
let _ = channel.send(self.tx.write_frame(packet.into()).await); let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else { } else {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
} }
} }
WsEvent::Close(packet, channel) => { WsEvent::Close(packet, channel) => {
if let Some(mut stream) = if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.stream.disconnect(); stream.stream.disconnect();
stream.stream.close_channel(); stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await); let _ = channel.send(self.tx.write_frame(packet.into()).await);
@ -180,20 +198,20 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
WsEvent::EndFut => break, WsEvent::EndFut => break,
} }
} }
Ok(())
} }
async fn server_msg_loop<R>( async fn server_loop<R>(
&self, &self,
mut rx: R, mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32, buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
// will send continues once flow_control is at 10% of max // will send continues once flow_control is at 10% of max
let target_buffer_size = buffer_size * 90 / 100; let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
self.tx self.tx
.write_frame(Packet::new_continue(0, buffer_size).into()) .write_frame(Packet::new_continue(0, buffer_size).into())
.await?; .await?;
@ -214,7 +232,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
let flow_control_event: Arc<Event> = Event::new().into(); let flow_control_event: Arc<Event> = Event::new().into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into(); let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.stream_map.lock().await.insert( self.stream_map.insert(
packet.stream_id, packet.stream_id,
MuxMapValue { MuxMapValue {
stream: ch_tx, stream: ch_tx,
@ -232,7 +250,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
Role::Server, Role::Server,
stream_type, stream_type,
ch_rx, ch_rx,
self.close_tx.clone(), close_tx.clone(),
is_closed, is_closed,
flow_control, flow_control,
flow_control_event, flow_control_event,
@ -242,7 +260,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data); let _ = stream.stream.unbounded_send(data);
if stream.stream_type == StreamType::Tcp { if stream.stream_type == StreamType::Tcp {
stream.flow_control.store( stream.flow_control.store(
@ -257,9 +275,47 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
} }
Continue(_) => break Err(WispError::InvalidPacketType), Continue(_) => break Err(WispError::InvalidPacketType),
Close(_) => { Close(_) => {
if let Some(mut stream) = if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
self.stream_map.lock().await.remove(&packet.stream_id) stream.is_closed.store(true, Ordering::Release);
{ stream.stream.disconnect();
stream.stream.close_channel();
}
}
}
}
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
let packet = Packet::try_from(frame)?;
use PacketType::*;
match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect(); stream.stream.disconnect();
stream.stream.close_channel(); stream.stream.close_channel();
@ -290,8 +346,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
/// } /// }
/// ``` /// ```
pub struct ServerMux { pub struct ServerMux {
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>, close_tx: mpsc::Sender<WsEvent>,
close_tx: mpsc::UnboundedSender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
} }
@ -305,22 +360,19 @@ impl ServerMux {
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>(); let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(write); let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new()));
( (
Self { Self {
muxstream_recv: rx, muxstream_recv: rx,
close_tx: close_tx.clone(), close_tx: close_tx.clone(),
stream_map: map.clone(),
}, },
ServerMuxInner { MuxInner {
tx: write, tx: write,
close_tx, stream_map: DashMap::new().into(),
stream_map: map.clone(),
} }
.into_future(read, close_rx, tx, buffer_size), .server_into_future(read, close_rx, tx, buffer_size, close_tx),
) )
} }
@ -333,124 +385,13 @@ impl ServerMux {
/// ///
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
/// this function is called. /// this function is called.
pub async fn close(&self) { pub async fn close(&mut self) -> Result<(), WispError> {
self.stream_map.lock().await.drain().for_each(|mut x| { self.close_tx
x.1.is_closed.store(true, Ordering::Release); .send(WsEvent::EndFut)
x.1.stream.disconnect(); .await
x.1.stream.close_channel(); .map_err(|_| WispError::MuxMessageFailedToSend)
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
} }
} }
struct ClientMuxInner<W>
where
W: ws::WebSocketWrite,
{
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
}
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
pub(crate) async fn into_future<R>(
self,
rx: R,
close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
let ret = futures::select! {
x = self.client_bg_loop(close_rx).fuse() => x,
x = self.client_loop(rx).fuse() => x
};
self.stream_map.lock().await.drain().for_each(|mut x| {
x.1.is_closed.store(true, Ordering::Release);
x.1.stream.disconnect();
x.1.stream.close_channel();
});
ret
}
async fn client_bg_loop(
&self,
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await {
match msg {
WsEvent::SendPacket(packet, channel) => {
if self
.stream_map
.lock()
.await
.get(&packet.stream_id)
.is_some()
{
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::Close(packet, channel) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.stream.disconnect();
stream.stream.close_channel();
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut => break,
}
}
Ok(())
}
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
where
R: ws::WebSocketRead,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
let packet = Packet::try_from(frame)?;
use PacketType::*;
match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data);
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(_) => {
if let Some(mut stream) =
self.stream_map.lock().await.remove(&packet.stream_id)
{
stream.is_closed.store(true, Ordering::Release);
stream.stream.disconnect();
stream.stream.close_channel();
}
}
}
}
}
}
/// Client side multiplexor. /// Client side multiplexor.
/// ///
/// # Example /// # Example
@ -470,9 +411,9 @@ where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>, stream_map: Arc<DashMap<u32, MuxMapValue>>,
next_free_stream_id: AtomicU32, next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
buf_size: u32, buf_size: u32,
target_buf_size: u32, target_buf_size: u32,
} }
@ -492,23 +433,23 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
return Err(WispError::InvalidStreamId); return Err(WispError::InvalidStreamId);
} }
if let PacketType::Continue(packet) = first_packet.packet_type { if let PacketType::Continue(packet) = first_packet.packet_type {
let (tx, rx) = mpsc::unbounded::<WsEvent>(); let (tx, rx) = mpsc::channel::<WsEvent>(256);
let map = Arc::new(Mutex::new(HashMap::new())); let map = Arc::new(DashMap::new());
Ok(( Ok((
Self { Self {
tx: write.clone(), tx: write.clone(),
stream_map: map.clone(), stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1), next_free_stream_id: AtomicU32::new(1),
close_tx: tx, close_tx: tx.clone(),
buf_size: packet.buffer_remaining, buf_size: packet.buffer_remaining,
// server-only // server-only
target_buf_size: 0, target_buf_size: 0,
}, },
ClientMuxInner { MuxInner {
tx: write.clone(), tx: write.clone(),
stream_map: map.clone(), stream_map: map.clone(),
} }
.into_future(read, rx), .client_into_future(read, rx),
)) ))
} else { } else {
Err(WispError::InvalidPacketType) Err(WispError::InvalidPacketType)
@ -540,7 +481,7 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
self.next_free_stream_id self.next_free_stream_id
.store(next_stream_id, Ordering::Release); .store(next_stream_id, Ordering::Release);
self.stream_map.lock().await.insert( self.stream_map.insert(
stream_id, stream_id,
MuxMapValue { MuxMapValue {
stream: ch_tx, stream: ch_tx,
@ -568,12 +509,10 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
/// ///
/// Also terminates the multiplexor future. Creating a stream is UB after calling this /// Also terminates the multiplexor future. Creating a stream is UB after calling this
/// function. /// function.
pub async fn close(&self) { pub async fn close(&mut self) -> Result<(), WispError> {
self.stream_map.lock().await.drain().for_each(|mut x| { self.close_tx
x.1.is_closed.store(true, Ordering::Release); .send(WsEvent::EndFut)
x.1.stream.disconnect(); .await
x.1.stream.close_channel(); .map_err(|_| WispError::MuxMessageFailedToSend)
});
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
} }
} }

View file

@ -1,5 +1,5 @@
use crate::{ws, WispError}; use crate::{ws, WispError};
use bytes::{Buf, BufMut, Bytes}; use bytes::{Buf, BufMut, Bytes, BytesMut};
/// Wisp stream type. /// Wisp stream type.
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
@ -115,13 +115,13 @@ impl TryFrom<Bytes> for ConnectPacket {
} }
} }
impl From<ConnectPacket> for Vec<u8> { impl From<ConnectPacket> for Bytes {
fn from(packet: ConnectPacket) -> Self { fn from(packet: ConnectPacket) -> Self {
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len()); let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
encoded.put_u8(packet.stream_type as u8); encoded.put_u8(packet.stream_type as u8);
encoded.put_u16_le(packet.destination_port); encoded.put_u16_le(packet.destination_port);
encoded.extend(packet.destination_hostname.bytes()); encoded.extend(packet.destination_hostname.bytes());
encoded encoded.freeze()
} }
} }
@ -153,11 +153,11 @@ impl TryFrom<Bytes> for ContinuePacket {
} }
} }
impl From<ContinuePacket> for Vec<u8> { impl From<ContinuePacket> for Bytes {
fn from(packet: ContinuePacket) -> Self { fn from(packet: ContinuePacket) -> Self {
let mut encoded = Self::with_capacity(4); let mut encoded = BytesMut::with_capacity(4);
encoded.put_u32_le(packet.buffer_remaining); encoded.put_u32_le(packet.buffer_remaining);
encoded encoded.freeze()
} }
} }
@ -190,11 +190,11 @@ impl TryFrom<Bytes> for ClosePacket {
} }
} }
impl From<ClosePacket> for Vec<u8> { impl From<ClosePacket> for Bytes {
fn from(packet: ClosePacket) -> Self { fn from(packet: ClosePacket) -> Self {
let mut encoded = Self::with_capacity(1); let mut encoded = BytesMut::with_capacity(1);
encoded.put_u8(packet.reason as u8); encoded.put_u8(packet.reason as u8);
encoded encoded.freeze()
} }
} }
@ -224,12 +224,12 @@ impl PacketType {
} }
} }
impl From<PacketType> for Vec<u8> { impl From<PacketType> for Bytes {
fn from(packet: PacketType) -> Self { fn from(packet: PacketType) -> Self {
use PacketType::*; use PacketType::*;
match packet { match packet {
Connect(x) => x.into(), Connect(x) => x.into(),
Data(x) => x.to_vec(), Data(x) => x,
Continue(x) => x.into(), Continue(x) => x.into(),
Close(x) => x.into(), Close(x) => x.into(),
} }
@ -250,7 +250,10 @@ impl Packet {
/// ///
/// The helper functions should be used for most use cases. /// The helper functions should be used for most use cases.
pub fn new(stream_id: u32, packet: PacketType) -> Self { pub fn new(stream_id: u32, packet: PacketType) -> Self {
Self { stream_id, packet_type: packet } Self {
stream_id,
packet_type: packet,
}
} }
/// Create a new connect packet. /// Create a new connect packet.
@ -316,13 +319,15 @@ impl TryFrom<Bytes> for Packet {
} }
} }
impl From<Packet> for Vec<u8> { impl From<Packet> for Bytes {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
let mut encoded = Self::with_capacity(1 + 4); let inner_u8 = packet.packet_type.as_u8();
encoded.push(packet.packet_type.as_u8()); let inner = Bytes::from(packet.packet_type);
let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len());
encoded.put_u8(inner_u8);
encoded.put_u32_le(packet.stream_id); encoded.put_u32_le(packet.stream_id);
encoded.extend(Vec::<u8>::from(packet.packet_type)); encoded.extend(inner);
encoded encoded.freeze()
} }
} }
@ -341,6 +346,6 @@ impl TryFrom<ws::Frame> for Packet {
impl From<Packet> for ws::Frame { impl From<Packet> for ws::Frame {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
Self::binary(Vec::<u8>::from(packet).into()) Self::binary(packet.into())
} }
} }

View file

@ -45,28 +45,42 @@ pin_project! {
/// Sink for the [`unfold`] function. /// Sink for the [`unfold`] function.
#[derive(Debug)] #[derive(Debug)]
#[must_use = "sinks do nothing unless polled"] #[must_use = "sinks do nothing unless polled"]
pub struct Unfold<T, F, FC, R> { pub struct Unfold<T, F, R, CT, CF, CR> {
function: F, function: F,
close_function: FC, close_function: CF,
#[pin] #[pin]
state: UnfoldState<T, R>, state: UnfoldState<T, R>,
#[pin]
close_state: UnfoldState<CT, CR>
} }
} }
pub(crate) fn unfold<T, F, FC, R, Item, E>(init: T, function: F, close_function: FC) -> Unfold<T, F, FC, R> pub(crate) fn unfold<T, F, R, CT, CF, CR, Item, E>(
init: T,
function: F,
close_init: CT,
close_function: CF,
) -> Unfold<T, F, R, CT, CF, CR>
where where
F: FnMut(T, Item) -> R, F: FnMut(T, Item) -> R,
R: Future<Output = Result<T, E>>, R: Future<Output = Result<T, E>>,
FC: Fn() -> Result<(), E>, CF: FnMut(CT) -> CR,
CR: Future<Output = Result<CT, E>>,
{ {
Unfold { function, close_function, state: UnfoldState::Value { value: init } } Unfold {
function,
close_function,
state: UnfoldState::Value { value: init },
close_state: UnfoldState::Value { value: close_init },
}
} }
impl<T, F, FC, R, Item, E> Sink<Item> for Unfold<T, F, FC, R> impl<T, F, R, CT, CF, CR, Item, E> Sink<Item> for Unfold<T, F, R, CT, CF, CR>
where where
F: FnMut(T, Item) -> R, F: FnMut(T, Item) -> R,
R: Future<Output = Result<T, E>>, R: Future<Output = Result<T, E>>,
FC: Fn() -> Result<(), E>, CF: FnMut(CT) -> CR,
CR: Future<Output = Result<CT, E>>,
{ {
type Error = E; type Error = E;
@ -104,6 +118,27 @@ where
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?; ready!(self.as_mut().poll_flush(cx))?;
Poll::Ready((self.close_function)()) let mut this = self.project();
Poll::Ready(
if let Some(future) = this.close_state.as_mut().project_future() {
match ready!(future.poll(cx)) {
Ok(state) => {
this.close_state.set(UnfoldState::Value { value: state });
Ok(())
}
Err(err) => {
this.close_state.set(UnfoldState::Empty);
Err(err)
}
}
} else {
let future = match this.close_state.as_mut().take_value() {
Some(value) => (this.close_function)(value),
None => panic!("start_send called without poll_ready being called first"),
};
this.close_state.set(UnfoldState::Future { future });
return Poll::Pending;
},
)
} }
} }

View file

@ -7,7 +7,7 @@ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
stream, stream,
task::{Context, Poll}, task::{Context, Poll},
Sink, Stream, StreamExt, Sink, SinkExt, Stream, StreamExt,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -31,7 +31,7 @@ pub struct MuxStreamRead {
/// Type of the stream. /// Type of the stream.
pub stream_type: StreamType, pub stream_type: StreamType,
role: Role, role: Role,
tx: mpsc::UnboundedSender<WsEvent>, tx: mpsc::Sender<WsEvent>,
rx: mpsc::UnboundedReceiver<Bytes>, rx: mpsc::UnboundedReceiver<Bytes>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
@ -51,13 +51,14 @@ impl MuxStreamRead {
if val > self.target_flow_control { if val > self.target_flow_control {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx self.tx
.unbounded_send(WsEvent::SendPacket( .send(WsEvent::SendPacket(
Packet::new_continue( Packet::new_continue(
self.stream_id, self.stream_id,
self.flow_control.fetch_add(val, Ordering::AcqRel) + val, self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
), ),
tx, tx,
)) ))
.await
.ok()?; .ok()?;
rx.await.ok()?.ok()?; rx.await.ok()?.ok()?;
self.flow_control_read.store(0, Ordering::Release); self.flow_control_read.store(0, Ordering::Release);
@ -80,7 +81,7 @@ pub struct MuxStreamWrite {
/// Type of the stream. /// Type of the stream.
pub stream_type: StreamType, pub stream_type: StreamType,
role: Role, role: Role,
tx: mpsc::UnboundedSender<WsEvent>, tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
@ -88,7 +89,7 @@ pub struct MuxStreamWrite {
impl MuxStreamWrite { impl MuxStreamWrite {
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
@ -100,10 +101,11 @@ impl MuxStreamWrite {
} }
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx self.tx
.unbounded_send(WsEvent::SendPacket( .send(WsEvent::SendPacket(
Packet::new_data(self.stream_id, data), Packet::new_data(self.stream_id, data),
tx, tx,
)) ))
.await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??; rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
if self.role == Role::Client && self.stream_type == StreamType::Tcp { if self.role == Role::Client && self.stream_type == StreamType::Tcp {
@ -135,7 +137,7 @@ impl MuxStreamWrite {
} }
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
@ -143,10 +145,11 @@ impl MuxStreamWrite {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.tx self.tx
.unbounded_send(WsEvent::Close( .send(WsEvent::Close(
Packet::new_close(self.stream_id, reason), Packet::new_close(self.stream_id, reason),
tx, tx,
)) ))
.await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??; rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
@ -157,25 +160,19 @@ impl MuxStreamWrite {
let handle = self.get_close_handle(); let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold( Box::pin(sink_unfold::unfold(
self, self,
|tx, data| async move { |mut tx, data| async move {
tx.write(data).await?; tx.write(data).await?;
Ok(tx) Ok(tx)
}, },
move || handle.close_sync(CloseReason::Unknown), handle,
move |mut handle| async {
handle.close(CloseReason::Unknown).await?;
Ok(handle)
},
)) ))
} }
} }
impl Drop for MuxStreamWrite {
fn drop(&mut self) {
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
let _ = self.tx.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, CloseReason::Unknown),
tx,
));
}
}
/// Multiplexor stream. /// Multiplexor stream.
pub struct MuxStream { pub struct MuxStream {
/// ID of the stream. /// ID of the stream.
@ -191,7 +188,7 @@ impl MuxStream {
role: Role, role: Role,
stream_type: StreamType, stream_type: StreamType,
rx: mpsc::UnboundedReceiver<Bytes>, rx: mpsc::UnboundedReceiver<Bytes>,
tx: mpsc::UnboundedSender<WsEvent>, tx: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
@ -228,7 +225,7 @@ impl MuxStream {
} }
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
self.tx.write(data).await self.tx.write(data).await
} }
@ -248,7 +245,7 @@ impl MuxStream {
} }
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
self.tx.close(reason).await self.tx.close(reason).await
} }
@ -271,13 +268,13 @@ impl MuxStream {
pub struct MuxStreamCloser { pub struct MuxStreamCloser {
/// ID of the stream. /// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
close_channel: mpsc::UnboundedSender<WsEvent>, close_channel: mpsc::Sender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
} }
impl MuxStreamCloser { impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called. /// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> { pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
@ -285,32 +282,16 @@ impl MuxStreamCloser {
let (tx, rx) = oneshot::channel::<Result<(), WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel self.close_channel
.unbounded_send(WsEvent::Close( .send(WsEvent::Close(
Packet::new_close(self.stream_id, reason), Packet::new_close(self.stream_id, reason),
tx, tx,
)) ))
.await
.map_err(|x| WispError::Other(Box::new(x)))?; .map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??; rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
Ok(()) Ok(())
} }
pub(crate) fn close_sync(&self, reason: CloseReason) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.is_closed.store(true, Ordering::Release);
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.unbounded_send(WsEvent::Close(
Packet::new_close(self.stream_id, reason),
tx,
))
.map_err(|x| WispError::Other(Box::new(x)))?;
Ok(())
}
} }
pin_project! { pin_project! {
@ -336,10 +317,7 @@ impl MuxStreamIo {
impl Stream for MuxStreamIo { impl Stream for MuxStreamIo {
type Item = Result<Vec<u8>, std::io::Error>; type Item = Result<Vec<u8>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project() self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec())))
.rx
.poll_next(cx)
.map(|x| x.map(|x| Ok(x.to_vec())))
} }
} }