rewrite actor

This commit is contained in:
Toshit Chawda 2024-08-31 16:20:56 -07:00
parent b1f56c1dae
commit 9cd87b7243
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 470 additions and 472 deletions

View file

@ -11,6 +11,7 @@ mod fastwebsockets;
#[cfg(feature = "generic_stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
pub mod generic;
mod inner;
mod packet;
mod sink_unfold;
mod stream;
@ -18,21 +19,19 @@ pub mod ws;
pub use crate::{packet::*, stream::*};
use bytes::{Bytes, BytesMut};
use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use flume as mpsc;
use futures::{channel::oneshot, select, Future, FutureExt};
use futures_timer::Delay;
use scc::HashMap;
use inner::{MuxInner, WsEvent};
use std::{
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
/// Wisp version supported by this crate.
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
@ -157,363 +156,6 @@ impl std::fmt::Display for WispError {
impl std::error::Error for WispError {}
struct MuxMapValue {
stream: mpsc::Sender<Bytes>,
stream_type: StreamType,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
is_closed: Arc<AtomicBool>,
close_reason: Arc<AtomicCloseReason>,
is_closed_event: Arc<Event>,
}
impl Drop for MuxMapValue {
fn drop(&mut self) {
self.is_closed.store(true, Ordering::Release);
self.is_closed_event.notify(usize::MAX);
}
}
struct MuxInner {
tx: ws::LockedWebSocketWrite,
stream_map: HashMap<u32, MuxMapValue>,
buffer_size: u32,
fut_exited: Arc<AtomicBool>,
}
impl MuxInner {
pub async fn server_into_future<R>(
self,
rx: R,
extensions: Vec<AnyProtocolExtension>,
close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
self.as_future(
close_rx,
close_tx.clone(),
self.server_loop(rx, extensions, muxstream_sender, close_tx),
)
.await
}
pub async fn client_into_future<R>(
self,
rx: R,
extensions: Vec<AnyProtocolExtension>,
close_rx: mpsc::Receiver<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
self.as_future(close_rx, close_tx, self.client_loop(rx, extensions))
.await
}
async fn as_future(
&self,
close_rx: mpsc::Receiver<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
wisp_fut: impl Future<Output = Result<(), WispError>>,
) -> Result<(), WispError> {
let ret = futures::select! {
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
x = wisp_fut.fuse() => x,
};
self.fut_exited.store(true, Ordering::Release);
self.stream_map.clear_async().await;
let _ = self.tx.close().await;
ret
}
async fn create_new_stream(
&self,
stream_id: u32,
stream_type: StreamType,
role: Role,
stream_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
target_buffer_size: u32,
) -> Result<(MuxMapValue, MuxStream), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
let close_reason: Arc<AtomicCloseReason> =
AtomicCloseReason::new(CloseReason::Unknown).into();
let is_closed_event: Arc<Event> = Event::new().into();
Ok((
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
close_reason: close_reason.clone(),
is_closed_event: is_closed_event.clone(),
},
MuxStream::new(
stream_id,
role,
stream_type,
ch_rx,
stream_tx,
tx,
is_closed,
is_closed_event,
close_reason,
flow_control,
flow_control_event,
target_buffer_size,
),
))
}
async fn stream_loop(
&self,
stream_rx: mpsc::Receiver<WsEvent>,
stream_tx: mpsc::Sender<WsEvent>,
) {
let mut next_free_stream_id: u32 = 1;
while let Ok(msg) = stream_rx.recv_async().await {
match msg {
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?;
let (map_value, stream) = self
.create_new_stream(
stream_id,
stream_type,
Role::Client,
stream_tx.clone(),
self.tx.clone(),
0,
)
.await?;
self.tx
.write_frame(
Packet::new_connect(stream_id, stream_type, port, host).into(),
)
.await?;
self.stream_map.upsert_async(stream_id, map_value).await;
next_free_stream_id = next_stream_id;
Ok(stream)
}
.await;
let _ = channel.send(ret);
}
WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove_async(&packet.stream_id).await
{
if let PacketType::Close(close) = packet.packet_type {
self.close_stream(stream, close);
}
let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
WsEvent::EndFut(x) => {
if let Some(reason) = x {
let _ = self
.tx
.write_frame(Packet::new_close(0, reason).into())
.await;
}
break;
}
}
}
}
fn close_stream(&self, stream: MuxMapValue, close_packet: ClosePacket) {
stream
.close_reason
.store(close_packet.reason, Ordering::Release);
stream.is_closed.store(true, Ordering::Release);
stream.is_closed_event.notify(usize::MAX);
stream.flow_control.store(u32::MAX, Ordering::Release);
stream.flow_control_event.notify(usize::MAX);
}
async fn server_loop<R>(
&self,
mut rx: R,
mut extensions: Vec<AnyProtocolExtension>,
muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>,
stream_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
// will send continues once flow_control is at 10% of max
let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
loop {
let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload);
}
}
if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{
use PacketType::*;
match packet.packet_type {
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
Connect(inner_packet) => {
let (map_value, stream) = self
.create_new_stream(
packet.stream_id,
inner_packet.stream_type,
Role::Server,
stream_tx.clone(),
self.tx.clone(),
target_buffer_size,
)
.await?;
muxstream_sender
.send_async((inner_packet, stream))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
self.stream_map
.upsert_async(packet.stream_id, map_value)
.await;
}
Data(data) => {
let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if let Some(extra_frame) = optional_frame {
if data.is_empty() {
data = extra_frame.payload.into();
} else {
data.extend_from_slice(&extra_frame.payload);
}
}
let _ = stream.stream.try_send(data.freeze());
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
.flow_control
.load(Ordering::Acquire)
.saturating_sub(1),
Ordering::Release,
);
}
}
}
Close(inner_packet) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, stream)) =
self.stream_map.remove_async(&packet.stream_id).await
{
self.close_stream(stream, inner_packet)
}
}
}
}
}
}
async fn client_loop<R>(
&self,
mut rx: R,
mut extensions: Vec<AnyProtocolExtension>,
) -> Result<(), WispError>
where
R: ws::WebSocketRead + Send,
{
loop {
let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload);
}
}
if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{
use PacketType::*;
match packet.packet_type {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if let Some(extra_frame) = optional_frame {
if data.is_empty() {
data = extra_frame.payload.into();
} else {
data.extend_from_slice(&extra_frame.payload);
}
}
let _ = stream.stream.send_async(data.freeze()).await;
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
.store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.flow_control_event.notify(u32::MAX);
}
}
}
Close(inner_packet) => {
if packet.stream_id == 0 {
break Ok(());
}
if let Some((_, stream)) =
self.stream_map.remove_async(&packet.stream_id).await
{
self.close_stream(stream, inner_packet)
}
}
}
}
}
}
}
async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
@ -576,7 +218,7 @@ pub struct ServerMux {
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>,
close_tx: mpsc::Sender<WsEvent>,
actor_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: ws::LockedWebSocketWrite,
fut_exited: Arc<AtomicBool>,
@ -589,8 +231,8 @@ impl ServerMux {
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut read: R,
write: W,
mut rx: R,
tx: W,
buffer_size: u32,
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
@ -598,55 +240,47 @@ impl ServerMux {
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(Box::new(write));
let fut_exited = Arc::new(AtomicBool::new(false));
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
write
.write_frame(Packet::new_continue(0, buffer_size).into())
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) =
if let Some(builders) = extension_builders {
write
.write_frame(
Packet::new_info(
builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect(),
)
.into(),
tx.write_frame(
Packet::new_info(
builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect(),
)
.await?;
maybe_wisp_v2(&mut read, &write, builders).await?
.into(),
)
.await?;
maybe_wisp_v2(&mut rx, &tx, builders).await?
} else {
(Vec::new(), None, true)
};
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
let (mux_inner, fut_exited, actor_tx, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions,
buffer_size,
);
Ok(ServerMuxResult(
Self {
muxstream_recv: rx,
close_tx: close_tx.clone(),
muxstream_recv,
actor_tx,
downgraded,
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
tx: write.clone(),
supported_extension_ids,
tx,
fut_exited: fut_exited.clone(),
},
MuxInner {
tx: write,
stream_map: HashMap::new(),
buffer_size,
fut_exited,
}
.server_into_future(
AppendingWebSocketRead(extra_packet, read),
supported_extensions,
close_rx,
tx,
close_tx,
),
mux_inner.into_future(),
))
}
@ -662,7 +296,7 @@ impl ServerMux {
if self.fut_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.close_tx
self.actor_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
@ -695,7 +329,7 @@ impl ServerMux {
impl Drop for ServerMux {
fn drop(&mut self) {
let _ = self.close_tx.send(WsEvent::EndFut(None));
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
@ -762,7 +396,7 @@ pub struct ClientMux {
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extension_ids: Vec<u8>,
stream_tx: mpsc::Sender<WsEvent>,
actor_tx: mpsc::Sender<WsEvent>,
tx: ws::LockedWebSocketWrite,
fut_exited: Arc<AtomicBool>,
}
@ -774,68 +408,61 @@ impl ClientMux {
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut read: R,
write: W,
mut rx: R,
tx: W,
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let write = ws::LockedWebSocketWrite::new(Box::new(write));
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
let fut_exited = Arc::new(AtomicBool::new(false));
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) =
if let Some(builders) = extension_builders {
let x = maybe_wisp_v2(&mut read, &write, builders).await?;
let x = maybe_wisp_v2(&mut rx, &tx, builders).await?;
// if not downgraded
if !x.2 {
write
.write_frame(
Packet::new_info(
builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect(),
)
.into(),
tx.write_frame(
Packet::new_info(
builders
.iter()
.map(|x| x.build_to_extension(Role::Client))
.collect(),
)
.await?;
.into(),
)
.await?;
}
x
} else {
(Vec::new(), None, true)
};
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
let (mux_inner, fut_exited, actor_tx) = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions,
packet.buffer_remaining,
);
Ok(ClientMuxResult(
Self {
stream_tx: tx.clone(),
actor_tx,
downgraded,
supported_extension_ids: supported_extensions
.iter()
.map(|x| x.get_id())
.collect(),
tx: write.clone(),
fut_exited: fut_exited.clone(),
},
MuxInner {
tx: write,
stream_map: HashMap::new(),
buffer_size: packet.buffer_remaining,
fut_exited,
}
.client_into_future(
AppendingWebSocketRead(extra_packet, read),
supported_extensions,
rx,
supported_extension_ids,
tx,
),
fut_exited,
},
mux_inner.into_future(),
))
} else {
Err(WispError::InvalidPacketType)
@ -863,7 +490,7 @@ impl ClientMux {
]));
}
let (tx, rx) = oneshot::channel();
self.stream_tx
self.actor_tx
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
@ -874,7 +501,7 @@ impl ClientMux {
if self.fut_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.stream_tx
self.actor_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
@ -907,7 +534,7 @@ impl ClientMux {
impl Drop for ClientMux {
fn drop(&mut self) {
let _ = self.stream_tx.send(WsEvent::EndFut(None));
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}