read split frames

This commit is contained in:
Toshit Chawda 2024-07-22 11:04:12 -07:00
parent 7f37c8338e
commit 76eeec87dc
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 653 additions and 13 deletions

491
server/flamegraph.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 545 KiB

View file

@ -1,5 +1,5 @@
use anyhow::Context;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
use fastwebsockets::upgrade::UpgradeFut;
use futures_util::FutureExt;
use hyper_util::rt::TokioIo;
use tokio::{
@ -165,7 +165,6 @@ pub async fn handle_wisp(fut: UpgradeFut, id: String) -> anyhow::Result<()> {
assert_eq!(parts.read_buf.len(), 0);
parts.io.into_inner().split()
});
let read = FragmentCollectorRead::new(read);
let (extensions, buffer_size) = CONFIG.wisp.to_opts();

View file

@ -3,7 +3,8 @@ use std::ops::Deref;
use async_trait::async_trait;
use bytes::BytesMut;
use fastwebsockets::{
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketRead,
WebSocketWrite,
};
use tokio::io::{AsyncRead, AsyncWrite};
@ -88,13 +89,114 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
}
}
#[async_trait]
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for WebSocketRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<crate::ws::Frame<'static>, WispError> {
let mut frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await?;
if frame.opcode == OpCode::Continuation {
return Err(WispError::WsImplError(Box::new(
WebSocketError::InvalidContinuationFrame,
)));
}
let mut buf = BytesMut::from(frame.payload);
let opcode = frame.opcode;
while !frame.fin {
frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await?;
if frame.opcode != OpCode::Continuation {
return Err(WispError::WsImplError(Box::new(
WebSocketError::InvalidContinuationFrame,
)));
}
buf.extend_from_slice(&frame.payload);
}
Ok(crate::ws::Frame {
opcode: opcode.into(),
payload: crate::ws::Payload::Bytes(buf),
finished: frame.fin,
})
}
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(crate::ws::Frame<'static>, Option<crate::ws::Frame<'static>>), WispError> {
let mut frame_cnt = 1;
let mut frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await?;
let mut extra_frame = None;
if frame.opcode == OpCode::Continuation {
return Err(WispError::WsImplError(Box::new(
WebSocketError::InvalidContinuationFrame,
)));
}
let mut buf = BytesMut::from(frame.payload);
let opcode = frame.opcode;
while !frame.fin {
frame = self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await?;
if frame.opcode != OpCode::Continuation {
return Err(WispError::WsImplError(Box::new(
WebSocketError::InvalidContinuationFrame,
)));
}
if frame_cnt == 1 {
let payload = BytesMut::from(frame.payload);
extra_frame = Some(crate::ws::Frame {
opcode: opcode.into(),
payload: crate::ws::Payload::Bytes(payload),
finished: true,
});
} else if frame_cnt == 2 {
let extra_payload = extra_frame.take().unwrap().payload;
buf.extend_from_slice(&extra_payload);
buf.extend_from_slice(&frame.payload);
} else {
buf.extend_from_slice(&frame.payload);
}
frame_cnt += 1;
}
Ok((
crate::ws::Frame {
opcode: opcode.into(),
payload: crate::ws::Payload::Bytes(buf),
finished: frame.fin,
},
extra_frame,
))
}
}
#[async_trait]
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into())
}
async fn wisp_write_split(&mut self, header: crate::ws::Frame<'_>, body: crate::ws::Frame<'_>) -> Result<(), WispError> {
async fn wisp_write_split(
&mut self,
header: crate::ws::Frame<'_>,
body: crate::ws::Frame<'_>,
) -> Result<(), WispError> {
let mut header = Frame::from(header);
header.fin = false;
self.write_frame(header).await?;

View file

@ -29,7 +29,7 @@ use std::{
},
time::Duration,
};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
/// Wisp version supported by this crate.
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
@ -352,10 +352,19 @@ impl MuxInner {
let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload);
}
}
if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{
@ -380,8 +389,16 @@ impl MuxInner {
self.stream_map.insert(packet.stream_id, map_value);
}
Data(data) => {
let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.try_send(BytesMut::from(data).freeze());
if let Some(extra_frame) = optional_frame {
if data.is_empty() {
data = extra_frame.payload.into();
} else {
data.extend_from_slice(&extra_frame.payload);
}
}
let _ = stream.stream.try_send(data.freeze());
if stream.stream_type == StreamType::Tcp {
stream.flow_control.store(
stream
@ -413,11 +430,19 @@ impl MuxInner {
R: ws::WebSocketRead + Send,
{
loop {
let frame = rx.wisp_read_frame(&self.tx).await?;
let (mut frame, optional_frame) = rx.wisp_read_split(&self.tx).await?;
if frame.opcode == ws::OpCode::Close {
break Ok(());
}
if let Some(ref extra_frame) = optional_frame {
if frame.payload[0] != PacketType::Data(Payload::Bytes(BytesMut::new())).as_u8() {
let mut payload = BytesMut::from(frame.payload);
payload.extend_from_slice(&extra_frame.payload);
frame.payload = Payload::Bytes(payload);
}
}
if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{
@ -425,11 +450,16 @@ impl MuxInner {
match packet.packet_type {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream
.stream
.send_async(BytesMut::from(data).freeze())
.await;
if let Some(extra_frame) = optional_frame {
if data.is_empty() {
data = extra_frame.payload.into();
} else {
data.extend_from_slice(&extra_frame.payload);
}
}
let _ = stream.stream.send_async(data.freeze()).await;
}
}
Continue(inner_packet) => {

View file

@ -156,6 +156,14 @@ pub trait WebSocketRead {
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError>;
/// Read a split frame from the socket.
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
self.wisp_read_frame(tx).await.map(|x| (x, None))
}
}
/// Generic WebSocket write trait.
@ -225,6 +233,16 @@ where
if let Some(x) = self.0.take() {
return Ok(x);
}
return self.1.wisp_read_frame(tx).await;
self.1.wisp_read_frame(tx).await
}
async fn wisp_read_split(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<(Frame<'static>, Option<Frame<'static>>), WispError> {
if let Some(x) = self.0.take() {
return Ok((x, None));
}
self.1.wisp_read_split(tx).await
}
}