send payload everywhere

This commit is contained in:
Toshit Chawda 2024-11-27 20:20:31 -08:00
parent f7be65ae74
commit 5b3fc56b38
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 29 additions and 14 deletions

View file

@ -50,7 +50,7 @@ pub fn iostream_from_asyncrw(asyncrw: ProviderAsyncRW, buffer_size: usize) -> Ep
pub fn iostream_from_stream(stream: ProviderUnencryptedStream) -> EpoxyIoStream {
let (rx, tx) = stream.into_split();
create_iostream(
Box::pin(rx.map_err(EpoxyError::Io)),
Box::pin(rx.map_ok(Bytes::from).map_err(EpoxyError::Io)),
Box::pin(tx.sink_map_err(EpoxyError::Io)),
)
}

View file

@ -9,7 +9,7 @@ use crate::{
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
Role, StreamType, WispError,
};
use bytes::{Bytes, BytesMut};
use bytes::BytesMut;
use event_listener::Event;
use flume as mpsc;
use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt};
@ -31,7 +31,7 @@ pub(crate) enum WsEvent<W: WebSocketWrite + 'static> {
}
struct MuxMapValue {
stream: mpsc::Sender<Bytes>,
stream: mpsc::Sender<Payload<'static>>,
stream_type: StreamType,
should_flow_control: bool,
@ -414,7 +414,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
data.extend_from_slice(&extra_frame.payload);
}
}
let _ = stream.stream.try_send(data.freeze());
let _ = stream.stream.try_send(Payload::Bytes(data));
if self.role == Role::Server && stream.should_flow_control {
stream.flow_control.store(
stream

View file

@ -7,7 +7,7 @@ use std::{
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use bytes::BytesMut;
use futures::{
ready, stream::IntoAsyncRead, task::noop_waker_ref, AsyncBufRead, AsyncRead, AsyncWrite, Sink,
Stream, TryStreamExt,
@ -47,7 +47,7 @@ impl MuxStreamIo {
}
impl Stream for MuxStreamIo {
type Item = Result<Bytes, std::io::Error>;
type Item = Result<Payload<'static>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx)
}
@ -73,7 +73,7 @@ pin_project! {
/// Read side of a multiplexor stream that implements futures `Stream`.
pub struct MuxStreamIoStream {
#[pin]
pub(crate) rx: Pin<Box<dyn Stream<Item = Result<Bytes, WispError>> + Send>>,
pub(crate) rx: Pin<Box<dyn Stream<Item = Result<Payload<'static>, WispError>> + Send>>,
pub(crate) is_closed: Arc<AtomicBool>,
pub(crate) close_reason: Arc<AtomicCloseReason>,
}
@ -96,7 +96,7 @@ impl MuxStreamIoStream {
}
impl Stream for MuxStreamIoStream {
type Item = Result<Bytes, std::io::Error>;
type Item = Result<Payload<'static>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project()
.rx

View file

@ -30,7 +30,7 @@ pub struct MuxStreamRead<W: WebSocketWrite + 'static> {
role: Role,
tx: LockedWebSocketWrite<W>,
rx: mpsc::Receiver<Bytes>,
rx: mpsc::Receiver<Payload<'static>>,
is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>,
@ -44,7 +44,7 @@ pub struct MuxStreamRead<W: WebSocketWrite + 'static> {
impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
/// Read an event from the stream.
pub async fn read(&self) -> Result<Option<Bytes>, WispError> {
pub async fn read(&self) -> Result<Option<Payload<'static>>, WispError> {
if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) {
return Ok(None);
}
@ -72,7 +72,7 @@ impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
pub(crate) fn into_inner_stream(
self,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, WispError>> + Send>> {
) -> Pin<Box<dyn Stream<Item = Result<Payload<'static>, WispError>> + Send>> {
Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await.transpose()?, rx))
}))
@ -271,7 +271,7 @@ impl<W: WebSocketWrite + 'static> MuxStream<W> {
stream_id: u32,
role: Role,
stream_type: StreamType,
rx: mpsc::Receiver<Bytes>,
rx: mpsc::Receiver<Payload<'static>>,
mux_tx: mpsc::Sender<WsEvent<W>>,
tx: LockedWebSocketWrite<W>,
is_closed: Arc<AtomicBool>,
@ -320,7 +320,7 @@ impl<W: WebSocketWrite + 'static> MuxStream<W> {
}
/// Read an event from the stream.
pub async fn read(&self) -> Result<Option<Bytes>, WispError> {
pub async fn read(&self) -> Result<Option<Payload<'static>>, WispError> {
self.rx.read().await
}

View file

@ -7,7 +7,7 @@
use std::{future::Future, ops::Deref, pin::Pin, sync::Arc};
use crate::WispError;
use bytes::{Buf, BytesMut};
use bytes::{Buf, Bytes, BytesMut};
use futures::{lock::Mutex, TryFutureExt};
/// Payload of the websocket frame.
@ -51,6 +51,15 @@ impl From<Payload<'_>> for BytesMut {
}
}
impl From<Payload<'static>> for Bytes {
fn from(value: Payload<'static>) -> Self {
match value {
Payload::Bytes(x) => x.freeze(),
Payload::Borrowed(x) => x.into(),
}
}
}
impl Deref for Payload<'_> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
@ -61,6 +70,12 @@ impl Deref for Payload<'_> {
}
}
impl AsRef<[u8]> for Payload<'_> {
fn as_ref(&self) -> &[u8] {
self
}
}
impl Clone for Payload<'_> {
fn clone(&self) -> Self {
match self {