mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
wisp-mux... SEVEN!!
This commit is contained in:
parent
194ad4e5c8
commit
3f381d6b39
53 changed files with 3721 additions and 4821 deletions
|
@ -1,493 +1,427 @@
|
|||
use std::{collections::HashMap, sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
}};
|
||||
use std::{
|
||||
pin::pin,
|
||||
sync::{
|
||||
atomic::{AtomicU32, AtomicU8, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
task::Context,
|
||||
};
|
||||
|
||||
use futures::{
|
||||
channel::oneshot,
|
||||
stream::{select, unfold},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
extensions::AnyProtocolExtension,
|
||||
ws::{Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead, WebSocketWrite},
|
||||
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
|
||||
Role, StreamType, WispError,
|
||||
locked_sink::Waiter,
|
||||
packet::{
|
||||
ClosePacket, CloseReason, ConnectPacket, ContinuePacket, MaybeExtensionPacket, Packet,
|
||||
PacketType, StreamType,
|
||||
},
|
||||
stream::MuxStream,
|
||||
ws::{Payload, WebSocketRead, WebSocketWrite},
|
||||
LockedWebSocketWrite, WispError,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use event_listener::Event;
|
||||
use flume as mpsc;
|
||||
use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
pub(crate) enum WsEvent<W: WebSocketWrite + 'static> {
|
||||
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||
pub(crate) enum WsEvent<W: WebSocketWrite> {
|
||||
Close(u32, ClosePacket, oneshot::Sender<Result<(), WispError>>),
|
||||
CreateStream(
|
||||
StreamType,
|
||||
String,
|
||||
u16,
|
||||
ConnectPacket,
|
||||
oneshot::Sender<Result<MuxStream<W>, WispError>>,
|
||||
),
|
||||
SendPing(Payload<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||
SendPong(Payload<'static>),
|
||||
WispMessage(Option<Packet<'static>>, Option<Frame<'static>>),
|
||||
WispMessage(Packet<'static>),
|
||||
EndFut(Option<CloseReason>),
|
||||
Noop,
|
||||
}
|
||||
|
||||
struct MuxMapValue {
|
||||
stream: mpsc::Sender<Payload<'static>>,
|
||||
stream_type: StreamType,
|
||||
pub(crate) type StreamMap = FxHashMap<u32, StreamMapValue>;
|
||||
|
||||
should_flow_control: bool,
|
||||
flow_control: Arc<AtomicU32>,
|
||||
flow_control_event: Arc<Event>,
|
||||
|
||||
is_closed: Arc<AtomicBool>,
|
||||
close_reason: Arc<AtomicCloseReason>,
|
||||
is_closed_event: Arc<Event>,
|
||||
pub(crate) struct StreamMapValue {
|
||||
pub info: Arc<StreamInfo>,
|
||||
pub stream: flume::Sender<Payload>,
|
||||
}
|
||||
|
||||
pub(crate) struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
|
||||
// gets taken by the mux task
|
||||
rx: Option<R>,
|
||||
// gets taken by the mux task
|
||||
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) enum FlowControl {
|
||||
/// flow control completely disabled
|
||||
Disabled,
|
||||
/// flow control enabled
|
||||
/// - incoming: do not send buffer updates and no buffer
|
||||
/// - outgoing: track sent amount and wait
|
||||
EnabledTrackAmount,
|
||||
/// flow control enabled
|
||||
/// - incoming: send buffer updates and force buffer
|
||||
/// - outgoing: do not track sent amount and do not wait
|
||||
EnabledSendMessages,
|
||||
}
|
||||
|
||||
pub(crate) struct StreamInfo {
|
||||
pub id: u32,
|
||||
|
||||
pub flow_status: FlowControl,
|
||||
pub target_flow_control: u32,
|
||||
flow_control: AtomicU32,
|
||||
close_reason: AtomicU8,
|
||||
flow_waker: Mutex<Waiter>,
|
||||
}
|
||||
|
||||
impl StreamInfo {
|
||||
pub fn new(id: u32, flow_status: FlowControl, buffer_size: u32) -> Self {
|
||||
debug_assert_ne!(id, 0);
|
||||
|
||||
// 90%
|
||||
#[expect(clippy::cast_possible_truncation)]
|
||||
let target = ((u64::from(buffer_size) * 90) / 100) as u32;
|
||||
|
||||
Self {
|
||||
id,
|
||||
|
||||
flow_status,
|
||||
target_flow_control: target,
|
||||
flow_control: AtomicU32::new(buffer_size),
|
||||
flow_waker: Mutex::new(Waiter::Woken),
|
||||
close_reason: AtomicU8::new(CloseReason::Unknown.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flow_set(&self, amt: u32) {
|
||||
self.flow_control.store(amt, Ordering::Relaxed);
|
||||
}
|
||||
pub fn flow_add(&self, amt: u32) -> u32 {
|
||||
let new = self
|
||||
.flow_control
|
||||
.load(Ordering::Relaxed)
|
||||
.saturating_add(amt);
|
||||
self.flow_control.store(new, Ordering::Relaxed);
|
||||
new
|
||||
}
|
||||
pub fn flow_sub(&self, amt: u32) -> u32 {
|
||||
let new = self
|
||||
.flow_control
|
||||
.load(Ordering::Relaxed)
|
||||
.saturating_sub(amt);
|
||||
self.flow_control.store(new, Ordering::Relaxed);
|
||||
new
|
||||
}
|
||||
pub fn flow_dec(&self) {
|
||||
self.flow_sub(1);
|
||||
}
|
||||
pub fn flow_empty(&self) -> bool {
|
||||
self.flow_control.load(Ordering::Relaxed) == 0
|
||||
}
|
||||
|
||||
pub fn flow_register(&self, cx: &mut Context<'_>) {
|
||||
self.flow_waker
|
||||
.lock()
|
||||
.expect("flow_waker was poisoned")
|
||||
.register(cx);
|
||||
}
|
||||
pub fn flow_wake(&self) {
|
||||
let mut waiter = self.flow_waker.lock().expect("flow_waker was poisoned");
|
||||
if let Some(waker) = waiter.wake() {
|
||||
drop(waiter);
|
||||
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_reason(&self) -> CloseReason {
|
||||
self.close_reason.load(Ordering::Relaxed).into()
|
||||
}
|
||||
pub fn set_reason(&self, reason: CloseReason) {
|
||||
self.close_reason.store(reason.into(), Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait MultiplexorActor<W: WebSocketWrite>: Send {
|
||||
fn handle_connect_packet(
|
||||
&mut self,
|
||||
stream: MuxStream<W>,
|
||||
pkt: ConnectPacket,
|
||||
) -> Result<(), WispError>;
|
||||
|
||||
fn handle_data_packet(
|
||||
&mut self,
|
||||
id: u32,
|
||||
pkt: Payload,
|
||||
streams: &mut StreamMap,
|
||||
) -> Result<(), WispError> {
|
||||
if let Some(stream) = streams.get(&id) {
|
||||
let _ = stream.stream.try_send(pkt);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_continue_packet(
|
||||
&mut self,
|
||||
id: u32,
|
||||
pkt: ContinuePacket,
|
||||
streams: &mut StreamMap,
|
||||
) -> Result<(), WispError>;
|
||||
|
||||
fn get_flow_control(ty: StreamType, flow_stream_types: &[u8]) -> FlowControl;
|
||||
}
|
||||
|
||||
struct MuxStart<R: WebSocketRead, W: WebSocketWrite> {
|
||||
rx: R,
|
||||
downgrade: Option<Packet<'static>>,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
actor_rx: flume::Receiver<WsEvent<W>>,
|
||||
}
|
||||
|
||||
pub(crate) struct MuxInner<R: WebSocketRead, W: WebSocketWrite, M: MultiplexorActor<W>> {
|
||||
start: Option<MuxStart<R, W>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
// gets taken by the mux task
|
||||
extensions: Option<Vec<AnyProtocolExtension>>,
|
||||
tcp_extensions: Vec<u8>,
|
||||
role: Role,
|
||||
flow_stream_types: Box<[u8]>,
|
||||
|
||||
// gets taken by the mux task
|
||||
actor_rx: Option<mpsc::Receiver<WsEvent<W>>>,
|
||||
actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
fut_exited: Arc<AtomicBool>,
|
||||
|
||||
stream_map: FxHashMap<u32, MuxMapValue>,
|
||||
mux: M,
|
||||
|
||||
streams: StreamMap,
|
||||
current_id: u32,
|
||||
buffer_size: u32,
|
||||
target_buffer_size: u32,
|
||||
|
||||
server_tx: mpsc::Sender<(ConnectPacket, MuxStream<W>)>,
|
||||
actor_tx: flume::Sender<WsEvent<W>>,
|
||||
}
|
||||
|
||||
pub(crate) struct MuxInnerResult<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
|
||||
pub mux: MuxInner<R, W>,
|
||||
pub actor_exited: Arc<AtomicBool>,
|
||||
pub actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
pub(crate) struct MuxInnerResult<R: WebSocketRead, W: WebSocketWrite, M: MultiplexorActor<W>> {
|
||||
pub mux: MuxInner<R, W, M>,
|
||||
pub actor_tx: flume::Sender<WsEvent<W>>,
|
||||
}
|
||||
|
||||
impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
|
||||
fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec<u8> {
|
||||
extensions
|
||||
impl<R: WebSocketRead, W: WebSocketWrite, M: MultiplexorActor<W>> MuxInner<R, W, M> {
|
||||
#[expect(clippy::new_ret_no_self)]
|
||||
pub fn new(
|
||||
rx: R,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
mux: M,
|
||||
downgrade: Option<Packet<'static>>,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
buffer_size: u32,
|
||||
) -> MuxInnerResult<R, W, M> {
|
||||
let (actor_tx, actor_rx) = flume::unbounded();
|
||||
|
||||
let flow_extensions = extensions
|
||||
.iter()
|
||||
.flat_map(|x| x.get_congestion_stream_types())
|
||||
.copied()
|
||||
.chain(std::iter::once(StreamType::Tcp.into()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[expect(clippy::type_complexity)]
|
||||
pub fn new_server(
|
||||
rx: R,
|
||||
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
buffer_size: u32,
|
||||
) -> (
|
||||
MuxInnerResult<R, W>,
|
||||
mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
|
||||
) {
|
||||
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
|
||||
let (server_tx, server_rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||
let ret_fut_tx = fut_tx.clone();
|
||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||
|
||||
// 90% of the buffer size, not possible to overflow
|
||||
#[expect(clippy::cast_possible_truncation)]
|
||||
let target_buffer_size = ((u64::from(buffer_size) * 90) / 100) as u32;
|
||||
|
||||
(
|
||||
MuxInnerResult {
|
||||
mux: Self {
|
||||
rx: Some(rx),
|
||||
maybe_downgrade_packet,
|
||||
tx,
|
||||
|
||||
actor_rx: Some(fut_rx),
|
||||
actor_tx: fut_tx,
|
||||
fut_exited: fut_exited.clone(),
|
||||
|
||||
tcp_extensions: Self::get_tcp_extensions(&extensions),
|
||||
extensions: Some(extensions),
|
||||
buffer_size,
|
||||
target_buffer_size,
|
||||
|
||||
role: Role::Server,
|
||||
|
||||
stream_map: HashMap::default(),
|
||||
|
||||
server_tx,
|
||||
},
|
||||
actor_exited: fut_exited,
|
||||
actor_tx: ret_fut_tx,
|
||||
},
|
||||
server_rx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_client(
|
||||
rx: R,
|
||||
maybe_downgrade_packet: Option<Packet<'static>>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
extensions: Vec<AnyProtocolExtension>,
|
||||
buffer_size: u32,
|
||||
) -> MuxInnerResult<R, W> {
|
||||
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent<W>>(256);
|
||||
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
|
||||
let ret_fut_tx = fut_tx.clone();
|
||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||
.collect();
|
||||
|
||||
MuxInnerResult {
|
||||
actor_tx: actor_tx.clone(),
|
||||
mux: Self {
|
||||
rx: Some(rx),
|
||||
maybe_downgrade_packet,
|
||||
start: Some(MuxStart {
|
||||
rx,
|
||||
downgrade,
|
||||
extensions,
|
||||
actor_rx,
|
||||
}),
|
||||
tx,
|
||||
flow_stream_types: flow_extensions,
|
||||
|
||||
actor_rx: Some(fut_rx),
|
||||
actor_tx: fut_tx,
|
||||
fut_exited: fut_exited.clone(),
|
||||
mux,
|
||||
|
||||
tcp_extensions: Self::get_tcp_extensions(&extensions),
|
||||
extensions: Some(extensions),
|
||||
streams: StreamMap::default(),
|
||||
current_id: 0,
|
||||
buffer_size,
|
||||
target_buffer_size: 0,
|
||||
|
||||
role: Role::Client,
|
||||
|
||||
stream_map: HashMap::default(),
|
||||
|
||||
server_tx,
|
||||
actor_tx,
|
||||
},
|
||||
actor_exited: fut_exited,
|
||||
actor_tx: ret_fut_tx,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn into_future(mut self) -> Result<(), WispError> {
|
||||
let ret = self.stream_loop().await;
|
||||
let ret = self.entry().await;
|
||||
|
||||
self.fut_exited.store(true, Ordering::Release);
|
||||
|
||||
for stream in self.stream_map.values() {
|
||||
Self::close_stream(stream, ClosePacket::new(CloseReason::Unknown));
|
||||
for stream in self.streams.drain() {
|
||||
Self::close_stream(
|
||||
stream.1,
|
||||
ClosePacket {
|
||||
reason: CloseReason::Unknown,
|
||||
},
|
||||
);
|
||||
}
|
||||
self.stream_map.clear();
|
||||
|
||||
let _ = self.tx.close().await;
|
||||
self.tx.lock().await;
|
||||
let _ = self.tx.get().close().await;
|
||||
self.tx.unlock();
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
fn create_new_stream(
|
||||
&mut self,
|
||||
stream_id: u32,
|
||||
stream_type: StreamType,
|
||||
) -> (MuxMapValue, MuxStream<W>) {
|
||||
let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server {
|
||||
self.buffer_size as usize
|
||||
} else {
|
||||
usize::MAX - 8
|
||||
});
|
||||
async fn entry(&mut self) -> Result<(), WispError> {
|
||||
let MuxStart {
|
||||
rx,
|
||||
downgrade,
|
||||
extensions,
|
||||
actor_rx,
|
||||
} = self.start.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
|
||||
let should_flow_control = self.tcp_extensions.contains(&stream_type.into());
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
let close_reason: Arc<AtomicCloseReason> =
|
||||
AtomicCloseReason::new(CloseReason::Unknown).into();
|
||||
let is_closed_event: Arc<Event> = Event::new().into();
|
||||
|
||||
(
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
stream_type,
|
||||
|
||||
should_flow_control,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
|
||||
is_closed: is_closed.clone(),
|
||||
close_reason: close_reason.clone(),
|
||||
is_closed_event: is_closed_event.clone(),
|
||||
},
|
||||
MuxStream::new(
|
||||
stream_id,
|
||||
self.role,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
self.actor_tx.clone(),
|
||||
self.tx.clone(),
|
||||
is_closed,
|
||||
is_closed_event,
|
||||
close_reason,
|
||||
should_flow_control,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
self.target_buffer_size,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn close_stream(stream: &MuxMapValue, close_packet: ClosePacket) {
|
||||
stream
|
||||
.close_reason
|
||||
.store(close_packet.reason, Ordering::Release);
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.is_closed_event.notify(usize::MAX);
|
||||
stream.flow_control.store(u32::MAX, Ordering::Release);
|
||||
stream.flow_control_event.notify(usize::MAX);
|
||||
}
|
||||
|
||||
async fn process_wisp_message(
|
||||
rx: &mut R,
|
||||
tx: &LockedWebSocketWrite<W>,
|
||||
extensions: &mut [AnyProtocolExtension],
|
||||
msg: (Frame<'static>, Option<Frame<'static>>),
|
||||
) -> Result<Option<WsEvent<W>>, WispError> {
|
||||
let (mut frame, optional_frame) = msg;
|
||||
if frame.opcode == OpCode::Close {
|
||||
return Ok(None);
|
||||
} else if frame.opcode == OpCode::Ping {
|
||||
return Ok(Some(WsEvent::SendPong(frame.payload)));
|
||||
} else if frame.opcode == OpCode::Pong {
|
||||
return Ok(Some(WsEvent::Noop));
|
||||
}
|
||||
|
||||
if let Some(ref extra_frame) = optional_frame {
|
||||
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
|
||||
let mut payload = BytesMut::from(frame.payload);
|
||||
payload.extend_from_slice(&extra_frame.payload);
|
||||
frame.payload = Payload::Bytes(payload);
|
||||
}
|
||||
}
|
||||
|
||||
let packet = Packet::maybe_handle_extension(frame, extensions, rx, tx).await?;
|
||||
|
||||
Ok(Some(WsEvent::WispMessage(packet, optional_frame)))
|
||||
}
|
||||
|
||||
async fn stream_loop(&mut self) -> Result<(), WispError> {
|
||||
let mut next_free_stream_id: u32 = 1;
|
||||
|
||||
let rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
let maybe_downgrade_packet = self.maybe_downgrade_packet.take();
|
||||
|
||||
let tx = self.tx.clone();
|
||||
let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
|
||||
let extensions = self.extensions.take().ok_or(WispError::MuxTaskStarted)?;
|
||||
|
||||
if let Some(downgrade_packet) = maybe_downgrade_packet {
|
||||
if self.handle_packet(downgrade_packet, None).await? {
|
||||
if let Some(packet) = downgrade {
|
||||
if self.handle_packet(packet)? {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut read_stream = Box::pin(unfold(
|
||||
(rx, tx.clone(), extensions),
|
||||
|(mut rx, tx, mut extensions)| async {
|
||||
let ret = async {
|
||||
let msg = rx.wisp_read_split(&tx).await?;
|
||||
Self::process_wisp_message(&mut rx, &tx, &mut extensions, msg).await
|
||||
let read_stream = pin!(unfold(
|
||||
(rx, self.tx.clone(), extensions),
|
||||
|(mut rx, mut tx, mut extensions)| async {
|
||||
let ret: Result<Option<WsEvent<W>>, WispError> = async {
|
||||
if let Some(msg) = rx.next().await {
|
||||
match MaybeExtensionPacket::decode(msg?, &mut extensions, &mut rx, &mut tx)
|
||||
.await?
|
||||
{
|
||||
MaybeExtensionPacket::Packet(x) => Ok(Some(WsEvent::WispMessage(x))),
|
||||
MaybeExtensionPacket::ExtensionHandled => Ok(None),
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
.await;
|
||||
ret.transpose().map(|x| (x, (rx, tx, extensions)))
|
||||
},
|
||||
))
|
||||
.fuse();
|
||||
));
|
||||
|
||||
let mut recv_fut = fut_rx.recv_async().fuse();
|
||||
while let Some(msg) = select! {
|
||||
x = recv_fut => {
|
||||
drop(recv_fut);
|
||||
recv_fut = fut_rx.recv_async().fuse();
|
||||
Ok(x.ok())
|
||||
},
|
||||
x = read_stream.next() => {
|
||||
x.transpose()
|
||||
}
|
||||
}? {
|
||||
match msg {
|
||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||
let mut stream = select(read_stream, actor_rx.into_stream().map(Ok));
|
||||
|
||||
while let Some(msg) = stream.next().await {
|
||||
match msg? {
|
||||
WsEvent::CreateStream(connect, channel) => {
|
||||
let ret: Result<MuxStream<W>, WispError> = async {
|
||||
let stream_id = next_free_stream_id;
|
||||
let next_stream_id = next_free_stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?;
|
||||
|
||||
let (map_value, stream) = self.create_new_stream(stream_id, stream_type);
|
||||
let (stream, stream_id) = self.create_stream(connect.stream_type)?;
|
||||
|
||||
self.tx.lock().await;
|
||||
self.tx
|
||||
.write_frame(
|
||||
Packet::new_connect(stream_id, stream_type, port, host).into(),
|
||||
.get()
|
||||
.send(
|
||||
Packet {
|
||||
stream_id,
|
||||
packet_type: PacketType::Connect(connect),
|
||||
}
|
||||
.encode(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.stream_map.insert(stream_id, map_value);
|
||||
|
||||
next_free_stream_id = next_stream_id;
|
||||
self.tx.unlock();
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
.await;
|
||||
let _ = channel.send(ret);
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some(stream) = self.stream_map.remove(&packet.stream_id) {
|
||||
if let PacketType::Close(close) = packet.packet_type {
|
||||
Self::close_stream(&stream, close);
|
||||
WsEvent::Close(id, close, channel) => {
|
||||
if let Some(stream) = self.streams.remove(&id) {
|
||||
Self::close_stream(stream, close);
|
||||
let pkt = Packet {
|
||||
stream_id: id,
|
||||
packet_type: PacketType::Close(close),
|
||||
}
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
.encode();
|
||||
|
||||
self.tx.lock().await;
|
||||
let ret = self.tx.get().send(pkt).await;
|
||||
self.tx.unlock();
|
||||
|
||||
let _ = channel.send(ret);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId(id)));
|
||||
}
|
||||
}
|
||||
WsEvent::SendPing(payload, channel) => {
|
||||
let _ = channel.send(
|
||||
self.tx
|
||||
.write_frame(Frame::new(OpCode::Ping, payload, true))
|
||||
.await,
|
||||
);
|
||||
}
|
||||
WsEvent::SendPong(payload) => {
|
||||
self.tx
|
||||
.write_frame(Frame::new(OpCode::Pong, payload, true))
|
||||
.await?;
|
||||
}
|
||||
WsEvent::EndFut(x) => {
|
||||
if let Some(reason) = x {
|
||||
self.tx.lock().await;
|
||||
let _ = self
|
||||
.tx
|
||||
.write_frame(Packet::new_close(0, reason).into())
|
||||
.get()
|
||||
.send(Packet::new_close(0, reason).encode())
|
||||
.await;
|
||||
self.tx.unlock();
|
||||
}
|
||||
break;
|
||||
}
|
||||
WsEvent::WispMessage(packet, optional_frame) => {
|
||||
if let Some(packet) = packet {
|
||||
let should_break = self.handle_packet(packet, optional_frame).await?;
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
WsEvent::WispMessage(packet) => {
|
||||
let should_break = self.handle_packet(packet)?;
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
WsEvent::Noop => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_close_packet(&mut self, stream_id: u32, inner_packet: ClosePacket) -> bool {
|
||||
fn create_stream(&mut self, ty: StreamType) -> Result<(MuxStream<W>, u32), WispError> {
|
||||
let id = self
|
||||
.current_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?;
|
||||
self.current_id = id;
|
||||
Ok((self.add_stream(id, ty), id))
|
||||
}
|
||||
|
||||
fn add_stream(&mut self, id: u32, ty: StreamType) -> MuxStream<W> {
|
||||
let flow = M::get_flow_control(ty, &self.flow_stream_types);
|
||||
let (data_tx, data_rx) = if flow == FlowControl::EnabledSendMessages {
|
||||
flume::bounded(self.buffer_size as usize)
|
||||
} else {
|
||||
flume::unbounded()
|
||||
};
|
||||
|
||||
let info = Arc::new(StreamInfo::new(id, flow, self.buffer_size));
|
||||
let val = StreamMapValue {
|
||||
info: info.clone(),
|
||||
stream: data_tx,
|
||||
};
|
||||
self.streams.insert(id, val);
|
||||
|
||||
MuxStream::new(data_rx, self.actor_tx.clone(), self.tx.clone(), info)
|
||||
}
|
||||
|
||||
fn close_stream(stream: StreamMapValue, close: ClosePacket) {
|
||||
drop(stream.stream);
|
||||
stream.info.set_reason(close.reason);
|
||||
}
|
||||
|
||||
fn handle_packet(&mut self, packet: Packet<'static>) -> Result<bool, WispError> {
|
||||
use PacketType as P;
|
||||
match packet.packet_type {
|
||||
P::Connect(connect) => {
|
||||
let stream = self.add_stream(packet.stream_id, connect.stream_type);
|
||||
self.mux.handle_connect_packet(stream, connect)?;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
P::Data(data) => {
|
||||
self.mux.handle_data_packet(
|
||||
packet.stream_id,
|
||||
data.into_owned(),
|
||||
&mut self.streams,
|
||||
)?;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
P::Continue(cont) => {
|
||||
self.mux
|
||||
.handle_continue_packet(packet.stream_id, cont, &mut self.streams)?;
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
P::Close(close) => Ok(self.handle_close_packet(packet.stream_id, close)),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_close_packet(&mut self, stream_id: u32, close: ClosePacket) -> bool {
|
||||
if stream_id == 0 {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(stream) = self.stream_map.remove(&stream_id) {
|
||||
Self::close_stream(&stream, inner_packet);
|
||||
if let Some(stream) = self.streams.remove(&stream_id) {
|
||||
Self::close_stream(stream, close);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn handle_data_packet(
|
||||
&mut self,
|
||||
stream_id: u32,
|
||||
optional_frame: Option<Frame<'static>>,
|
||||
data: Payload<'static>,
|
||||
) -> bool {
|
||||
let mut data = BytesMut::from(data);
|
||||
|
||||
if let Some(stream) = self.stream_map.get(&stream_id) {
|
||||
if let Some(extra_frame) = optional_frame {
|
||||
if data.is_empty() {
|
||||
data = extra_frame.payload.into();
|
||||
} else {
|
||||
data.extend_from_slice(&extra_frame.payload);
|
||||
}
|
||||
}
|
||||
let _ = stream.stream.try_send(Payload::Bytes(data));
|
||||
if self.role == Role::Server && stream.should_flow_control {
|
||||
stream.flow_control.store(
|
||||
stream
|
||||
.flow_control
|
||||
.load(Ordering::Acquire)
|
||||
.saturating_sub(1),
|
||||
Ordering::Release,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
packet: Packet<'static>,
|
||||
optional_frame: Option<Frame<'static>>,
|
||||
) -> Result<bool, WispError> {
|
||||
use PacketType as P;
|
||||
match packet.packet_type {
|
||||
P::Data(data) => Ok(self.handle_data_packet(packet.stream_id, optional_frame, data)),
|
||||
P::Close(inner_packet) => Ok(self.handle_close_packet(packet.stream_id, inner_packet)),
|
||||
|
||||
_ => match self.role {
|
||||
Role::Server => self.server_handle_packet(packet, optional_frame).await,
|
||||
Role::Client => self.client_handle_packet(&packet),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn server_handle_packet(
|
||||
&mut self,
|
||||
packet: Packet<'static>,
|
||||
_optional_frame: Option<Frame<'static>>,
|
||||
) -> Result<bool, WispError> {
|
||||
use PacketType as P;
|
||||
match packet.packet_type {
|
||||
P::Connect(inner_packet) => {
|
||||
let (map_value, stream) =
|
||||
self.create_new_stream(packet.stream_id, inner_packet.stream_type);
|
||||
self.server_tx
|
||||
.send_async((inner_packet, stream))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
self.stream_map.insert(packet.stream_id, map_value);
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
// Continue | Info => invalid packet type
|
||||
// Data | Close => specialcased
|
||||
_ => Err(WispError::InvalidPacketType),
|
||||
}
|
||||
}
|
||||
|
||||
fn client_handle_packet(&mut self, packet: &Packet<'static>) -> Result<bool, WispError> {
|
||||
use PacketType as P;
|
||||
match packet.packet_type {
|
||||
P::Continue(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
if stream.stream_type == StreamType::Tcp {
|
||||
stream
|
||||
.flow_control
|
||||
.store(inner_packet.buffer_remaining, Ordering::Release);
|
||||
let _ = stream.flow_control_event.notify(u32::MAX);
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
// Connect | Info => invalid packet type
|
||||
// Data | Close => specialcased
|
||||
_ => Err(WispError::InvalidPacketType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue