wisp_mux documentation

This commit is contained in:
Toshit Chawda 2024-02-07 18:53:27 -08:00
parent f574163991
commit 747ec0eb12
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
8 changed files with 269 additions and 66 deletions

View file

@ -17,7 +17,7 @@ use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, WsEvent}; use wisp_mux::{ws, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, MuxEvent};
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -97,7 +97,7 @@ async fn accept_http(
.collect::<Vec<&str>>(), .collect::<Vec<&str>>(),
) )
}) && protocols.contains(&"wisp-v1") }) && protocols.contains(&"wisp-v1")
&& (uri == "" || uri == "/") && (uri.is_empty() || uri == "/")
{ {
tokio::spawn(async move { accept_ws(fut, addr.clone()).await }); tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
res.headers_mut().insert( res.headers_mut().insert(
@ -105,7 +105,7 @@ async fn accept_http(
HeaderValue::from_str("wisp-v1").unwrap(), HeaderValue::from_str("wisp-v1").unwrap(),
); );
} else { } else {
let uri = uri.strip_prefix("/").unwrap_or(uri).to_string(); let uri = uri.strip_prefix('/').unwrap_or(uri).to_string();
tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await }); tokio::spawn(async move { accept_wsproxy(fut, uri, addr.clone()).await });
} }
@ -154,10 +154,10 @@ async fn handle_mux(
event = stream.read() => { event = stream.read() => {
match event { match event {
Some(event) => match event { Some(event) => match event {
WsEvent::Send(data) => { MuxEvent::Send(data) => {
udp_socket.send(&data).await.map_err(|x| WispError::Other(Box::new(x)))?; udp_socket.send(&data).await.map_err(|x| WispError::Other(Box::new(x)))?;
} }
WsEvent::Close(_) => return Ok(false), MuxEvent::Close(_) => return Ok(false),
}, },
None => break, None => break,
} }

2
wisp/README.md Normal file
View file

@ -0,0 +1,2 @@
# wisp-mux
A library for easily creating [Wisp](https://github.com/MercuryWorkshop/wisp-protocol) servers and clients.

View file

@ -1,4 +1,9 @@
#![deny(missing_docs)]
#![feature(impl_trait_in_assoc_type)] #![feature(impl_trait_in_assoc_type)]
//! A library for easily creating [Wisp] clients and servers.
//!
//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol
#[cfg(feature = "fastwebsockets")] #[cfg(feature = "fastwebsockets")]
mod fastwebsockets; mod fastwebsockets;
mod packet; mod packet;
@ -24,29 +29,49 @@ use std::{
}, },
}; };
/// The role of the multiplexor.
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum Role { pub enum Role {
/// Client side, can create new channels to proxy.
Client, Client,
/// Server side, can listen for channels to proxy.
Server, Server,
} }
/// Errors the Wisp implementation can return.
#[derive(Debug)] #[derive(Debug)]
pub enum WispError { pub enum WispError {
/// The packet recieved did not have enough data.
PacketTooSmall, PacketTooSmall,
/// The packet recieved had an invalid type.
InvalidPacketType, InvalidPacketType,
/// The stream had an invalid type.
InvalidStreamType, InvalidStreamType,
/// The stream had an invalid ID.
InvalidStreamId, InvalidStreamId,
/// The URI recieved was invalid.
InvalidUri, InvalidUri,
/// The URI recieved had no host.
UriHasNoHost, UriHasNoHost,
/// The URI recieved had no port.
UriHasNoPort, UriHasNoPort,
/// The max stream count was reached.
MaxStreamCountReached, MaxStreamCountReached,
/// The stream had already been closed.
StreamAlreadyClosed, StreamAlreadyClosed,
/// The websocket frame recieved had an invalid type.
WsFrameInvalidType, WsFrameInvalidType,
/// The websocket frame recieved was not finished.
WsFrameNotFinished, WsFrameNotFinished,
/// Error specific to the websocket implementation.
WsImplError(Box<dyn std::error::Error + Sync + Send>), WsImplError(Box<dyn std::error::Error + Sync + Send>),
/// The websocket implementation socket closed.
WsImplSocketClosed, WsImplSocketClosed,
/// The websocket implementation did not support the action.
WsImplNotSupported, WsImplNotSupported,
/// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
/// Other error.
Other(Box<dyn std::error::Error + Sync + Send>), Other(Box<dyn std::error::Error + Sync + Send>),
} }
@ -87,17 +112,17 @@ where
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>, stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<MuxEvent>>>>,
close_tx: mpsc::UnboundedSender<MuxEvent>, close_tx: mpsc::UnboundedSender<WsEvent>,
} }
impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> { impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
pub async fn into_future<R>( pub async fn into_future<R>(
self, self,
rx: R, rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>, close_rx: mpsc::UnboundedReceiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream<W>)>,
buffer_size: u32 buffer_size: u32,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
@ -107,20 +132,20 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x x = self.server_msg_loop(rx, muxstream_sender, buffer_size).fuse() => x
}; };
self.stream_map.lock().await.iter().for_each(|x| { self.stream_map.lock().await.iter().for_each(|x| {
let _ = x.1.unbounded_send(WsEvent::Close(ClosePacket::new(0x01))); let _ = x.1.unbounded_send(MuxEvent::Close(ClosePacket::new(0x01)));
}); });
ret ret
} }
async fn server_close_loop( async fn server_close_loop(
&self, &self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>, mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<WsEvent>>>>, stream_map: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<MuxEvent>>>>,
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await { while let Some(msg) = close_rx.next().await {
match msg { match msg {
MuxEvent::Close(stream_id, reason, channel) => { WsEvent::Close(stream_id, reason, channel) => {
if stream_map.lock().await.remove(&stream_id).is_some() { if stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send( let _ = channel.send(
tx.write_frame(Packet::new_close(stream_id, reason).into()) tx.write_frame(Packet::new_close(stream_id, reason).into())
@ -154,6 +179,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
match packet.packet { match packet.packet {
Connect(inner_packet) => { Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded(); let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type;
self.stream_map.lock().await.insert(packet.stream_id, ch_tx); self.stream_map.lock().await.insert(packet.stream_id, ch_tx);
muxstream_sender muxstream_sender
.unbounded_send(( .unbounded_send((
@ -161,6 +187,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
MuxStream::new( MuxStream::new(
packet.stream_id, packet.stream_id,
Role::Server, Role::Server,
stream_type,
ch_rx, ch_rx,
self.tx.clone(), self.tx.clone(),
self.close_tx.clone(), self.close_tx.clone(),
@ -173,13 +200,13 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data)); let _ = stream.unbounded_send(MuxEvent::Send(data));
} }
} }
Continue(_) => unreachable!(), Continue(_) => unreachable!(),
Close(inner_packet) => { Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet)); let _ = stream.unbounded_send(MuxEvent::Close(inner_packet));
} }
self.stream_map.lock().await.remove(&packet.stream_id); self.stream_map.lock().await.remove(&packet.stream_id);
} }
@ -193,6 +220,25 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMuxInner<W> {
} }
} }
/// Server-side multiplexor.
///
/// # Example
/// ```
/// use wisp_mux::ServerMux;
///
/// let (mux, fut) = ServerMux::new(rx, tx, 128);
/// tokio::spawn(async move {
/// if let Err(e) = fut.await {
/// println!("error in multiplexor: {:?}", e);
/// }
/// });
/// while let Some((packet, stream)) = mux.server_new_stream().await {
/// tokio::spawn(async move {
/// let url = format!("{}:{}", packet.destination_hostname, packet.destination_port);
/// // do something with `url` and `packet.stream_type`
/// });
/// }
/// ```
pub struct ServerMux<W> pub struct ServerMux<W>
where where
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
@ -201,11 +247,16 @@ where
} }
impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> { impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
pub fn new<R>(read: R, write: W, buffer_size: u32) -> (Self, impl Future<Output = Result<(), WispError>>) /// Create a new server-side multiplexor.
pub fn new<R>(
read: R,
write: W,
buffer_size: u32,
) -> (Self, impl Future<Output = Result<(), WispError>>)
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
let (close_tx, close_rx) = mpsc::unbounded::<MuxEvent>(); let (close_tx, close_rx) = mpsc::unbounded::<WsEvent>();
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>(); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream<W>)>();
let write = ws::LockedWebSocketWrite::new(write); let write = ws::LockedWebSocketWrite::new(write);
let map = Arc::new(Mutex::new(HashMap::new())); let map = Arc::new(Mutex::new(HashMap::new()));
@ -220,25 +271,31 @@ impl<W: ws::WebSocketWrite + Send + 'static> ServerMux<W> {
) )
} }
/// Wait for a stream to be created.
pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> { pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream<W>)> {
self.muxstream_recv.next().await self.muxstream_recv.next().await
} }
} }
pub struct ClientMuxInner<W> struct ClientMuxMapValue {
stream: mpsc::UnboundedSender<MuxEvent>,
flow_control: Arc<AtomicU32>,
flow_control_event: Arc<Event>,
}
struct ClientMuxInner<W>
where where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: stream_map: Arc<Mutex<HashMap<u32, ClientMuxMapValue>>>,
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
} }
impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> { impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
pub async fn into_future<R>( pub(crate) async fn into_future<R>(
self, self,
rx: R, rx: R,
close_rx: mpsc::UnboundedReceiver<MuxEvent>, close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
@ -251,11 +308,11 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
async fn client_bg_loop( async fn client_bg_loop(
&self, &self,
mut close_rx: mpsc::UnboundedReceiver<MuxEvent>, mut close_rx: mpsc::UnboundedReceiver<WsEvent>,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
while let Some(msg) = close_rx.next().await { while let Some(msg) = close_rx.next().await {
match msg { match msg {
MuxEvent::Close(stream_id, reason, channel) => { WsEvent::Close(stream_id, reason, channel) => {
if self.stream_map.lock().await.remove(&stream_id).is_some() { if self.stream_map.lock().await.remove(&stream_id).is_some() {
let _ = channel.send( let _ = channel.send(
self.tx self.tx
@ -282,20 +339,20 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
Connect(_) => unreachable!(), Connect(_) => unreachable!(),
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.0.unbounded_send(WsEvent::Send(data)); let _ = stream.stream.unbounded_send(MuxEvent::Send(data));
} }
} }
Continue(inner_packet) => { Continue(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
stream stream
.1 .flow_control
.store(inner_packet.buffer_remaining, Ordering::Release); .store(inner_packet.buffer_remaining, Ordering::Release);
let _ = stream.2.notify(u32::MAX); let _ = stream.flow_control_event.notify(u32::MAX);
} }
} }
Close(inner_packet) => { Close(inner_packet) => {
if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) { if let Some(stream) = self.stream_map.lock().await.get(&packet.stream_id) {
let _ = stream.0.unbounded_send(WsEvent::Close(inner_packet)); let _ = stream.stream.unbounded_send(MuxEvent::Close(inner_packet));
} }
self.stream_map.lock().await.remove(&packet.stream_id); self.stream_map.lock().await.remove(&packet.stream_id);
} }
@ -306,19 +363,33 @@ impl<W: ws::WebSocketWrite + Send> ClientMuxInner<W> {
} }
} }
/// Client side multiplexor.
///
/// # Example
/// ```
/// use wisp_mux::{ClientMux, StreamType};
///
/// let (mux, fut) = ClientMux::new(rx, tx).await?;
/// tokio::spawn(async move {
/// if let Err(e) = fut.await {
/// println!("error in multiplexor: {:?}", e);
/// }
/// });
/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
/// ```
pub struct ClientMux<W> pub struct ClientMux<W>
where where
W: ws::WebSocketWrite, W: ws::WebSocketWrite,
{ {
tx: ws::LockedWebSocketWrite<W>, tx: ws::LockedWebSocketWrite<W>,
stream_map: stream_map: Arc<Mutex<HashMap<u32, ClientMuxMapValue>>>,
Arc<Mutex<HashMap<u32, (mpsc::UnboundedSender<WsEvent>, Arc<AtomicU32>, Arc<Event>)>>>,
next_free_stream_id: AtomicU32, next_free_stream_id: AtomicU32,
close_tx: mpsc::UnboundedSender<MuxEvent>, close_tx: mpsc::UnboundedSender<WsEvent>,
buf_size: u32, buf_size: u32,
} }
impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> { impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
/// Create a new client side multiplexor.
pub async fn new<R>( pub async fn new<R>(
mut read: R, mut read: R,
write: W, write: W,
@ -332,7 +403,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
return Err(WispError::InvalidStreamId); return Err(WispError::InvalidStreamId);
} }
if let PacketType::Continue(packet) = first_packet.packet { if let PacketType::Continue(packet) = first_packet.packet {
let (tx, rx) = mpsc::unbounded::<MuxEvent>(); let (tx, rx) = mpsc::unbounded::<WsEvent>();
let map = Arc::new(Mutex::new(HashMap::new())); let map = Arc::new(Mutex::new(HashMap::new()));
Ok(( Ok((
Self { Self {
@ -353,6 +424,7 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
} }
} }
/// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream( pub async fn client_new_stream(
&self, &self,
stream_type: StreamType, stream_type: StreamType,
@ -372,13 +444,18 @@ impl<W: ws::WebSocketWrite + Send + 'static> ClientMux<W> {
.ok_or(WispError::MaxStreamCountReached)?, .ok_or(WispError::MaxStreamCountReached)?,
Ordering::Release, Ordering::Release,
); );
self.stream_map self.stream_map.lock().await.insert(
.lock() stream_id,
.await ClientMuxMapValue {
.insert(stream_id, (ch_tx, flow_control.clone(), evt.clone())); stream: ch_tx,
flow_control: flow_control.clone(),
flow_control_event: evt.clone(),
},
);
Ok(MuxStream::new( Ok(MuxStream::new(
stream_id, stream_id,
Role::Client, Role::Client,
stream_type,
ch_rx, ch_rx,
self.tx.clone(), self.tx.clone(),
self.close_tx.clone(), self.close_tx.clone(),

View file

@ -2,9 +2,12 @@ use crate::ws;
use crate::WispError; use crate::WispError;
use bytes::{Buf, BufMut, Bytes}; use bytes::{Buf, BufMut, Bytes};
#[derive(Debug)] /// Wisp stream type.
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum StreamType { pub enum StreamType {
/// TCP Wisp stream.
Tcp = 0x01, Tcp = 0x01,
/// UDP Wisp stream.
Udp = 0x02, Udp = 0x02,
} }
@ -20,15 +23,26 @@ impl TryFrom<u8> for StreamType {
} }
} }
#[derive(Debug)] /// Packet used to create a new stream.
///
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
#[derive(Debug, Clone)]
pub struct ConnectPacket { pub struct ConnectPacket {
/// Whether the new stream should use a TCP or UDP socket.
pub stream_type: StreamType, pub stream_type: StreamType,
/// Destination TCP/UDP port for the new stream.
pub destination_port: u16, pub destination_port: u16,
/// Destination hostname, in a UTF-8 string.
pub destination_hostname: String, pub destination_hostname: String,
} }
impl ConnectPacket { impl ConnectPacket {
pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self { /// Create a new connect packet.
pub fn new(
stream_type: StreamType,
destination_port: u16,
destination_hostname: String,
) -> Self {
Self { Self {
stream_type, stream_type,
destination_port, destination_port,
@ -61,12 +75,17 @@ impl From<ConnectPacket> for Vec<u8> {
} }
} }
#[derive(Debug)] /// Packet used for Wisp TCP stream flow control.
///
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue).
#[derive(Debug, Copy, Clone)]
pub struct ContinuePacket { pub struct ContinuePacket {
/// Number of packets that the server can buffer for the current stream.
pub buffer_remaining: u32, pub buffer_remaining: u32,
} }
impl ContinuePacket { impl ContinuePacket {
/// Create a new continue packet.
pub fn new(buffer_remaining: u32) -> Self { pub fn new(buffer_remaining: u32) -> Self {
Self { buffer_remaining } Self { buffer_remaining }
} }
@ -92,12 +111,21 @@ impl From<ContinuePacket> for Vec<u8> {
} }
} }
#[derive(Debug)] /// Packet used to close a stream.
///
/// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close).
#[derive(Debug, Copy, Clone)]
pub struct ClosePacket { pub struct ClosePacket {
/// The close reason.
///
/// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons).
pub reason: u8, pub reason: u8,
} }
impl ClosePacket { impl ClosePacket {
/// Create a new close packet.
pub fn new(reason: u8) -> Self { pub fn new(reason: u8) -> Self {
Self { reason } Self { reason }
} }
@ -123,15 +151,21 @@ impl From<ClosePacket> for Vec<u8> {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
/// Type of packet recieved.
pub enum PacketType { pub enum PacketType {
/// Connect packet.
Connect(ConnectPacket), Connect(ConnectPacket),
/// Data packet.
Data(Bytes), Data(Bytes),
/// Continue packet.
Continue(ContinuePacket), Continue(ContinuePacket),
/// Close packet.
Close(ClosePacket), Close(ClosePacket),
} }
impl PacketType { impl PacketType {
/// Get the packet type used in the protocol.
pub fn as_u8(&self) -> u8 { pub fn as_u8(&self) -> u8 {
use PacketType::*; use PacketType::*;
match self { match self {
@ -155,17 +189,24 @@ impl From<PacketType> for Vec<u8> {
} }
} }
#[derive(Debug)] /// Wisp protocol packet.
#[derive(Debug, Clone)]
pub struct Packet { pub struct Packet {
/// Stream this packet is associated with.
pub stream_id: u32, pub stream_id: u32,
/// Packet recieved.
pub packet: PacketType, pub packet: PacketType,
} }
impl Packet { impl Packet {
/// Create a new packet.
///
/// The helper functions should be used for most use cases.
pub fn new(stream_id: u32, packet: PacketType) -> Self { pub fn new(stream_id: u32, packet: PacketType) -> Self {
Self { stream_id, packet } Self { stream_id, packet }
} }
/// Create a new connect packet.
pub fn new_connect( pub fn new_connect(
stream_id: u32, stream_id: u32,
stream_type: StreamType, stream_type: StreamType,
@ -182,6 +223,7 @@ impl Packet {
} }
} }
/// Create a new data packet.
pub fn new_data(stream_id: u32, data: Bytes) -> Self { pub fn new_data(stream_id: u32, data: Bytes) -> Self {
Self { Self {
stream_id, stream_id,
@ -189,6 +231,7 @@ impl Packet {
} }
} }
/// Create a new continue packet.
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
Self { Self {
stream_id, stream_id,
@ -196,6 +239,7 @@ impl Packet {
} }
} }
/// Create a new close packet.
pub fn new_close(stream_id: u32, reason: u8) -> Self { pub fn new_close(stream_id: u32, reason: u8) -> Self {
Self { Self {
stream_id, stream_id,

View file

@ -16,35 +16,43 @@ use std::{
}, },
}; };
pub enum WsEvent { /// Multiplexor event recieved from a Wisp stream.
pub enum MuxEvent {
/// The other side has sent data.
Send(Bytes), Send(Bytes),
/// The other side has closed.
Close(crate::ClosePacket), Close(crate::ClosePacket),
} }
pub enum MuxEvent { pub(crate) enum WsEvent {
Close(u32, u8, oneshot::Sender<Result<(), crate::WispError>>), Close(u32, u8, oneshot::Sender<Result<(), crate::WispError>>),
} }
/// Read side of a multiplexor stream.
pub struct MuxStreamRead<W> pub struct MuxStreamRead<W>
where where
W: crate::ws::WebSocketWrite, W: crate::ws::WebSocketWrite,
{ {
/// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
/// Type of the stream.
pub stream_type: crate::StreamType,
role: crate::Role, role: crate::Role,
tx: crate::ws::LockedWebSocketWrite<W>, tx: crate::ws::LockedWebSocketWrite<W>,
rx: mpsc::UnboundedReceiver<WsEvent>, rx: mpsc::UnboundedReceiver<MuxEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
} }
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> { impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
pub async fn read(&mut self) -> Option<WsEvent> { /// Read an event from the stream.
pub async fn read(&mut self) -> Option<MuxEvent> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return None; return None;
} }
match self.rx.next().await? { match self.rx.next().await? {
WsEvent::Send(bytes) => { MuxEvent::Send(bytes) => {
if self.role == crate::Role::Server { if self.role == crate::Role::Server && self.stream_type == crate::StreamType::Tcp {
let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst); let old_val = self.flow_control.fetch_add(1, Ordering::SeqCst);
self.tx self.tx
.write_frame( .write_frame(
@ -53,11 +61,11 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
.await .await
.ok()?; .ok()?;
} }
Some(WsEvent::Send(bytes)) Some(MuxEvent::Send(bytes))
} }
WsEvent::Close(packet) => { MuxEvent::Close(packet) => {
self.is_closed.store(true, Ordering::Release); self.is_closed.store(true, Ordering::Release);
Some(WsEvent::Close(packet)) Some(MuxEvent::Close(packet))
} }
} }
} }
@ -67,8 +75,8 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
let evt = rx.read().await?; let evt = rx.read().await?;
Some(( Some((
match evt { match evt {
WsEvent::Send(bytes) => bytes, MuxEvent::Send(bytes) => bytes,
WsEvent::Close(_) => return None, MuxEvent::Close(_) => return None,
}, },
rx, rx,
)) ))
@ -76,25 +84,28 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamRead<W> {
} }
} }
/// Write side of a multiplexor stream.
pub struct MuxStreamWrite<W> pub struct MuxStreamWrite<W>
where where
W: crate::ws::WebSocketWrite, W: crate::ws::WebSocketWrite,
{ {
/// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
role: crate::Role, role: crate::Role,
tx: crate::ws::LockedWebSocketWrite<W>, tx: crate::ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<MuxEvent>, close_channel: mpsc::UnboundedSender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
continue_recieved: Arc<Event>, continue_recieved: Arc<Event>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
} }
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> { impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
/// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> { pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed); return Err(crate::WispError::StreamAlreadyClosed);
} }
if self.role == crate::Role::Client && self.flow_control.load(Ordering::Acquire) <= 0 { if self.role == crate::Role::Client && self.flow_control.load(Ordering::Acquire) == 0 {
self.continue_recieved.listen().await; self.continue_recieved.listen().await;
} }
self.tx self.tx
@ -112,6 +123,17 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
Ok(()) Ok(())
} }
/// Get a handle to close the connection.
///
/// Useful to close the connection without having access to the stream.
///
/// # Example
/// ```
/// let handle = stream.get_close_handle();
/// if let Err(error) = handle_stream(stream) {
/// handle.close(0x01);
/// }
/// ```
pub fn get_close_handle(&self) -> MuxStreamCloser { pub fn get_close_handle(&self) -> MuxStreamCloser {
MuxStreamCloser { MuxStreamCloser {
stream_id: self.stream_id, stream_id: self.stream_id,
@ -120,13 +142,14 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStreamWrite<W> {
} }
} }
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed); return Err(crate::WispError::StreamAlreadyClosed);
} }
let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>();
self.close_channel self.close_channel
.unbounded_send(MuxEvent::Close(self.stream_id, reason, tx)) .unbounded_send(WsEvent::Close(self.stream_id, reason, tx))
.map_err(|x| crate::WispError::Other(Box::new(x)))?; .map_err(|x| crate::WispError::Other(Box::new(x)))?;
rx.await rx.await
.map_err(|x| crate::WispError::Other(Box::new(x)))??; .map_err(|x| crate::WispError::Other(Box::new(x)))??;
@ -148,26 +171,30 @@ impl<W: crate::ws::WebSocketWrite> Drop for MuxStreamWrite<W> {
let (tx, _) = oneshot::channel::<Result<(), crate::WispError>>(); let (tx, _) = oneshot::channel::<Result<(), crate::WispError>>();
let _ = self let _ = self
.close_channel .close_channel
.unbounded_send(MuxEvent::Close(self.stream_id, 0x01, tx)); .unbounded_send(WsEvent::Close(self.stream_id, 0x01, tx));
} }
} }
/// Multiplexor stream.
pub struct MuxStream<W> pub struct MuxStream<W>
where where
W: crate::ws::WebSocketWrite, W: crate::ws::WebSocketWrite,
{ {
/// ID of the stream.
pub stream_id: u32, pub stream_id: u32,
rx: MuxStreamRead<W>, rx: MuxStreamRead<W>,
tx: MuxStreamWrite<W>, tx: MuxStreamWrite<W>,
} }
impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> { impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
stream_id: u32, stream_id: u32,
role: crate::Role, role: crate::Role,
rx: mpsc::UnboundedReceiver<WsEvent>, stream_type: crate::StreamType,
rx: mpsc::UnboundedReceiver<MuxEvent>,
tx: crate::ws::LockedWebSocketWrite<W>, tx: crate::ws::LockedWebSocketWrite<W>,
close_channel: mpsc::UnboundedSender<MuxEvent>, close_channel: mpsc::UnboundedSender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
flow_control: Arc<AtomicU32>, flow_control: Arc<AtomicU32>,
continue_recieved: Arc<Event> continue_recieved: Arc<Event>
@ -176,6 +203,7 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
stream_id, stream_id,
rx: MuxStreamRead { rx: MuxStreamRead {
stream_id, stream_id,
stream_type,
role, role,
tx: tx.clone(), tx: tx.clone(),
rx, rx,
@ -194,26 +222,42 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
} }
} }
pub async fn read(&mut self) -> Option<WsEvent> { /// Read an event from the stream.
pub async fn read(&mut self) -> Option<MuxEvent> {
self.rx.read().await self.rx.read().await
} }
/// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> { pub async fn write(&self, data: Bytes) -> Result<(), crate::WispError> {
self.tx.write(data).await self.tx.write(data).await
} }
/// Get a handle to close the connection.
///
/// Useful to close the connection without having access to the stream.
///
/// # Example
/// ```
/// let handle = stream.get_close_handle();
/// if let Err(error) = handle_stream(stream) {
/// handle.close(0x01);
/// }
/// ```
pub fn get_close_handle(&self) -> MuxStreamCloser { pub fn get_close_handle(&self) -> MuxStreamCloser {
self.tx.get_close_handle() self.tx.get_close_handle()
} }
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
self.tx.close(reason).await self.tx.close(reason).await
} }
/// Split the stream into read and write parts, consuming it.
pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) { pub fn into_split(self) -> (MuxStreamRead<W>, MuxStreamWrite<W>) {
(self.rx, self.tx) (self.rx, self.tx)
} }
/// Turn the stream into one that implements futures `Stream + Sink`, consuming it.
pub fn into_io(self) -> MuxStreamIo { pub fn into_io(self) -> MuxStreamIo {
MuxStreamIo { MuxStreamIo {
rx: self.rx.into_stream(), rx: self.rx.into_stream(),
@ -222,20 +266,23 @@ impl<W: crate::ws::WebSocketWrite + Send + 'static> MuxStream<W> {
} }
} }
/// Close handle for a multiplexor stream.
pub struct MuxStreamCloser { pub struct MuxStreamCloser {
stream_id: u32, /// ID of the stream.
close_channel: mpsc::UnboundedSender<MuxEvent>, pub stream_id: u32,
close_channel: mpsc::UnboundedSender<WsEvent>,
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
} }
impl MuxStreamCloser { impl MuxStreamCloser {
/// Close the stream. You will no longer be able to write or read after this has been called.
pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> { pub async fn close(&self, reason: u8) -> Result<(), crate::WispError> {
if self.is_closed.load(Ordering::Acquire) { if self.is_closed.load(Ordering::Acquire) {
return Err(crate::WispError::StreamAlreadyClosed); return Err(crate::WispError::StreamAlreadyClosed);
} }
let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>(); let (tx, rx) = oneshot::channel::<Result<(), crate::WispError>>();
self.close_channel self.close_channel
.unbounded_send(MuxEvent::Close(self.stream_id, reason, tx)) .unbounded_send(WsEvent::Close(self.stream_id, reason, tx))
.map_err(|x| crate::WispError::Other(Box::new(x)))?; .map_err(|x| crate::WispError::Other(Box::new(x)))?;
rx.await rx.await
.map_err(|x| crate::WispError::Other(Box::new(x)))??; .map_err(|x| crate::WispError::Other(Box::new(x)))??;
@ -245,6 +292,7 @@ impl MuxStreamCloser {
} }
pin_project! { pin_project! {
/// Multiplexor stream that implements futures `Stream + Sink`.
pub struct MuxStreamIo { pub struct MuxStreamIo {
#[pin] #[pin]
rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>, rx: Pin<Box<dyn Stream<Item = Bytes> + Send>>,
@ -254,6 +302,10 @@ pin_project! {
} }
impl MuxStreamIo { impl MuxStreamIo {
/// Turn the stream into one that implements futures `AsyncRead + AsyncWrite`.
///
/// Enable the `tokio_io` feature to implement the tokio version of `AsyncRead` and
/// `AsyncWrite`.
pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Vec<u8>> { pub fn into_asyncrw(self) -> IoStream<MuxStreamIo, Vec<u8>> {
IoStream::new(self) IoStream::new(self)
} }

View file

@ -1,6 +1,5 @@
#![allow(dead_code)] #![allow(dead_code)]
// Taken from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio.rs //! hyper_util::rt::tokio::TokioIo
// hyper-util fails to compile on WASM as it has a dependency on socket2
use std::{ use std::{
pin::Pin, pin::Pin,

View file

@ -1,3 +1,4 @@
//! Helper that implements a Tower Service for a client multiplexor.
use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError}; use crate::{tokioio::TokioIo, ws::WebSocketWrite, ClientMux, MuxStreamIo, StreamType, WispError};
use async_io_stream::IoStream; use async_io_stream::IoStream;
use futures::{ use futures::{
@ -6,6 +7,7 @@ use futures::{
}; };
use std::sync::Arc; use std::sync::Arc;
/// Wrapper struct that implements a Tower Service sfor a client multiplexor.
pub struct ServiceWrapper<W: WebSocketWrite + Send + 'static>(pub Arc<ClientMux<W>>); pub struct ServiceWrapper<W: WebSocketWrite + Send + 'static>(pub Arc<ClientMux<W>>);
impl<W: WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri> for ServiceWrapper<W> { impl<W: WebSocketWrite + Send + 'static> tower_service::Service<hyper::Uri> for ServiceWrapper<W> {

View file

@ -1,23 +1,41 @@
//! Abstraction over WebSocket implementations.
//!
//! Use the [`fastwebsockets`] and [`ws_stream_wasm`] implementations of these traits as an example
//! for implementing them for other WebSocket implementations.
//!
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
//! [`ws_stream_wasm`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/ws_stream_wasm.rs
use bytes::Bytes; use bytes::Bytes;
use futures::lock::Mutex; use futures::lock::Mutex;
use std::sync::Arc; use std::sync::Arc;
/// Opcode of the WebSocket frame.
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
pub enum OpCode { pub enum OpCode {
/// Text frame.
Text, Text,
/// Binary frame.
Binary, Binary,
/// Close frame.
Close, Close,
/// Ping frame.
Ping, Ping,
/// Pong frame.
Pong, Pong,
} }
/// WebSocket frame.
pub struct Frame { pub struct Frame {
/// Whether the frame is finished or not.
pub finished: bool, pub finished: bool,
/// Opcode of the WebSocket frame.
pub opcode: OpCode, pub opcode: OpCode,
/// Payload of the WebSocket frame.
pub payload: Bytes, pub payload: Bytes,
} }
impl Frame { impl Frame {
/// Create a new text frame.
pub fn text(payload: Bytes) -> Self { pub fn text(payload: Bytes) -> Self {
Self { Self {
finished: true, finished: true,
@ -26,6 +44,7 @@ impl Frame {
} }
} }
/// Create a new binary frame.
pub fn binary(payload: Bytes) -> Self { pub fn binary(payload: Bytes) -> Self {
Self { Self {
finished: true, finished: true,
@ -34,6 +53,7 @@ impl Frame {
} }
} }
/// Create a new close frame.
pub fn close(payload: Bytes) -> Self { pub fn close(payload: Bytes) -> Self {
Self { Self {
finished: true, finished: true,
@ -43,27 +63,34 @@ impl Frame {
} }
} }
/// Generic WebSocket read trait.
pub trait WebSocketRead { pub trait WebSocketRead {
/// Read a frame from the socket.
fn wisp_read_frame( fn wisp_read_frame(
&mut self, &mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>, tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite + Send>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>> + Send; ) -> impl std::future::Future<Output = Result<Frame, crate::WispError>> + Send;
} }
/// Generic WebSocket write trait.
pub trait WebSocketWrite { pub trait WebSocketWrite {
/// Write a frame to the socket.
fn wisp_write_frame( fn wisp_write_frame(
&mut self, &mut self,
frame: Frame, frame: Frame,
) -> impl std::future::Future<Output = Result<(), crate::WispError>> + Send; ) -> impl std::future::Future<Output = Result<(), crate::WispError>> + Send;
} }
/// Locked WebSocket that can be shared between threads.
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>); pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>);
impl<S: WebSocketWrite + Send> LockedWebSocketWrite<S> { impl<S: WebSocketWrite + Send> LockedWebSocketWrite<S> {
/// Create a new locked websocket.
pub fn new(ws: S) -> Self { pub fn new(ws: S) -> Self {
Self(Arc::new(Mutex::new(ws))) Self(Arc::new(Mutex::new(ws)))
} }
/// Write a frame to the websocket.
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> { pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
self.0.lock().await.wisp_write_frame(frame).await self.0.lock().await.wisp_write_frame(frame).await
} }