add wasm ws support

This commit is contained in:
Toshit Chawda 2024-01-28 11:20:41 -08:00
parent 8f85828e73
commit e95d148488
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
7 changed files with 277 additions and 25 deletions

2
Cargo.lock generated
View file

@ -271,6 +271,7 @@ dependencies = [
"wasm-streams", "wasm-streams",
"web-sys", "web-sys",
"webpki-roots", "webpki-roots",
"wisp-mux",
"ws_stream_wasm", "ws_stream_wasm",
] ]
@ -1535,6 +1536,7 @@ dependencies = [
"futures", "futures",
"futures-util", "futures-util",
"tokio", "tokio",
"ws_stream_wasm",
] ]
[[package]] [[package]]

View file

@ -33,6 +33,7 @@ async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"]
fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] } fastwebsockets = { version = "0.6.0", features = ["simdutf8", "unstable-split"] }
rand = "0.8.5" rand = "0.8.5"
base64 = "0.21.7" base64 = "0.21.7"
wisp-mux = { path = "../wisp", features = ["ws_stream_wasm"] }
[dependencies.getrandom] [dependencies.getrandom]
features = ["js"] features = ["js"]

View file

@ -16,7 +16,7 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent}; use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent};
type HttpBody = http_body_util::Empty<hyper::body::Bytes>; type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
@ -162,7 +162,7 @@ async fn handle_mux(
} }
} }
} }
Ok(false) Ok(true)
} }
async fn accept_ws( async fn accept_ws(
@ -177,22 +177,17 @@ async fn accept_ws(
let mut mux = ServerMux::new(rx, tx); let mut mux = ServerMux::new(rx, tx);
mux.server_loop(&mut |packet, stream| async move { mux.server_loop(&mut |packet, stream| async move {
let tx_cloned_err = stream.get_write_half(); let mut close_err = stream.get_close_handle();
let tx_cloned_ok = stream.get_write_half(); let mut close_ok = stream.get_close_handle();
let stream_id = stream.stream_id;
tokio::spawn(async move { tokio::spawn(async move {
let _ = handle_mux(packet, stream) let _ = handle_mux(packet, stream)
.or_else(|err| async move { .or_else(|err| async move {
let _ = tx_cloned_err let _ = close_err.close(0x03).await;
.write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03)))
.await;
Err(err) Err(err)
}) })
.and_then(|should_send| async move { .and_then(|should_send| async move {
if should_send { if should_send {
tx_cloned_ok close_ok.close(0x02).await
.write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02)))
.await
} else { } else {
Ok(()) Ok(())
} }

View file

@ -10,6 +10,8 @@ fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional =
futures = "0.3.30" futures = "0.3.30"
futures-util = "0.3.30" futures-util = "0.3.30"
tokio = { version = "1.35.1", optional = true } tokio = { version = "1.35.1", optional = true }
ws_stream_wasm = { version = "0.7.4", optional = true }
[features] [features]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
ws_stream_wasm = ["dep:ws_stream_wasm"]

View file

@ -28,10 +28,11 @@ impl From<Frame<'_>> for crate::ws::Frame {
} }
} }
impl From<crate::ws::Frame> for Frame<'_> { impl TryFrom<crate::ws::Frame> for Frame<'_> {
fn from(frame: crate::ws::Frame) -> Self { type Error = crate::WispError;
fn try_from(frame: crate::ws::Frame) -> Result<Self, Self::Error> {
use crate::ws::OpCode::*; use crate::ws::OpCode::*;
match frame.opcode { Ok(match frame.opcode {
Text => Self::text(Payload::Owned(frame.payload.to_vec())), Text => Self::text(Payload::Owned(frame.payload.to_vec())),
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())), Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
@ -42,7 +43,7 @@ impl From<crate::ws::Frame> for Frame<'_> {
Payload::Owned(frame.payload.to_vec()), Payload::Owned(frame.payload.to_vec()),
), ),
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())), Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
} })
} }
} }
@ -66,6 +67,6 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> { impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into()) self.write_frame(frame.try_into()?).await.map_err(|e| e.into())
} }
} }

View file

@ -2,13 +2,18 @@
mod fastwebsockets; mod fastwebsockets;
mod packet; mod packet;
pub mod ws; pub mod ws;
#[cfg(feature = "ws_stream_wasm")]
mod ws_stream_wasm;
pub use crate::packet::*; pub use crate::packet::*;
use bytes::Bytes; use bytes::Bytes;
use dashmap::DashMap; use dashmap::DashMap;
use futures::{channel::mpsc, StreamExt}; use futures::{channel::mpsc, channel::oneshot, SinkExt, StreamExt};
use std::sync::Arc; use std::sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
};
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Role { pub enum Role {
@ -21,9 +26,13 @@ pub enum WispError {
PacketTooSmall, PacketTooSmall,
InvalidPacketType, InvalidPacketType,
InvalidStreamType, InvalidStreamType,
InvalidStreamId,
MaxStreamCountReached,
StreamAlreadyClosed,
WsFrameInvalidType, WsFrameInvalidType,
WsFrameNotFinished, WsFrameNotFinished,
WsImplError(Box<dyn std::error::Error + Sync + Send>), WsImplError(Box<dyn std::error::Error + Sync + Send>),
WsImplSocketClosed,
WsImplNotSupported, WsImplNotSupported,
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
Other(Box<dyn std::error::Error + Sync + Send>), Other(Box<dyn std::error::Error + Sync + Send>),
@ -42,9 +51,13 @@ impl std::fmt::Display for WispError {
PacketTooSmall => write!(f, "Packet too small"), PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"), InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"), InvalidStreamType => write!(f, "Invalid stream type"),
InvalidStreamId => write!(f, "Invalid stream id"),
MaxStreamCountReached => write!(f, "Maximum stream count reached"),
StreamAlreadyClosed => write!(f, "Stream already closed"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"), WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
WsImplSocketClosed => write!(f, "Websocket implementation error: websocket closed"),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
Other(err) => write!(f, "Other error: {:?}", err), Other(err) => write!(f, "Other error: {:?}", err),
@ -59,6 +72,10 @@ pub enum WsEvent {
Close(ClosePacket), Close(ClosePacket),
} }
pub enum MuxEvent {
Close(u32, u8, oneshot::Sender<Result<(), WispError>>),
}
pub struct MuxStream<W> pub struct MuxStream<W>
where where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
@ -66,21 +83,75 @@ where
pub stream_id: u32, pub stream_id: u32,
rx: mpsc::UnboundedReceiver<WsEvent>, rx: mpsc::UnboundedReceiver<WsEvent>,
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<MuxEvent>,
is_closed: Arc<AtomicBool>,
} }
impl<W: ws::WebSocketWrite> MuxStream<W> { impl<W: ws::WebSocketWrite> MuxStream<W> {
pub async fn read(&mut self) -> Option<WsEvent> { pub async fn read(&mut self) -> Option<WsEvent> {
self.rx.next().await if self.is_closed.load(Ordering::Acquire) {
return None;
}
match self.rx.next().await? {
WsEvent::Send(bytes) => Some(WsEvent::Send(bytes)),
WsEvent::Close(packet) => {
self.is_closed.store(true, Ordering::Release);
Some(WsEvent::Close(packet))
}
}
} }
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> { pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
self.tx self.tx
.write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data))) .write_frame(Packet::new_data(self.stream_id, data).into())
.await .await
} }
pub fn get_write_half(&self) -> ws::LockedWebSocketWrite<W> { pub fn get_close_handle(&self) -> MuxStreamCloser {
self.tx.clone() MuxStreamCloser {
stream_id: self.stream_id,
close_channel: self.close_channel.clone(),
is_closed: self.is_closed.clone(),
}
}
pub async fn close(&mut self, reason: u8) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.send(MuxEvent::Close(self.stream_id, reason, tx))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
self.is_closed.store(true, Ordering::Release);
Ok(())
}
}
pub struct MuxStreamCloser {
stream_id: u32,
close_channel: mpsc::UnboundedSender<MuxEvent>,
is_closed: Arc<AtomicBool>,
}
impl MuxStreamCloser {
pub async fn close(&mut self, reason: u8) -> Result<(), WispError> {
if self.is_closed.load(Ordering::Acquire) {
return Err(WispError::StreamAlreadyClosed);
}
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
self.close_channel
.send(MuxEvent::Close(self.stream_id, reason, tx))
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
rx.await.map_err(|x| WispError::Other(Box::new(x)))??;
self.is_closed.store(true, Ordering::Release);
Ok(())
} }
} }
@ -92,14 +163,37 @@ where
rx: R, rx: R,
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>, stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
close_tx: mpsc::UnboundedSender<MuxEvent>,
} }
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> { impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
pub fn new(read: R, write: W) -> Self { pub fn new(read: R, write: W) -> Self {
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
Self { Self {
rx: read, rx: read,
tx: ws::LockedWebSocketWrite::new(write), tx: ws::LockedWebSocketWrite::new(write),
stream_map: Arc::new(DashMap::new()), stream_map: Arc::new(DashMap::new()),
close_rx: rx,
close_tx: tx,
}
}
pub async fn server_bg_loop(&mut self) {
while let Some(msg) = self.close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if self.stream_map.clone().remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
}
} }
} }
@ -111,7 +205,7 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
FR: std::future::Future<Output = Result<(), crate::WispError>>, FR: std::future::Future<Output = Result<(), crate::WispError>>,
{ {
self.tx self.tx
.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX))) .write_frame(Packet::new_continue(0, u32::MAX).into())
.await?; .await?;
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await { while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
@ -127,14 +221,19 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
stream_id: packet.stream_id, stream_id: packet.stream_id,
rx: ch_rx, rx: ch_rx,
tx: self.tx.clone(), tx: self.tx.clone(),
close_channel: self.close_tx.clone(),
is_closed: AtomicBool::new(false).into(),
}, },
).await; )
.await;
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data)); let _ = stream.unbounded_send(WsEvent::Send(data));
self.tx self.tx
.write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX))) .write_frame(
Packet::new_continue(packet.stream_id, u32::MAX).into(),
)
.await?; .await?;
} }
} }
@ -142,6 +241,7 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
Close(inner_packet) => { Close(inner_packet) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) { if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
self.stream_map.clone().remove(&packet.stream_id);
} }
} }
} }
@ -150,3 +250,97 @@ impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
Ok(()) Ok(())
} }
} }
pub struct ClientMux<R, W>
where
R: ws::WebSocketRead,
W: ws::WebSocketWrite,
{
rx: R,
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
next_free_stream_id: AtomicU32,
close_rx: mpsc::UnboundedReceiver<MuxEvent>,
close_tx: mpsc::UnboundedSender<MuxEvent>,
}
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ClientMux<R, W> {
pub fn new(read: R, write: W) -> Self {
let (tx, rx) = mpsc::unbounded::<MuxEvent>();
Self {
rx: read,
tx: ws::LockedWebSocketWrite::new(write),
stream_map: Arc::new(DashMap::new()),
next_free_stream_id: AtomicU32::new(1),
close_rx: rx,
close_tx: tx,
}
}
pub async fn client_bg_loop(&mut self) {
while let Some(msg) = self.close_rx.next().await {
match msg {
MuxEvent::Close(stream_id, reason, channel) => {
if self.stream_map.clone().remove(&stream_id).is_some() {
let _ = channel.send(
self.tx
.write_frame(Packet::new_close(stream_id, reason).into())
.await,
);
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
}
}
}
}
pub async fn client_loop(&mut self) -> Result<(), WispError> {
self.tx
.write_frame(Packet::new_continue(0, u32::MAX).into())
.await?;
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(_) => unreachable!(),
Data(data) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
}
}
Continue(_) => {}
Close(inner_packet) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
self.stream_map.clone().remove(&packet.stream_id);
}
}
}
}
}
Ok(())
}
pub async fn client_new_stream(
&mut self,
) -> Result<MuxStream<impl ws::WebSocketWrite>, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
self.next_free_stream_id.store(
stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?,
Ordering::Release,
);
self.stream_map.clone().insert(stream_id, ch_tx);
Ok(MuxStream {
stream_id,
rx: ch_rx,
tx: self.tx.clone(),
close_channel: self.close_tx.clone(),
is_closed: AtomicBool::new(false).into(),
})
}
}

View file

@ -0,0 +1,57 @@
use futures::{SinkExt, StreamExt};
use ws_stream_wasm::{WsErr, WsMessage, WsStream};
impl From<WsMessage> for crate::ws::Frame {
fn from(msg: WsMessage) -> Self {
use crate::ws::OpCode;
match msg {
WsMessage::Text(str) => Self {
finished: true,
opcode: OpCode::Text,
payload: str.into(),
},
WsMessage::Binary(bin) => Self {
finished: true,
opcode: OpCode::Binary,
payload: bin.into(),
},
}
}
}
impl TryFrom<crate::ws::Frame> for WsMessage {
type Error = crate::WispError;
fn try_from(msg: crate::ws::Frame) -> Result<Self, Self::Error> {
use crate::ws::OpCode;
match msg.opcode {
OpCode::Text => Ok(Self::Text(std::str::from_utf8(&msg.payload)?.to_string())),
OpCode::Binary => Ok(Self::Binary(msg.payload.to_vec())),
_ => Err(Self::Error::WsImplNotSupported),
}
}
}
impl From<WsErr> for crate::WispError {
fn from(err: WsErr) -> Self {
Self::WsImplError(Box::new(err))
}
}
impl crate::ws::WebSocketRead for WsStream {
async fn wisp_read_frame(
&mut self,
_: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.next()
.await
.ok_or(crate::WispError::WsImplSocketClosed)?
.into())
}
}
impl crate::ws::WebSocketWrite for WsStream {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.send(frame.try_into()?).await.map_err(|e| e.into())
}
}