From ef5ed52e7121dc84ce2abd3c494d7ddf4fdbd5b1 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Thu, 11 Apr 2024 19:05:14 -0700 Subject: [PATCH] preliminary support for wisp v2 --- Cargo.lock | 39 +++- client/Cargo.toml | 3 +- client/src/lib.rs | 4 +- client/src/udp_stream.rs | 2 +- client/src/utils.rs | 30 +-- client/src/websocket.rs | 10 +- client/src/wrappers.rs | 11 +- rustfmt.toml | 1 + server/src/main.rs | 21 +- simple-wisp-client/src/main.rs | 31 ++- wisp/Cargo.toml | 5 +- wisp/src/extensions.rs | 190 ++++++++++++++++ wisp/src/fastwebsockets.rs | 15 +- wisp/src/lib.rs | 383 +++++++++++++++++++++++---------- wisp/src/packet.rs | 172 ++++++++++++--- wisp/src/sink_unfold.rs | 8 +- wisp/src/stream.rs | 11 +- wisp/src/ws.rs | 42 ++-- 18 files changed, 772 insertions(+), 206 deletions(-) create mode 100644 rustfmt.toml create mode 100644 wisp/src/extensions.rs diff --git a/Cargo.lock b/Cargo.lock index 4876e52..fb57868 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,9 +133,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.79" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2 1.0.79", "quote 1.0.36", @@ -525,6 +525,7 @@ name = "epoxy-client" version = "1.5.1" dependencies = [ "async-compression", + "async-trait", "async_io_stream", "base64", "bytes", @@ -542,7 +543,7 @@ dependencies = [ "pin-project-lite", "ring", "rustls-pki-types", - "send_wrapper", + "send_wrapper 0.6.0", "tokio", "tokio-rustls", "tokio-util", @@ -744,6 +745,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" +dependencies = [ + "gloo-timers", + "send_wrapper 0.4.0", +] + [[package]] name = "futures-util" version = "0.3.30" @@ -791,6 +802,18 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "gloo-timers" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "h2" version = "0.3.26" @@ -1659,6 +1682,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +[[package]] +name = "send_wrapper" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" + [[package]] name = "send_wrapper" version = "0.6.0" @@ -2531,14 +2560,16 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "wisp-mux" -version = "3.0.0" +version = "4.0.0" dependencies = [ + "async-trait", "async_io_stream", "bytes", "dashmap", "event-listener", "fastwebsockets 0.7.1", "futures", + "futures-timer", "futures-util", "pin-project-lite", "tokio", diff --git a/client/Cargo.toml b/client/Cargo.toml index 21a2af1..ee1d108 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -25,7 +25,7 @@ tokio-util = { version = "0.7.10", features = ["io"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } fastwebsockets = { version = "0.6.0", features = ["unstable-split"] } base64 = "0.21.7" -wisp-mux = { path = "../wisp", features = ["tokio_io"] } +wisp-mux = { path = "../wisp", features = ["tokio_io", "wasm"] } async_io_stream = { version = "0.3.3", features = ["tokio_io"] } getrandom = { version = "0.2.12", features = ["js"] } hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] } @@ -35,6 +35,7 @@ console_error_panic_hook = "0.1.7" send_wrapper = "0.6.0" event-listener = "5.2.0" wasmtimer = "0.2.0" +async-trait = "0.1.80" [dependencies.ring] features = ["wasm32_unknown_unknown_js"] diff --git a/client/src/lib.rs b/client/src/lib.rs index 6a6217c..0f8f678 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -105,7 +105,7 @@ pub fn certs() -> Result { #[wasm_bindgen(inspectable)] pub struct EpoxyClient { rustls_config: Arc, - mux: Arc>>, + mux: Arc>, hyper_client: Client, #[wasm_bindgen(getter_with_clone)] pub useragent: String, @@ -164,7 +164,7 @@ impl EpoxyClient { async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result { let channel = self .mux - .read() + .write() .await .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) .await diff --git a/client/src/udp_stream.rs b/client/src/udp_stream.rs index c026ca7..877bab4 100644 --- a/client/src/udp_stream.rs +++ b/client/src/udp_stream.rs @@ -33,7 +33,7 @@ impl EpxUdpStream { let io = tcp .mux - .read() + .write() .await .client_new_stream(StreamType::Udp, url_host.to_string(), url_port) .await diff --git a/client/src/utils.rs b/client/src/utils.rs index 1fdcf2e..3b05027 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -6,7 +6,10 @@ use wasm_bindgen_futures::JsFuture; use hyper::rt::Executor; use js_sys::ArrayBuffer; use std::future::Future; -use wisp_mux::WispError; +use wisp_mux::{ + extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + WispError, +}; #[wasm_bindgen] extern "C" { @@ -192,25 +195,25 @@ pub fn get_url_port(url: &Uri) -> Result { pub async fn make_mux( url: &str, -) -> Result< - ( - ClientMux, - impl Future>, - ), - WispError, -> { +) -> Result<(ClientMux, impl Future> + Send), WispError> { let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) .await .map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - let mux = ClientMux::new(wrx, wtx).await?; + let mux = ClientMux::new( + wrx, + wtx, + Some(vec![UdpProtocolExtension().into()]), + Some(&[&UdpProtocolExtensionBuilder()]), + ) + .await?; Ok(mux) } pub fn spawn_mux_fut( - mux: Arc>>, - fut: impl Future> + 'static, + mux: Arc>, + fut: impl Future> + Send + 'static, url: String, ) { wasm_bindgen_futures::spawn_local(async move { @@ -225,10 +228,7 @@ pub fn spawn_mux_fut( }); } -pub async fn replace_mux( - mux: Arc>>, - url: &str, -) -> Result<(), WispError> { +pub async fn replace_mux(mux: Arc>, url: &str) -> Result<(), WispError> { let (mux_replace, fut) = make_mux(url).await?; let mut mux_write = mux.write().await; mux_write.close().await?; diff --git a/client/src/websocket.rs b/client/src/websocket.rs index fff1f44..414e53e 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -106,7 +106,7 @@ impl EpxWebSocket { break; } // ping/pong/continue - _ => {}, + _ => {} } } }); @@ -115,7 +115,13 @@ impl EpxWebSocket { .call0(&Object::default()) .replace_err("Failed to call onopen")?; - Ok(Self { tx, onerror, origin, protocols, url: url.to_string() }) + Ok(Self { + tx, + onerror, + origin, + protocols, + url: url.to_string(), + }) } .await; if let Err(ret) = ret { diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 9b16525..e67779e 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -53,7 +53,7 @@ impl Stream for IncomingBody { } #[derive(Clone)] -pub struct ServiceWrapper(pub Arc>>, pub String); +pub struct ServiceWrapper(pub Arc>, pub String); impl tower_service::Service for ServiceWrapper { type Response = TokioIo; @@ -69,7 +69,7 @@ impl tower_service::Service for ServiceWrapper { let mux_url = self.1.clone(); async move { let stream = mux - .read() + .write() .await .client_new_stream( StreamType::Tcp, @@ -193,11 +193,9 @@ pub struct WebSocketReader { close_event: Arc, } +#[async_trait::async_trait] impl WebSocketRead for WebSocketReader { - async fn wisp_read_frame( - &mut self, - _: &LockedWebSocketWrite, - ) -> Result { + async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result { use WebSocketMessage::*; if self.closed.load(Ordering::Acquire) { return Err(WispError::WsImplSocketClosed); @@ -306,6 +304,7 @@ impl WebSocketWrapper { } } +#[async_trait::async_trait] impl WebSocketWrite for WebSocketWrapper { async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { use wisp_mux::ws::OpCode::*; diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..c3c8c37 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +imports_granularity = "Crate" diff --git a/server/src/main.rs b/server/src/main.rs index 0b1f6f3..61561b7 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -4,13 +4,12 @@ use std::io::Error; use bytes::Bytes; use clap::Parser; use fastwebsockets::{ - upgrade::{self, UpgradeFut}, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, - WebSocketError, + upgrade::{self, UpgradeFut}, + CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, }; use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, Request, Response, - StatusCode, + body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; @@ -20,7 +19,10 @@ use tokio_util::codec::{BytesCodec, Framed}; #[cfg(unix)] use tokio_util::either::Either; -use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError}; +use wisp_mux::{ + extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, +}; type HttpBody = http_body_util::Full; @@ -261,7 +263,14 @@ async fn accept_ws( println!("{:?}: connected", addr); - let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX); + let (mut mux, fut) = ServerMux::new( + rx, + tx, + u32::MAX, + Some(vec![UdpProtocolExtension().into()]), + Some(&[&UdpProtocolExtensionBuilder()]), + ) + .await?; tokio::spawn(async move { if let Err(e) = fut.await { diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index e626b80..4dd329a 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -11,7 +11,14 @@ use hyper::{ }; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ - error::Error, future::Future, io::{stdout, IsTerminal, Write}, net::SocketAddr, process::exit, sync::Arc, time::{Duration, Instant}, usize + error::Error, + future::Future, + io::{stdout, IsTerminal, Write}, + net::SocketAddr, + process::exit, + sync::Arc, + time::{Duration, Instant}, + usize, }; use tokio::{ net::TcpStream, @@ -21,7 +28,10 @@ use tokio::{ }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; -use wisp_mux::{ClientMux, StreamType, WispError}; +use wisp_mux::{ + extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + ClientMux, StreamType, WispError, +}; #[derive(Debug)] enum WispClientError { @@ -71,6 +81,9 @@ struct Cli { /// Duration to run the test for #[arg(short, long)] duration: Option, + /// Ask for UDP + #[arg(short, long)] + udp: bool, } #[tokio::main(flavor = "multi_thread")] @@ -117,7 +130,6 @@ async fn main() -> Result<(), Box> { fastwebsockets::handshake::generate_key(), ) .header("Sec-WebSocket-Version", "13") - .header("Sec-WebSocket-Protocol", "wisp-v1") .body(Empty::::new())?; let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; @@ -125,7 +137,18 @@ async fn main() -> Result<(), Box> { let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); - let (mux, fut) = ClientMux::new(rx, tx).await?; + let (mut mux, fut) = if opts.udp { + ClientMux::new( + rx, + tx, + Some(vec![UdpProtocolExtension().into()]), + Some(&[&UdpProtocolExtensionBuilder()]), + ) + .await? + } else { + ClientMux::new(rx, tx, Some(vec![]), Some(&[])).await? + }; + let mut threads = Vec::with_capacity(opts.streams * 2 + 3); threads.push(tokio::spawn(fut)); diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 473640d..8cf2cba 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "3.0.0" +version = "4.0.0" license = "LGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -9,12 +9,14 @@ readme = "README.md" edition = "2021" [dependencies] +async-trait = "0.1.79" async_io_stream = "0.3.3" bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } futures = "0.3.30" +futures-timer = "3.0.3" futures-util = "0.3.30" pin-project-lite = "0.2.13" tokio = { version = "1.35.1", optional = true, default-features = false } @@ -22,6 +24,7 @@ tokio = { version = "1.35.1", optional = true, default-features = false } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] tokio_io = ["async_io_stream/tokio_io"] +wasm = ["futures-timer/wasm-bindgen"] [package.metadata.docs.rs] all-features = true diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs new file mode 100644 index 0000000..9358c4a --- /dev/null +++ b/wisp/src/extensions.rs @@ -0,0 +1,190 @@ +//! Wisp protocol extensions. + +use std::ops::{Deref, DerefMut}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +/// Type-erased protocol extension that implements Clone. +#[derive(Debug)] +pub struct AnyProtocolExtension(Box); + +impl AnyProtocolExtension { + /// Create a new type-erased protocol extension. + pub fn new(extension: T) -> Self { + Self(Box::new(extension)) + } +} + +impl Deref for AnyProtocolExtension { + type Target = dyn ProtocolExtension; + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl DerefMut for AnyProtocolExtension { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } +} + +impl Clone for AnyProtocolExtension { + fn clone(&self) -> Self { + Self(self.0.box_clone()) + } +} + +impl From for Bytes { + fn from(value: AnyProtocolExtension) -> Self { + let mut bytes = BytesMut::with_capacity(5); + let payload = value.encode(); + bytes.put_u8(value.get_id()); + bytes.put_u32_le(payload.len() as u32); + bytes.extend(payload); + bytes.freeze() + } +} + +/// A Wisp protocol extension. +/// +/// See [the +/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). +#[async_trait] +pub trait ProtocolExtension: std::fmt::Debug { + /// Get the protocol extension ID. + fn get_id(&self) -> u8; + /// Get the protocol extension's supported packets. + /// + /// Used to decide whether to call the protocol extension's packet handler. + fn get_supported_packets(&self) -> &'static [u8]; + + /// Encode self into Bytes. + fn encode(&self) -> Bytes; + + /// Handle the handshake part of a Wisp connection. + /// + /// This should be used to send or receive data before any streams are created. + async fn handle_handshake( + &mut self, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + packet: Bytes, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Clone the protocol extension. + fn box_clone(&self) -> Box; +} + +/// Trait to build a Wisp protocol extension for the client. +pub trait ProtocolExtensionBuilder { + /// Get the protocol extension ID. + /// + /// Used to decide whether this builder should be used. + fn get_id(&self) -> u8; + + /// Build a protocol extension from the extension's metadata. + fn build(&self, bytes: Bytes, role: Role) -> AnyProtocolExtension; +} + +pub mod udp { + //! UDP protocol extension. + //! + //! # Example + //! ``` + //! let (mux, fut) = ServerMux::new( + //! rx, + //! tx, + //! 128, + //! Some(vec![UdpProtocolExtension().into()]), + //! Some(&[&UdpProtocolExtensionBuilder()]) + //! ); + //! ``` + //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) + use async_trait::async_trait; + use bytes::Bytes; + + use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + WispError, + }; + + use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + + #[derive(Debug)] + /// UDP protocol extension. + pub struct UdpProtocolExtension(); + + impl UdpProtocolExtension { + /// UDP protocol extension ID. + pub const ID: u8 = 0x01; + } + + #[async_trait] + impl ProtocolExtension for UdpProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + Bytes::new() + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(Self()) + } + } + + impl From for AnyProtocolExtension { + fn from(value: UdpProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } + } + + /// UDP protocol extension builder. + pub struct UdpProtocolExtensionBuilder(); + + impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + 0x01 + } + + fn build(&self, _: Bytes, _: crate::Role) -> AnyProtocolExtension { + AnyProtocolExtension(Box::new(UdpProtocolExtension())) + } + } +} diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 7a66908..548649f 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -1,9 +1,12 @@ +use async_trait::async_trait; use bytes::Bytes; use fastwebsockets::{ FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, }; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::{ws::LockedWebSocketWrite, WispError}; + impl From for crate::ws::OpCode { fn from(opcode: OpCode) -> Self { use OpCode::*; @@ -58,11 +61,12 @@ impl From for crate::WispError { } } -impl crate::ws::WebSocketRead for FragmentCollectorRead { +#[async_trait] +impl crate::ws::WebSocketRead for FragmentCollectorRead { async fn wisp_read_frame( &mut self, - tx: &crate::ws::LockedWebSocketWrite, - ) -> Result { + tx: &LockedWebSocketWrite, + ) -> Result { Ok(self .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) .await? @@ -70,8 +74,9 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead } } -impl crate::ws::WebSocketWrite for WebSocketWrite { - async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { +#[async_trait] +impl crate::ws::WebSocketWrite for WebSocketWrite { + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { self.write_frame(frame.into()).await.map_err(|e| e.into()) } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 152be13..076e10c 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -4,6 +4,7 @@ //! //! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol +pub mod extensions; #[cfg(feature = "fastwebsockets")] #[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))] mod fastwebsockets; @@ -12,18 +13,28 @@ mod sink_unfold; mod stream; pub mod ws; -pub use crate::packet::*; -pub use crate::stream::*; +pub use crate::{packet::*, stream::*}; use bytes::Bytes; use dashmap::DashMap; use event_listener::Event; -use futures::SinkExt; -use futures::{channel::mpsc, Future, FutureExt, StreamExt}; -use std::sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, +use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; +use futures::{ + channel::{mpsc, oneshot}, + select, Future, FutureExt, SinkExt, StreamExt, }; +use futures_timer::Delay; +use std::{ + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; +use ws::AppendingWebSocketRead; + +/// Wisp version supported by this crate. +pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; /// The role of the multiplexor. #[derive(Debug, PartialEq, Copy, Clone)] @@ -37,9 +48,9 @@ pub enum Role { /// Errors the Wisp implementation can return. #[derive(Debug)] pub enum WispError { - /// The packet recieved did not have enough data. + /// The packet received did not have enough data. PacketTooSmall, - /// The packet recieved had an invalid type. + /// The packet received had an invalid type. InvalidPacketType, /// The stream had an invalid type. InvalidStreamType, @@ -47,19 +58,19 @@ pub enum WispError { InvalidStreamId, /// The close packet had an invalid reason. InvalidCloseReason, - /// The URI recieved was invalid. + /// The URI received was invalid. InvalidUri, - /// The URI recieved had no host. + /// The URI received had no host. UriHasNoHost, - /// The URI recieved had no port. + /// The URI received had no port. UriHasNoPort, /// The max stream count was reached. MaxStreamCountReached, /// The stream had already been closed. StreamAlreadyClosed, - /// The websocket frame recieved had an invalid type. + /// The websocket frame received had an invalid type. WsFrameInvalidType, - /// The websocket frame recieved was not finished. + /// The websocket frame received was not finished. WsFrameNotFinished, /// Error specific to the websocket implementation. WsImplError(Box), @@ -67,17 +78,33 @@ pub enum WispError { WsImplSocketClosed, /// The websocket implementation did not support the action. WsImplNotSupported, + /// Error specific to the protocol extension implementation. + ExtensionImplError(Box), + /// The protocol extension implementation did not support the action. + ExtensionImplNotSupported, + /// The UDP protocol extension is not supported by the server. + UdpExtensionNotSupported, /// The string was invalid UTF-8. Utf8Error(std::str::Utf8Error), + /// The integer failed to convert. + TryFromIntError(std::num::TryFromIntError), /// Other error. Other(Box), /// Failed to send message to multiplexor task. MuxMessageFailedToSend, + /// Failed to receive message from multiplexor task. + MuxMessageFailedToRecv, } impl From for WispError { - fn from(err: std::str::Utf8Error) -> WispError { - WispError::Utf8Error(err) + fn from(err: std::str::Utf8Error) -> Self { + Self::Utf8Error(err) + } +} + +impl From for WispError { + fn from(value: std::num::TryFromIntError) -> Self { + Self::TryFromIntError(value) } } @@ -103,9 +130,21 @@ impl std::fmt::Display for WispError { Self::WsImplNotSupported => { write!(f, "Websocket implementation error: unsupported feature") } + Self::ExtensionImplError(err) => { + write!(f, "Protocol extension implementation error: {}", err) + } + Self::ExtensionImplNotSupported => { + write!( + f, + "Protocol extension implementation error: unsupported feature" + ) + } + Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"), Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), + Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err), Self::Other(err) => write!(f, "Other error: {}", err), Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), + Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), } } } @@ -120,29 +159,27 @@ struct MuxMapValue { is_closed: Arc, } -struct MuxInner -where - W: ws::WebSocketWrite, -{ - tx: ws::LockedWebSocketWrite, - stream_map: Arc>, +struct MuxInner { + tx: ws::LockedWebSocketWrite, + stream_map: DashMap, + buffer_size: u32, } -impl MuxInner { +impl MuxInner { pub async fn server_into_future( self, rx: R, close_rx: mpsc::Receiver, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, - buffer_size: u32, close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { - self.into_future( + self.as_future( close_rx, - self.server_loop(rx, muxstream_sender, buffer_size, close_tx), + close_tx.clone(), + self.server_loop(rx, muxstream_sender, close_tx), ) .await } @@ -151,20 +188,23 @@ impl MuxInner { self, rx: R, close_rx: mpsc::Receiver, + close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { - self.into_future(close_rx, self.client_loop(rx)).await + self.as_future(close_rx, close_tx, self.client_loop(rx)) + .await } - async fn into_future( + async fn as_future( &self, close_rx: mpsc::Receiver, + close_tx: mpsc::Sender, wisp_fut: impl Future>, ) -> Result<(), WispError> { let ret = futures::select! { - _ = self.stream_loop(close_rx).fuse() => Ok(()), + _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), x = wisp_fut.fuse() => x, }; self.stream_map.iter_mut().for_each(|mut x| { @@ -176,7 +216,12 @@ impl MuxInner { ret } - async fn stream_loop(&self, mut stream_rx: mpsc::Receiver) { + async fn stream_loop( + &self, + mut stream_rx: mpsc::Receiver, + stream_tx: mpsc::Sender, + ) { + let mut next_free_stream_id: u32 = 1; while let Some(msg) = stream_rx.next().await { match msg { WsEvent::SendPacket(packet, channel) => { @@ -186,6 +231,53 @@ impl MuxInner { let _ = channel.send(Err(WispError::InvalidStreamId)); } } + WsEvent::CreateStream(stream_type, host, port, channel) => { + let ret: Result = async { + let (ch_tx, ch_rx) = mpsc::unbounded(); + let stream_id = next_free_stream_id; + let next_stream_id = next_free_stream_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?; + + let flow_control_event: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + + let is_closed: Arc = AtomicBool::new(false).into(); + + self.tx + .write_frame( + Packet::new_connect(stream_id, stream_type, port, host).into(), + ) + .await?; + + next_free_stream_id = next_stream_id; + + self.stream_map.insert( + stream_id, + MuxMapValue { + stream: ch_tx, + stream_type, + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + }, + ); + + Ok(MuxStream::new( + stream_id, + Role::Client, + stream_type, + ch_rx, + stream_tx.clone(), + is_closed, + flow_control, + flow_control_event, + 0, + )) + } + .await; + let _ = channel.send(ret); + } WsEvent::Close(packet, channel) => { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.stream.disconnect(); @@ -204,17 +296,13 @@ impl MuxInner { &self, mut rx: R, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, - buffer_size: u32, close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { // will send continues once flow_control is at 10% of max - let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32; - self.tx - .write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; + let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32; loop { let frame = rx.wisp_read_frame(&self.tx).await?; @@ -228,7 +316,7 @@ impl MuxInner { Connect(inner_packet) => { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_type = inner_packet.stream_type; - let flow_control: Arc = AtomicU32::new(buffer_size).into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); let flow_control_event: Arc = Event::new().into(); let is_closed: Arc = AtomicBool::new(false).into(); @@ -273,7 +361,7 @@ impl MuxInner { } } } - Continue(_) => break Err(WispError::InvalidPacketType), + Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Close(_) => { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); @@ -298,7 +386,7 @@ impl MuxInner { use PacketType::*; match packet.packet_type { - Connect(_) => break Err(WispError::InvalidPacketType), + Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { let _ = stream.stream.unbounded_send(data); @@ -332,7 +420,7 @@ impl MuxInner { /// ``` /// use wisp_mux::ServerMux; /// -/// let (mux, fut) = ServerMux::new(rx, tx, 128); +/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some(vec![]), Some([])); /// tokio::spawn(async move { /// if let Err(e) = fut.await { /// println!("error in multiplexor: {:?}", e); @@ -346,34 +434,89 @@ impl MuxInner { /// } /// ``` pub struct ServerMux { + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extensions: Arc<[AnyProtocolExtension]>, close_tx: mpsc::Sender, muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, } impl ServerMux { /// Create a new server-side multiplexor. - pub fn new( - read: R, + /// + /// If either extensions or extension_builders are None a Wisp v1 connection is created + /// otherwise a Wisp v2 connection is created. + pub async fn new( + mut read: R, write: W, buffer_size: u32, - ) -> (Self, impl Future>) + extensions: Option>, + extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, + ) -> Result<(Self, impl Future> + Send), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, + W: ws::WebSocketWrite + Send + 'static, { let (close_tx, close_rx) = mpsc::channel::(256); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); - let write = ws::LockedWebSocketWrite::new(write); - ( + let write = ws::LockedWebSocketWrite::new(Box::new(write)); + + write + .write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; + + let mut supported_extensions = Vec::new(); + let mut extra_packet = Vec::with_capacity(1); + let mut downgraded = true; + + if let Some(extensions) = extensions { + if let Some(builders) = extension_builders { + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + // TODO change this to correct timeout once draft 2 is out + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.push(packet.into()); + } + } + } + } + + Ok(( Self { muxstream_recv: rx, close_tx: close_tx.clone(), + downgraded, + supported_extensions: supported_extensions.into(), }, MuxInner { tx: write, - stream_map: DashMap::new().into(), + stream_map: DashMap::new(), + buffer_size, } - .server_into_future(read, close_rx, tx, buffer_size, close_tx), - ) + .server_into_future( + AppendingWebSocketRead(extra_packet, read), + close_rx, + tx, + close_tx, + ), + )) } /// Wait for a stream to be created. @@ -398,7 +541,7 @@ impl ServerMux { /// ``` /// use wisp_mux::{ClientMux, StreamType}; /// -/// let (mux, fut) = ClientMux::new(rx, tx).await?; +/// let (mux, fut) = ClientMux::new(rx, tx, Some(vec![]), []).await?; /// tokio::spawn(async move { /// if let Err(e) = fut.await { /// println!("error in multiplexor: {:?}", e); @@ -406,50 +549,88 @@ impl ServerMux { /// }); /// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80); /// ``` -pub struct ClientMux -where - W: ws::WebSocketWrite, -{ - tx: ws::LockedWebSocketWrite, - stream_map: Arc>, - next_free_stream_id: AtomicU32, +pub struct ClientMux { + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extensions: Arc<[AnyProtocolExtension]>, close_tx: mpsc::Sender, - buf_size: u32, - target_buf_size: u32, } -impl ClientMux { +impl ClientMux { /// Create a new client side multiplexor. - pub async fn new( + /// + /// If either extensions or extension_builders are None a Wisp v1 connection is created + /// otherwise a Wisp v2 connection is created. + pub async fn new( mut read: R, write: W, - ) -> Result<(Self, impl Future>), WispError> + extensions: Option>, + extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, + ) -> Result<(Self, impl Future> + Send), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, + W: ws::WebSocketWrite + Send + 'static, { - let write = ws::LockedWebSocketWrite::new(write); + let write = ws::LockedWebSocketWrite::new(Box::new(write)); let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; if first_packet.stream_id != 0 { return Err(WispError::InvalidStreamId); } if let PacketType::Continue(packet) = first_packet.packet_type { + let mut supported_extensions = Vec::new(); + let mut extra_packet = Vec::with_capacity(1); + let mut downgraded = true; + + if let Some(extensions) = extensions { + if let Some(builders) = extension_builders { + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + // TODO change this to correct timeout once draft 2 is out + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + downgraded = false; + } else { + extra_packet.push(packet.into()); + } + } + } + } + + for extension in supported_extensions.iter_mut() { + extension.handle_handshake(&mut read, &write).await?; + } + let (tx, rx) = mpsc::channel::(256); - let map = Arc::new(DashMap::new()); Ok(( Self { - tx: write.clone(), - stream_map: map.clone(), - next_free_stream_id: AtomicU32::new(1), close_tx: tx.clone(), - buf_size: packet.buffer_remaining, - // server-only - target_buf_size: 0, + downgraded, + supported_extensions: supported_extensions.into(), }, MuxInner { - tx: write.clone(), - stream_map: map.clone(), + tx: write, + stream_map: DashMap::new(), + buffer_size: packet.buffer_remaining, } - .client_into_future(read, rx), + .client_into_future( + AppendingWebSocketRead(extra_packet, read), + rx, + tx, + ), )) } else { Err(WispError::InvalidPacketType) @@ -458,51 +639,25 @@ impl ClientMux { /// Create a new stream, multiplexed through Wisp. pub async fn client_new_stream( - &self, + &mut self, stream_type: StreamType, host: String, port: u16, ) -> Result { - let (ch_tx, ch_rx) = mpsc::unbounded(); - let stream_id = self.next_free_stream_id.load(Ordering::Acquire); - let next_stream_id = stream_id - .checked_add(1) - .ok_or(WispError::MaxStreamCountReached)?; - - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buf_size).into(); - - let is_closed: Arc = AtomicBool::new(false).into(); - - self.tx - .write_frame(Packet::new_connect(stream_id, stream_type, port, host).into()) - .await?; - - self.next_free_stream_id - .store(next_stream_id, Ordering::Release); - - self.stream_map.insert( - stream_id, - MuxMapValue { - stream: ch_tx, - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); - - Ok(MuxStream::new( - stream_id, - Role::Client, - stream_type, - ch_rx, - self.close_tx.clone(), - is_closed, - flow_control, - flow_control_event, - self.target_buf_size, - )) + if stream_type == StreamType::Udp + && !self + .supported_extensions + .iter() + .any(|x| x.get_id() == UdpProtocolExtension::ID) + { + return Err(WispError::UdpExtensionNotSupported); + } + let (tx, rx) = oneshot::channel(); + self.close_tx + .send(WsEvent::CreateStream(stream_type, host, port, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } /// Close all streams. diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index d3fb8c7..c2b2459 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,4 +1,8 @@ -use crate::{ws, WispError}; +use crate::{ + extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, + ws::{self, Frame, OpCode}, + Role, WispError, WISP_VERSION, +}; use bytes::{Buf, BufMut, Bytes, BytesMut}; /// Wisp stream type. @@ -34,6 +38,8 @@ pub enum CloseReason { Voluntary = 0x02, /// Unexpected stream closure due to a network error. Unexpected = 0x03, + /// Incompatible extensions. Only used during the handshake. + IncompatibleExtensions = 0x04, /// Stream creation failed due to invalid information. ServerStreamInvalidInfo = 0x41, /// Stream creation failed due to an unreachable destination host. @@ -55,19 +61,20 @@ pub enum CloseReason { impl TryFrom for CloseReason { type Error = WispError; fn try_from(stream_type: u8) -> Result { - use CloseReason::*; + use CloseReason as R; match stream_type { - 0x01 => Ok(Unknown), - 0x02 => Ok(Voluntary), - 0x03 => Ok(Unexpected), - 0x41 => Ok(ServerStreamInvalidInfo), - 0x42 => Ok(ServerStreamUnreachable), - 0x43 => Ok(ServerStreamConnectionTimedOut), - 0x44 => Ok(ServerStreamConnectionRefused), - 0x47 => Ok(ServerStreamTimedOut), - 0x48 => Ok(ServerStreamBlockedAddress), - 0x49 => Ok(ServerStreamThrottled), - 0x81 => Ok(ClientUnexpected), + 0x01 => Ok(R::Unknown), + 0x02 => Ok(R::Voluntary), + 0x03 => Ok(R::Unexpected), + 0x04 => Ok(R::IncompatibleExtensions), + 0x41 => Ok(R::ServerStreamInvalidInfo), + 0x42 => Ok(R::ServerStreamUnreachable), + 0x43 => Ok(R::ServerStreamConnectionTimedOut), + 0x44 => Ok(R::ServerStreamConnectionRefused), + 0x47 => Ok(R::ServerStreamTimedOut), + 0x48 => Ok(R::ServerStreamBlockedAddress), + 0x49 => Ok(R::ServerStreamThrottled), + 0x81 => Ok(R::ClientUnexpected), _ => Err(Self::Error::InvalidStreamType), } } @@ -198,6 +205,38 @@ impl From for Bytes { } } +/// Wisp version sent in the handshake. +#[derive(Debug, Clone)] +pub struct WispVersion { + /// Major Wisp version according to semver. + pub major: u8, + /// Minor Wisp version according to semver. + pub minor: u8, +} + +/// Packet used in the initial handshake. +/// +/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) +#[derive(Debug, Clone)] +pub struct InfoPacket { + /// Wisp version sent in the packet. + pub version: WispVersion, + /// List of protocol extensions sent in the packet. + pub extensions: Vec, +} + +impl From for Bytes { + fn from(value: InfoPacket) -> Self { + let mut bytes = BytesMut::with_capacity(2); + bytes.put_u8(value.version.major); + bytes.put_u8(value.version.minor); + for extension in value.extensions { + bytes.extend(Bytes::from(extension)); + } + bytes.freeze() + } +} + #[derive(Debug, Clone)] /// Type of packet recieved. pub enum PacketType { @@ -209,6 +248,8 @@ pub enum PacketType { Continue(ContinuePacket), /// Close packet. Close(ClosePacket), + /// Info packet. + Info(InfoPacket), } impl PacketType { @@ -220,6 +261,7 @@ impl PacketType { Data(_) => 0x02, Continue(_) => 0x03, Close(_) => 0x04, + Info(_) => 0x05, } } } @@ -232,6 +274,7 @@ impl From for Bytes { Data(x) => x, Continue(x) => x.into(), Close(x) => x.into(), + Info(x) => x.into(), } } } @@ -296,6 +339,97 @@ impl Packet { packet_type: PacketType::Close(ClosePacket::new(reason)), } } + + pub(crate) fn new_info(extensions: Vec) -> Self { + Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version: WISP_VERSION, + extensions, + }), + } + } + + fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { + use PacketType::*; + Ok(Self { + stream_id: bytes.get_u32_le(), + packet_type: match packet_type { + 0x01 => Connect(ConnectPacket::try_from(bytes)?), + 0x02 => Data(bytes), + 0x03 => Continue(ContinuePacket::try_from(bytes)?), + 0x04 => Close(ClosePacket::try_from(bytes)?), + // 0x05 is handled seperately + _ => return Err(WispError::InvalidPacketType), + }, + }) + } + + pub(crate) fn maybe_parse_info( + frame: Frame, + role: Role, + extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], + ) -> Result { + if !frame.finished { + return Err(WispError::WsFrameNotFinished); + } + if frame.opcode != OpCode::Binary { + return Err(WispError::WsFrameInvalidType); + } + let mut bytes = frame.payload; + if bytes.remaining() < 1 { + return Err(WispError::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + if packet_type == 0x05 { + Self::parse_info(bytes, role, extension_builders) + } else { + Self::parse_packet(packet_type, bytes) + } + } + + fn parse_info( + mut bytes: Bytes, + role: Role, + extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], + ) -> Result { + // packet type is already read by code that calls this + if bytes.remaining() < 4 + 2 { + return Err(WispError::PacketTooSmall); + } + if bytes.get_u32_le() != 0 { + return Err(WispError::InvalidStreamId); + } + + let version = WispVersion { + major: bytes.get_u8(), + minor: bytes.get_u8(), + }; + + let mut extensions = Vec::new(); + + while bytes.remaining() > 4 { + // We have some extensions + let id = bytes.get_u8(); + let length = usize::try_from(bytes.get_u32_le())?; + if bytes.remaining() < length { + return Err(WispError::PacketTooSmall); + } + if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { + extensions.push(builder.build(bytes.copy_to_bytes(length), role)) + } else { + bytes.advance(length) + } + } + + Ok(Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version, + extensions, + }), + }) + } } impl TryFrom for Packet { @@ -305,17 +439,7 @@ impl TryFrom for Packet { return Err(Self::Error::PacketTooSmall); } let packet_type = bytes.get_u8(); - use PacketType::*; - Ok(Self { - stream_id: bytes.get_u32_le(), - packet_type: match packet_type { - 0x01 => Connect(ConnectPacket::try_from(bytes)?), - 0x02 => Data(bytes), - 0x03 => Continue(ContinuePacket::try_from(bytes)?), - 0x04 => Close(ClosePacket::try_from(bytes)?), - _ => return Err(Self::Error::InvalidPacketType), - }, - }) + Self::parse_packet(packet_type, bytes) } } diff --git a/wisp/src/sink_unfold.rs b/wisp/src/sink_unfold.rs index c82254a..dfb170e 100644 --- a/wisp/src/sink_unfold.rs +++ b/wisp/src/sink_unfold.rs @@ -1,8 +1,10 @@ //! futures sink unfold with a close function use core::{future::Future, pin::Pin}; -use futures::ready; -use futures::task::{Context, Poll}; -use futures::Sink; +use futures::{ + ready, + task::{Context, Poll}, + Sink, +}; use pin_project_lite::pin_project; pin_project! { diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 109b9ab..f579140 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -21,6 +21,12 @@ use std::{ pub(crate) enum WsEvent { SendPacket(Packet, oneshot::Sender>), Close(Packet, oneshot::Sender>), + CreateStream( + StreamType, + String, + u16, + oneshot::Sender>, + ), EndFut, } @@ -317,7 +323,10 @@ impl MuxStreamIo { impl Stream for MuxStreamIo { type Item = Result, std::io::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec()))) + self.project() + .rx + .poll_next(cx) + .map(|x| x.map(|x| Ok(x.to_vec()))) } } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index af57572..7348bb8 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -4,9 +4,10 @@ //! for other WebSocket implementations. //! //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs +use crate::WispError; +use async_trait::async_trait; use bytes::Bytes; use futures::lock::Mutex; -use std::sync::Arc; /// Opcode of the WebSocket frame. #[derive(Debug, PartialEq, Clone, Copy)] @@ -64,30 +65,26 @@ impl Frame { } /// Generic WebSocket read trait. +#[async_trait] pub trait WebSocketRead { /// Read a frame from the socket. - fn wisp_read_frame( - &mut self, - tx: &crate::ws::LockedWebSocketWrite, - ) -> impl std::future::Future>; + async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result; } /// Generic WebSocket write trait. +#[async_trait] pub trait WebSocketWrite { /// Write a frame to the socket. - fn wisp_write_frame( - &mut self, - frame: Frame, - ) -> impl std::future::Future>; + async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; } -/// Locked WebSocket that can be shared between threads. -pub struct LockedWebSocketWrite(Arc>); +/// Locked WebSocket. +pub struct LockedWebSocketWrite(Mutex>); -impl LockedWebSocketWrite { +impl LockedWebSocketWrite { /// Create a new locked websocket. - pub fn new(ws: S) -> Self { - Self(Arc::new(Mutex::new(ws))) + pub fn new(ws: Box) -> Self { + Self(Mutex::new(ws)) } /// Write a frame to the websocket. @@ -96,8 +93,19 @@ impl LockedWebSocketWrite { } } -impl Clone for LockedWebSocketWrite { - fn clone(&self) -> Self { - Self(self.0.clone()) +pub(crate) struct AppendingWebSocketRead(pub Vec, pub R) +where + R: WebSocketRead + Send; + +#[async_trait] +impl WebSocketRead for AppendingWebSocketRead +where + R: WebSocketRead + Send, +{ + async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result { + if let Some(x) = self.0.pop() { + return Ok(x); + } + return self.1.wisp_read_frame(tx).await; } }