diff --git a/Cargo.lock b/Cargo.lock index 6326ae1..dc862b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1530,11 +1530,13 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" name = "wisp-mux" version = "0.1.0" dependencies = [ + "async_io_stream", "bytes", "dashmap", "fastwebsockets", "futures", "futures-util", + "pin-project-lite", "tokio", "ws_stream_wasm", ] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index ae279ae..9dc0a2d 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -4,11 +4,13 @@ version = "0.1.0" edition = "2021" [dependencies] +async_io_stream = "0.3.3" bytes = "1.5.0" dashmap = "5.5.3" fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true } futures = "0.3.30" futures-util = "0.3.30" +pin-project-lite = "0.2.13" tokio = { version = "1.35.1", optional = true } ws_stream_wasm = { version = "0.7.4", optional = true } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 8c8a76b..ff86585 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,11 +1,18 @@ +use async_io_stream::IoStream; use bytes::Bytes; use futures::{ channel::{mpsc, oneshot}, - StreamExt, + sink, stream, + task::{Context, Poll}, + AsyncRead, AsyncWrite, Sink, Stream, StreamExt, }; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, }; pub enum WsEvent { @@ -36,6 +43,19 @@ impl MuxStreamRead { } } } + + pub(crate) fn into_stream(self) -> Pin>> { + Box::pin(stream::unfold(self, |mut rx| async move { + let evt = rx.read().await?; + Some(( + match evt { + WsEvent::Send(bytes) => bytes, + WsEvent::Close(_) => return None, + }, + rx, + )) + })) + } } pub struct MuxStreamWrite @@ -80,6 +100,16 @@ impl MuxStreamWrite { self.is_closed.store(true, Ordering::Release); Ok(()) } + + pub(crate) fn into_sink<'a>(self) -> Pin + 'a>> + where + W: 'a, + { + Box::pin(sink::unfold(self, |mut tx, data| async move { + tx.write(data).await?; + Ok(tx) + })) + } } impl Drop for MuxStreamWrite { @@ -143,6 +173,16 @@ impl MuxStream { pub fn into_split(self) -> (MuxStreamRead, MuxStreamWrite) { (self.rx, self.tx) } + + pub fn into_io<'a>(self) -> MuxStreamIo<'a> + where + W: 'a, + { + MuxStreamIo { + rx: self.rx.into_stream(), + tx: self.tx.into_sink(), + } + } } pub struct MuxStreamCloser { @@ -166,3 +206,57 @@ impl MuxStreamCloser { Ok(()) } } + +pin_project! { + pub struct MuxStreamIo<'a> { + #[pin] + rx: Pin + 'a>>, + #[pin] + tx: Pin + 'a>>, + } +} + +impl<'a> MuxStreamIo<'a> { + pub fn into_asyncrw(self) -> impl AsyncRead + AsyncWrite + 'a { + IoStream::new(self.map(|x| Ok::, std::io::Error>(x.to_vec()))) + } +} + +impl Stream for MuxStreamIo<'_> { + type Item = Bytes; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().rx.poll_next(cx) + } +} + +impl Sink for MuxStreamIo<'_> { + type Error = crate::WispError; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_ready(cx) + } + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.project().tx.start_send(item) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx) + } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx) + } +} + +impl Sink> for MuxStreamIo<'_> { + type Error = std::io::Error; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_ready(cx).map_err(std::io::Error::other) + } + fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + self.project().tx.start_send(item.into()).map_err(std::io::Error::other) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_flush(cx).map_err(std::io::Error::other) + } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().tx.poll_close(cx).map_err(std::io::Error::other) + } +} diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index dc8bdcc..5b1243e 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -57,9 +57,7 @@ pub trait WebSocketWrite { ) -> impl std::future::Future>; } -pub struct LockedWebSocketWrite(Arc>) -where - S: WebSocketWrite; +pub struct LockedWebSocketWrite(Arc>); impl LockedWebSocketWrite { pub fn new(ws: S) -> Self {