mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
remove the mutex<hashmap> in wisp_mux, other improvements
This commit is contained in:
parent
ff2a1ad269
commit
7001ee8fa5
16 changed files with 346 additions and 309 deletions
22
Cargo.lock
generated
22
Cargo.lock
generated
|
@ -153,6 +153,12 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-counter"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62f447d68cfa5a9ab0c1c862a703da2a65b5ed1b7ce1153c9eb0169506d56019"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
|
@ -516,7 +522,7 @@ checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
|
|||
|
||||
[[package]]
|
||||
name = "epoxy-client"
|
||||
version = "1.4.2"
|
||||
version = "1.5.0"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"async_io_stream",
|
||||
|
@ -1727,18 +1733,29 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
|
|||
name = "simple-wisp-client"
|
||||
version = "1.0.0"
|
||||
dependencies = [
|
||||
"atomic-counter",
|
||||
"bytes",
|
||||
"console-subscriber",
|
||||
"fastwebsockets 0.7.1",
|
||||
"futures",
|
||||
"http-body-util",
|
||||
"hyper 1.2.0",
|
||||
"simple_moving_average",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"wisp-mux",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simple_moving_average"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a4b144ad185430cd033299e2c93e465d5a7e65fbb858593dc57181fa13cd310"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
|
@ -2500,10 +2517,11 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
|
|||
|
||||
[[package]]
|
||||
name = "wisp-mux"
|
||||
version = "2.0.2"
|
||||
version = "3.0.0"
|
||||
dependencies = [
|
||||
"async_io_stream",
|
||||
"bytes",
|
||||
"dashmap",
|
||||
"event-listener",
|
||||
"fastwebsockets 0.7.1",
|
||||
"futures",
|
||||
|
|
|
@ -8,8 +8,10 @@ rustls-pki-types = { git = "https://github.com/r58Playz/rustls-pki-types" }
|
|||
|
||||
[profile.release]
|
||||
lto = true
|
||||
opt-level = 'z'
|
||||
strip = true
|
||||
debug = true
|
||||
panic = "abort"
|
||||
codegen-units = 1
|
||||
opt-level = 3
|
||||
|
||||
[profile.release.package.epoxy-client]
|
||||
opt-level = 'z'
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "epoxy-client"
|
||||
version = "1.4.2"
|
||||
version = "1.5.0"
|
||||
edition = "2021"
|
||||
license = "LGPL-3.0-only"
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@mercuryworkshop/epoxy-tls",
|
||||
"version": "1.4.2",
|
||||
"version": "1.5.0",
|
||||
"description": "A wasm library for using raw encrypted tls/ssl/https/websocket streams on the browser",
|
||||
"scripts": {
|
||||
"build": "./build.sh"
|
||||
|
|
|
@ -47,8 +47,8 @@ enum EpxCompression {
|
|||
Gzip,
|
||||
}
|
||||
|
||||
type EpxIoTlsStream = TlsStream<IoStream<MuxStreamIo, Vec<u8>>>;
|
||||
type EpxIoUnencryptedStream = IoStream<MuxStreamIo, Vec<u8>>;
|
||||
type EpxIoTlsStream = TlsStream<EpxIoUnencryptedStream>;
|
||||
type EpxIoStream = Either<EpxIoTlsStream, EpxIoUnencryptedStream>;
|
||||
|
||||
#[wasm_bindgen(start)]
|
||||
|
|
|
@ -231,7 +231,7 @@ pub async fn replace_mux(
|
|||
) -> Result<(), WispError> {
|
||||
let (mux_replace, fut) = make_mux(url).await?;
|
||||
let mut mux_write = mux.write().await;
|
||||
mux_write.close().await;
|
||||
mux_write.close().await?;
|
||||
*mux_write = mux_replace;
|
||||
drop(mux_write);
|
||||
spawn_mux_fut(mux, fut, url.into());
|
||||
|
|
|
@ -56,7 +56,7 @@ impl Stream for IncomingBody {
|
|||
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux<WebSocketWrapper>>>, pub String);
|
||||
|
||||
impl tower_service::Service<hyper::Uri> for ServiceWrapper {
|
||||
type Response = TokioIo<IoStream<MuxStreamIo, Vec<u8>>>;
|
||||
type Response = TokioIo<EpxIoUnencryptedStream>;
|
||||
type Error = WispError;
|
||||
type Future = impl Future<Output = Result<Self::Response, Self::Error>>;
|
||||
|
||||
|
|
|
@ -239,7 +239,7 @@ async fn accept_ws(
|
|||
|
||||
println!("{:?}: connected", addr);
|
||||
|
||||
let (mut mux, fut) = ServerMux::new(rx, tx, 128);
|
||||
let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = fut.await {
|
||||
|
@ -247,7 +247,7 @@ async fn accept_ws(
|
|||
}
|
||||
});
|
||||
|
||||
while let Some((packet, stream)) = mux.server_new_stream().await {
|
||||
while let Some((packet, mut stream)) = mux.server_new_stream().await {
|
||||
tokio::spawn(async move {
|
||||
if block_local {
|
||||
match lookup_host(format!(
|
||||
|
@ -272,8 +272,8 @@ async fn accept_ws(
|
|||
}
|
||||
}
|
||||
}
|
||||
let close_err = stream.get_close_handle();
|
||||
let close_ok = stream.get_close_handle();
|
||||
let mut close_err = stream.get_close_handle();
|
||||
let mut close_ok = stream.get_close_handle();
|
||||
let _ = handle_mux(packet, stream)
|
||||
.or_else(|err| async move {
|
||||
let _ = close_err.close(CloseReason::Unexpected).await;
|
||||
|
|
|
@ -4,12 +4,14 @@ version = "1.0.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
atomic-counter = "1.0.1"
|
||||
bytes = "1.5.0"
|
||||
console-subscriber = { version = "0.2.0", optional = true }
|
||||
fastwebsockets = { version = "0.7.1", features = ["unstable-split", "upgrade"] }
|
||||
futures = "0.3.30"
|
||||
http-body-util = "0.1.0"
|
||||
hyper = { version = "1.1.0", features = ["http1", "client"] }
|
||||
simple_moving_average = "1.0.2"
|
||||
tokio = { version = "1.36.0", features = ["full"] }
|
||||
tokio-native-tls = "0.3.1"
|
||||
tokio-util = "0.7.10"
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
use atomic_counter::{AtomicCounter, RelaxedCounter};
|
||||
use bytes::Bytes;
|
||||
use fastwebsockets::{handshake, FragmentCollectorRead};
|
||||
use futures::io::AsyncWriteExt;
|
||||
use futures::future::select_all;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{
|
||||
header::{CONNECTION, UPGRADE},
|
||||
Request,
|
||||
};
|
||||
use std::{error::Error, future::Future};
|
||||
use tokio::net::TcpStream;
|
||||
use simple_moving_average::{SingleSumSMA, SMA};
|
||||
use std::{
|
||||
error::Error,
|
||||
future::Future,
|
||||
io::{stdout, IsTerminal, Write},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
usize,
|
||||
};
|
||||
use tokio::{net::TcpStream, time::interval};
|
||||
use tokio_native_tls::{native_tls, TlsConnector};
|
||||
use wisp_mux::{ClientMux, StreamType};
|
||||
use tokio_util::either::Either;
|
||||
use wisp_mux::{ClientMux, StreamType, WispError};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StrError(String);
|
||||
|
@ -70,6 +79,18 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
.nth(6)
|
||||
.ok_or(StrError::new("no should tls"))?
|
||||
.parse()?;
|
||||
let thread_cnt: usize = std::env::args().nth(7).unwrap_or("10".into()).parse()?;
|
||||
|
||||
println!(
|
||||
"connecting to {}://{}:{}{} and sending &[0; 1024] to {}:{} with threads {}",
|
||||
if should_tls { "wss" } else { "ws" },
|
||||
addr,
|
||||
addr_port,
|
||||
addr_path,
|
||||
addr_dest,
|
||||
addr_dest_port,
|
||||
thread_cnt
|
||||
);
|
||||
|
||||
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
|
||||
let socket = if should_tls {
|
||||
|
@ -98,23 +119,59 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
let rx = FragmentCollectorRead::new(rx);
|
||||
|
||||
let (mux, fut) = ClientMux::new(rx, tx).await?;
|
||||
let mut threads = Vec::with_capacity(thread_cnt + 1);
|
||||
|
||||
tokio::task::spawn(async move { println!("err: {:?}", fut.await); });
|
||||
threads.push(tokio::spawn(fut));
|
||||
|
||||
let mut hi: u64 = 0;
|
||||
loop {
|
||||
let payload = Bytes::from_static(&[0; 1024]);
|
||||
|
||||
let cnt = Arc::new(RelaxedCounter::new(0));
|
||||
|
||||
for _ in 0..thread_cnt {
|
||||
let mut channel = mux
|
||||
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
|
||||
.await?
|
||||
.into_io()
|
||||
.into_asyncrw();
|
||||
for _ in 0..256 {
|
||||
channel.write_all(b"hiiiiiiii").await?;
|
||||
hi += 1;
|
||||
println!("said hi {}", hi);
|
||||
.await?;
|
||||
let cnt = cnt.clone();
|
||||
let payload = payload.clone();
|
||||
threads.push(tokio::spawn(async move {
|
||||
loop {
|
||||
channel.write(payload.clone()).await?;
|
||||
channel.read().await;
|
||||
cnt.inc();
|
||||
}
|
||||
#[allow(unreachable_code)]
|
||||
Ok::<(), WispError>(())
|
||||
}));
|
||||
}
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
threads.push(tokio::spawn(async move {
|
||||
let mut interval = interval(Duration::from_millis(100));
|
||||
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
|
||||
let mut last_time = 0;
|
||||
let is_term = stdout().is_terminal();
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = cnt.get();
|
||||
let stat = format!(
|
||||
"sent &[0; 1024] cnt: {:?}, +{:?}, moving average (100): {:?}",
|
||||
now,
|
||||
now - last_time,
|
||||
avg.get_average()
|
||||
);
|
||||
if is_term {
|
||||
print!("\x1b[2K{}\r", stat);
|
||||
} else {
|
||||
println!("{}", stat);
|
||||
}
|
||||
stdout().flush().unwrap();
|
||||
avg.add_sample(now - last_time);
|
||||
last_time = now;
|
||||
}
|
||||
}));
|
||||
|
||||
let out = select_all(threads.into_iter()).await;
|
||||
|
||||
println!("\n\nout: {:?}", out.0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "wisp-mux"
|
||||
version = "2.0.2"
|
||||
version = "3.0.0"
|
||||
license = "LGPL-3.0-only"
|
||||
description = "A library for easily creating Wisp servers and clients."
|
||||
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
|
||||
|
@ -11,6 +11,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
async_io_stream = "0.3.3"
|
||||
bytes = "1.5.0"
|
||||
dashmap = { version = "5.5.3", features = ["inline"] }
|
||||
event-listener = "5.0.0"
|
||||
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
|
||||
futures = "0.3.30"
|
||||
|
|
|
@ -8,7 +8,9 @@ impl From<OpCode> for crate::ws::OpCode {
|
|||
fn from(opcode: OpCode) -> Self {
|
||||
use OpCode::*;
|
||||
match opcode {
|
||||
Continuation => unreachable!("continuation should never be recieved when using a fragmentcollector"),
|
||||
Continuation => {
|
||||
unreachable!("continuation should never be recieved when using a fragmentcollector")
|
||||
}
|
||||
Text => Self::Text,
|
||||
Binary => Self::Binary,
|
||||
Close => Self::Close,
|
||||
|
@ -70,8 +72,6 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
|
|||
|
||||
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
|
||||
self.write_frame(frame.into())
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
|
333
wisp/src/lib.rs
333
wisp/src/lib.rs
|
@ -16,14 +16,13 @@ pub use crate::packet::*;
|
|||
pub use crate::stream::*;
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use event_listener::Event;
|
||||
use futures::{channel::mpsc, lock::Mutex, Future, FutureExt, StreamExt};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
use futures::SinkExt;
|
||||
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
/// The role of the multiplexor.
|
||||
|
@ -72,6 +71,8 @@ pub enum WispError {
|
|||
Utf8Error(std::str::Utf8Error),
|
||||
/// Other error.
|
||||
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||
/// Failed to send message to multiplexor task.
|
||||
MuxMessageFailedToSend,
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for WispError {
|
||||
|
@ -82,25 +83,29 @@ impl From<std::str::Utf8Error> for WispError {
|
|||
|
||||
impl std::fmt::Display for WispError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
use WispError::*;
|
||||
match self {
|
||||
PacketTooSmall => write!(f, "Packet too small"),
|
||||
InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
InvalidStreamType => write!(f, "Invalid stream type"),
|
||||
InvalidStreamId => write!(f, "Invalid stream id"),
|
||||
InvalidCloseReason => write!(f, "Invalid close reason"),
|
||||
InvalidUri => write!(f, "Invalid URI"),
|
||||
UriHasNoHost => write!(f, "URI has no host"),
|
||||
UriHasNoPort => write!(f, "URI has no port"),
|
||||
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||
StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
|
||||
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
|
||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Other(err) => write!(f, "Other error: {}", err),
|
||||
Self::PacketTooSmall => write!(f, "Packet too small"),
|
||||
Self::InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
Self::InvalidStreamType => write!(f, "Invalid stream type"),
|
||||
Self::InvalidStreamId => write!(f, "Invalid stream id"),
|
||||
Self::InvalidCloseReason => write!(f, "Invalid close reason"),
|
||||
Self::InvalidUri => write!(f, "Invalid URI"),
|
||||
Self::UriHasNoHost => write!(f, "URI has no host"),
|
||||
Self::UriHasNoPort => write!(f, "URI has no port"),
|
||||
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err),
|
||||
Self::WsImplSocketClosed => {
|
||||
write!(f, "Websocket implementation error: websocket closed")
|
||||
}
|
||||
Self::WsImplNotSupported => {
|
||||
write!(f, "Websocket implementation error: unsupported feature")
|
||||
}
|
||||
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -115,61 +120,74 @@ struct MuxMapValue {
|
|||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
struct ServerMuxInner<W>
|
||||
struct MuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
||||
pub async fn into_future<R>(
|
||||
impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||
pub async fn server_into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.into_future(
|
||||
close_rx,
|
||||
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn client_into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.into_future(close_rx, self.client_loop(rx)).await
|
||||
}
|
||||
|
||||
async fn into_future(
|
||||
&self,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
wisp_fut: impl Future<Output = Result<(), WispError>>,
|
||||
) -> Result<(), WispError> {
|
||||
let ret = futures::select! {
|
||||
x = self.server_bg_loop(close_rx).fuse() => x,
|
||||
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
|
||||
_ = self.stream_loop(close_rx).fuse() => Ok(()),
|
||||
x = wisp_fut.fuse() => x,
|
||||
};
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
self.stream_map.iter_mut().for_each(|mut x| {
|
||||
x.is_closed.store(true, Ordering::Release);
|
||||
x.stream.disconnect();
|
||||
x.stream.close_channel();
|
||||
});
|
||||
self.stream_map.clear();
|
||||
ret
|
||||
}
|
||||
|
||||
async fn server_bg_loop(
|
||||
&self,
|
||||
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
|
||||
while let Some(msg) = stream_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self
|
||||
.stream_map
|
||||
.lock()
|
||||
.await
|
||||
.get(&packet.stream_id)
|
||||
.is_some()
|
||||
{
|
||||
if self.stream_map.get(&packet.stream_id).is_some() {
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
|
@ -180,20 +198,20 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
WsEvent::EndFut => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_msg_loop<R>(
|
||||
async fn server_loop<R>(
|
||||
&self,
|
||||
mut rx: R,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
// will send continues once flow_control is at 10% of max
|
||||
let target_buffer_size = buffer_size * 90 / 100;
|
||||
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
|
||||
self.tx
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
@ -214,7 +232,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.stream_map.lock().await.insert(
|
||||
self.stream_map.insert(
|
||||
packet.stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
|
@ -232,7 +250,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
Role::Server,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
self.close_tx.clone(),
|
||||
close_tx.clone(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
|
@ -242,7 +260,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream.flow_control.store(
|
||||
|
@ -257,9 +275,47 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
}
|
||||
Continue(_) => break Err(WispError::InvalidPacketType),
|
||||
Close(_) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
loop {
|
||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||
if frame.opcode == ws::OpCode::Close {
|
||||
break Ok(());
|
||||
}
|
||||
let packet = Packet::try_from(frame)?;
|
||||
|
||||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Connect(_) => break Err(WispError::InvalidPacketType),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream
|
||||
.flow_control
|
||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||
}
|
||||
}
|
||||
}
|
||||
Close(_) => {
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
|
@ -290,8 +346,7 @@ impl<W: ws::WebSocketWrite> ServerMuxInner<W> {
|
|||
/// }
|
||||
/// ```
|
||||
pub struct ServerMux {
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
|
||||
}
|
||||
|
||||
|
@ -305,22 +360,19 @@ impl ServerMux {
|
|||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
|
||||
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
(
|
||||
Self {
|
||||
muxstream_recv: rx,
|
||||
close_tx: close_tx.clone(),
|
||||
stream_map: map.clone(),
|
||||
},
|
||||
ServerMuxInner {
|
||||
MuxInner {
|
||||
tx: write,
|
||||
close_tx,
|
||||
stream_map: map.clone(),
|
||||
stream_map: DashMap::new().into(),
|
||||
}
|
||||
.into_future(read, close_rx, tx, buffer_size),
|
||||
.server_into_future(read, close_rx, tx, buffer_size, close_tx),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -333,124 +385,13 @@ impl ServerMux {
|
|||
///
|
||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
||||
/// this function is called.
|
||||
pub async fn close(&self) {
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientMuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> ClientMuxInner<W> {
|
||||
pub(crate) async fn into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
let ret = futures::select! {
|
||||
x = self.client_bg_loop(close_rx).fuse() => x,
|
||||
x = self.client_loop(rx).fuse() => x
|
||||
};
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
ret
|
||||
}
|
||||
|
||||
async fn client_bg_loop(
|
||||
&self,
|
||||
mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
) -> Result<(), WispError> {
|
||||
while let Some(msg) = close_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self
|
||||
.stream_map
|
||||
.lock()
|
||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||
self.close_tx
|
||||
.send(WsEvent::EndFut)
|
||||
.await
|
||||
.get(&packet.stream_id)
|
||||
.is_some()
|
||||
{
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::EndFut => break,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn client_loop<R>(&self, mut rx: R) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
loop {
|
||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||
if frame.opcode == ws::OpCode::Close {
|
||||
break Ok(());
|
||||
}
|
||||
let packet = Packet::try_from(frame)?;
|
||||
|
||||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Connect(_) => break Err(WispError::InvalidPacketType),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
}
|
||||
}
|
||||
Continue(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream
|
||||
.flow_control
|
||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||
}
|
||||
}
|
||||
}
|
||||
Close(_) => {
|
||||
if let Some(mut stream) =
|
||||
self.stream_map.lock().await.remove(&packet.stream_id)
|
||||
{
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.stream.disconnect();
|
||||
stream.stream.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Client side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -470,9 +411,9 @@ where
|
|||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<Mutex<HashMap<u32, MuxMapValue>>>,
|
||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
||||
next_free_stream_id: AtomicU32,
|
||||
close_tx: mpsc::UnboundedSender<WsEvent>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
buf_size: u32,
|
||||
target_buf_size: u32,
|
||||
}
|
||||
|
@ -492,23 +433,23 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let (tx, rx) = mpsc::unbounded::<WsEvent>();
|
||||
let map = Arc::new(Mutex::new(HashMap::new()));
|
||||
let (tx, rx) = mpsc::channel::<WsEvent>(256);
|
||||
let map = Arc::new(DashMap::new());
|
||||
Ok((
|
||||
Self {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
next_free_stream_id: AtomicU32::new(1),
|
||||
close_tx: tx,
|
||||
close_tx: tx.clone(),
|
||||
buf_size: packet.buffer_remaining,
|
||||
// server-only
|
||||
target_buf_size: 0,
|
||||
},
|
||||
ClientMuxInner {
|
||||
MuxInner {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
}
|
||||
.into_future(read, rx),
|
||||
.client_into_future(read, rx),
|
||||
))
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
|
@ -540,7 +481,7 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
self.next_free_stream_id
|
||||
.store(next_stream_id, Ordering::Release);
|
||||
|
||||
self.stream_map.lock().await.insert(
|
||||
self.stream_map.insert(
|
||||
stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
|
@ -568,12 +509,10 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
///
|
||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||
/// function.
|
||||
pub async fn close(&self) {
|
||||
self.stream_map.lock().await.drain().for_each(|mut x| {
|
||||
x.1.is_closed.store(true, Ordering::Release);
|
||||
x.1.stream.disconnect();
|
||||
x.1.stream.close_channel();
|
||||
});
|
||||
let _ = self.close_tx.unbounded_send(WsEvent::EndFut);
|
||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||
self.close_tx
|
||||
.send(WsEvent::EndFut)
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{ws, WispError};
|
||||
use bytes::{Buf, BufMut, Bytes};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
/// Wisp stream type.
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
|
@ -115,13 +115,13 @@ impl TryFrom<Bytes> for ConnectPacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ConnectPacket> for Vec<u8> {
|
||||
impl From<ConnectPacket> for Bytes {
|
||||
fn from(packet: ConnectPacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||
let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||
encoded.put_u8(packet.stream_type as u8);
|
||||
encoded.put_u16_le(packet.destination_port);
|
||||
encoded.extend(packet.destination_hostname.bytes());
|
||||
encoded
|
||||
encoded.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -153,11 +153,11 @@ impl TryFrom<Bytes> for ContinuePacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ContinuePacket> for Vec<u8> {
|
||||
impl From<ContinuePacket> for Bytes {
|
||||
fn from(packet: ContinuePacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(4);
|
||||
let mut encoded = BytesMut::with_capacity(4);
|
||||
encoded.put_u32_le(packet.buffer_remaining);
|
||||
encoded
|
||||
encoded.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,11 +190,11 @@ impl TryFrom<Bytes> for ClosePacket {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ClosePacket> for Vec<u8> {
|
||||
impl From<ClosePacket> for Bytes {
|
||||
fn from(packet: ClosePacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(1);
|
||||
let mut encoded = BytesMut::with_capacity(1);
|
||||
encoded.put_u8(packet.reason as u8);
|
||||
encoded
|
||||
encoded.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -224,12 +224,12 @@ impl PacketType {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<PacketType> for Vec<u8> {
|
||||
impl From<PacketType> for Bytes {
|
||||
fn from(packet: PacketType) -> Self {
|
||||
use PacketType::*;
|
||||
match packet {
|
||||
Connect(x) => x.into(),
|
||||
Data(x) => x.to_vec(),
|
||||
Data(x) => x,
|
||||
Continue(x) => x.into(),
|
||||
Close(x) => x.into(),
|
||||
}
|
||||
|
@ -250,7 +250,10 @@ impl Packet {
|
|||
///
|
||||
/// The helper functions should be used for most use cases.
|
||||
pub fn new(stream_id: u32, packet: PacketType) -> Self {
|
||||
Self { stream_id, packet_type: packet }
|
||||
Self {
|
||||
stream_id,
|
||||
packet_type: packet,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new connect packet.
|
||||
|
@ -316,13 +319,15 @@ impl TryFrom<Bytes> for Packet {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Packet> for Vec<u8> {
|
||||
impl From<Packet> for Bytes {
|
||||
fn from(packet: Packet) -> Self {
|
||||
let mut encoded = Self::with_capacity(1 + 4);
|
||||
encoded.push(packet.packet_type.as_u8());
|
||||
let inner_u8 = packet.packet_type.as_u8();
|
||||
let inner = Bytes::from(packet.packet_type);
|
||||
let mut encoded = BytesMut::with_capacity(1 + 4 + inner.len());
|
||||
encoded.put_u8(inner_u8);
|
||||
encoded.put_u32_le(packet.stream_id);
|
||||
encoded.extend(Vec::<u8>::from(packet.packet_type));
|
||||
encoded
|
||||
encoded.extend(inner);
|
||||
encoded.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -341,6 +346,6 @@ impl TryFrom<ws::Frame> for Packet {
|
|||
|
||||
impl From<Packet> for ws::Frame {
|
||||
fn from(packet: Packet) -> Self {
|
||||
Self::binary(Vec::<u8>::from(packet).into())
|
||||
Self::binary(packet.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,28 +45,42 @@ pin_project! {
|
|||
/// Sink for the [`unfold`] function.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "sinks do nothing unless polled"]
|
||||
pub struct Unfold<T, F, FC, R> {
|
||||
pub struct Unfold<T, F, R, CT, CF, CR> {
|
||||
function: F,
|
||||
close_function: FC,
|
||||
close_function: CF,
|
||||
#[pin]
|
||||
state: UnfoldState<T, R>,
|
||||
#[pin]
|
||||
close_state: UnfoldState<CT, CR>
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unfold<T, F, FC, R, Item, E>(init: T, function: F, close_function: FC) -> Unfold<T, F, FC, R>
|
||||
pub(crate) fn unfold<T, F, R, CT, CF, CR, Item, E>(
|
||||
init: T,
|
||||
function: F,
|
||||
close_init: CT,
|
||||
close_function: CF,
|
||||
) -> Unfold<T, F, R, CT, CF, CR>
|
||||
where
|
||||
F: FnMut(T, Item) -> R,
|
||||
R: Future<Output = Result<T, E>>,
|
||||
FC: Fn() -> Result<(), E>,
|
||||
CF: FnMut(CT) -> CR,
|
||||
CR: Future<Output = Result<CT, E>>,
|
||||
{
|
||||
Unfold { function, close_function, state: UnfoldState::Value { value: init } }
|
||||
Unfold {
|
||||
function,
|
||||
close_function,
|
||||
state: UnfoldState::Value { value: init },
|
||||
close_state: UnfoldState::Value { value: close_init },
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, F, FC, R, Item, E> Sink<Item> for Unfold<T, F, FC, R>
|
||||
impl<T, F, R, CT, CF, CR, Item, E> Sink<Item> for Unfold<T, F, R, CT, CF, CR>
|
||||
where
|
||||
F: FnMut(T, Item) -> R,
|
||||
R: Future<Output = Result<T, E>>,
|
||||
FC: Fn() -> Result<(), E>,
|
||||
CF: FnMut(CT) -> CR,
|
||||
CR: Future<Output = Result<CT, E>>,
|
||||
{
|
||||
type Error = E;
|
||||
|
||||
|
@ -104,6 +118,27 @@ where
|
|||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
Poll::Ready((self.close_function)())
|
||||
let mut this = self.project();
|
||||
Poll::Ready(
|
||||
if let Some(future) = this.close_state.as_mut().project_future() {
|
||||
match ready!(future.poll(cx)) {
|
||||
Ok(state) => {
|
||||
this.close_state.set(UnfoldState::Value { value: state });
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
this.close_state.set(UnfoldState::Empty);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let future = match this.close_state.as_mut().take_value() {
|
||||
Some(value) => (this.close_function)(value),
|
||||
None => panic!("start_send called without poll_ready being called first"),
|
||||
};
|
||||
this.close_state.set(UnfoldState::Future { future });
|
||||
return Poll::Pending;
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use futures::{
|
|||
channel::{mpsc, oneshot},
|
||||
stream,
|
||||
task::{Context, Poll},
|
||||
Sink, Stream, StreamExt,
|
||||
Sink, SinkExt, Stream, StreamExt,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
|
@ -31,7 +31,7 @@ pub struct MuxStreamRead {
|
|||
/// Type of the stream.
|
||||
pub stream_type: StreamType,
|
||||
role: Role,
|
||||
tx: mpsc::UnboundedSender<WsEvent>,
|
||||
tx: mpsc::Sender<WsEvent>,
|
||||
rx: mpsc::UnboundedReceiver<Bytes>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
|
@ -51,13 +51,14 @@ impl MuxStreamRead {
|
|||
if val > self.target_flow_control {
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.tx
|
||||
.unbounded_send(WsEvent::SendPacket(
|
||||
.send(WsEvent::SendPacket(
|
||||
Packet::new_continue(
|
||||
self.stream_id,
|
||||
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
|
||||
),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.ok()?;
|
||||
rx.await.ok()?.ok()?;
|
||||
self.flow_control_read.store(0, Ordering::Release);
|
||||
|
@ -80,7 +81,7 @@ pub struct MuxStreamWrite {
|
|||
/// Type of the stream.
|
||||
pub stream_type: StreamType,
|
||||
role: Role,
|
||||
tx: mpsc::UnboundedSender<WsEvent>,
|
||||
tx: mpsc::Sender<WsEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
continue_recieved: Arc<Event>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
|
@ -88,7 +89,7 @@ pub struct MuxStreamWrite {
|
|||
|
||||
impl MuxStreamWrite {
|
||||
/// Write data to the stream.
|
||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
||||
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(WispError::StreamAlreadyClosed);
|
||||
}
|
||||
|
@ -100,10 +101,11 @@ impl MuxStreamWrite {
|
|||
}
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.tx
|
||||
.unbounded_send(WsEvent::SendPacket(
|
||||
.send(WsEvent::SendPacket(
|
||||
Packet::new_data(self.stream_id, data),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||
|
@ -135,7 +137,7 @@ impl MuxStreamWrite {
|
|||
}
|
||||
|
||||
/// Close the stream. You will no longer be able to write or read after this has been called.
|
||||
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(WispError::StreamAlreadyClosed);
|
||||
}
|
||||
|
@ -143,10 +145,11 @@ impl MuxStreamWrite {
|
|||
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.tx
|
||||
.unbounded_send(WsEvent::Close(
|
||||
.send(WsEvent::Close(
|
||||
Packet::new_close(self.stream_id, reason),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
||||
|
||||
|
@ -157,25 +160,19 @@ impl MuxStreamWrite {
|
|||
let handle = self.get_close_handle();
|
||||
Box::pin(sink_unfold::unfold(
|
||||
self,
|
||||
|tx, data| async move {
|
||||
|mut tx, data| async move {
|
||||
tx.write(data).await?;
|
||||
Ok(tx)
|
||||
},
|
||||
move || handle.close_sync(CloseReason::Unknown),
|
||||
handle,
|
||||
move |mut handle| async {
|
||||
handle.close(CloseReason::Unknown).await?;
|
||||
Ok(handle)
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MuxStreamWrite {
|
||||
fn drop(&mut self) {
|
||||
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
|
||||
let _ = self.tx.unbounded_send(WsEvent::Close(
|
||||
Packet::new_close(self.stream_id, CloseReason::Unknown),
|
||||
tx,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplexor stream.
|
||||
pub struct MuxStream {
|
||||
/// ID of the stream.
|
||||
|
@ -191,7 +188,7 @@ impl MuxStream {
|
|||
role: Role,
|
||||
stream_type: StreamType,
|
||||
rx: mpsc::UnboundedReceiver<Bytes>,
|
||||
tx: mpsc::UnboundedSender<WsEvent>,
|
||||
tx: mpsc::Sender<WsEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
continue_recieved: Arc<Event>,
|
||||
|
@ -228,7 +225,7 @@ impl MuxStream {
|
|||
}
|
||||
|
||||
/// Write data to the stream.
|
||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
||||
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
|
||||
self.tx.write(data).await
|
||||
}
|
||||
|
||||
|
@ -248,7 +245,7 @@ impl MuxStream {
|
|||
}
|
||||
|
||||
/// Close the stream. You will no longer be able to write or read after this has been called.
|
||||
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
|
||||
self.tx.close(reason).await
|
||||
}
|
||||
|
||||
|
@ -271,13 +268,13 @@ impl MuxStream {
|
|||
pub struct MuxStreamCloser {
|
||||
/// ID of the stream.
|
||||
pub stream_id: u32,
|
||||
close_channel: mpsc::UnboundedSender<WsEvent>,
|
||||
close_channel: mpsc::Sender<WsEvent>,
|
||||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl MuxStreamCloser {
|
||||
/// Close the stream. You will no longer be able to write or read after this has been called.
|
||||
pub async fn close(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
pub async fn close(&mut self, reason: CloseReason) -> Result<(), WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(WispError::StreamAlreadyClosed);
|
||||
}
|
||||
|
@ -285,32 +282,16 @@ impl MuxStreamCloser {
|
|||
|
||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.close_channel
|
||||
.unbounded_send(WsEvent::Close(
|
||||
.send(WsEvent::Close(
|
||||
Packet::new_close(self.stream_id, reason),
|
||||
tx,
|
||||
))
|
||||
.await
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn close_sync(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
if self.is_closed.load(Ordering::Acquire) {
|
||||
return Err(WispError::StreamAlreadyClosed);
|
||||
}
|
||||
self.is_closed.store(true, Ordering::Release);
|
||||
|
||||
let (tx, _) = oneshot::channel::<Result<(), WispError>>();
|
||||
self.close_channel
|
||||
.unbounded_send(WsEvent::Close(
|
||||
Packet::new_close(self.stream_id, reason),
|
||||
tx,
|
||||
))
|
||||
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
|
@ -336,10 +317,7 @@ impl MuxStreamIo {
|
|||
impl Stream for MuxStreamIo {
|
||||
type Item = Result<Vec<u8>, std::io::Error>;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.project()
|
||||
.rx
|
||||
.poll_next(cx)
|
||||
.map(|x| x.map(|x| Ok(x.to_vec())))
|
||||
self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec())))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue