preliminary support for wisp v2

This commit is contained in:
Toshit Chawda 2024-04-11 19:05:14 -07:00
parent 98072be3d4
commit ef5ed52e71
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
18 changed files with 772 additions and 206 deletions

39
Cargo.lock generated
View file

@ -133,9 +133,9 @@ dependencies = [
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.79" version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [ dependencies = [
"proc-macro2 1.0.79", "proc-macro2 1.0.79",
"quote 1.0.36", "quote 1.0.36",
@ -525,6 +525,7 @@ name = "epoxy-client"
version = "1.5.1" version = "1.5.1"
dependencies = [ dependencies = [
"async-compression", "async-compression",
"async-trait",
"async_io_stream", "async_io_stream",
"base64", "base64",
"bytes", "bytes",
@ -542,7 +543,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"ring", "ring",
"rustls-pki-types", "rustls-pki-types",
"send_wrapper", "send_wrapper 0.6.0",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-util", "tokio-util",
@ -744,6 +745,16 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
dependencies = [
"gloo-timers",
"send_wrapper 0.4.0",
]
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.30" version = "0.3.30"
@ -791,6 +802,18 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "gloo-timers"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c"
dependencies = [
"futures-channel",
"futures-core",
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.26" version = "0.3.26"
@ -1659,6 +1682,12 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
[[package]]
name = "send_wrapper"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0"
[[package]] [[package]]
name = "send_wrapper" name = "send_wrapper"
version = "0.6.0" version = "0.6.0"
@ -2531,14 +2560,16 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
[[package]] [[package]]
name = "wisp-mux" name = "wisp-mux"
version = "3.0.0" version = "4.0.0"
dependencies = [ dependencies = [
"async-trait",
"async_io_stream", "async_io_stream",
"bytes", "bytes",
"dashmap", "dashmap",
"event-listener", "event-listener",
"fastwebsockets 0.7.1", "fastwebsockets 0.7.1",
"futures", "futures",
"futures-timer",
"futures-util", "futures-util",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",

View file

@ -25,7 +25,7 @@ tokio-util = { version = "0.7.10", features = ["io"] }
async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] }
fastwebsockets = { version = "0.6.0", features = ["unstable-split"] } fastwebsockets = { version = "0.6.0", features = ["unstable-split"] }
base64 = "0.21.7" base64 = "0.21.7"
wisp-mux = { path = "../wisp", features = ["tokio_io"] } wisp-mux = { path = "../wisp", features = ["tokio_io", "wasm"] }
async_io_stream = { version = "0.3.3", features = ["tokio_io"] } async_io_stream = { version = "0.3.3", features = ["tokio_io"] }
getrandom = { version = "0.2.12", features = ["js"] } getrandom = { version = "0.2.12", features = ["js"] }
hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] } hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] }
@ -35,6 +35,7 @@ console_error_panic_hook = "0.1.7"
send_wrapper = "0.6.0" send_wrapper = "0.6.0"
event-listener = "5.2.0" event-listener = "5.2.0"
wasmtimer = "0.2.0" wasmtimer = "0.2.0"
async-trait = "0.1.80"
[dependencies.ring] [dependencies.ring]
features = ["wasm32_unknown_unknown_js"] features = ["wasm32_unknown_unknown_js"]

View file

@ -105,7 +105,7 @@ pub fn certs() -> Result<JsValue, JsValue> {
#[wasm_bindgen(inspectable)] #[wasm_bindgen(inspectable)]
pub struct EpoxyClient { pub struct EpoxyClient {
rustls_config: Arc<rustls::ClientConfig>, rustls_config: Arc<rustls::ClientConfig>,
mux: Arc<RwLock<ClientMux<WebSocketWrapper>>>, mux: Arc<RwLock<ClientMux>>,
hyper_client: Client<TlsWispService, HttpBody>, hyper_client: Client<TlsWispService, HttpBody>,
#[wasm_bindgen(getter_with_clone)] #[wasm_bindgen(getter_with_clone)]
pub useragent: String, pub useragent: String,
@ -164,7 +164,7 @@ impl EpoxyClient {
async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> { async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result<EpxIoTlsStream, JsError> {
let channel = self let channel = self
.mux .mux
.read() .write()
.await .await
.client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port)
.await .await

View file

@ -33,7 +33,7 @@ impl EpxUdpStream {
let io = tcp let io = tcp
.mux .mux
.read() .write()
.await .await
.client_new_stream(StreamType::Udp, url_host.to_string(), url_port) .client_new_stream(StreamType::Udp, url_host.to_string(), url_port)
.await .await

View file

@ -6,7 +6,10 @@ use wasm_bindgen_futures::JsFuture;
use hyper::rt::Executor; use hyper::rt::Executor;
use js_sys::ArrayBuffer; use js_sys::ArrayBuffer;
use std::future::Future; use std::future::Future;
use wisp_mux::WispError; use wisp_mux::{
extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
WispError,
};
#[wasm_bindgen] #[wasm_bindgen]
extern "C" { extern "C" {
@ -192,25 +195,25 @@ pub fn get_url_port(url: &Uri) -> Result<u16, JsError> {
pub async fn make_mux( pub async fn make_mux(
url: &str, url: &str,
) -> Result< ) -> Result<(ClientMux, impl Future<Output = Result<(), WispError>> + Send), WispError> {
(
ClientMux<WebSocketWrapper>,
impl Future<Output = Result<(), WispError>>,
),
WispError,
> {
let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) let (wtx, wrx) = WebSocketWrapper::connect(url, vec![])
.await .await
.map_err(|_| WispError::WsImplSocketClosed)?; .map_err(|_| WispError::WsImplSocketClosed)?;
wtx.wait_for_open().await; wtx.wait_for_open().await;
let mux = ClientMux::new(wrx, wtx).await?; let mux = ClientMux::new(
wrx,
wtx,
Some(vec![UdpProtocolExtension().into()]),
Some(&[&UdpProtocolExtensionBuilder()]),
)
.await?;
Ok(mux) Ok(mux)
} }
pub fn spawn_mux_fut( pub fn spawn_mux_fut(
mux: Arc<RwLock<ClientMux<WebSocketWrapper>>>, mux: Arc<RwLock<ClientMux>>,
fut: impl Future<Output = Result<(), WispError>> + 'static, fut: impl Future<Output = Result<(), WispError>> + Send + 'static,
url: String, url: String,
) { ) {
wasm_bindgen_futures::spawn_local(async move { wasm_bindgen_futures::spawn_local(async move {
@ -225,10 +228,7 @@ pub fn spawn_mux_fut(
}); });
} }
pub async fn replace_mux( pub async fn replace_mux(mux: Arc<RwLock<ClientMux>>, url: &str) -> Result<(), WispError> {
mux: Arc<RwLock<ClientMux<WebSocketWrapper>>>,
url: &str,
) -> Result<(), WispError> {
let (mux_replace, fut) = make_mux(url).await?; let (mux_replace, fut) = make_mux(url).await?;
let mut mux_write = mux.write().await; let mut mux_write = mux.write().await;
mux_write.close().await?; mux_write.close().await?;

View file

@ -106,7 +106,7 @@ impl EpxWebSocket {
break; break;
} }
// ping/pong/continue // ping/pong/continue
_ => {}, _ => {}
} }
} }
}); });
@ -115,7 +115,13 @@ impl EpxWebSocket {
.call0(&Object::default()) .call0(&Object::default())
.replace_err("Failed to call onopen")?; .replace_err("Failed to call onopen")?;
Ok(Self { tx, onerror, origin, protocols, url: url.to_string() }) Ok(Self {
tx,
onerror,
origin,
protocols,
url: url.to_string(),
})
} }
.await; .await;
if let Err(ret) = ret { if let Err(ret) = ret {

View file

@ -53,7 +53,7 @@ impl Stream for IncomingBody {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct ServiceWrapper(pub Arc<RwLock<ClientMux<WebSocketWrapper>>>, pub String); pub struct ServiceWrapper(pub Arc<RwLock<ClientMux>>, pub String);
impl tower_service::Service<hyper::Uri> for ServiceWrapper { impl tower_service::Service<hyper::Uri> for ServiceWrapper {
type Response = TokioIo<EpxIoUnencryptedStream>; type Response = TokioIo<EpxIoUnencryptedStream>;
@ -69,7 +69,7 @@ impl tower_service::Service<hyper::Uri> for ServiceWrapper {
let mux_url = self.1.clone(); let mux_url = self.1.clone();
async move { async move {
let stream = mux let stream = mux
.read() .write()
.await .await
.client_new_stream( .client_new_stream(
StreamType::Tcp, StreamType::Tcp,
@ -193,11 +193,9 @@ pub struct WebSocketReader {
close_event: Arc<Event>, close_event: Arc<Event>,
} }
#[async_trait::async_trait]
impl WebSocketRead for WebSocketReader { impl WebSocketRead for WebSocketReader {
async fn wisp_read_frame( async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> {
&mut self,
_: &LockedWebSocketWrite<impl WebSocketWrite>,
) -> Result<Frame, WispError> {
use WebSocketMessage::*; use WebSocketMessage::*;
if self.closed.load(Ordering::Acquire) { if self.closed.load(Ordering::Acquire) {
return Err(WispError::WsImplSocketClosed); return Err(WispError::WsImplSocketClosed);
@ -306,6 +304,7 @@ impl WebSocketWrapper {
} }
} }
#[async_trait::async_trait]
impl WebSocketWrite for WebSocketWrapper { impl WebSocketWrite for WebSocketWrapper {
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> {
use wisp_mux::ws::OpCode::*; use wisp_mux::ws::OpCode::*;

1
rustfmt.toml Normal file
View file

@ -0,0 +1 @@
imports_granularity = "Crate"

View file

@ -4,13 +4,12 @@ use std::io::Error;
use bytes::Bytes; use bytes::Bytes;
use clap::Parser; use clap::Parser;
use fastwebsockets::{ use fastwebsockets::{
upgrade::{self, UpgradeFut}, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, upgrade::{self, UpgradeFut},
WebSocketError, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError,
}; };
use futures_util::{SinkExt, StreamExt, TryFutureExt}; use futures_util::{SinkExt, StreamExt, TryFutureExt};
use hyper::{ use hyper::{
body::Incoming, server::conn::http1, service::service_fn, Request, Response, body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode,
StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket};
@ -20,7 +19,10 @@ use tokio_util::codec::{BytesCodec, Framed};
#[cfg(unix)] #[cfg(unix)]
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError}; use wisp_mux::{
extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
};
type HttpBody = http_body_util::Full<hyper::body::Bytes>; type HttpBody = http_body_util::Full<hyper::body::Bytes>;
@ -261,7 +263,14 @@ async fn accept_ws(
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX); let (mut mux, fut) = ServerMux::new(
rx,
tx,
u32::MAX,
Some(vec![UdpProtocolExtension().into()]),
Some(&[&UdpProtocolExtensionBuilder()]),
)
.await?;
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = fut.await { if let Err(e) = fut.await {

View file

@ -11,7 +11,14 @@ use hyper::{
}; };
use simple_moving_average::{SingleSumSMA, SMA}; use simple_moving_average::{SingleSumSMA, SMA};
use std::{ use std::{
error::Error, future::Future, io::{stdout, IsTerminal, Write}, net::SocketAddr, process::exit, sync::Arc, time::{Duration, Instant}, usize error::Error,
future::Future,
io::{stdout, IsTerminal, Write},
net::SocketAddr,
process::exit,
sync::Arc,
time::{Duration, Instant},
usize,
}; };
use tokio::{ use tokio::{
net::TcpStream, net::TcpStream,
@ -21,7 +28,10 @@ use tokio::{
}; };
use tokio_native_tls::{native_tls, TlsConnector}; use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{ClientMux, StreamType, WispError}; use wisp_mux::{
extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ClientMux, StreamType, WispError,
};
#[derive(Debug)] #[derive(Debug)]
enum WispClientError { enum WispClientError {
@ -71,6 +81,9 @@ struct Cli {
/// Duration to run the test for /// Duration to run the test for
#[arg(short, long)] #[arg(short, long)]
duration: Option<humantime::Duration>, duration: Option<humantime::Duration>,
/// Ask for UDP
#[arg(short, long)]
udp: bool,
} }
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
@ -117,7 +130,6 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
fastwebsockets::handshake::generate_key(), fastwebsockets::handshake::generate_key(),
) )
.header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Protocol", "wisp-v1")
.body(Empty::<Bytes>::new())?; .body(Empty::<Bytes>::new())?;
let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?;
@ -125,7 +137,18 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = ws.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let (mux, fut) = ClientMux::new(rx, tx).await?; let (mut mux, fut) = if opts.udp {
ClientMux::new(
rx,
tx,
Some(vec![UdpProtocolExtension().into()]),
Some(&[&UdpProtocolExtensionBuilder()]),
)
.await?
} else {
ClientMux::new(rx, tx, Some(vec![]), Some(&[])).await?
};
let mut threads = Vec::with_capacity(opts.streams * 2 + 3); let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
threads.push(tokio::spawn(fut)); threads.push(tokio::spawn(fut));

View file

@ -1,6 +1,6 @@
[package] [package]
name = "wisp-mux" name = "wisp-mux"
version = "3.0.0" version = "4.0.0"
license = "LGPL-3.0-only" license = "LGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients." description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
@ -9,12 +9,14 @@ readme = "README.md"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
async-trait = "0.1.79"
async_io_stream = "0.3.3" async_io_stream = "0.3.3"
bytes = "1.5.0" bytes = "1.5.0"
dashmap = { version = "5.5.3", features = ["inline"] } dashmap = { version = "5.5.3", features = ["inline"] }
event-listener = "5.0.0" event-listener = "5.0.0"
fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true }
futures = "0.3.30" futures = "0.3.30"
futures-timer = "3.0.3"
futures-util = "0.3.30" futures-util = "0.3.30"
pin-project-lite = "0.2.13" pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", optional = true, default-features = false } tokio = { version = "1.35.1", optional = true, default-features = false }
@ -22,6 +24,7 @@ tokio = { version = "1.35.1", optional = true, default-features = false }
[features] [features]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
tokio_io = ["async_io_stream/tokio_io"] tokio_io = ["async_io_stream/tokio_io"]
wasm = ["futures-timer/wasm-bindgen"]
[package.metadata.docs.rs] [package.metadata.docs.rs]
all-features = true all-features = true

190
wisp/src/extensions.rs Normal file
View file

@ -0,0 +1,190 @@
//! Wisp protocol extensions.
use std::ops::{Deref, DerefMut};
use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut};
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
Role, WispError,
};
/// Type-erased protocol extension that implements Clone.
#[derive(Debug)]
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>);
impl AnyProtocolExtension {
/// Create a new type-erased protocol extension.
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
Self(Box::new(extension))
}
}
impl Deref for AnyProtocolExtension {
type Target = dyn ProtocolExtension;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl DerefMut for AnyProtocolExtension {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut()
}
}
impl Clone for AnyProtocolExtension {
fn clone(&self) -> Self {
Self(self.0.box_clone())
}
}
impl From<AnyProtocolExtension> for Bytes {
fn from(value: AnyProtocolExtension) -> Self {
let mut bytes = BytesMut::with_capacity(5);
let payload = value.encode();
bytes.put_u8(value.get_id());
bytes.put_u32_le(payload.len() as u32);
bytes.extend(payload);
bytes.freeze()
}
}
/// A Wisp protocol extension.
///
/// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions).
#[async_trait]
pub trait ProtocolExtension: std::fmt::Debug {
/// Get the protocol extension ID.
fn get_id(&self) -> u8;
/// Get the protocol extension's supported packets.
///
/// Used to decide whether to call the protocol extension's packet handler.
fn get_supported_packets(&self) -> &'static [u8];
/// Encode self into Bytes.
fn encode(&self) -> Bytes;
/// Handle the handshake part of a Wisp connection.
///
/// This should be used to send or receive data before any streams are created.
async fn handle_handshake(
&mut self,
read: &mut dyn WebSocketRead,
write: &LockedWebSocketWrite,
) -> Result<(), WispError>;
/// Handle receiving a packet.
async fn handle_packet(
&mut self,
packet: Bytes,
read: &mut dyn WebSocketRead,
write: &LockedWebSocketWrite,
) -> Result<(), WispError>;
/// Clone the protocol extension.
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
}
/// Trait to build a Wisp protocol extension for the client.
pub trait ProtocolExtensionBuilder {
/// Get the protocol extension ID.
///
/// Used to decide whether this builder should be used.
fn get_id(&self) -> u8;
/// Build a protocol extension from the extension's metadata.
fn build(&self, bytes: Bytes, role: Role) -> AnyProtocolExtension;
}
pub mod udp {
//! UDP protocol extension.
//!
//! # Example
//! ```
//! let (mux, fut) = ServerMux::new(
//! rx,
//! tx,
//! 128,
//! Some(vec![UdpProtocolExtension().into()]),
//! Some(&[&UdpProtocolExtensionBuilder()])
//! );
//! ```
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp)
use async_trait::async_trait;
use bytes::Bytes;
use crate::{
ws::{LockedWebSocketWrite, WebSocketRead},
WispError,
};
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
#[derive(Debug)]
/// UDP protocol extension.
pub struct UdpProtocolExtension();
impl UdpProtocolExtension {
/// UDP protocol extension ID.
pub const ID: u8 = 0x01;
}
#[async_trait]
impl ProtocolExtension for UdpProtocolExtension {
fn get_id(&self) -> u8 {
Self::ID
}
fn get_supported_packets(&self) -> &'static [u8] {
&[]
}
fn encode(&self) -> Bytes {
Bytes::new()
}
async fn handle_handshake(
&mut self,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
/// Handle receiving a packet.
async fn handle_packet(
&mut self,
_: Bytes,
_: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(Self())
}
}
impl From<UdpProtocolExtension> for AnyProtocolExtension {
fn from(value: UdpProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value))
}
}
/// UDP protocol extension builder.
pub struct UdpProtocolExtensionBuilder();
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
fn get_id(&self) -> u8 {
0x01
}
fn build(&self, _: Bytes, _: crate::Role) -> AnyProtocolExtension {
AnyProtocolExtension(Box::new(UdpProtocolExtension()))
}
}
}

View file

@ -1,9 +1,12 @@
use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use fastwebsockets::{ use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
}; };
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use crate::{ws::LockedWebSocketWrite, WispError};
impl From<OpCode> for crate::ws::OpCode { impl From<OpCode> for crate::ws::OpCode {
fn from(opcode: OpCode) -> Self { fn from(opcode: OpCode) -> Self {
use OpCode::*; use OpCode::*;
@ -58,11 +61,12 @@ impl From<WebSocketError> for crate::WispError {
} }
} }
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> { #[async_trait]
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame( async fn wisp_read_frame(
&mut self, &mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>, tx: &LockedWebSocketWrite,
) -> Result<crate::ws::Frame, crate::WispError> { ) -> Result<crate::ws::Frame, WispError> {
Ok(self Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await? .await?
@ -70,8 +74,9 @@ impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S>
} }
} }
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> { #[async_trait]
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into()) self.write_frame(frame.into()).await.map_err(|e| e.into())
} }
} }

View file

@ -4,6 +4,7 @@
//! //!
//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol //! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol
pub mod extensions;
#[cfg(feature = "fastwebsockets")] #[cfg(feature = "fastwebsockets")]
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))] #[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
mod fastwebsockets; mod fastwebsockets;
@ -12,18 +13,28 @@ mod sink_unfold;
mod stream; mod stream;
pub mod ws; pub mod ws;
pub use crate::packet::*; pub use crate::{packet::*, stream::*};
pub use crate::stream::*;
use bytes::Bytes; use bytes::Bytes;
use dashmap::DashMap; use dashmap::DashMap;
use event_listener::Event; use event_listener::Event;
use futures::SinkExt; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use futures::{channel::mpsc, Future, FutureExt, StreamExt}; use futures::{
use std::sync::{ channel::{mpsc, oneshot},
select, Future, FutureExt, SinkExt, StreamExt,
};
use futures_timer::Delay;
use std::{
sync::{
atomic::{AtomicBool, AtomicU32, Ordering}, atomic::{AtomicBool, AtomicU32, Ordering},
Arc, Arc,
},
time::Duration,
}; };
use ws::AppendingWebSocketRead;
/// Wisp version supported by this crate.
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
/// The role of the multiplexor. /// The role of the multiplexor.
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
@ -37,9 +48,9 @@ pub enum Role {
/// Errors the Wisp implementation can return. /// Errors the Wisp implementation can return.
#[derive(Debug)] #[derive(Debug)]
pub enum WispError { pub enum WispError {
/// The packet recieved did not have enough data. /// The packet received did not have enough data.
PacketTooSmall, PacketTooSmall,
/// The packet recieved had an invalid type. /// The packet received had an invalid type.
InvalidPacketType, InvalidPacketType,
/// The stream had an invalid type. /// The stream had an invalid type.
InvalidStreamType, InvalidStreamType,
@ -47,19 +58,19 @@ pub enum WispError {
InvalidStreamId, InvalidStreamId,
/// The close packet had an invalid reason. /// The close packet had an invalid reason.
InvalidCloseReason, InvalidCloseReason,
/// The URI recieved was invalid. /// The URI received was invalid.
InvalidUri, InvalidUri,
/// The URI recieved had no host. /// The URI received had no host.
UriHasNoHost, UriHasNoHost,
/// The URI recieved had no port. /// The URI received had no port.
UriHasNoPort, UriHasNoPort,
/// The max stream count was reached. /// The max stream count was reached.
MaxStreamCountReached, MaxStreamCountReached,
/// The stream had already been closed. /// The stream had already been closed.
StreamAlreadyClosed, StreamAlreadyClosed,
/// The websocket frame recieved had an invalid type. /// The websocket frame received had an invalid type.
WsFrameInvalidType, WsFrameInvalidType,
/// The websocket frame recieved was not finished. /// The websocket frame received was not finished.
WsFrameNotFinished, WsFrameNotFinished,
/// Error specific to the websocket implementation. /// Error specific to the websocket implementation.
WsImplError(Box<dyn std::error::Error + Sync + Send>), WsImplError(Box<dyn std::error::Error + Sync + Send>),
@ -67,17 +78,33 @@ pub enum WispError {
WsImplSocketClosed, WsImplSocketClosed,
/// The websocket implementation did not support the action. /// The websocket implementation did not support the action.
WsImplNotSupported, WsImplNotSupported,
/// Error specific to the protocol extension implementation.
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action.
ExtensionImplNotSupported,
/// The UDP protocol extension is not supported by the server.
UdpExtensionNotSupported,
/// The string was invalid UTF-8. /// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
/// The integer failed to convert.
TryFromIntError(std::num::TryFromIntError),
/// Other error. /// Other error.
Other(Box<dyn std::error::Error + Sync + Send>), Other(Box<dyn std::error::Error + Sync + Send>),
/// Failed to send message to multiplexor task. /// Failed to send message to multiplexor task.
MuxMessageFailedToSend, MuxMessageFailedToSend,
/// Failed to receive message from multiplexor task.
MuxMessageFailedToRecv,
} }
impl From<std::str::Utf8Error> for WispError { impl From<std::str::Utf8Error> for WispError {
fn from(err: std::str::Utf8Error) -> WispError { fn from(err: std::str::Utf8Error) -> Self {
WispError::Utf8Error(err) Self::Utf8Error(err)
}
}
impl From<std::num::TryFromIntError> for WispError {
fn from(value: std::num::TryFromIntError) -> Self {
Self::TryFromIntError(value)
} }
} }
@ -103,9 +130,21 @@ impl std::fmt::Display for WispError {
Self::WsImplNotSupported => { Self::WsImplNotSupported => {
write!(f, "Websocket implementation error: unsupported feature") write!(f, "Websocket implementation error: unsupported feature")
} }
Self::ExtensionImplError(err) => {
write!(f, "Protocol extension implementation error: {}", err)
}
Self::ExtensionImplNotSupported => {
write!(
f,
"Protocol extension implementation error: unsupported feature"
)
}
Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"),
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err), Self::Other(err) => write!(f, "Other error: {}", err),
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
} }
} }
} }
@ -120,29 +159,27 @@ struct MuxMapValue {
is_closed: Arc<AtomicBool>, is_closed: Arc<AtomicBool>,
} }
struct MuxInner<W> struct MuxInner {
where tx: ws::LockedWebSocketWrite,
W: ws::WebSocketWrite, stream_map: DashMap<u32, MuxMapValue>,
{ buffer_size: u32,
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, MuxMapValue>>,
} }
impl<W: ws::WebSocketWrite> MuxInner<W> { impl MuxInner {
pub async fn server_into_future<R>( pub async fn server_into_future<R>(
self, self,
rx: R, rx: R,
close_rx: mpsc::Receiver<WsEvent>, close_rx: mpsc::Receiver<WsEvent>,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
self.into_future( self.as_future(
close_rx, close_rx,
self.server_loop(rx, muxstream_sender, buffer_size, close_tx), close_tx.clone(),
self.server_loop(rx, muxstream_sender, close_tx),
) )
.await .await
} }
@ -151,20 +188,23 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
self, self,
rx: R, rx: R,
close_rx: mpsc::Receiver<WsEvent>, close_rx: mpsc::Receiver<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
self.into_future(close_rx, self.client_loop(rx)).await self.as_future(close_rx, close_tx, self.client_loop(rx))
.await
} }
async fn into_future( async fn as_future(
&self, &self,
close_rx: mpsc::Receiver<WsEvent>, close_rx: mpsc::Receiver<WsEvent>,
close_tx: mpsc::Sender<WsEvent>,
wisp_fut: impl Future<Output = Result<(), WispError>>, wisp_fut: impl Future<Output = Result<(), WispError>>,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
let ret = futures::select! { let ret = futures::select! {
_ = self.stream_loop(close_rx).fuse() => Ok(()), _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
x = wisp_fut.fuse() => x, x = wisp_fut.fuse() => x,
}; };
self.stream_map.iter_mut().for_each(|mut x| { self.stream_map.iter_mut().for_each(|mut x| {
@ -176,7 +216,12 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
ret ret
} }
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) { async fn stream_loop(
&self,
mut stream_rx: mpsc::Receiver<WsEvent>,
stream_tx: mpsc::Sender<WsEvent>,
) {
let mut next_free_stream_id: u32 = 1;
while let Some(msg) = stream_rx.next().await { while let Some(msg) = stream_rx.next().await {
match msg { match msg {
WsEvent::SendPacket(packet, channel) => { WsEvent::SendPacket(packet, channel) => {
@ -186,6 +231,53 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
} }
} }
WsEvent::CreateStream(stream_type, host, port, channel) => {
let ret: Result<MuxStream, WispError> = async {
let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_id = next_free_stream_id;
let next_stream_id = next_free_stream_id
.checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?;
let flow_control_event: Arc<Event> = Event::new().into();
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
self.tx
.write_frame(
Packet::new_connect(stream_id, stream_type, port, host).into(),
)
.await?;
next_free_stream_id = next_stream_id;
self.stream_map.insert(
stream_id,
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
Ok(MuxStream::new(
stream_id,
Role::Client,
stream_type,
ch_rx,
stream_tx.clone(),
is_closed,
flow_control,
flow_control_event,
0,
))
}
.await;
let _ = channel.send(ret);
}
WsEvent::Close(packet, channel) => { WsEvent::Close(packet, channel) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.stream.disconnect(); stream.stream.disconnect();
@ -204,17 +296,13 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
&self, &self,
mut rx: R, mut rx: R,
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
buffer_size: u32,
close_tx: mpsc::Sender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
) -> Result<(), WispError> ) -> Result<(), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead,
{ {
// will send continues once flow_control is at 10% of max // will send continues once flow_control is at 10% of max
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32; let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
self.tx
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
loop { loop {
let frame = rx.wisp_read_frame(&self.tx).await?; let frame = rx.wisp_read_frame(&self.tx).await?;
@ -228,7 +316,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
Connect(inner_packet) => { Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded(); let (ch_tx, ch_rx) = mpsc::unbounded();
let stream_type = inner_packet.stream_type; let stream_type = inner_packet.stream_type;
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into(); let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
let flow_control_event: Arc<Event> = Event::new().into(); let flow_control_event: Arc<Event> = Event::new().into();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into(); let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
@ -273,7 +361,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
} }
} }
} }
Continue(_) => break Err(WispError::InvalidPacketType), Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
Close(_) => { Close(_) => {
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
stream.is_closed.store(true, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
@ -298,7 +386,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
use PacketType::*; use PacketType::*;
match packet.packet_type { match packet.packet_type {
Connect(_) => break Err(WispError::InvalidPacketType), Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.unbounded_send(data); let _ = stream.stream.unbounded_send(data);
@ -332,7 +420,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
/// ``` /// ```
/// use wisp_mux::ServerMux; /// use wisp_mux::ServerMux;
/// ///
/// let (mux, fut) = ServerMux::new(rx, tx, 128); /// let (mux, fut) = ServerMux::new(rx, tx, 128, Some(vec![]), Some([]));
/// tokio::spawn(async move { /// tokio::spawn(async move {
/// if let Err(e) = fut.await { /// if let Err(e) = fut.await {
/// println!("error in multiplexor: {:?}", e); /// println!("error in multiplexor: {:?}", e);
@ -346,34 +434,89 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
/// } /// }
/// ``` /// ```
pub struct ServerMux { pub struct ServerMux {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Arc<[AnyProtocolExtension]>,
close_tx: mpsc::Sender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
} }
impl ServerMux { impl ServerMux {
/// Create a new server-side multiplexor. /// Create a new server-side multiplexor.
pub fn new<R, W: ws::WebSocketWrite>( ///
read: R, /// If either extensions or extension_builders are None a Wisp v1 connection is created
/// otherwise a Wisp v2 connection is created.
pub async fn new<R, W>(
mut read: R,
write: W, write: W,
buffer_size: u32, buffer_size: u32,
) -> (Self, impl Future<Output = Result<(), WispError>>) extensions: Option<Vec<AnyProtocolExtension>>,
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{ {
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256); let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
let write = ws::LockedWebSocketWrite::new(write); let write = ws::LockedWebSocketWrite::new(Box::new(write));
(
write
.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let mut supported_extensions = Vec::new();
let mut extra_packet = Vec::with_capacity(1);
let mut downgraded = true;
if let Some(extensions) = extensions {
if let Some(builders) = extension_builders {
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
write
.write_frame(Packet::new_info(extensions).into())
.await?;
if let Some(frame) = select! {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
// TODO change this to correct timeout once draft 2 is out
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.push(packet.into());
}
}
}
}
Ok((
Self { Self {
muxstream_recv: rx, muxstream_recv: rx,
close_tx: close_tx.clone(), close_tx: close_tx.clone(),
downgraded,
supported_extensions: supported_extensions.into(),
}, },
MuxInner { MuxInner {
tx: write, tx: write,
stream_map: DashMap::new().into(), stream_map: DashMap::new(),
buffer_size,
} }
.server_into_future(read, close_rx, tx, buffer_size, close_tx), .server_into_future(
) AppendingWebSocketRead(extra_packet, read),
close_rx,
tx,
close_tx,
),
))
} }
/// Wait for a stream to be created. /// Wait for a stream to be created.
@ -398,7 +541,7 @@ impl ServerMux {
/// ``` /// ```
/// use wisp_mux::{ClientMux, StreamType}; /// use wisp_mux::{ClientMux, StreamType};
/// ///
/// let (mux, fut) = ClientMux::new(rx, tx).await?; /// let (mux, fut) = ClientMux::new(rx, tx, Some(vec![]), []).await?;
/// tokio::spawn(async move { /// tokio::spawn(async move {
/// if let Err(e) = fut.await { /// if let Err(e) = fut.await {
/// println!("error in multiplexor: {:?}", e); /// println!("error in multiplexor: {:?}", e);
@ -406,50 +549,88 @@ impl ServerMux {
/// }); /// });
/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80); /// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
/// ``` /// ```
pub struct ClientMux<W> pub struct ClientMux {
where /// Whether the connection was downgraded to Wisp v1.
W: ws::WebSocketWrite, ///
{ /// If this variable is true you must assume no extensions are supported.
tx: ws::LockedWebSocketWrite<W>, pub downgraded: bool,
stream_map: Arc<DashMap<u32, MuxMapValue>>, /// Extensions that are supported by both sides.
next_free_stream_id: AtomicU32, pub supported_extensions: Arc<[AnyProtocolExtension]>,
close_tx: mpsc::Sender<WsEvent>, close_tx: mpsc::Sender<WsEvent>,
buf_size: u32,
target_buf_size: u32,
} }
impl<W: ws::WebSocketWrite> ClientMux<W> { impl ClientMux {
/// Create a new client side multiplexor. /// Create a new client side multiplexor.
pub async fn new<R>( ///
/// If either extensions or extension_builders are None a Wisp v1 connection is created
/// otherwise a Wisp v2 connection is created.
pub async fn new<R, W>(
mut read: R, mut read: R,
write: W, write: W,
) -> Result<(Self, impl Future<Output = Result<(), WispError>>), WispError> extensions: Option<Vec<AnyProtocolExtension>>,
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
where where
R: ws::WebSocketRead, R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{ {
let write = ws::LockedWebSocketWrite::new(write); let write = ws::LockedWebSocketWrite::new(Box::new(write));
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
if first_packet.stream_id != 0 { if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId); return Err(WispError::InvalidStreamId);
} }
if let PacketType::Continue(packet) = first_packet.packet_type { if let PacketType::Continue(packet) = first_packet.packet_type {
let mut supported_extensions = Vec::new();
let mut extra_packet = Vec::with_capacity(1);
let mut downgraded = true;
if let Some(extensions) = extensions {
if let Some(builders) = extension_builders {
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(&write).fuse() => Some(x?),
// TODO change this to correct timeout once draft 2 is out
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
write
.write_frame(Packet::new_info(extensions).into())
.await?;
downgraded = false;
} else {
extra_packet.push(packet.into());
}
}
}
}
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(&mut read, &write).await?;
}
let (tx, rx) = mpsc::channel::<WsEvent>(256); let (tx, rx) = mpsc::channel::<WsEvent>(256);
let map = Arc::new(DashMap::new());
Ok(( Ok((
Self { Self {
tx: write.clone(),
stream_map: map.clone(),
next_free_stream_id: AtomicU32::new(1),
close_tx: tx.clone(), close_tx: tx.clone(),
buf_size: packet.buffer_remaining, downgraded,
// server-only supported_extensions: supported_extensions.into(),
target_buf_size: 0,
}, },
MuxInner { MuxInner {
tx: write.clone(), tx: write,
stream_map: map.clone(), stream_map: DashMap::new(),
buffer_size: packet.buffer_remaining,
} }
.client_into_future(read, rx), .client_into_future(
AppendingWebSocketRead(extra_packet, read),
rx,
tx,
),
)) ))
} else { } else {
Err(WispError::InvalidPacketType) Err(WispError::InvalidPacketType)
@ -458,51 +639,25 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
/// Create a new stream, multiplexed through Wisp. /// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream( pub async fn client_new_stream(
&self, &mut self,
stream_type: StreamType, stream_type: StreamType,
host: String, host: String,
port: u16, port: u16,
) -> Result<MuxStream, WispError> { ) -> Result<MuxStream, WispError> {
let (ch_tx, ch_rx) = mpsc::unbounded(); if stream_type == StreamType::Udp
let stream_id = self.next_free_stream_id.load(Ordering::Acquire); && !self
let next_stream_id = stream_id .supported_extensions
.checked_add(1) .iter()
.ok_or(WispError::MaxStreamCountReached)?; .any(|x| x.get_id() == UdpProtocolExtension::ID)
{
let flow_control_event: Arc<Event> = Event::new().into(); return Err(WispError::UdpExtensionNotSupported);
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into(); }
let (tx, rx) = oneshot::channel();
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into(); self.close_tx
.send(WsEvent::CreateStream(stream_type, host, port, tx))
self.tx .await
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into()) .map_err(|_| WispError::MuxMessageFailedToSend)?;
.await?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
self.next_free_stream_id
.store(next_stream_id, Ordering::Release);
self.stream_map.insert(
stream_id,
MuxMapValue {
stream: ch_tx,
stream_type,
flow_control: flow_control.clone(),
flow_control_event: flow_control_event.clone(),
is_closed: is_closed.clone(),
},
);
Ok(MuxStream::new(
stream_id,
Role::Client,
stream_type,
ch_rx,
self.close_tx.clone(),
is_closed,
flow_control,
flow_control_event,
self.target_buf_size,
))
} }
/// Close all streams. /// Close all streams.

View file

@ -1,4 +1,8 @@
use crate::{ws, WispError}; use crate::{
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
ws::{self, Frame, OpCode},
Role, WispError, WISP_VERSION,
};
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
/// Wisp stream type. /// Wisp stream type.
@ -34,6 +38,8 @@ pub enum CloseReason {
Voluntary = 0x02, Voluntary = 0x02,
/// Unexpected stream closure due to a network error. /// Unexpected stream closure due to a network error.
Unexpected = 0x03, Unexpected = 0x03,
/// Incompatible extensions. Only used during the handshake.
IncompatibleExtensions = 0x04,
/// Stream creation failed due to invalid information. /// Stream creation failed due to invalid information.
ServerStreamInvalidInfo = 0x41, ServerStreamInvalidInfo = 0x41,
/// Stream creation failed due to an unreachable destination host. /// Stream creation failed due to an unreachable destination host.
@ -55,19 +61,20 @@ pub enum CloseReason {
impl TryFrom<u8> for CloseReason { impl TryFrom<u8> for CloseReason {
type Error = WispError; type Error = WispError;
fn try_from(stream_type: u8) -> Result<Self, Self::Error> { fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
use CloseReason::*; use CloseReason as R;
match stream_type { match stream_type {
0x01 => Ok(Unknown), 0x01 => Ok(R::Unknown),
0x02 => Ok(Voluntary), 0x02 => Ok(R::Voluntary),
0x03 => Ok(Unexpected), 0x03 => Ok(R::Unexpected),
0x41 => Ok(ServerStreamInvalidInfo), 0x04 => Ok(R::IncompatibleExtensions),
0x42 => Ok(ServerStreamUnreachable), 0x41 => Ok(R::ServerStreamInvalidInfo),
0x43 => Ok(ServerStreamConnectionTimedOut), 0x42 => Ok(R::ServerStreamUnreachable),
0x44 => Ok(ServerStreamConnectionRefused), 0x43 => Ok(R::ServerStreamConnectionTimedOut),
0x47 => Ok(ServerStreamTimedOut), 0x44 => Ok(R::ServerStreamConnectionRefused),
0x48 => Ok(ServerStreamBlockedAddress), 0x47 => Ok(R::ServerStreamTimedOut),
0x49 => Ok(ServerStreamThrottled), 0x48 => Ok(R::ServerStreamBlockedAddress),
0x81 => Ok(ClientUnexpected), 0x49 => Ok(R::ServerStreamThrottled),
0x81 => Ok(R::ClientUnexpected),
_ => Err(Self::Error::InvalidStreamType), _ => Err(Self::Error::InvalidStreamType),
} }
} }
@ -198,6 +205,38 @@ impl From<ClosePacket> for Bytes {
} }
} }
/// Wisp version sent in the handshake.
#[derive(Debug, Clone)]
pub struct WispVersion {
/// Major Wisp version according to semver.
pub major: u8,
/// Minor Wisp version according to semver.
pub minor: u8,
}
/// Packet used in the initial handshake.
///
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info)
#[derive(Debug, Clone)]
pub struct InfoPacket {
/// Wisp version sent in the packet.
pub version: WispVersion,
/// List of protocol extensions sent in the packet.
pub extensions: Vec<AnyProtocolExtension>,
}
impl From<InfoPacket> for Bytes {
fn from(value: InfoPacket) -> Self {
let mut bytes = BytesMut::with_capacity(2);
bytes.put_u8(value.version.major);
bytes.put_u8(value.version.minor);
for extension in value.extensions {
bytes.extend(Bytes::from(extension));
}
bytes.freeze()
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Type of packet recieved. /// Type of packet recieved.
pub enum PacketType { pub enum PacketType {
@ -209,6 +248,8 @@ pub enum PacketType {
Continue(ContinuePacket), Continue(ContinuePacket),
/// Close packet. /// Close packet.
Close(ClosePacket), Close(ClosePacket),
/// Info packet.
Info(InfoPacket),
} }
impl PacketType { impl PacketType {
@ -220,6 +261,7 @@ impl PacketType {
Data(_) => 0x02, Data(_) => 0x02,
Continue(_) => 0x03, Continue(_) => 0x03,
Close(_) => 0x04, Close(_) => 0x04,
Info(_) => 0x05,
} }
} }
} }
@ -232,6 +274,7 @@ impl From<PacketType> for Bytes {
Data(x) => x, Data(x) => x,
Continue(x) => x.into(), Continue(x) => x.into(),
Close(x) => x.into(), Close(x) => x.into(),
Info(x) => x.into(),
} }
} }
} }
@ -296,6 +339,97 @@ impl Packet {
packet_type: PacketType::Close(ClosePacket::new(reason)), packet_type: PacketType::Close(ClosePacket::new(reason)),
} }
} }
pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version: WISP_VERSION,
extensions,
}),
}
}
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
use PacketType::*;
Ok(Self {
stream_id: bytes.get_u32_le(),
packet_type: match packet_type {
0x01 => Connect(ConnectPacket::try_from(bytes)?),
0x02 => Data(bytes),
0x03 => Continue(ContinuePacket::try_from(bytes)?),
0x04 => Close(ClosePacket::try_from(bytes)?),
// 0x05 is handled seperately
_ => return Err(WispError::InvalidPacketType),
},
})
}
pub(crate) fn maybe_parse_info(
frame: Frame,
role: Role,
extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)],
) -> Result<Self, WispError> {
if !frame.finished {
return Err(WispError::WsFrameNotFinished);
}
if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType);
}
let mut bytes = frame.payload;
if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall);
}
let packet_type = bytes.get_u8();
if packet_type == 0x05 {
Self::parse_info(bytes, role, extension_builders)
} else {
Self::parse_packet(packet_type, bytes)
}
}
fn parse_info(
mut bytes: Bytes,
role: Role,
extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)],
) -> Result<Self, WispError> {
// packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall);
}
if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let version = WispVersion {
major: bytes.get_u8(),
minor: bytes.get_u8(),
};
let mut extensions = Vec::new();
while bytes.remaining() > 4 {
// We have some extensions
let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length {
return Err(WispError::PacketTooSmall);
}
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
extensions.push(builder.build(bytes.copy_to_bytes(length), role))
} else {
bytes.advance(length)
}
}
Ok(Self {
stream_id: 0,
packet_type: PacketType::Info(InfoPacket {
version,
extensions,
}),
})
}
} }
impl TryFrom<Bytes> for Packet { impl TryFrom<Bytes> for Packet {
@ -305,17 +439,7 @@ impl TryFrom<Bytes> for Packet {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
let packet_type = bytes.get_u8(); let packet_type = bytes.get_u8();
use PacketType::*; Self::parse_packet(packet_type, bytes)
Ok(Self {
stream_id: bytes.get_u32_le(),
packet_type: match packet_type {
0x01 => Connect(ConnectPacket::try_from(bytes)?),
0x02 => Data(bytes),
0x03 => Continue(ContinuePacket::try_from(bytes)?),
0x04 => Close(ClosePacket::try_from(bytes)?),
_ => return Err(Self::Error::InvalidPacketType),
},
})
} }
} }

View file

@ -1,8 +1,10 @@
//! futures sink unfold with a close function //! futures sink unfold with a close function
use core::{future::Future, pin::Pin}; use core::{future::Future, pin::Pin};
use futures::ready; use futures::{
use futures::task::{Context, Poll}; ready,
use futures::Sink; task::{Context, Poll},
Sink,
};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
pin_project! { pin_project! {

View file

@ -21,6 +21,12 @@ use std::{
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>), SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
Close(Packet, oneshot::Sender<Result<(), WispError>>), Close(Packet, oneshot::Sender<Result<(), WispError>>),
CreateStream(
StreamType,
String,
u16,
oneshot::Sender<Result<MuxStream, WispError>>,
),
EndFut, EndFut,
} }
@ -317,7 +323,10 @@ impl MuxStreamIo {
impl Stream for MuxStreamIo { impl Stream for MuxStreamIo {
type Item = Result<Vec<u8>, std::io::Error>; type Item = Result<Vec<u8>, std::io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec()))) self.project()
.rx
.poll_next(cx)
.map(|x| x.map(|x| Ok(x.to_vec())))
} }
} }

View file

@ -4,9 +4,10 @@
//! for other WebSocket implementations. //! for other WebSocket implementations.
//! //!
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
use crate::WispError;
use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures::lock::Mutex; use futures::lock::Mutex;
use std::sync::Arc;
/// Opcode of the WebSocket frame. /// Opcode of the WebSocket frame.
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
@ -64,30 +65,26 @@ impl Frame {
} }
/// Generic WebSocket read trait. /// Generic WebSocket read trait.
#[async_trait]
pub trait WebSocketRead { pub trait WebSocketRead {
/// Read a frame from the socket. /// Read a frame from the socket.
fn wisp_read_frame( async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError>;
&mut self,
tx: &crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
} }
/// Generic WebSocket write trait. /// Generic WebSocket write trait.
#[async_trait]
pub trait WebSocketWrite { pub trait WebSocketWrite {
/// Write a frame to the socket. /// Write a frame to the socket.
fn wisp_write_frame( async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
&mut self,
frame: Frame,
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
} }
/// Locked WebSocket that can be shared between threads. /// Locked WebSocket.
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>); pub struct LockedWebSocketWrite(Mutex<Box<dyn WebSocketWrite + Send>>);
impl<S: WebSocketWrite> LockedWebSocketWrite<S> { impl LockedWebSocketWrite {
/// Create a new locked websocket. /// Create a new locked websocket.
pub fn new(ws: S) -> Self { pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
Self(Arc::new(Mutex::new(ws))) Self(Mutex::new(ws))
} }
/// Write a frame to the websocket. /// Write a frame to the websocket.
@ -96,8 +93,19 @@ impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
} }
} }
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> { pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R)
fn clone(&self) -> Self { where
Self(self.0.clone()) R: WebSocketRead + Send;
#[async_trait]
impl<R> WebSocketRead for AppendingWebSocketRead<R>
where
R: WebSocketRead + Send,
{
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> {
if let Some(x) = self.0.pop() {
return Ok(x);
}
return self.1.wisp_read_frame(tx).await;
} }
} }