congestion stream types

This commit is contained in:
Toshit Chawda 2024-09-07 10:41:49 -07:00
parent d6c095fe7b
commit f5b50bcc98
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
9 changed files with 103 additions and 62 deletions

View file

@ -87,20 +87,15 @@ impl IncomingBody {
impl Stream for IncomingBody { impl Stream for IncomingBody {
type Item = std::io::Result<Bytes>; type Item = std::io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project(); self.project().incoming.poll_frame(cx).map(|x| {
let ret = this.incoming.poll_frame(cx); x.map(|x| {
match ret { x.map_err(std::io::Error::other).and_then(|x| {
Poll::Ready(item) => Poll::<Option<Self::Item>>::Ready(match item { x.into_data().map_err(|_| {
Some(frame) => frame std::io::Error::other("trailer frame recieved; not implemented")
.map(|x| {
x.into_data()
.map_err(|_| std::io::Error::other("not data frame"))
}) })
.ok(), })
None => None, })
}), })
Poll::Pending => Poll::<Option<Self::Item>>::Pending,
}
} }
} }

View file

@ -37,6 +37,10 @@ impl ProtocolExtension for TWispServerProtocolExtension {
&[0xF0] &[0xF0]
} }
fn get_congestion_stream_types(&self) -> &'static [u8] {
&[0x03]
}
fn encode(&self) -> Bytes { fn encode(&self) -> Bytes {
Bytes::new() Bytes::new()
} }

View file

@ -243,7 +243,10 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
let mut set: JoinSet<()> = JoinSet::new(); let mut set: JoinSet<()> = JoinSet::new();
let event: Arc<Event> = Event::new().into(); let event: Arc<Event> = Event::new().into();
set.spawn(tokio::task::unconstrained(fut.map(|_| {}))); let mux_id = id.clone();
set.spawn(tokio::task::unconstrained(fut.map(move |x| {
trace!("wisp client id {:?} multiplexor result {:?}", mux_id, x)
})));
while let Some((connect, stream)) = mux.server_new_stream().await { while let Some((connect, stream)) = mux.server_new_stream().await {
set.spawn(handle_stream( set.spawn(handle_stream(

View file

@ -65,6 +65,10 @@ pub trait ProtocolExtension: std::fmt::Debug {
/// ///
/// Used to decide whether to call the protocol extension's packet handler. /// Used to decide whether to call the protocol extension's packet handler.
fn get_supported_packets(&self) -> &'static [u8]; fn get_supported_packets(&self) -> &'static [u8];
/// Get stream types that should be treated as TCP.
///
/// Used to decide whether to handle congestion control for that stream type.
fn get_congestion_stream_types(&self) -> &'static [u8];
/// Encode self into Bytes. /// Encode self into Bytes.
fn encode(&self) -> Bytes; fn encode(&self) -> Bytes;

View file

@ -99,6 +99,10 @@ impl ProtocolExtension for PasswordProtocolExtension {
&[] &[]
} }
fn get_congestion_stream_types(&self) -> &'static [u8] {
&[]
}
fn encode(&self) -> Bytes { fn encode(&self) -> Bytes {
match self.role { match self.role {
Role::Server => Bytes::new(), Role::Server => Bytes::new(),

View file

@ -39,6 +39,10 @@ impl ProtocolExtension for UdpProtocolExtension {
&[] &[]
} }
fn get_congestion_stream_types(&self) -> &'static [u8] {
&[]
}
fn encode(&self) -> Bytes { fn encode(&self) -> Bytes {
Bytes::new() Bytes::new()
} }

View file

@ -31,6 +31,7 @@ struct MuxMapValue {
stream: mpsc::Sender<Bytes>, stream: mpsc::Sender<Bytes>,
stream_type: StreamType, stream_type: StreamType,
should_flow_control: bool,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>, flow_control_event: Arc<Event>,
@ -44,11 +45,12 @@ pub struct MuxInner<R: WebSocketRead + Send> {
rx: Option<R>, rx: Option<R>,
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>, extensions: Vec<AnyProtocolExtension>,
tcp_extensions: Vec<u8>,
role: Role, role: Role,
// gets taken by the mux task // gets taken by the mux task
fut_rx: Option<mpsc::Receiver<WsEvent>>, actor_rx: Option<mpsc::Receiver<WsEvent>>,
fut_tx: mpsc::Sender<WsEvent>, actor_tx: mpsc::Sender<WsEvent>,
fut_exited: Arc<AtomicBool>, fut_exited: Arc<AtomicBool>,
stream_map: IntMap<u32, MuxMapValue>, stream_map: IntMap<u32, MuxMapValue>,
@ -59,16 +61,29 @@ pub struct MuxInner<R: WebSocketRead + Send> {
server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>, server_tx: mpsc::Sender<(ConnectPacket, MuxStream)>,
} }
pub struct MuxInnerResult<R: WebSocketRead + Send> {
pub mux: MuxInner<R>,
pub actor_exited: Arc<AtomicBool>,
pub actor_tx: mpsc::Sender<WsEvent>,
}
impl<R: WebSocketRead + Send> MuxInner<R> { impl<R: WebSocketRead + Send> MuxInner<R> {
fn get_tcp_extensions(extensions: &[AnyProtocolExtension]) -> Vec<u8> {
extensions
.iter()
.flat_map(|x| x.get_congestion_stream_types())
.copied()
.chain(std::iter::once(StreamType::Tcp.into()))
.collect()
}
pub fn new_server( pub fn new_server(
rx: R, rx: R,
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>, extensions: Vec<AnyProtocolExtension>,
buffer_size: u32, buffer_size: u32,
) -> ( ) -> (
Self, MuxInnerResult<R>,
Arc<AtomicBool>,
mpsc::Sender<WsEvent>,
mpsc::Receiver<(ConnectPacket, MuxStream)>, mpsc::Receiver<(ConnectPacket, MuxStream)>,
) { ) {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256); let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
@ -77,26 +92,29 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
let fut_exited = Arc::new(AtomicBool::new(false)); let fut_exited = Arc::new(AtomicBool::new(false));
( (
Self { MuxInnerResult {
rx: Some(rx), mux: Self {
tx, rx: Some(rx),
tx,
fut_rx: Some(fut_rx), actor_rx: Some(fut_rx),
fut_tx, actor_tx: fut_tx,
fut_exited: fut_exited.clone(), fut_exited: fut_exited.clone(),
extensions, tcp_extensions: Self::get_tcp_extensions(&extensions),
buffer_size, extensions,
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32, buffer_size,
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32,
role: Role::Server, role: Role::Server,
stream_map: IntMap::default(), stream_map: IntMap::default(),
server_tx, server_tx,
},
actor_exited: fut_exited,
actor_tx: ret_fut_tx,
}, },
fut_exited,
ret_fut_tx,
server_rx, server_rx,
) )
} }
@ -106,21 +124,22 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
tx: LockedWebSocketWrite, tx: LockedWebSocketWrite,
extensions: Vec<AnyProtocolExtension>, extensions: Vec<AnyProtocolExtension>,
buffer_size: u32, buffer_size: u32,
) -> (Self, Arc<AtomicBool>, mpsc::Sender<WsEvent>) { ) -> MuxInnerResult<R> {
let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256); let (fut_tx, fut_rx) = mpsc::bounded::<WsEvent>(256);
let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let (server_tx, _) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let ret_fut_tx = fut_tx.clone(); let ret_fut_tx = fut_tx.clone();
let fut_exited = Arc::new(AtomicBool::new(false)); let fut_exited = Arc::new(AtomicBool::new(false));
( MuxInnerResult {
Self { mux: Self {
rx: Some(rx), rx: Some(rx),
tx, tx,
fut_rx: Some(fut_rx), actor_rx: Some(fut_rx),
fut_tx, actor_tx: fut_tx,
fut_exited: fut_exited.clone(), fut_exited: fut_exited.clone(),
tcp_extensions: Self::get_tcp_extensions(&extensions),
extensions, extensions,
buffer_size, buffer_size,
target_buffer_size: 0, target_buffer_size: 0,
@ -131,9 +150,9 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
server_tx, server_tx,
}, },
fut_exited, actor_exited: fut_exited,
ret_fut_tx, actor_tx: ret_fut_tx,
) }
} }
pub async fn into_future(mut self) -> Result<(), WispError> { pub async fn into_future(mut self) -> Result<(), WispError> {
@ -157,6 +176,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
) -> Result<(MuxMapValue, MuxStream), WispError> { ) -> Result<(MuxMapValue, MuxStream), WispError> {
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
let should_flow_control = self.tcp_extensions.contains(&stream_type.into());
let flow_control_event: Arc<Event> = Event::new().into(); let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into(); let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
@ -170,6 +190,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
stream: ch_tx, stream: ch_tx,
stream_type, stream_type,
should_flow_control,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(), flow_control_event: flow_control_event.clone(),
@ -182,11 +203,12 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
self.role, self.role,
stream_type, stream_type,
ch_rx, ch_rx,
self.fut_tx.clone(), self.actor_tx.clone(),
self.tx.clone(), self.tx.clone(),
is_closed, is_closed,
is_closed_event, is_closed_event,
close_reason, close_reason,
should_flow_control,
flow_control, flow_control,
flow_control_event, flow_control_event,
self.target_buffer_size, self.target_buffer_size,
@ -233,7 +255,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?; let mut rx = self.rx.take().ok_or(WispError::MuxTaskStarted)?;
let tx = self.tx.clone(); let tx = self.tx.clone();
let fut_rx = self.fut_rx.take().ok_or(WispError::MuxTaskStarted)?; let fut_rx = self.actor_rx.take().ok_or(WispError::MuxTaskStarted)?;
let mut recv_fut = fut_rx.recv_async().fuse(); let mut recv_fut = fut_rx.recv_async().fuse();
let mut read_fut = rx.wisp_read_split(&tx).fuse(); let mut read_fut = rx.wisp_read_split(&tx).fuse();
@ -349,7 +371,7 @@ impl<R: WebSocketRead + Send> MuxInner<R> {
} }
} }
let _ = stream.stream.try_send(data.freeze()); let _ = stream.stream.try_send(data.freeze());
if self.role == Role::Server && stream.stream_type == StreamType::Tcp { if self.role == Role::Server && stream.should_flow_control {
stream.flow_control.store( stream.flow_control.store(
stream stream
.flow_control .flow_control

View file

@ -224,7 +224,7 @@ pub struct ServerMux {
actor_tx: mpsc::Sender<WsEvent>, actor_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: ws::LockedWebSocketWrite, tx: ws::LockedWebSocketWrite,
fut_exited: Arc<AtomicBool>, actor_exited: Arc<AtomicBool>,
} }
impl ServerMux { impl ServerMux {
@ -267,7 +267,7 @@ impl ServerMux {
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
let (mux_inner, fut_exited, actor_tx, muxstream_recv) = MuxInner::new_server( let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),
tx.clone(), tx.clone(),
supported_extensions, supported_extensions,
@ -277,26 +277,26 @@ impl ServerMux {
Ok(ServerMuxResult( Ok(ServerMuxResult(
Self { Self {
muxstream_recv, muxstream_recv,
actor_tx, actor_tx: mux_result.actor_tx,
downgraded, downgraded,
supported_extension_ids, supported_extension_ids,
tx, tx,
fut_exited: fut_exited.clone(), actor_exited: mux_result.actor_exited,
}, },
mux_inner.into_future(), mux_result.mux.into_future(),
)) ))
} }
/// Wait for a stream to be created. /// Wait for a stream to be created.
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
if self.fut_exited.load(Ordering::Acquire) { if self.actor_exited.load(Ordering::Acquire) {
return None; return None;
} }
self.muxstream_recv.recv_async().await.ok() self.muxstream_recv.recv_async().await.ok()
} }
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> { async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.fut_exited.load(Ordering::Acquire) { if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded); return Err(WispError::MuxTaskEnded);
} }
self.actor_tx self.actor_tx
@ -325,7 +325,7 @@ impl ServerMux {
MuxProtocolExtensionStream { MuxProtocolExtensionStream {
stream_id: 0, stream_id: 0,
tx: self.tx.clone(), tx: self.tx.clone(),
is_closed: self.fut_exited.clone(), is_closed: self.actor_exited.clone(),
} }
} }
} }
@ -401,7 +401,7 @@ pub struct ClientMux {
pub supported_extension_ids: Vec<u8>, pub supported_extension_ids: Vec<u8>,
actor_tx: mpsc::Sender<WsEvent>, actor_tx: mpsc::Sender<WsEvent>,
tx: ws::LockedWebSocketWrite, tx: ws::LockedWebSocketWrite,
fut_exited: Arc<AtomicBool>, actor_exited: Arc<AtomicBool>,
} }
impl ClientMux { impl ClientMux {
@ -450,7 +450,7 @@ impl ClientMux {
let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect(); let supported_extension_ids = supported_extensions.iter().map(|x| x.get_id()).collect();
let (mux_inner, fut_exited, actor_tx) = MuxInner::new_client( let mux_result = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),
tx.clone(), tx.clone(),
supported_extensions, supported_extensions,
@ -459,13 +459,13 @@ impl ClientMux {
Ok(ClientMuxResult( Ok(ClientMuxResult(
Self { Self {
actor_tx, actor_tx: mux_result.actor_tx,
downgraded, downgraded,
supported_extension_ids, supported_extension_ids,
tx, tx,
fut_exited, actor_exited: mux_result.actor_exited,
}, },
mux_inner.into_future(), mux_result.mux.into_future(),
)) ))
} else { } else {
Err(WispError::InvalidPacketType) Err(WispError::InvalidPacketType)
@ -479,7 +479,7 @@ impl ClientMux {
host: String, host: String,
port: u16, port: u16,
) -> Result<MuxStream, WispError> { ) -> Result<MuxStream, WispError> {
if self.fut_exited.load(Ordering::Acquire) { if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded); return Err(WispError::MuxTaskEnded);
} }
if stream_type == StreamType::Udp if stream_type == StreamType::Udp
@ -501,7 +501,7 @@ impl ClientMux {
} }
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> { async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.fut_exited.load(Ordering::Acquire) { if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded); return Err(WispError::MuxTaskEnded);
} }
self.actor_tx self.actor_tx
@ -530,7 +530,7 @@ impl ClientMux {
MuxProtocolExtensionStream { MuxProtocolExtensionStream {
stream_id: 0, stream_id: 0,
tx: self.tx.clone(), tx: self.tx.clone(),
is_closed: self.fut_exited.clone(), is_closed: self.actor_exited.clone(),
} }
} }
} }

View file

@ -40,6 +40,7 @@ pub struct MuxStreamRead {
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
close_reason: Arc<AtomicCloseReason>, close_reason: Arc<AtomicCloseReason>,
should_flow_control: bool,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
flow_control_read: AtomicU32, flow_control_read: AtomicU32,
target_flow_control: u32, target_flow_control: u32,
@ -55,7 +56,7 @@ impl MuxStreamRead {
x = self.rx.recv_async() => x.ok()?, x = self.rx.recv_async() => x.ok()?,
_ = self.is_closed_event.listen().fuse() => return None _ = self.is_closed_event.listen().fuse() => return None
}; };
if self.role == Role::Server && self.stream_type == StreamType::Tcp { if self.role == Role::Server && self.should_flow_control {
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) { if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) {
self.tx self.tx
@ -114,6 +115,7 @@ pub struct MuxStreamWrite {
close_reason: Arc<AtomicCloseReason>, close_reason: Arc<AtomicCloseReason>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
should_flow_control: bool,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
} }
@ -124,7 +126,7 @@ impl MuxStreamWrite {
body: Frame<'a>, body: Frame<'a>,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
if self.role == Role::Client if self.role == Role::Client
&& self.stream_type == StreamType::Tcp && self.should_flow_control
&& self.flow_control.load(Ordering::Acquire) == 0 && self.flow_control.load(Ordering::Acquire) == 0
{ {
self.continue_recieved.listen().await; self.continue_recieved.listen().await;
@ -278,6 +280,7 @@ impl MuxStream {
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
close_reason: Arc<AtomicCloseReason>, close_reason: Arc<AtomicCloseReason>,
should_flow_control: bool,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
target_flow_control: u32, target_flow_control: u32,
@ -293,6 +296,7 @@ impl MuxStream {
is_closed: is_closed.clone(), is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(), is_closed_event: is_closed_event.clone(),
close_reason: close_reason.clone(), close_reason: close_reason.clone(),
should_flow_control,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0), flow_control_read: AtomicU32::new(0),
target_flow_control, target_flow_control,
@ -305,6 +309,7 @@ impl MuxStream {
tx, tx,
is_closed: is_closed.clone(), is_closed: is_closed.clone(),
close_reason: close_reason.clone(), close_reason: close_reason.clone(),
should_flow_control,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
continue_recieved: continue_recieved.clone(), continue_recieved: continue_recieved.clone(),
}, },