finally implement AsyncRead/Write

This commit is contained in:
Toshit Chawda 2024-01-29 19:30:55 -08:00
parent 14ddecf3fd
commit be7d92b4c5
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 103 additions and 7 deletions

2
Cargo.lock generated
View file

@ -1530,11 +1530,13 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
name = "wisp-mux" name = "wisp-mux"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async_io_stream",
"bytes", "bytes",
"dashmap", "dashmap",
"fastwebsockets", "fastwebsockets",
"futures", "futures",
"futures-util", "futures-util",
"pin-project-lite",
"tokio", "tokio",
"ws_stream_wasm", "ws_stream_wasm",
] ]

View file

@ -4,11 +4,13 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
async_io_stream = "0.3.3"
bytes = "1.5.0" bytes = "1.5.0"
dashmap = "5.5.3" dashmap = "5.5.3"
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
futures = "0.3.30" futures = "0.3.30"
futures-util = "0.3.30" futures-util = "0.3.30"
pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", optional = true } tokio = { version = "1.35.1", optional = true }
ws_stream_wasm = { version = "0.7.4", optional = true } ws_stream_wasm = { version = "0.7.4", optional = true }

View file

@ -1,11 +1,18 @@
use async_io_stream::IoStream;
use bytes::Bytes; use bytes::Bytes;
use futures::{ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
StreamExt, sink, stream,
task::{Context, Poll},
AsyncRead, AsyncWrite, Sink, Stream, StreamExt,
}; };
use std::sync::{ use pin_project_lite::pin_project;
atomic::{AtomicBool, Ordering}, use std::{
Arc, pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
}; };
pub enum WsEvent { pub enum WsEvent {
@ -36,6 +43,19 @@ impl MuxStreamRead {
} }
} }
} }
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes>>> {
Box::pin(stream::unfold(self, |mut rx| async move {
let evt = rx.read().await?;
Some((
match evt {
WsEvent::Send(bytes) => bytes,
WsEvent::Close(_) => return None,
},
rx,
))
}))
}
} }
pub struct MuxStreamWrite<W> pub struct MuxStreamWrite<W>
@ -80,6 +100,16 @@ impl<W: crate::ws::WebSocketWrite> MuxStreamWrite<W> {
self.is_closed.store(true, Ordering::Release); self.is_closed.store(true, Ordering::Release);
Ok(()) Ok(())
} }
pub(crate) fn into_sink<'a>(self) -> Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + 'a>>
where
W: 'a,
{
Box::pin(sink::unfold(self, |mut tx, data| async move {
tx.write(data).await?;
Ok(tx)
}))
}
} }
impl<W: crate::ws::WebSocketWrite> Drop for MuxStreamWrite<W> { impl<W: crate::ws::WebSocketWrite> Drop for MuxStreamWrite<W> {
@ -143,6 +173,16 @@ impl<W: crate::ws::WebSocketWrite> MuxStream<W> {
pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite<W>) { pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite<W>) {
(self.rx, self.tx) (self.rx, self.tx)
} }
pub fn into_io<'a>(self) -> MuxStreamIo<'a>
where
W: 'a,
{
MuxStreamIo {
rx: self.rx.into_stream(),
tx: self.tx.into_sink(),
}
}
} }
pub struct MuxStreamCloser { pub struct MuxStreamCloser {
@ -166,3 +206,57 @@ impl MuxStreamCloser {
Ok(()) Ok(())
} }
} }
pin_project! {
pub struct MuxStreamIo<'a> {
#[pin]
rx: Pin<Box<dyn Stream<Item = Bytes> + 'a>>,
#[pin]
tx: Pin<Box<dyn Sink<Bytes, Error = crate::WispError> + 'a>>,
}
}
impl<'a> MuxStreamIo<'a> {
pub fn into_asyncrw(self) -> impl AsyncRead + AsyncWrite + 'a {
IoStream::new(self.map(|x| Ok::<Vec<u8>, std::io::Error>(x.to_vec())))
}
}
impl Stream for MuxStreamIo<'_> {
type Item = Bytes;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx)
}
}
impl Sink<Bytes> for MuxStreamIo<'_> {
type Error = crate::WispError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
self.project().tx.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_close(cx)
}
}
impl Sink<Vec<u8>> for MuxStreamIo<'_> {
type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_ready(cx).map_err(std::io::Error::other)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.project().tx.start_send(item.into()).map_err(std::io::Error::other)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_flush(cx).map_err(std::io::Error::other)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_close(cx).map_err(std::io::Error::other)
}
}

View file

@ -57,9 +57,7 @@ pub trait WebSocketWrite {
) -> impl std::future::Future<Output = Result<(), crate::WispError>>; ) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
} }
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>) pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
where
S: WebSocketWrite;
impl<S: WebSocketWrite> LockedWebSocketWrite<S> { impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
pub fn new(ws: S) -> Self { pub fn new(ws: S) -> Self {