mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
rewrite actor
This commit is contained in:
parent
b1f56c1dae
commit
9cd87b7243
5 changed files with 470 additions and 472 deletions
16
Cargo.lock
generated
16
Cargo.lock
generated
|
@ -1485,27 +1485,12 @@ version = "1.0.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "scc"
|
|
||||||
version = "2.1.16"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37"
|
|
||||||
dependencies = [
|
|
||||||
"sdd",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "scopeguard"
|
name = "scopeguard"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "sdd"
|
|
||||||
version = "3.0.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "send_wrapper"
|
name = "send_wrapper"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
@ -2274,7 +2259,6 @@ dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"futures-timer",
|
"futures-timer",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"scc",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ flume = "0.11.0"
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
futures-timer = "3.0.3"
|
futures-timer = "3.0.3"
|
||||||
pin-project-lite = "0.2.14"
|
pin-project-lite = "0.2.14"
|
||||||
scc = "2.1.16"
|
|
||||||
tokio = { version = "1.39.3", optional = true, default-features = false }
|
tokio = { version = "1.39.3", optional = true, default-features = false }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
|
401
wisp/src/inner.rs
Normal file
401
wisp/src/inner.rs
Normal file
|
@ -0,0 +1,401 @@
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
extensions::AnyProtocolExtension,
|
||||||
|
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||||
|
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
|
||||||
|
Role, StreamType, WispError,
|
||||||
|
};
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use event_listener::Event;
|
||||||
|
use flume as mpsc;
|
||||||
|
use futures::{channel::oneshot, FutureExt};
|
||||||
|
|
||||||
|
pub(crate) enum WsEvent {
|
||||||
|
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||||
|
CreateStream(
|
||||||
|
StreamType,
|
||||||
|
String,
|
||||||
|
u16,
|
||||||
|
oneshot::Sender<Result<MuxStream, WispError>>,
|
||||||
|
),
|
||||||
|
WispMessage(Frame<'static>, Option<Frame<'static>>),
|
||||||
|
EndFut(Option<CloseReason>),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MuxMapValue {
|
||||||
|
stream: mpsc::Sender<Bytes>,
|
||||||
|
stream_type: StreamType,
|
||||||
|
|
||||||
|
flow_control: Arc<AtomicU32>,
|
||||||
|
flow_control_event: Arc<Event>,
|
||||||
|
|
||||||
|
is_closed: Arc<AtomicBool>,
|
||||||
|
close_reason: Arc<AtomicCloseReason>,
|
||||||
|
is_closed_event: Arc<Event>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MuxInner<R: WebSocketRead + Send> {
|
||||||
|
rx: R,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
|
role: Role,
|
||||||
|
|
||||||
|
fut_rx: mpsc::Receiver<WsEvent>,
|
||||||
|
fut_tx: mpsc::Sender<WsEvent>,
|
||||||
|
fut_exited: Arc<AtomicBool>,
|
||||||
|
|
||||||
|
stream_map: HashMap<u32, MuxMapValue>,
|
||||||
|
|
||||||
|
buffer_size: u32,
|
||||||
|
target_buffer_size: u32,
|
||||||
|
|
||||||
|
server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: WebSocketRead + Send> MuxInner<R> {
|
||||||
|
pub fn new_server(
|
||||||
|
rx: R,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
|
buffer_size: u32,
|
||||||
|
) -> (
|
||||||
|
Self,
|
||||||
|
Arc<AtomicBool>,
|
||||||
|
mpsc::Sender<WsEvent>,
|
||||||
|
mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||||
|
) {
|
||||||
|
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
|
||||||
|
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||||
|
let ret_fut_tx = fut_tx.clone();
|
||||||
|
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
(
|
||||||
|
Self {
|
||||||
|
rx,
|
||||||
|
tx,
|
||||||
|
|
||||||
|
fut_rx,
|
||||||
|
fut_tx,
|
||||||
|
fut_exited: fut_exited.clone(),
|
||||||
|
|
||||||
|
extensions,
|
||||||
|
buffer_size,
|
||||||
|
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32,
|
||||||
|
|
||||||
|
role: Role::Server,
|
||||||
|
|
||||||
|
stream_map: HashMap::new(),
|
||||||
|
|
||||||
|
server_tx,
|
||||||
|
},
|
||||||
|
fut_exited,
|
||||||
|
ret_fut_tx,
|
||||||
|
server_rx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_client(
|
||||||
|
rx: R,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
|
extensions: Vec<AnyProtocolExtension>,
|
||||||
|
buffer_size: u32,
|
||||||
|
) -> (Self, Arc<AtomicBool>, mpsc::Sender<WsEvent>) {
|
||||||
|
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
|
||||||
|
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||||
|
let ret_fut_tx = fut_tx.clone();
|
||||||
|
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
(
|
||||||
|
Self {
|
||||||
|
rx,
|
||||||
|
tx,
|
||||||
|
|
||||||
|
fut_rx,
|
||||||
|
fut_tx,
|
||||||
|
fut_exited: fut_exited.clone(),
|
||||||
|
|
||||||
|
extensions,
|
||||||
|
buffer_size,
|
||||||
|
target_buffer_size: 0,
|
||||||
|
|
||||||
|
role: Role::Client,
|
||||||
|
|
||||||
|
stream_map: HashMap::new(),
|
||||||
|
|
||||||
|
server_tx,
|
||||||
|
},
|
||||||
|
fut_exited,
|
||||||
|
ret_fut_tx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn into_future(mut self) -> Result<(), WispError> {
|
||||||
|
let ret = self.stream_loop().await;
|
||||||
|
|
||||||
|
self.fut_exited.store(true, Ordering::Release);
|
||||||
|
|
||||||
|
for (_, stream) in self.stream_map.iter() {
|
||||||
|
self.close_stream(stream, ClosePacket::new(CloseReason::Unknown));
|
||||||
|
}
|
||||||
|
self.stream_map.clear();
|
||||||
|
|
||||||
|
let _ = self.tx.close().await;
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_new_stream(
|
||||||
|
&mut self,
|
||||||
|
stream_id: u32,
|
||||||
|
stream_type: StreamType,
|
||||||
|
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
||||||
|
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
||||||
|
|
||||||
|
let flow_control_event: Arc<Event> = Event::new().into();
|
||||||
|
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||||
|
|
||||||
|
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||||
|
let close_reason: Arc<AtomicCloseReason> =
|
||||||
|
AtomicCloseReason::new(CloseReason::Unknown).into();
|
||||||
|
let is_closed_event: Arc<Event> = Event::new().into();
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
MuxMapValue {
|
||||||
|
stream: ch_tx,
|
||||||
|
stream_type,
|
||||||
|
|
||||||
|
flow_control: flow_control.clone(),
|
||||||
|
flow_control_event: flow_control_event.clone(),
|
||||||
|
|
||||||
|
is_closed: is_closed.clone(),
|
||||||
|
close_reason: close_reason.clone(),
|
||||||
|
is_closed_event: is_closed_event.clone(),
|
||||||
|
},
|
||||||
|
MuxStream::new(
|
||||||
|
stream_id,
|
||||||
|
self.role,
|
||||||
|
stream_type,
|
||||||
|
ch_rx,
|
||||||
|
self.fut_tx.clone(),
|
||||||
|
self.tx.clone(),
|
||||||
|
is_closed,
|
||||||
|
is_closed_event,
|
||||||
|
close_reason,
|
||||||
|
flow_control,
|
||||||
|
flow_control_event,
|
||||||
|
self.target_buffer_size,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn close_stream(&self, stream: &MuxMapValue, close_packet: ClosePacket) {
|
||||||
|
stream
|
||||||
|
.close_reason
|
||||||
|
.store(close_packet.reason, Ordering::Release);
|
||||||
|
stream.is_closed.store(true, Ordering::Release);
|
||||||
|
stream.is_closed_event.notify(usize::MAX);
|
||||||
|
stream.flow_control.store(u32::MAX, Ordering::Release);
|
||||||
|
stream.flow_control_event.notify(usize::MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_message(&mut self) -> Result<Option<WsEvent>, WispError> {
|
||||||
|
futures::select! {
|
||||||
|
x = self.fut_rx.recv_async().fuse() => Ok(x.ok()),
|
||||||
|
x = self.rx.wisp_read_split(&self.tx).fuse() => {
|
||||||
|
let (mut frame, optional_frame) = x?;
|
||||||
|
if frame.opcode == OpCode::Close {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref extra_frame) = optional_frame {
|
||||||
|
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
|
||||||
|
let mut payload = BytesMut::from(frame.payload);
|
||||||
|
payload.extend_from_slice(&extra_frame.payload);
|
||||||
|
frame.payload = Payload::Bytes(payload);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(WsEvent::WispMessage(frame, optional_frame)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_loop(&mut self) -> Result<(), WispError> {
|
||||||
|
let mut next_free_stream_id: u32 = 1;
|
||||||
|
while let Some(msg) = self.get_message().await? {
|
||||||
|
match msg {
|
||||||
|
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||||
|
let ret: Result<MuxStream, WispError> = async {
|
||||||
|
let stream_id = next_free_stream_id;
|
||||||
|
let next_stream_id = next_free_stream_id
|
||||||
|
.checked_add(1)
|
||||||
|
.ok_or(WispError::MaxStreamCountReached)?;
|
||||||
|
|
||||||
|
let (map_value, stream) =
|
||||||
|
self.create_new_stream(stream_id, stream_type).await?;
|
||||||
|
|
||||||
|
self.tx
|
||||||
|
.write_frame(
|
||||||
|
Packet::new_connect(stream_id, stream_type, port, host).into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.stream_map.insert(stream_id, map_value);
|
||||||
|
|
||||||
|
next_free_stream_id = next_stream_id;
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
let _ = channel.send(ret);
|
||||||
|
}
|
||||||
|
WsEvent::Close(packet, channel) => {
|
||||||
|
if let Some(stream) = self.stream_map.remove(&packet.stream_id) {
|
||||||
|
if let PacketType::Close(close) = packet.packet_type {
|
||||||
|
self.close_stream(&stream, close);
|
||||||
|
}
|
||||||
|
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||||
|
} else {
|
||||||
|
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
WsEvent::EndFut(x) => {
|
||||||
|
if let Some(reason) = x {
|
||||||
|
let _ = self
|
||||||
|
.tx
|
||||||
|
.write_frame(Packet::new_close(0, reason).into())
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
WsEvent::WispMessage(frame, optional_frame) => {
|
||||||
|
if let Some(packet) = Packet::maybe_handle_extension(
|
||||||
|
frame,
|
||||||
|
&mut self.extensions,
|
||||||
|
&mut self.rx,
|
||||||
|
&mut self.tx,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
let should_break = match self.role {
|
||||||
|
Role::Server => {
|
||||||
|
self.server_handle_packet(packet, optional_frame).await?
|
||||||
|
}
|
||||||
|
Role::Client => {
|
||||||
|
self.client_handle_packet(packet, optional_frame).await?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if should_break {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_close_packet(
|
||||||
|
&mut self,
|
||||||
|
stream_id: u32,
|
||||||
|
inner_packet: ClosePacket,
|
||||||
|
) -> Result<bool, WispError> {
|
||||||
|
if stream_id == 0 {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(stream) = self.stream_map.remove(&stream_id) {
|
||||||
|
self.close_stream(&stream, inner_packet);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_data_packet(
|
||||||
|
&mut self,
|
||||||
|
stream_id: u32,
|
||||||
|
optional_frame: Option<Frame<'static>>,
|
||||||
|
data: Payload<'static>,
|
||||||
|
) -> Result<bool, WispError> {
|
||||||
|
let mut data = BytesMut::from(data);
|
||||||
|
|
||||||
|
if let Some(stream) = self.stream_map.get(&stream_id) {
|
||||||
|
if let Some(extra_frame) = optional_frame {
|
||||||
|
if data.is_empty() {
|
||||||
|
data = extra_frame.payload.into();
|
||||||
|
} else {
|
||||||
|
data.extend_from_slice(&extra_frame.payload);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = stream.stream.try_send(data.freeze());
|
||||||
|
if self.role == Role::Server && stream.stream_type == StreamType::Tcp {
|
||||||
|
stream.flow_control.store(
|
||||||
|
stream
|
||||||
|
.flow_control
|
||||||
|
.load(Ordering::Acquire)
|
||||||
|
.saturating_sub(1),
|
||||||
|
Ordering::Release,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn server_handle_packet(
|
||||||
|
&mut self,
|
||||||
|
packet: Packet<'static>,
|
||||||
|
optional_frame: Option<Frame<'static>>,
|
||||||
|
) -> Result<bool, WispError> {
|
||||||
|
use PacketType::*;
|
||||||
|
match packet.packet_type {
|
||||||
|
Continue(_) | Info(_) => Err(WispError::InvalidPacketType),
|
||||||
|
Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
|
||||||
|
Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
|
||||||
|
|
||||||
|
Connect(inner_packet) => {
|
||||||
|
let (map_value, stream) = self
|
||||||
|
.create_new_stream(packet.stream_id, inner_packet.stream_type)
|
||||||
|
.await?;
|
||||||
|
self.server_tx
|
||||||
|
.send_async((inner_packet, stream))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
|
self.stream_map.insert(packet.stream_id, map_value);
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn client_handle_packet(
|
||||||
|
&mut self,
|
||||||
|
packet: Packet<'static>,
|
||||||
|
optional_frame: Option<Frame<'static>>,
|
||||||
|
) -> Result<bool, WispError> {
|
||||||
|
use PacketType::*;
|
||||||
|
match packet.packet_type {
|
||||||
|
Connect(_) | Info(_) => Err(WispError::InvalidPacketType),
|
||||||
|
Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data),
|
||||||
|
Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet),
|
||||||
|
|
||||||
|
Continue(inner_packet) => {
|
||||||
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
|
if stream.stream_type == StreamType::Tcp {
|
||||||
|
stream
|
||||||
|
.flow_control
|
||||||
|
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||||
|
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
477
wisp/src/lib.rs
477
wisp/src/lib.rs
|
@ -11,6 +11,7 @@ mod fastwebsockets;
|
||||||
#[cfg(feature = "generic_stream")]
|
#[cfg(feature = "generic_stream")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
|
||||||
pub mod generic;
|
pub mod generic;
|
||||||
|
mod inner;
|
||||||
mod packet;
|
mod packet;
|
||||||
mod sink_unfold;
|
mod sink_unfold;
|
||||||
mod stream;
|
mod stream;
|
||||||
|
@ -18,21 +19,19 @@ pub mod ws;
|
||||||
|
|
||||||
pub use crate::{packet::*, stream::*};
|
pub use crate::{packet::*, stream::*};
|
||||||
|
|
||||||
use bytes::{Bytes, BytesMut};
|
|
||||||
use event_listener::Event;
|
|
||||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
||||||
use flume as mpsc;
|
use flume as mpsc;
|
||||||
use futures::{channel::oneshot, select, Future, FutureExt};
|
use futures::{channel::oneshot, select, Future, FutureExt};
|
||||||
use futures_timer::Delay;
|
use futures_timer::Delay;
|
||||||
use scc::HashMap;
|
use inner::{MuxInner, WsEvent};
|
||||||
use std::{
|
use std::{
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
},
|
},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
|
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
|
||||||
|
|
||||||
/// Wisp version supported by this crate.
|
/// Wisp version supported by this crate.
|
||||||
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||||
|
@ -157,363 +156,6 @@ impl std::fmt::Display for WispError {
|
||||||
|
|
||||||
impl std::error::Error for WispError {}
|
impl std::error::Error for WispError {}
|
||||||
|
|
||||||
struct MuxMapValue {
|
|
||||||
stream: mpsc::Sender<Bytes>,
|
|
||||||
stream_type: StreamType,
|
|
||||||
|
|
||||||
flow_control: Arc<AtomicU32>,
|
|
||||||
flow_control_event: Arc<Event>,
|
|
||||||
|
|
||||||
is_closed: Arc<AtomicBool>,
|
|
||||||
close_reason: Arc<AtomicCloseReason>,
|
|
||||||
is_closed_event: Arc<Event>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for MuxMapValue {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.is_closed.store(true, Ordering::Release);
|
|
||||||
self.is_closed_event.notify(usize::MAX);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MuxInner {
|
|
||||||
tx: ws::LockedWebSocketWrite,
|
|
||||||
stream_map: HashMap<u32, MuxMapValue>,
|
|
||||||
buffer_size: u32,
|
|
||||||
fut_exited: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MuxInner {
|
|
||||||
pub async fn server_into_future<R>(
|
|
||||||
self,
|
|
||||||
rx: R,
|
|
||||||
extensions: Vec<AnyProtocolExtension>,
|
|
||||||
close_rx: mpsc::Receiver<WsEvent>,
|
|
||||||
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
|
||||||
) -> Result<(), WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
{
|
|
||||||
self.as_future(
|
|
||||||
close_rx,
|
|
||||||
close_tx.clone(),
|
|
||||||
self.server_loop(rx, extensions, muxstream_sender, close_tx),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn client_into_future<R>(
|
|
||||||
self,
|
|
||||||
rx: R,
|
|
||||||
extensions: Vec<AnyProtocolExtension>,
|
|
||||||
close_rx: mpsc::Receiver<WsEvent>,
|
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
|
||||||
) -> Result<(), WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
{
|
|
||||||
self.as_future(close_rx, close_tx, self.client_loop(rx, extensions))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn as_future(
|
|
||||||
&self,
|
|
||||||
close_rx: mpsc::Receiver<WsEvent>,
|
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
|
||||||
wisp_fut: impl Future<Output = Result<(), WispError>>,
|
|
||||||
) -> Result<(), WispError> {
|
|
||||||
let ret = futures::select! {
|
|
||||||
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
|
|
||||||
x = wisp_fut.fuse() => x,
|
|
||||||
};
|
|
||||||
self.fut_exited.store(true, Ordering::Release);
|
|
||||||
self.stream_map.clear_async().await;
|
|
||||||
let _ = self.tx.close().await;
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_new_stream(
|
|
||||||
&self,
|
|
||||||
stream_id: u32,
|
|
||||||
stream_type: StreamType,
|
|
||||||
role: Role,
|
|
||||||
stream_tx: mpsc::Sender<WsEvent>,
|
|
||||||
tx: LockedWebSocketWrite,
|
|
||||||
target_buffer_size: u32,
|
|
||||||
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
|
||||||
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
|
||||||
|
|
||||||
let flow_control_event: Arc<Event> = Event::new().into();
|
|
||||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
|
||||||
|
|
||||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
|
||||||
let close_reason: Arc<AtomicCloseReason> =
|
|
||||||
AtomicCloseReason::new(CloseReason::Unknown).into();
|
|
||||||
let is_closed_event: Arc<Event> = Event::new().into();
|
|
||||||
|
|
||||||
Ok((
|
|
||||||
MuxMapValue {
|
|
||||||
stream: ch_tx,
|
|
||||||
stream_type,
|
|
||||||
|
|
||||||
flow_control: flow_control.clone(),
|
|
||||||
flow_control_event: flow_control_event.clone(),
|
|
||||||
|
|
||||||
is_closed: is_closed.clone(),
|
|
||||||
close_reason: close_reason.clone(),
|
|
||||||
is_closed_event: is_closed_event.clone(),
|
|
||||||
},
|
|
||||||
MuxStream::new(
|
|
||||||
stream_id,
|
|
||||||
role,
|
|
||||||
stream_type,
|
|
||||||
ch_rx,
|
|
||||||
stream_tx,
|
|
||||||
tx,
|
|
||||||
is_closed,
|
|
||||||
is_closed_event,
|
|
||||||
close_reason,
|
|
||||||
flow_control,
|
|
||||||
flow_control_event,
|
|
||||||
target_buffer_size,
|
|
||||||
),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn stream_loop(
|
|
||||||
&self,
|
|
||||||
stream_rx: mpsc::Receiver<WsEvent>,
|
|
||||||
stream_tx: mpsc::Sender<WsEvent>,
|
|
||||||
) {
|
|
||||||
let mut next_free_stream_id: u32 = 1;
|
|
||||||
while let Ok(msg) = stream_rx.recv_async().await {
|
|
||||||
match msg {
|
|
||||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
|
||||||
let ret: Result<MuxStream, WispError> = async {
|
|
||||||
let stream_id = next_free_stream_id;
|
|
||||||
let next_stream_id = next_free_stream_id
|
|
||||||
.checked_add(1)
|
|
||||||
.ok_or(WispError::MaxStreamCountReached)?;
|
|
||||||
|
|
||||||
let (map_value, stream) = self
|
|
||||||
.create_new_stream(
|
|
||||||
stream_id,
|
|
||||||
stream_type,
|
|
||||||
Role::Client,
|
|
||||||
stream_tx.clone(),
|
|
||||||
self.tx.clone(),
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
self.tx
|
|
||||||
.write_frame(
|
|
||||||
Packet::new_connect(stream_id, stream_type, port, host).into(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
self.stream_map.upsert_async(stream_id, map_value).await;
|
|
||||||
|
|
||||||
next_free_stream_id = next_stream_id;
|
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
.await;
|
|
||||||
let _ = channel.send(ret);
|
|
||||||
}
|
|
||||||
WsEvent::Close(packet, channel) => {
|
|
||||||
if let Some((_, stream)) = self.stream_map.remove_async(&packet.stream_id).await
|
|
||||||
{
|
|
||||||
if let PacketType::Close(close) = packet.packet_type {
|
|
||||||
self.close_stream(stream, close);
|
|
||||||
}
|
|
||||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
|
||||||
} else {
|
|
||||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
WsEvent::EndFut(x) => {
|
|
||||||
if let Some(reason) = x {
|
|
||||||
let _ = self
|
|
||||||
.tx
|
|
||||||
.write_frame(Packet::new_close(0, reason).into())
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn close_stream(&self, stream: MuxMapValue, close_packet: ClosePacket) {
|
|
||||||
stream
|
|
||||||
.close_reason
|
|
||||||
.store(close_packet.reason, Ordering::Release);
|
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
|
||||||
stream.is_closed_event.notify(usize::MAX);
|
|
||||||
stream.flow_control.store(u32::MAX, Ordering::Release);
|
|
||||||
stream.flow_control_event.notify(usize::MAX);
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn server_loop<R>(
|
|
||||||
&self,
|
|
||||||
mut rx: R,
|
|
||||||
mut extensions: Vec<AnyProtocolExtension>,
|
|
||||||
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
|
|
||||||
stream_tx: mpsc::Sender<WsEvent>,
|
|
||||||
) -> Result<(), WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
{
|
|
||||||
// will send continues once flow_control is at 10% of max
|
|
||||||
let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
|
|
||||||
if frame.opcode == ws::OpCode::Close {
|
|
||||||
break Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref extra_frame) = optional_frame {
|
|
||||||
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
|
|
||||||
let mut payload = BytesMut::from(frame.payload);
|
|
||||||
payload.extend_from_slice(&extra_frame.payload);
|
|
||||||
frame.payload = Payload::Bytes(payload);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(packet) =
|
|
||||||
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
|
|
||||||
{
|
|
||||||
use PacketType::*;
|
|
||||||
match packet.packet_type {
|
|
||||||
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
|
||||||
Connect(inner_packet) => {
|
|
||||||
let (map_value, stream) = self
|
|
||||||
.create_new_stream(
|
|
||||||
packet.stream_id,
|
|
||||||
inner_packet.stream_type,
|
|
||||||
Role::Server,
|
|
||||||
stream_tx.clone(),
|
|
||||||
self.tx.clone(),
|
|
||||||
target_buffer_size,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
muxstream_sender
|
|
||||||
.send_async((inner_packet, stream))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
|
||||||
self.stream_map
|
|
||||||
.upsert_async(packet.stream_id, map_value)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
Data(data) => {
|
|
||||||
let mut data = BytesMut::from(data);
|
|
||||||
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
|
|
||||||
if let Some(extra_frame) = optional_frame {
|
|
||||||
if data.is_empty() {
|
|
||||||
data = extra_frame.payload.into();
|
|
||||||
} else {
|
|
||||||
data.extend_from_slice(&extra_frame.payload);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let _ = stream.stream.try_send(data.freeze());
|
|
||||||
if stream.stream_type == StreamType::Tcp {
|
|
||||||
stream.flow_control.store(
|
|
||||||
stream
|
|
||||||
.flow_control
|
|
||||||
.load(Ordering::Acquire)
|
|
||||||
.saturating_sub(1),
|
|
||||||
Ordering::Release,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Close(inner_packet) => {
|
|
||||||
if packet.stream_id == 0 {
|
|
||||||
break Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some((_, stream)) =
|
|
||||||
self.stream_map.remove_async(&packet.stream_id).await
|
|
||||||
{
|
|
||||||
self.close_stream(stream, inner_packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn client_loop<R>(
|
|
||||||
&self,
|
|
||||||
mut rx: R,
|
|
||||||
mut extensions: Vec<AnyProtocolExtension>,
|
|
||||||
) -> Result<(), WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
{
|
|
||||||
loop {
|
|
||||||
let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
|
|
||||||
if frame.opcode == ws::OpCode::Close {
|
|
||||||
break Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref extra_frame) = optional_frame {
|
|
||||||
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
|
|
||||||
let mut payload = BytesMut::from(frame.payload);
|
|
||||||
payload.extend_from_slice(&extra_frame.payload);
|
|
||||||
frame.payload = Payload::Bytes(payload);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(packet) =
|
|
||||||
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
|
|
||||||
{
|
|
||||||
use PacketType::*;
|
|
||||||
match packet.packet_type {
|
|
||||||
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
|
||||||
Data(data) => {
|
|
||||||
let mut data = BytesMut::from(data);
|
|
||||||
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
|
|
||||||
if let Some(extra_frame) = optional_frame {
|
|
||||||
if data.is_empty() {
|
|
||||||
data = extra_frame.payload.into();
|
|
||||||
} else {
|
|
||||||
data.extend_from_slice(&extra_frame.payload);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let _ = stream.stream.send_async(data.freeze()).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Continue(inner_packet) => {
|
|
||||||
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
|
|
||||||
if stream.stream_type == StreamType::Tcp {
|
|
||||||
stream
|
|
||||||
.flow_control
|
|
||||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
|
||||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Close(inner_packet) => {
|
|
||||||
if packet.stream_id == 0 {
|
|
||||||
break Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some((_, stream)) =
|
|
||||||
self.stream_map.remove_async(&packet.stream_id).await
|
|
||||||
{
|
|
||||||
self.close_stream(stream, inner_packet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn maybe_wisp_v2<R>(
|
async fn maybe_wisp_v2<R>(
|
||||||
read: &mut R,
|
read: &mut R,
|
||||||
write: &LockedWebSocketWrite,
|
write: &LockedWebSocketWrite,
|
||||||
|
@ -576,7 +218,7 @@ pub struct ServerMux {
|
||||||
pub downgraded: bool,
|
pub downgraded: bool,
|
||||||
/// Extensions that are supported by both sides.
|
/// Extensions that are supported by both sides.
|
||||||
pub supported_extension_ids: Vec<u8>,
|
pub supported_extension_ids: Vec<u8>,
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
actor_tx: mpsc::Sender<WsEvent>,
|
||||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||||
tx: ws::LockedWebSocketWrite,
|
tx: ws::LockedWebSocketWrite,
|
||||||
fut_exited: Arc<AtomicBool>,
|
fut_exited: Arc<AtomicBool>,
|
||||||
|
@ -589,8 +231,8 @@ impl ServerMux {
|
||||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
/// if the extensions you need are available after the multiplexor has been created.
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
pub async fn create<R, W>(
|
pub async fn create<R, W>(
|
||||||
mut read: R,
|
mut rx: R,
|
||||||
write: W,
|
tx: W,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
|
@ -598,19 +240,14 @@ impl ServerMux {
|
||||||
R: ws::WebSocketRead + Send,
|
R: ws::WebSocketRead + Send,
|
||||||
W: ws::WebSocketWrite + Send + 'static,
|
W: ws::WebSocketWrite + Send + 'static,
|
||||||
{
|
{
|
||||||
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
|
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
|
||||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
|
||||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
|
||||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
|
||||||
|
|
||||||
write
|
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let (supported_extensions, extra_packet, downgraded) =
|
let (supported_extensions, extra_packet, downgraded) =
|
||||||
if let Some(builders) = extension_builders {
|
if let Some(builders) = extension_builders {
|
||||||
write
|
tx.write_frame(
|
||||||
.write_frame(
|
|
||||||
Packet::new_info(
|
Packet::new_info(
|
||||||
builders
|
builders
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -620,33 +257,30 @@ impl ServerMux {
|
||||||
.into(),
|
.into(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
maybe_wisp_v2(&mut read, &write, builders).await?
|
maybe_wisp_v2(&mut rx, &tx, builders).await?
|
||||||
} else {
|
} else {
|
||||||
(Vec::new(), None, true)
|
(Vec::new(), None, true)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
|
||||||
|
|
||||||
|
let (mux_inner, fut_exited, actor_tx, muxstream_recv) = MuxInner::new_server(
|
||||||
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
|
tx.clone(),
|
||||||
|
supported_extensions,
|
||||||
|
buffer_size,
|
||||||
|
);
|
||||||
|
|
||||||
Ok(ServerMuxResult(
|
Ok(ServerMuxResult(
|
||||||
Self {
|
Self {
|
||||||
muxstream_recv: rx,
|
muxstream_recv,
|
||||||
close_tx: close_tx.clone(),
|
actor_tx,
|
||||||
downgraded,
|
downgraded,
|
||||||
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
|
supported_extension_ids,
|
||||||
tx: write.clone(),
|
tx,
|
||||||
fut_exited: fut_exited.clone(),
|
fut_exited: fut_exited.clone(),
|
||||||
},
|
},
|
||||||
MuxInner {
|
mux_inner.into_future(),
|
||||||
tx: write,
|
|
||||||
stream_map: HashMap::new(),
|
|
||||||
buffer_size,
|
|
||||||
fut_exited,
|
|
||||||
}
|
|
||||||
.server_into_future(
|
|
||||||
AppendingWebSocketRead(extra_packet, read),
|
|
||||||
supported_extensions,
|
|
||||||
close_rx,
|
|
||||||
tx,
|
|
||||||
close_tx,
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -662,7 +296,7 @@ impl ServerMux {
|
||||||
if self.fut_exited.load(Ordering::Acquire) {
|
if self.fut_exited.load(Ordering::Acquire) {
|
||||||
return Err(WispError::MuxTaskEnded);
|
return Err(WispError::MuxTaskEnded);
|
||||||
}
|
}
|
||||||
self.close_tx
|
self.actor_tx
|
||||||
.send_async(WsEvent::EndFut(reason))
|
.send_async(WsEvent::EndFut(reason))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
|
@ -695,7 +329,7 @@ impl ServerMux {
|
||||||
|
|
||||||
impl Drop for ServerMux {
|
impl Drop for ServerMux {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let _ = self.close_tx.send(WsEvent::EndFut(None));
|
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +396,7 @@ pub struct ClientMux {
|
||||||
pub downgraded: bool,
|
pub downgraded: bool,
|
||||||
/// Extensions that are supported by both sides.
|
/// Extensions that are supported by both sides.
|
||||||
pub supported_extension_ids: Vec<u8>,
|
pub supported_extension_ids: Vec<u8>,
|
||||||
stream_tx: mpsc::Sender<WsEvent>,
|
actor_tx: mpsc::Sender<WsEvent>,
|
||||||
tx: ws::LockedWebSocketWrite,
|
tx: ws::LockedWebSocketWrite,
|
||||||
fut_exited: Arc<AtomicBool>,
|
fut_exited: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
@ -774,29 +408,28 @@ impl ClientMux {
|
||||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
/// if the extensions you need are available after the multiplexor has been created.
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
pub async fn create<R, W>(
|
pub async fn create<R, W>(
|
||||||
mut read: R,
|
mut rx: R,
|
||||||
write: W,
|
tx: W,
|
||||||
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
|
||||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead + Send,
|
R: ws::WebSocketRead + Send,
|
||||||
W: ws::WebSocketWrite + Send + 'static,
|
W: ws::WebSocketWrite + Send + 'static,
|
||||||
{
|
{
|
||||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
|
||||||
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
|
||||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
|
||||||
|
|
||||||
if first_packet.stream_id != 0 {
|
if first_packet.stream_id != 0 {
|
||||||
return Err(WispError::InvalidStreamId);
|
return Err(WispError::InvalidStreamId);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||||
let (supported_extensions, extra_packet, downgraded) =
|
let (supported_extensions, extra_packet, downgraded) =
|
||||||
if let Some(builders) = extension_builders {
|
if let Some(builders) = extension_builders {
|
||||||
let x = maybe_wisp_v2(&mut read, &write, builders).await?;
|
let x = maybe_wisp_v2(&mut rx, &tx, builders).await?;
|
||||||
// if not downgraded
|
// if not downgraded
|
||||||
if !x.2 {
|
if !x.2 {
|
||||||
write
|
tx.write_frame(
|
||||||
.write_frame(
|
|
||||||
Packet::new_info(
|
Packet::new_info(
|
||||||
builders
|
builders
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -812,30 +445,24 @@ impl ClientMux {
|
||||||
(Vec::new(), None, true)
|
(Vec::new(), None, true)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
|
||||||
|
|
||||||
|
let (mux_inner, fut_exited, actor_tx) = MuxInner::new_client(
|
||||||
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
|
tx.clone(),
|
||||||
|
supported_extensions,
|
||||||
|
packet.buffer_remaining,
|
||||||
|
);
|
||||||
|
|
||||||
Ok(ClientMuxResult(
|
Ok(ClientMuxResult(
|
||||||
Self {
|
Self {
|
||||||
stream_tx: tx.clone(),
|
actor_tx,
|
||||||
downgraded,
|
downgraded,
|
||||||
supported_extension_ids: supported_extensions
|
supported_extension_ids,
|
||||||
.iter()
|
|
||||||
.map(|x| x.get_id())
|
|
||||||
.collect(),
|
|
||||||
tx: write.clone(),
|
|
||||||
fut_exited: fut_exited.clone(),
|
|
||||||
},
|
|
||||||
MuxInner {
|
|
||||||
tx: write,
|
|
||||||
stream_map: HashMap::new(),
|
|
||||||
buffer_size: packet.buffer_remaining,
|
|
||||||
fut_exited,
|
|
||||||
}
|
|
||||||
.client_into_future(
|
|
||||||
AppendingWebSocketRead(extra_packet, read),
|
|
||||||
supported_extensions,
|
|
||||||
rx,
|
|
||||||
tx,
|
tx,
|
||||||
),
|
fut_exited,
|
||||||
|
},
|
||||||
|
mux_inner.into_future(),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
Err(WispError::InvalidPacketType)
|
Err(WispError::InvalidPacketType)
|
||||||
|
@ -863,7 +490,7 @@ impl ClientMux {
|
||||||
]));
|
]));
|
||||||
}
|
}
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
self.stream_tx
|
self.actor_tx
|
||||||
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
|
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
|
@ -874,7 +501,7 @@ impl ClientMux {
|
||||||
if self.fut_exited.load(Ordering::Acquire) {
|
if self.fut_exited.load(Ordering::Acquire) {
|
||||||
return Err(WispError::MuxTaskEnded);
|
return Err(WispError::MuxTaskEnded);
|
||||||
}
|
}
|
||||||
self.stream_tx
|
self.actor_tx
|
||||||
.send_async(WsEvent::EndFut(reason))
|
.send_async(WsEvent::EndFut(reason))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
|
@ -907,7 +534,7 @@ impl ClientMux {
|
||||||
|
|
||||||
impl Drop for ClientMux {
|
impl Drop for ClientMux {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let _ = self.stream_tx.send(WsEvent::EndFut(None));
|
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
sink_unfold,
|
inner::WsEvent, sink_unfold, ws::{Frame, LockedWebSocketWrite, Payload}, AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError
|
||||||
ws::{Frame, LockedWebSocketWrite, Payload},
|
|
||||||
AtomicCloseReason, CloseReason, Packet, Role, StreamType, WispError,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use bytes::{BufMut, Bytes, BytesMut};
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
@ -23,17 +21,6 @@ use std::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) enum WsEvent {
|
|
||||||
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
|
|
||||||
CreateStream(
|
|
||||||
StreamType,
|
|
||||||
String,
|
|
||||||
u16,
|
|
||||||
oneshot::Sender<Result<MuxStream, WispError>>,
|
|
||||||
),
|
|
||||||
EndFut(Option<CloseReason>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read side of a multiplexor stream.
|
/// Read side of a multiplexor stream.
|
||||||
pub struct MuxStreamRead {
|
pub struct MuxStreamRead {
|
||||||
/// ID of the stream.
|
/// ID of the stream.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue