send payload everywhere

This commit is contained in:
Toshit Chawda 2024-11-27 20:20:31 -08:00
parent f7be65ae74
commit 5b3fc56b38
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 29 additions and 14 deletions

View file

@ -50,7 +50,7 @@ pub fn iostream_from_asyncrw(asyncrw: ProviderAsyncRW, buffer_size: usize) -> Ep
pub fn iostream_from_stream(stream: ProviderUnencryptedStream) -> EpoxyIoStream { pub fn iostream_from_stream(stream: ProviderUnencryptedStream) -> EpoxyIoStream {
let (rx, tx) = stream.into_split(); let (rx, tx) = stream.into_split();
create_iostream( create_iostream(
Box::pin(rx.map_err(EpoxyError::Io)), Box::pin(rx.map_ok(Bytes::from).map_err(EpoxyError::Io)),
Box::pin(tx.sink_map_err(EpoxyError::Io)), Box::pin(tx.sink_map_err(EpoxyError::Io)),
) )
} }

View file

@ -9,7 +9,7 @@ use crate::{
AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType, AtomicCloseReason, ClosePacket, CloseReason, ConnectPacket, MuxStream, Packet, PacketType,
Role, StreamType, WispError, Role, StreamType, WispError,
}; };
use bytes::{Bytes, BytesMut}; use bytes::BytesMut;
use event_listener::Event; use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt}; use futures::{channel::oneshot, select, stream::unfold, FutureExt, StreamExt};
@ -31,7 +31,7 @@ pub(crate) enum WsEvent<W: WebSocketWrite + 'static> {
} }
struct MuxMapValue { struct MuxMapValue {
stream: mpsc::Sender<Bytes>, stream: mpsc::Sender<Payload<'static>>,
stream_type: StreamType, stream_type: StreamType,
should_flow_control: bool, should_flow_control: bool,
@ -414,7 +414,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
data.extend_from_slice(&extra_frame.payload); data.extend_from_slice(&extra_frame.payload);
} }
} }
let _ = stream.stream.try_send(data.freeze()); let _ = stream.stream.try_send(Payload::Bytes(data));
if self.role == Role::Server && stream.should_flow_control { if self.role == Role::Server && stream.should_flow_control {
stream.flow_control.store( stream.flow_control.store(
stream stream

View file

@ -7,7 +7,7 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use bytes::{Bytes, BytesMut}; use bytes::BytesMut;
use futures::{ use futures::{
ready, stream::IntoAsyncRead, task::noop_waker_ref, AsyncBufRead, AsyncRead, AsyncWrite, Sink, ready, stream::IntoAsyncRead, task::noop_waker_ref, AsyncBufRead, AsyncRead, AsyncWrite, Sink,
Stream, TryStreamExt, Stream, TryStreamExt,
@ -47,7 +47,7 @@ impl MuxStreamIo {
} }
impl Stream for MuxStreamIo { impl Stream for MuxStreamIo {
type Item = Result<Bytes, std::io::Error>; type Item = Result<Payload<'static>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx) self.project().rx.poll_next(cx)
} }
@ -73,7 +73,7 @@ pin_project! {
/// Read side of a multiplexor stream that implements futures `Stream`. /// Read side of a multiplexor stream that implements futures `Stream`.
pub struct MuxStreamIoStream { pub struct MuxStreamIoStream {
#[pin] #[pin]
pub(crate) rx: Pin<Box<dyn Stream<Item = Result<Bytes, WispError>> + Send>>, pub(crate) rx: Pin<Box<dyn Stream<Item = Result<Payload<'static>, WispError>> + Send>>,
pub(crate) is_closed: Arc<AtomicBool>, pub(crate) is_closed: Arc<AtomicBool>,
pub(crate) close_reason: Arc<AtomicCloseReason>, pub(crate) close_reason: Arc<AtomicCloseReason>,
} }
@ -96,7 +96,7 @@ impl MuxStreamIoStream {
} }
impl Stream for MuxStreamIoStream { impl Stream for MuxStreamIoStream {
type Item = Result<Bytes, std::io::Error>; type Item = Result<Payload<'static>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project() self.project()
.rx .rx

View file

@ -30,7 +30,7 @@ pub struct MuxStreamRead<W: WebSocketWrite + 'static> {
role: Role, role: Role,
tx: LockedWebSocketWrite<W>, tx: LockedWebSocketWrite<W>,
rx: mpsc::Receiver<Bytes>, rx: mpsc::Receiver<Payload<'static>>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
@ -44,7 +44,7 @@ pub struct MuxStreamRead<W: WebSocketWrite + 'static> {
impl<W: WebSocketWrite + 'static> MuxStreamRead<W> { impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
/// Read an event from the stream. /// Read an event from the stream.
pub async fn read(&self) -> Result<Option<Bytes>, WispError> { pub async fn read(&self) -> Result<Option<Payload<'static>>, WispError> {
if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) { if self.rx.is_empty() && self.is_closed.load(Ordering::Acquire) {
return Ok(None); return Ok(None);
} }
@ -72,7 +72,7 @@ impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
pub(crate) fn into_inner_stream( pub(crate) fn into_inner_stream(
self, self,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, WispError>> + Send>> { ) -> Pin<Box<dyn Stream<Item = Result<Payload<'static>, WispError>> + Send>> {
Box::pin(stream::unfold(self, |rx| async move { Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await.transpose()?, rx)) Some((rx.read().await.transpose()?, rx))
})) }))
@ -271,7 +271,7 @@ impl<W: WebSocketWrite + 'static> MuxStream<W> {
stream_id: u32, stream_id: u32,
role: Role, role: Role,
stream_type: StreamType, stream_type: StreamType,
rx: mpsc::Receiver<Bytes>, rx: mpsc::Receiver<Payload<'static>>,
mux_tx: mpsc::Sender<WsEvent<W>>, mux_tx: mpsc::Sender<WsEvent<W>>,
tx: LockedWebSocketWrite<W>, tx: LockedWebSocketWrite<W>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
@ -320,7 +320,7 @@ impl<W: WebSocketWrite + 'static> MuxStream<W> {
} }
/// Read an event from the stream. /// Read an event from the stream.
pub async fn read(&self) -> Result<Option<Bytes>, WispError> { pub async fn read(&self) -> Result<Option<Payload<'static>>, WispError> {
self.rx.read().await self.rx.read().await
} }

View file

@ -7,7 +7,7 @@
use std::{future::Future, ops::Deref, pin::Pin, sync::Arc}; use std::{future::Future, ops::Deref, pin::Pin, sync::Arc};
use crate::WispError; use crate::WispError;
use bytes::{Buf, BytesMut}; use bytes::{Buf, Bytes, BytesMut};
use futures::{lock::Mutex, TryFutureExt}; use futures::{lock::Mutex, TryFutureExt};
/// Payload of the websocket frame. /// Payload of the websocket frame.
@ -51,6 +51,15 @@ impl From<Payload<'_>> for BytesMut {
} }
} }
impl From<Payload<'static>> for Bytes {
fn from(value: Payload<'static>) -> Self {
match value {
Payload::Bytes(x) => x.freeze(),
Payload::Borrowed(x) => x.into(),
}
}
}
impl Deref for Payload<'_> { impl Deref for Payload<'_> {
type Target = [u8]; type Target = [u8];
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -61,6 +70,12 @@ impl Deref for Payload<'_> {
} }
} }
impl AsRef<[u8]> for Payload<'_> {
fn as_ref(&self) -> &[u8] {
self
}
}
impl Clone for Payload<'_> { impl Clone for Payload<'_> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
match self { match self {