From d6353bd5a9deafaa6d1a9f5e97e93fd9b03a054f Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Wed, 17 Jul 2024 16:23:58 -0700 Subject: [PATCH] add a new Payload struct to allow for one-copy writes and cargo fmt --- certs-grabber/src/main.rs | 26 +- client/src/io_stream.rs | 286 +++--- client/src/lib.rs | 831 +++++++++--------- client/src/stream_provider.rs | 6 +- client/src/tokioio.rs | 242 +++--- client/src/websocket.rs | 345 ++++---- client/src/ws_wrapper.rs | 335 +++---- server/src/main.rs | 932 ++++++++++---------- simple-wisp-client/src/main.rs | 395 ++++----- wisp/src/extensions/mod.rs | 118 +-- wisp/src/extensions/password.rs | 346 ++++---- wisp/src/extensions/udp.rs | 102 +-- wisp/src/fastwebsockets.rs | 133 +-- wisp/src/lib.rs | 1438 ++++++++++++++++--------------- wisp/src/packet.rs | 852 +++++++++--------- wisp/src/sink_unfold.rs | 232 ++--- wisp/src/stream.rs | 79 +- wisp/src/ws.rs | 230 +++-- 18 files changed, 3533 insertions(+), 3395 deletions(-) diff --git a/certs-grabber/src/main.rs b/certs-grabber/src/main.rs index 8722ed3..b97f76b 100644 --- a/certs-grabber/src/main.rs +++ b/certs-grabber/src/main.rs @@ -3,15 +3,15 @@ use std::fmt::Write; use rustls_pki_types::TrustAnchor; fn main() { - let mut code = String::with_capacity(256 * 1_024); - code.push_str("const ROOTS = ["); - for anchor in webpki_roots::TLS_SERVER_ROOTS { - let TrustAnchor { - subject, - subject_public_key_info, - name_constraints, - } = anchor; - code.write_fmt(format_args!( + let mut code = String::with_capacity(256 * 1_024); + code.push_str("const ROOTS = ["); + for anchor in webpki_roots::TLS_SERVER_ROOTS { + let TrustAnchor { + subject, + subject_public_key_info, + name_constraints, + } = anchor; + code.write_fmt(format_args!( "{{subject:new Uint8Array([{}]),subject_public_key_info:new Uint8Array([{}]),name_constraints:{}}},", subject .as_ref() @@ -34,8 +34,8 @@ fn main() { } )) .unwrap(); - } - code.pop(); - code.push_str("];"); - println!("{}", code); + } + code.pop(); + code.push_str("];"); + println!("{}", code); } diff --git a/client/src/io_stream.rs b/client/src/io_stream.rs index 70c6a39..86d7be6 100644 --- a/client/src/io_stream.rs +++ b/client/src/io_stream.rs @@ -1,183 +1,181 @@ -use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; -use futures_util::{ - io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt, -}; +use bytes::{buf::UninitSlice, BufMut, BytesMut}; +use futures_util::{io::WriteHalf, lock::Mutex, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt}; use js_sys::{Function, Uint8Array}; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::spawn_local; +use wisp_mux::MuxStreamIoSink; use crate::{ - stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream}, - utils::convert_body, - EpoxyError, EpoxyHandlers, + stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream}, + utils::convert_body, + EpoxyError, EpoxyHandlers, }; #[wasm_bindgen] pub struct EpoxyIoStream { - tx: Mutex>, - onerror: Function, + tx: Mutex>, + onerror: Function, } #[wasm_bindgen] impl EpoxyIoStream { - pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self { - let (mut rx, tx) = stream.split(); - let tx = Mutex::new(tx); + pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self { + let (mut rx, tx) = stream.split(); + let tx = Mutex::new(tx); - let EpoxyHandlers { - onopen, - onclose, - onerror, - onmessage, - } = handlers; + let EpoxyHandlers { + onopen, + onclose, + onerror, + onmessage, + } = handlers; - let onerror_cloned = onerror.clone(); + let onerror_cloned = onerror.clone(); - // similar to tokio_util::io::ReaderStream - spawn_local(async move { - let mut buf = BytesMut::with_capacity(4096); - loop { - match rx - .read(unsafe { - std::mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut()) - }) - .await - { - Ok(cnt) => { - if cnt > 0 { - unsafe { buf.advance_mut(cnt) }; + // similar to tokio_util::io::ReaderStream + spawn_local(async move { + let mut buf = BytesMut::with_capacity(4096); + loop { + match rx + .read(unsafe { + std::mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut()) + }) + .await + { + Ok(cnt) => { + if cnt > 0 { + unsafe { buf.advance_mut(cnt) }; - let _ = onmessage - .call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref())); - } - } - Err(err) => { - let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); - break; - } - } - } - let _ = onclose.call0(&JsValue::null()); - }); + let _ = onmessage + .call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref())); + } + } + Err(err) => { + let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); + break; + } + } + } + let _ = onclose.call0(&JsValue::null()); + }); - let _ = onopen.call0(&JsValue::null()); + let _ = onopen.call0(&JsValue::null()); - Self { - tx, - onerror: onerror_cloned, - } - } + Self { + tx, + onerror: onerror_cloned, + } + } - pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { - let ret: Result<(), EpoxyError> = async move { - let payload = convert_body(payload) - .await - .map_err(|_| EpoxyError::InvalidPayload)? - .0 - .to_vec(); - Ok(self.tx.lock().await.write_all(&payload).await?) - } - .await; + pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { + let ret: Result<(), EpoxyError> = async move { + let payload = convert_body(payload) + .await + .map_err(|_| EpoxyError::InvalidPayload)? + .0 + .to_vec(); + Ok(self.tx.lock().await.write_all(&payload).await?) + } + .await; - match ret { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = self - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + match ret { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = self + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - pub async fn close(&self) -> Result<(), EpoxyError> { - match self.tx.lock().await.close().await { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = self - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err.into()) - } - } - } + pub async fn close(&self) -> Result<(), EpoxyError> { + match self.tx.lock().await.close().await { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = self + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err.into()) + } + } + } } #[wasm_bindgen] pub struct EpoxyUdpStream { - tx: Mutex>, - onerror: Function, + tx: Mutex, + onerror: Function, } #[wasm_bindgen] impl EpoxyUdpStream { - pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { - let (tx, mut rx) = stream.split(); - let tx = Mutex::new(tx); + pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { + let (mut rx, tx) = stream.into_split(); - let EpoxyHandlers { - onopen, - onclose, - onerror, - onmessage, - } = handlers; + let EpoxyHandlers { + onopen, + onclose, + onerror, + onmessage, + } = handlers; - let onerror_cloned = onerror.clone(); + let onerror_cloned = onerror.clone(); - spawn_local(async move { - while let Some(packet) = rx.next().await { - match packet { - Ok(buf) => { - let _ = onmessage.call1(&JsValue::null(), &Uint8Array::from(buf.as_ref())); - } - Err(err) => { - let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); - break; - } - } - } - let _ = onclose.call0(&JsValue::null()); - }); + spawn_local(async move { + while let Some(packet) = rx.next().await { + match packet { + Ok(buf) => { + let _ = onmessage.call1(&JsValue::null(), &Uint8Array::from(buf.as_ref())); + } + Err(err) => { + let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); + break; + } + } + } + let _ = onclose.call0(&JsValue::null()); + }); - let _ = onopen.call0(&JsValue::null()); + let _ = onopen.call0(&JsValue::null()); - Self { - tx, - onerror: onerror_cloned, - } - } + Self { + tx: tx.into(), + onerror: onerror_cloned, + } + } - pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { - let ret: Result<(), EpoxyError> = async move { - let payload = convert_body(payload) - .await - .map_err(|_| EpoxyError::InvalidPayload)? - .0 - .to_vec(); - Ok(self.tx.lock().await.send(payload.into()).await?) - } - .await; + pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { + let ret: Result<(), EpoxyError> = async move { + let payload = convert_body(payload) + .await + .map_err(|_| EpoxyError::InvalidPayload)? + .0 + .to_vec(); + Ok(self.tx.lock().await.send(payload.as_ref()).await?) + } + .await; - match ret { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = self - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + match ret { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = self + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - pub async fn close(&self) -> Result<(), EpoxyError> { - match self.tx.lock().await.close().await { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = self - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err.into()) - } - } - } + pub async fn close(&self) -> Result<(), EpoxyError> { + match self.tx.lock().await.close().await { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = self + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err.into()) + } + } + } } diff --git a/client/src/lib.rs b/client/src/lib.rs index f97b77c..1fe24b8 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -5,14 +5,14 @@ use std::{str::FromStr, sync::Arc}; use async_compression::futures::bufread as async_comp; use bytes::Bytes; use cfg_if::cfg_if; -use futures_util::TryStreamExt; #[cfg(feature = "full")] use futures_util::future::Either; +use futures_util::TryStreamExt; use http::{ - header::{InvalidHeaderName, InvalidHeaderValue}, - method::InvalidMethod, - uri::{InvalidUri, InvalidUriParts}, - HeaderName, HeaderValue, Method, Request, Response, + header::{InvalidHeaderName, InvalidHeaderValue}, + method::InvalidMethod, + uri::{InvalidUri, InvalidUriParts}, + HeaderName, HeaderValue, Method, Request, Response, }; use hyper::{body::Incoming, Uri}; use hyper_util_wasm::client::legacy::Client; @@ -22,7 +22,8 @@ use js_sys::{Array, Function, Object, Reflect}; use stream_provider::{StreamProvider, StreamProviderService}; use thiserror::Error; use utils::{ - asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, object_get, object_set, IncomingBody, UriExt, WasmExecutor + asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, + object_get, object_set, IncomingBody, UriExt, WasmExecutor, }; use wasm_bindgen::prelude::*; use wasm_streams::ReadableStream; @@ -45,409 +46,409 @@ type HttpBody = http_body_util::Full; #[derive(Debug, Error)] pub enum EpoxyError { - #[error("Invalid DNS name: {0:?}")] - InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError), - #[error("Wisp: {0:?}")] - Wisp(#[from] wisp_mux::WispError), - #[error("IO: {0:?}")] - Io(#[from] std::io::Error), - #[error("HTTP: {0:?}")] - Http(#[from] http::Error), - #[error("Hyper client: {0:?}")] - HyperClient(#[from] hyper_util_wasm::client::legacy::Error), - #[error("Hyper: {0:?}")] - Hyper(#[from] hyper::Error), - #[error("HTTP ToStr: {0:?}")] - ToStr(#[from] http::header::ToStrError), - #[cfg(feature = "full")] - #[error("Getrandom: {0:?}")] - GetRandom(#[from] getrandom::Error), - #[cfg(feature = "full")] - #[error("Fastwebsockets: {0:?}")] - FastWebSockets(#[from] fastwebsockets::WebSocketError), + #[error("Invalid DNS name: {0:?}")] + InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("Wisp: {0:?}")] + Wisp(#[from] wisp_mux::WispError), + #[error("IO: {0:?}")] + Io(#[from] std::io::Error), + #[error("HTTP: {0:?}")] + Http(#[from] http::Error), + #[error("Hyper client: {0:?}")] + HyperClient(#[from] hyper_util_wasm::client::legacy::Error), + #[error("Hyper: {0:?}")] + Hyper(#[from] hyper::Error), + #[error("HTTP ToStr: {0:?}")] + ToStr(#[from] http::header::ToStrError), + #[cfg(feature = "full")] + #[error("Getrandom: {0:?}")] + GetRandom(#[from] getrandom::Error), + #[cfg(feature = "full")] + #[error("Fastwebsockets: {0:?}")] + FastWebSockets(#[from] fastwebsockets::WebSocketError), - #[error("Invalid URL scheme")] - InvalidUrlScheme, - #[error("No URL host found")] - NoUrlHost, - #[error("No URL port found")] - NoUrlPort, - #[error("Invalid request body")] - InvalidRequestBody, - #[error("Invalid request")] - InvalidRequest, - #[error("Invalid websocket response status code")] - WsInvalidStatusCode, - #[error("Invalid websocket upgrade header")] - WsInvalidUpgradeHeader, - #[error("Invalid websocket connection header")] - WsInvalidConnectionHeader, - #[error("Invalid websocket payload")] - WsInvalidPayload, - #[error("Invalid payload")] - InvalidPayload, + #[error("Invalid URL scheme")] + InvalidUrlScheme, + #[error("No URL host found")] + NoUrlHost, + #[error("No URL port found")] + NoUrlPort, + #[error("Invalid request body")] + InvalidRequestBody, + #[error("Invalid request")] + InvalidRequest, + #[error("Invalid websocket response status code")] + WsInvalidStatusCode, + #[error("Invalid websocket upgrade header")] + WsInvalidUpgradeHeader, + #[error("Invalid websocket connection header")] + WsInvalidConnectionHeader, + #[error("Invalid websocket payload")] + WsInvalidPayload, + #[error("Invalid payload")] + InvalidPayload, - #[error("Invalid certificate store")] - InvalidCertStore, - #[error("WebSocket failed to connect")] - WebSocketConnectFailed, + #[error("Invalid certificate store")] + InvalidCertStore, + #[error("WebSocket failed to connect")] + WebSocketConnectFailed, - #[error("Failed to construct response headers object")] - ResponseHeadersFromEntriesFailed, - #[error("Failed to construct response object")] - ResponseNewFailed, - #[error("Failed to construct define_property object")] - DefinePropertyObjFailed, - #[error("Failed to set raw header item")] - RawHeaderSetFailed, + #[error("Failed to construct response headers object")] + ResponseHeadersFromEntriesFailed, + #[error("Failed to construct response object")] + ResponseNewFailed, + #[error("Failed to construct define_property object")] + DefinePropertyObjFailed, + #[error("Failed to set raw header item")] + RawHeaderSetFailed, } impl From for JsValue { - fn from(value: EpoxyError) -> Self { - JsError::from(value).into() - } + fn from(value: EpoxyError) -> Self { + JsError::from(value).into() + } } impl From for EpoxyError { - fn from(value: InvalidUri) -> Self { - http::Error::from(value).into() - } + fn from(value: InvalidUri) -> Self { + http::Error::from(value).into() + } } impl From for EpoxyError { - fn from(value: InvalidUriParts) -> Self { - http::Error::from(value).into() - } + fn from(value: InvalidUriParts) -> Self { + http::Error::from(value).into() + } } impl From for EpoxyError { - fn from(value: InvalidHeaderName) -> Self { - http::Error::from(value).into() - } + fn from(value: InvalidHeaderName) -> Self { + http::Error::from(value).into() + } } impl From for EpoxyError { - fn from(value: InvalidHeaderValue) -> Self { - http::Error::from(value).into() - } + fn from(value: InvalidHeaderValue) -> Self { + http::Error::from(value).into() + } } impl From for EpoxyError { - fn from(value: InvalidMethod) -> Self { - http::Error::from(value).into() - } + fn from(value: InvalidMethod) -> Self { + http::Error::from(value).into() + } } #[derive(Debug)] enum EpoxyResponse { - Success(Response), - Redirect((Response, http::Request)), + Success(Response), + Redirect((Response, http::Request)), } #[cfg(feature = "full")] enum EpoxyCompression { - Brotli, - Gzip, + Brotli, + Gzip, } #[wasm_bindgen] pub struct EpoxyClientOptions { - pub wisp_v2: bool, - pub udp_extension_required: bool, - #[wasm_bindgen(getter_with_clone)] - pub websocket_protocols: Vec, - pub redirect_limit: usize, - #[wasm_bindgen(getter_with_clone)] - pub user_agent: String, + pub wisp_v2: bool, + pub udp_extension_required: bool, + #[wasm_bindgen(getter_with_clone)] + pub websocket_protocols: Vec, + pub redirect_limit: usize, + #[wasm_bindgen(getter_with_clone)] + pub user_agent: String, } #[wasm_bindgen] impl EpoxyClientOptions { - #[wasm_bindgen(constructor)] - pub fn new_default() -> Self { - Self::default() - } + #[wasm_bindgen(constructor)] + pub fn new_default() -> Self { + Self::default() + } } impl Default for EpoxyClientOptions { - fn default() -> Self { - Self { + fn default() -> Self { + Self { wisp_v2: true, udp_extension_required: true, websocket_protocols: Vec::new(), redirect_limit: 10, user_agent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36".to_string(), } - } + } } #[wasm_bindgen(getter_with_clone)] pub struct EpoxyHandlers { - pub onopen: Function, - pub onclose: Function, - pub onerror: Function, - pub onmessage: Function, + pub onopen: Function, + pub onclose: Function, + pub onerror: Function, + pub onmessage: Function, } #[cfg(feature = "full")] #[wasm_bindgen] impl EpoxyHandlers { - #[wasm_bindgen(constructor)] - pub fn new( - onopen: Function, - onclose: Function, - onerror: Function, - onmessage: Function, - ) -> Self { - Self { - onopen, - onclose, - onerror, - onmessage, - } - } + #[wasm_bindgen(constructor)] + pub fn new( + onopen: Function, + onclose: Function, + onerror: Function, + onmessage: Function, + ) -> Self { + Self { + onopen, + onclose, + onerror, + onmessage, + } + } } #[wasm_bindgen(inspectable)] pub struct EpoxyClient { - stream_provider: Arc, - client: Client, + stream_provider: Arc, + client: Client, - pub redirect_limit: usize, - #[wasm_bindgen(getter_with_clone)] - pub user_agent: String, + pub redirect_limit: usize, + #[wasm_bindgen(getter_with_clone)] + pub user_agent: String, } #[wasm_bindgen] impl EpoxyClient { - #[wasm_bindgen(constructor)] - pub fn new( - wisp_url: String, - certs: Array, - options: EpoxyClientOptions, - ) -> Result { - let wisp_url: Uri = wisp_url.try_into()?; - if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") { - return Err(EpoxyError::InvalidUrlScheme); - } + #[wasm_bindgen(constructor)] + pub fn new( + wisp_url: String, + certs: Array, + options: EpoxyClientOptions, + ) -> Result { + let wisp_url: Uri = wisp_url.try_into()?; + if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") { + return Err(EpoxyError::InvalidUrlScheme); + } - let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), certs, &options)?); + let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), certs, &options)?); - let service = StreamProviderService(stream_provider.clone()); - let client = Client::builder(WasmExecutor) - .http09_responses(true) - .http1_title_case_headers(true) - .http1_preserve_header_case(true) - .build(service); + let service = StreamProviderService(stream_provider.clone()); + let client = Client::builder(WasmExecutor) + .http09_responses(true) + .http1_title_case_headers(true) + .http1_preserve_header_case(true) + .build(service); - Ok(Self { - stream_provider, - client, - redirect_limit: options.redirect_limit, - user_agent: options.user_agent, - }) - } + Ok(Self { + stream_provider, + client, + redirect_limit: options.redirect_limit, + user_agent: options.user_agent, + }) + } - pub async fn replace_stream_provider(&self) -> Result<(), EpoxyError> { - self.stream_provider.replace_client().await - } + pub async fn replace_stream_provider(&self) -> Result<(), EpoxyError> { + self.stream_provider.replace_client().await + } - #[cfg(feature = "full")] - pub async fn connect_websocket( - &self, - handlers: EpoxyHandlers, - url: String, - protocols: Vec, - headers: JsValue, - ) -> Result { - EpoxyWebSocket::connect(self, handlers, url, protocols, headers, &self.user_agent).await - } + #[cfg(feature = "full")] + pub async fn connect_websocket( + &self, + handlers: EpoxyHandlers, + url: String, + protocols: Vec, + headers: JsValue, + ) -> Result { + EpoxyWebSocket::connect(self, handlers, url, protocols, headers, &self.user_agent).await + } - #[cfg(feature = "full")] - pub async fn connect_tcp( - &self, - handlers: EpoxyHandlers, - url: String, - ) -> Result { - let url: Uri = url.try_into()?; - let host = url.host().ok_or(EpoxyError::NoUrlHost)?; - let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; - match self - .stream_provider - .get_asyncread(StreamType::Tcp, host.to_string(), port) - .await - { - Ok(stream) => Ok(EpoxyIoStream::connect(Either::Right(stream), handlers)), - Err(err) => { - let _ = handlers - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + #[cfg(feature = "full")] + pub async fn connect_tcp( + &self, + handlers: EpoxyHandlers, + url: String, + ) -> Result { + let url: Uri = url.try_into()?; + let host = url.host().ok_or(EpoxyError::NoUrlHost)?; + let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; + match self + .stream_provider + .get_asyncread(StreamType::Tcp, host.to_string(), port) + .await + { + Ok(stream) => Ok(EpoxyIoStream::connect(Either::Right(stream), handlers)), + Err(err) => { + let _ = handlers + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - #[cfg(feature = "full")] - pub async fn connect_tls( - &self, - handlers: EpoxyHandlers, - url: String, - ) -> Result { - let url: Uri = url.try_into()?; - let host = url.host().ok_or(EpoxyError::NoUrlHost)?; - let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; - match self - .stream_provider - .get_tls_stream(host.to_string(), port) - .await - { - Ok(stream) => Ok(EpoxyIoStream::connect(Either::Left(stream), handlers)), - Err(err) => { - let _ = handlers - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + #[cfg(feature = "full")] + pub async fn connect_tls( + &self, + handlers: EpoxyHandlers, + url: String, + ) -> Result { + let url: Uri = url.try_into()?; + let host = url.host().ok_or(EpoxyError::NoUrlHost)?; + let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; + match self + .stream_provider + .get_tls_stream(host.to_string(), port) + .await + { + Ok(stream) => Ok(EpoxyIoStream::connect(Either::Left(stream), handlers)), + Err(err) => { + let _ = handlers + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - #[cfg(feature = "full")] - pub async fn connect_udp( - &self, - handlers: EpoxyHandlers, - url: String, - ) -> Result { - let url: Uri = url.try_into()?; - let host = url.host().ok_or(EpoxyError::NoUrlHost)?; - let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; - match self - .stream_provider - .get_stream(StreamType::Udp, host.to_string(), port) - .await - { - Ok(stream) => Ok(EpoxyUdpStream::connect(stream, handlers)), - Err(err) => { - let _ = handlers - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + #[cfg(feature = "full")] + pub async fn connect_udp( + &self, + handlers: EpoxyHandlers, + url: String, + ) -> Result { + let url: Uri = url.try_into()?; + let host = url.host().ok_or(EpoxyError::NoUrlHost)?; + let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; + match self + .stream_provider + .get_stream(StreamType::Udp, host.to_string(), port) + .await + { + Ok(stream) => Ok(EpoxyUdpStream::connect(stream, handlers)), + Err(err) => { + let _ = handlers + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - async fn send_req_inner( - &self, - req: http::Request, - should_redirect: bool, - ) -> Result { - let new_req = if should_redirect { - Some(req.clone()) - } else { - None - }; + async fn send_req_inner( + &self, + req: http::Request, + should_redirect: bool, + ) -> Result { + let new_req = if should_redirect { + Some(req.clone()) + } else { + None + }; - let res = self.client.request(req).await; - match res { - Ok(res) => { - if is_redirect(res.status().as_u16()) - && let Some(mut new_req) = new_req - && let Some(location) = res.headers().get("Location") - && let Ok(redirect_url) = new_req.uri().get_redirect(location) - && let Some(redirect_url_authority) = redirect_url.clone().authority() - { - *new_req.uri_mut() = redirect_url; - new_req.headers_mut().insert( - "Host", - HeaderValue::from_str(redirect_url_authority.as_str())?, - ); - Ok(EpoxyResponse::Redirect((res, new_req))) - } else { - Ok(EpoxyResponse::Success(res)) - } - } - Err(err) => Err(err.into()), - } - } + let res = self.client.request(req).await; + match res { + Ok(res) => { + if is_redirect(res.status().as_u16()) + && let Some(mut new_req) = new_req + && let Some(location) = res.headers().get("Location") + && let Ok(redirect_url) = new_req.uri().get_redirect(location) + && let Some(redirect_url_authority) = redirect_url.clone().authority() + { + *new_req.uri_mut() = redirect_url; + new_req.headers_mut().insert( + "Host", + HeaderValue::from_str(redirect_url_authority.as_str())?, + ); + Ok(EpoxyResponse::Redirect((res, new_req))) + } else { + Ok(EpoxyResponse::Success(res)) + } + } + Err(err) => Err(err.into()), + } + } - async fn send_req( - &self, - req: http::Request, - should_redirect: bool, - ) -> Result<(hyper::Response, Uri, bool), EpoxyError> { - let mut redirected = false; - let mut current_url = req.uri().clone(); - let mut current_resp: EpoxyResponse = self.send_req_inner(req, should_redirect).await?; - for _ in 0..self.redirect_limit { - match current_resp { - EpoxyResponse::Success(_) => break, - EpoxyResponse::Redirect((_, req)) => { - redirected = true; - current_url = req.uri().clone(); - current_resp = self.send_req_inner(req, should_redirect).await? - } - } - } + async fn send_req( + &self, + req: http::Request, + should_redirect: bool, + ) -> Result<(hyper::Response, Uri, bool), EpoxyError> { + let mut redirected = false; + let mut current_url = req.uri().clone(); + let mut current_resp: EpoxyResponse = self.send_req_inner(req, should_redirect).await?; + for _ in 0..self.redirect_limit { + match current_resp { + EpoxyResponse::Success(_) => break, + EpoxyResponse::Redirect((_, req)) => { + redirected = true; + current_url = req.uri().clone(); + current_resp = self.send_req_inner(req, should_redirect).await? + } + } + } - match current_resp { - EpoxyResponse::Success(resp) => Ok((resp, current_url, redirected)), - EpoxyResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)), - } - } + match current_resp { + EpoxyResponse::Success(resp) => Ok((resp, current_url, redirected)), + EpoxyResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)), + } + } - pub async fn fetch( - &self, - url: String, - options: Object, - ) -> Result { - let url: Uri = url.try_into()?; - // only valid `Scheme`s are HTTP and HTTPS, which are the ones we support - url.scheme().ok_or(EpoxyError::InvalidUrlScheme)?; + pub async fn fetch( + &self, + url: String, + options: Object, + ) -> Result { + let url: Uri = url.try_into()?; + // only valid `Scheme`s are HTTP and HTTPS, which are the ones we support + url.scheme().ok_or(EpoxyError::InvalidUrlScheme)?; - let host = url.host().ok_or(EpoxyError::NoUrlHost)?; + let host = url.host().ok_or(EpoxyError::NoUrlHost)?; - let request_method = object_get(&options, "method") - .and_then(|x| x.as_string()) - .unwrap_or_else(|| "GET".to_string()); - let request_method: Method = Method::from_str(&request_method)?; + let request_method = object_get(&options, "method") + .and_then(|x| x.as_string()) + .unwrap_or_else(|| "GET".to_string()); + let request_method: Method = Method::from_str(&request_method)?; - let request_redirect = object_get(&options, "redirect") - .map(|x| { - !matches!( - x.as_string().unwrap_or_default().as_str(), - "error" | "manual" - ) - }) - .unwrap_or(true); + let request_redirect = object_get(&options, "redirect") + .map(|x| { + !matches!( + x.as_string().unwrap_or_default().as_str(), + "error" | "manual" + ) + }) + .unwrap_or(true); - let mut body_content_type: Option = None; - let body = match object_get(&options, "body") { - Some(buf) => { - let (body, req) = convert_body(buf) - .await - .map_err(|_| EpoxyError::InvalidRequestBody)?; - body_content_type = req.headers().get("Content-Type").ok().flatten(); - Bytes::from(body.to_vec()) - } - None => Bytes::new(), - }; + let mut body_content_type: Option = None; + let body = match object_get(&options, "body") { + Some(buf) => { + let (body, req) = convert_body(buf) + .await + .map_err(|_| EpoxyError::InvalidRequestBody)?; + body_content_type = req.headers().get("Content-Type").ok().flatten(); + Bytes::from(body.to_vec()) + } + None => Bytes::new(), + }; - let headers = object_get(&options, "headers").and_then(|val| { - if web_sys::Headers::instanceof(&val) { - Some(entries_of_object(&Object::from_entries(&val).ok()?)) - } else if val.is_truthy() { - Some(entries_of_object(&Object::from(val))) - } else { - None - } - }); + let headers = object_get(&options, "headers").and_then(|val| { + if web_sys::Headers::instanceof(&val) { + Some(entries_of_object(&Object::from_entries(&val).ok()?)) + } else if val.is_truthy() { + Some(entries_of_object(&Object::from(val))) + } else { + None + } + }); - let mut request_builder = Request::builder().uri(url.clone()).method(request_method); + let mut request_builder = Request::builder().uri(url.clone()).method(request_method); - // Generic InvalidRequest because this only returns None if the builder has some error - // which we don't know - let headers_map = request_builder - .headers_mut() - .ok_or(EpoxyError::InvalidRequest)?; + // Generic InvalidRequest because this only returns None if the builder has some error + // which we don't know + let headers_map = request_builder + .headers_mut() + .ok_or(EpoxyError::InvalidRequest)?; cfg_if! { if #[cfg(feature = "full")] { @@ -456,54 +457,54 @@ impl EpoxyClient { headers_map.insert("Accept-Encoding", HeaderValue::from_static("identity")); } } - headers_map.insert("Connection", HeaderValue::from_static("keep-alive")); - headers_map.insert("User-Agent", HeaderValue::from_str(&self.user_agent)?); - headers_map.insert("Host", HeaderValue::from_str(host)?); + headers_map.insert("Connection", HeaderValue::from_static("keep-alive")); + headers_map.insert("User-Agent", HeaderValue::from_str(&self.user_agent)?); + headers_map.insert("Host", HeaderValue::from_str(host)?); - if body.is_empty() { - headers_map.insert("Content-Length", HeaderValue::from_static("0")); - } + if body.is_empty() { + headers_map.insert("Content-Length", HeaderValue::from_static("0")); + } - if let Some(content_type) = body_content_type { - headers_map.insert("Content-Type", HeaderValue::from_str(&content_type)?); - } + if let Some(content_type) = body_content_type { + headers_map.insert("Content-Type", HeaderValue::from_str(&content_type)?); + } - if let Some(headers) = headers { - for hdr in headers { - headers_map.insert( - HeaderName::from_str(&hdr[0])?, - HeaderValue::from_str(&hdr[1])?, - ); - } - } + if let Some(headers) = headers { + for hdr in headers { + headers_map.insert( + HeaderName::from_str(&hdr[0])?, + HeaderValue::from_str(&hdr[1])?, + ); + } + } - let (response, response_uri, redirected) = self - .send_req(request_builder.body(HttpBody::new(body))?, request_redirect) - .await?; + let (response, response_uri, redirected) = self + .send_req(request_builder.body(HttpBody::new(body))?, request_redirect) + .await?; - let response_headers: Array = response - .headers() - .iter() - .filter_map(|val| { - Some(Array::of2( - &val.0.as_str().into(), - &val.1.to_str().ok()?.into(), - )) - }) - .collect(); - let response_headers = Object::from_entries(&response_headers) - .map_err(|_| EpoxyError::ResponseHeadersFromEntriesFailed)?; + let response_headers: Array = response + .headers() + .iter() + .filter_map(|val| { + Some(Array::of2( + &val.0.as_str().into(), + &val.1.to_str().ok()?.into(), + )) + }) + .collect(); + let response_headers = Object::from_entries(&response_headers) + .map_err(|_| EpoxyError::ResponseHeadersFromEntriesFailed)?; - let response_headers_raw = response.headers().clone(); + let response_headers_raw = response.headers().clone(); - let mut response_builder = ResponseInit::new(); - response_builder - .headers(&response_headers) - .status(response.status().as_u16()) - .status_text(response.status().canonical_reason().unwrap_or_default()); + let mut response_builder = ResponseInit::new(); + response_builder + .headers(&response_headers) + .status(response.status().as_u16()) + .status_text(response.status().canonical_reason().unwrap_or_default()); - cfg_if! { - if #[cfg(feature = "full")] { + cfg_if! { + if #[cfg(feature = "full")] { let response_stream = if !is_null_body(response.status().as_u16()) { let compression = match response .headers() @@ -532,59 +533,59 @@ impl EpoxyClient { } else { None }; - } else { - let response_stream = if !is_null_body(response.status().as_u16()) { - let response_body = IncomingBody::new(response.into_body()).into_async_read(); - Some(ReadableStream::from_stream(asyncread_to_readablestream_stream(response_body)).into_raw()) - } else { - None - }; - } - } + } else { + let response_stream = if !is_null_body(response.status().as_u16()) { + let response_body = IncomingBody::new(response.into_body()).into_async_read(); + Some(ReadableStream::from_stream(asyncread_to_readablestream_stream(response_body)).into_raw()) + } else { + None + }; + } + } - let resp = web_sys::Response::new_with_opt_readable_stream_and_init( - response_stream.as_ref(), - &response_builder, - ) - .map_err(|_| EpoxyError::ResponseNewFailed)?; + let resp = web_sys::Response::new_with_opt_readable_stream_and_init( + response_stream.as_ref(), + &response_builder, + ) + .map_err(|_| EpoxyError::ResponseNewFailed)?; - Object::define_property( - &resp, - &"url".into(), - &utils::define_property_obj(response_uri.to_string().into(), false) - .map_err(|_| EpoxyError::DefinePropertyObjFailed)?, - ); + Object::define_property( + &resp, + &"url".into(), + &utils::define_property_obj(response_uri.to_string().into(), false) + .map_err(|_| EpoxyError::DefinePropertyObjFailed)?, + ); - Object::define_property( - &resp, - &"redirected".into(), - &utils::define_property_obj(redirected.into(), false) - .map_err(|_| EpoxyError::DefinePropertyObjFailed)?, - ); + Object::define_property( + &resp, + &"redirected".into(), + &utils::define_property_obj(redirected.into(), false) + .map_err(|_| EpoxyError::DefinePropertyObjFailed)?, + ); - let raw_headers = Object::new(); - for (k, v) in response_headers_raw.iter() { - let k: JsValue = k.to_string().into(); - let v: JsValue = v.to_str()?.to_string().into(); - if let Ok(jv) = Reflect::get(&raw_headers, &k) { - if jv.is_array() { - let arr = Array::from(&jv); - arr.push(&v); - object_set(&raw_headers, &k, &arr)?; - } else if jv.is_truthy() { - object_set(&raw_headers, &k, &Array::of2(&jv, &v))?; - } else { - object_set(&raw_headers, &k, &v)?; - } - } - } - Object::define_property( - &resp, - &"rawHeaders".into(), - &utils::define_property_obj(raw_headers.into(), false) - .map_err(|_| EpoxyError::DefinePropertyObjFailed)?, - ); + let raw_headers = Object::new(); + for (k, v) in response_headers_raw.iter() { + let k: JsValue = k.to_string().into(); + let v: JsValue = v.to_str()?.to_string().into(); + if let Ok(jv) = Reflect::get(&raw_headers, &k) { + if jv.is_array() { + let arr = Array::from(&jv); + arr.push(&v); + object_set(&raw_headers, &k, &arr)?; + } else if jv.is_truthy() { + object_set(&raw_headers, &k, &Array::of2(&jv, &v))?; + } else { + object_set(&raw_headers, &k, &v)?; + } + } + } + Object::define_property( + &resp, + &"rawHeaders".into(), + &utils::define_property_obj(raw_headers.into(), false) + .map_err(|_| EpoxyError::DefinePropertyObjFailed)?, + ); - Ok(resp) - } + Ok(resp) + } } diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 18487e1..1e91686 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -5,7 +5,9 @@ use futures_rustls::{ TlsConnector, TlsStream, }; use futures_util::{ - future::Either, lock::{Mutex, MutexGuard}, AsyncRead, AsyncWrite, Future + future::Either, + lock::{Mutex, MutexGuard}, + AsyncRead, AsyncWrite, Future, }; use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection}; use js_sys::{Array, Reflect, Uint8Array}; @@ -81,7 +83,7 @@ impl StreamProvider { mut locked: MutexGuard<'_, Option>, ) -> Result<(), EpoxyError> { let extensions_vec: Vec> = - vec![Box::new(UdpProtocolExtensionBuilder())]; + vec![Box::new(UdpProtocolExtensionBuilder)]; let extensions = if self.wisp_v2 { Some(extensions_vec.as_slice()) } else { diff --git a/client/src/tokioio.rs b/client/src/tokioio.rs index 62cdf5c..a5797ee 100644 --- a/client/src/tokioio.rs +++ b/client/src/tokioio.rs @@ -2,168 +2,168 @@ //! hyper_util::rt::tokio::TokioIo use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; use pin_project_lite::pin_project; pin_project! { - /// A wrapping implementing hyper IO traits for a type that - /// implements Tokio's IO traits. - #[derive(Debug)] - pub struct TokioIo { - #[pin] - inner: T, - } + /// A wrapping implementing hyper IO traits for a type that + /// implements Tokio's IO traits. + #[derive(Debug)] + pub struct TokioIo { + #[pin] + inner: T, + } } impl TokioIo { - /// Wrap a type implementing Tokio's IO traits. - pub fn new(inner: T) -> Self { - Self { inner } - } + /// Wrap a type implementing Tokio's IO traits. + pub fn new(inner: T) -> Self { + Self { inner } + } - /// Borrow the inner type. - pub fn inner(&self) -> &T { - &self.inner - } + /// Borrow the inner type. + pub fn inner(&self) -> &T { + &self.inner + } - /// Mut borrow the inner type. - pub fn inner_mut(&mut self) -> &mut T { - &mut self.inner - } + /// Mut borrow the inner type. + pub fn inner_mut(&mut self) -> &mut T { + &mut self.inner + } - /// Consume this wrapper and get the inner type. - pub fn into_inner(self) -> T { - self.inner - } + /// Consume this wrapper and get the inner type. + pub fn into_inner(self) -> T { + self.inner + } } impl hyper::rt::Read for TokioIo where - T: tokio::io::AsyncRead, + T: tokio::io::AsyncRead, { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: hyper::rt::ReadBufCursor<'_>, - ) -> Poll> { - let n = unsafe { - let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); - match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { - Poll::Ready(Ok(())) => tbuf.filled().len(), - other => return other, - } - }; + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; - unsafe { - buf.advance(n); - } - Poll::Ready(Ok(())) - } + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } } impl hyper::rt::Write for TokioIo where - T: tokio::io::AsyncWrite, + T: tokio::io::AsyncWrite, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) - } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) - } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } - fn is_write_vectored(&self) -> bool { - tokio::io::AsyncWrite::is_write_vectored(&self.inner) - } + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } } impl tokio::io::AsyncRead for TokioIo where - T: hyper::rt::Read, + T: hyper::rt::Read, { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - tbuf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - //let init = tbuf.initialized().len(); - let filled = tbuf.filled().len(); - let sub_filled = unsafe { - let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); - match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { - Poll::Ready(Ok(())) => buf.filled().len(), - other => return other, - } - }; + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; - let n_filled = filled + sub_filled; - // At least sub_filled bytes had to have been initialized. - let n_init = sub_filled; - unsafe { - tbuf.assume_init(n_init); - tbuf.set_filled(n_filled); - } + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } - Poll::Ready(Ok(())) - } + Poll::Ready(Ok(())) + } } impl tokio::io::AsyncWrite for TokioIo where - T: hyper::rt::Write, + T: hyper::rt::Write, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - hyper::rt::Write::poll_write(self.project().inner, cx, buf) - } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - hyper::rt::Write::poll_flush(self.project().inner, cx) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - hyper::rt::Write::poll_shutdown(self.project().inner, cx) - } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } - fn is_write_vectored(&self) -> bool { - hyper::rt::Write::is_write_vectored(&self.inner) - } + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) - } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } } diff --git a/client/src/websocket.rs b/client/src/websocket.rs index 296ee55..9f9bd0d 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -3,73 +3,78 @@ use std::{str::from_utf8, sync::Arc}; use base64::{prelude::BASE64_STANDARD, Engine}; use bytes::Bytes; use fastwebsockets::{ - FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, + FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, }; use futures_util::lock::Mutex; use getrandom::getrandom; use http::{ - header::{ - CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, - }, - Method, Request, Response, StatusCode, Uri, + header::{ + CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, + UPGRADE, USER_AGENT, + }, + Method, Request, Response, StatusCode, Uri, }; use hyper::{ - body::Incoming, - upgrade::{self, Upgraded}, + body::Incoming, + upgrade::{self, Upgraded}, }; use js_sys::{ArrayBuffer, Function, Object, Uint8Array}; use tokio::io::WriteHalf; use wasm_bindgen::{prelude::*, JsError, JsValue}; use wasm_bindgen_futures::spawn_local; -use crate::{tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody}; +use crate::{ + tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody, +}; #[wasm_bindgen] pub struct EpoxyWebSocket { - tx: Arc>>>>, - onerror: Function, + tx: Arc>>>>, + onerror: Function, } #[wasm_bindgen] impl EpoxyWebSocket { - pub(crate) async fn connect( - client: &EpoxyClient, - handlers: EpoxyHandlers, - url: String, - protocols: Vec, + pub(crate) async fn connect( + client: &EpoxyClient, + handlers: EpoxyHandlers, + url: String, + protocols: Vec, headers: JsValue, user_agent: &str, - ) -> Result { - let EpoxyHandlers { - onopen, - onclose, - onerror, - onmessage, - } = handlers; - let onerror_cloned = onerror.clone(); - let ret: Result = async move { - let url: Uri = url.try_into()?; - let host = url.host().ok_or(EpoxyError::NoUrlHost)?; + ) -> Result { + let EpoxyHandlers { + onopen, + onclose, + onerror, + onmessage, + } = handlers; + let onerror_cloned = onerror.clone(); + let ret: Result = async move { + let url: Uri = url.try_into()?; + let host = url.host().ok_or(EpoxyError::NoUrlHost)?; - let mut rand = [0u8; 16]; - getrandom(&mut rand)?; - let key = BASE64_STANDARD.encode(rand); + let mut rand = [0u8; 16]; + getrandom(&mut rand)?; + let key = BASE64_STANDARD.encode(rand); - let mut request = Request::builder() - .method(Method::GET) - .uri(url.clone()) - .header(HOST, host) - .header(CONNECTION, "upgrade") - .header(UPGRADE, "websocket") - .header(SEC_WEBSOCKET_KEY, key) - .header(SEC_WEBSOCKET_VERSION, "13") + let mut request = Request::builder() + .method(Method::GET) + .uri(url.clone()) + .header(HOST, host) + .header(CONNECTION, "upgrade") + .header(UPGRADE, "websocket") + .header(SEC_WEBSOCKET_KEY, key) + .header(SEC_WEBSOCKET_VERSION, "13") .header(USER_AGENT, user_agent); - if !protocols.is_empty() { - request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(",")); - } + if !protocols.is_empty() { + request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(",")); + } - if web_sys::Headers::instanceof(&headers) && let Ok(entries) = Object::from_entries(&headers) { + if web_sys::Headers::instanceof(&headers) + && let Ok(entries) = Object::from_entries(&headers) + { for header in entries_of_object(&entries) { request = request.header(&header[0], &header[1]); } @@ -79,153 +84,153 @@ impl EpoxyWebSocket { } } - let request = request.body(HttpBody::new(Bytes::new()))?; + let request = request.body(HttpBody::new(Bytes::new()))?; - let mut response = client.client.request(request).await?; - verify(&response)?; + let mut response = client.client.request(request).await?; + verify(&response)?; - let websocket = WebSocket::after_handshake( - TokioIo::new(upgrade::on(&mut response).await?), - Role::Client, - ); + let websocket = WebSocket::after_handshake( + TokioIo::new(upgrade::on(&mut response).await?), + Role::Client, + ); - let (rx, tx) = websocket.split(tokio::io::split); + let (rx, tx) = websocket.split(tokio::io::split); - let mut rx = FragmentCollectorRead::new(rx); - let tx = Arc::new(Mutex::new(tx)); + let mut rx = FragmentCollectorRead::new(rx); + let tx = Arc::new(Mutex::new(tx)); - let read_tx = tx.clone(); - let onerror_cloned = onerror.clone(); + let read_tx = tx.clone(); + let onerror_cloned = onerror.clone(); - spawn_local(async move { - loop { - match rx - .read_frame(&mut |arg| async { - read_tx.lock().await.write_frame(arg).await - }) - .await - { - Ok(frame) => match frame.opcode { - OpCode::Text => { - if let Ok(str) = from_utf8(&frame.payload) { - let _ = onmessage.call1(&JsValue::null(), &str.into()); - } - } - OpCode::Binary => { - let _ = onmessage.call1( - &JsValue::null(), - &Uint8Array::from(frame.payload.to_vec().as_slice()).into(), - ); - } - OpCode::Close => { - break; - } - // ping/pong/continue - _ => {} - }, - Err(err) => { - let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); - break; - } - } - } - let _ = onclose.call0(&JsValue::null()); - }); + spawn_local(async move { + loop { + match rx + .read_frame(&mut |arg| async { + read_tx.lock().await.write_frame(arg).await + }) + .await + { + Ok(frame) => match frame.opcode { + OpCode::Text => { + if let Ok(str) = from_utf8(&frame.payload) { + let _ = onmessage.call1(&JsValue::null(), &str.into()); + } + } + OpCode::Binary => { + let _ = onmessage.call1( + &JsValue::null(), + &Uint8Array::from(frame.payload.to_vec().as_slice()).into(), + ); + } + OpCode::Close => { + break; + } + // ping/pong/continue + _ => {} + }, + Err(err) => { + let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); + break; + } + } + } + let _ = onclose.call0(&JsValue::null()); + }); - let _ = onopen.call0(&JsValue::null()); + let _ = onopen.call0(&JsValue::null()); - Ok(Self { - tx, - onerror: onerror_cloned, - }) - } - .await; + Ok(Self { + tx, + onerror: onerror_cloned, + }) + } + .await; - match ret { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + match ret { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { - let ret = if let Some(str) = payload.as_string() { - self.tx - .lock() - .await - .write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec()))) - .await - .map_err(EpoxyError::from) - } else if let Ok(binary) = payload.dyn_into::() { - self.tx - .lock() - .await - .write_frame(Frame::binary(Payload::Owned( - Uint8Array::new(&binary).to_vec(), - ))) - .await - .map_err(EpoxyError::from) - } else { - Err(EpoxyError::WsInvalidPayload) - }; + pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { + let ret = if let Some(str) = payload.as_string() { + self.tx + .lock() + .await + .write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec()))) + .await + .map_err(EpoxyError::from) + } else if let Ok(binary) = payload.dyn_into::() { + self.tx + .lock() + .await + .write_frame(Frame::binary(Payload::Owned( + Uint8Array::new(&binary).to_vec(), + ))) + .await + .map_err(EpoxyError::from) + } else { + Err(EpoxyError::WsInvalidPayload) + }; - match ret { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = self - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err) - } - } - } + match ret { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = self + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err) + } + } + } - pub async fn close(&self, code: u16, reason: String) -> Result<(), EpoxyError> { - let ret = self - .tx - .lock() - .await - .write_frame(Frame::close(code, reason.as_bytes())) - .await; - match ret { - Ok(ok) => Ok(ok), - Err(err) => { - let _ = self - .onerror - .call1(&JsValue::null(), &err.to_string().into()); - Err(err.into()) - } - } - } + pub async fn close(&self, code: u16, reason: String) -> Result<(), EpoxyError> { + let ret = self + .tx + .lock() + .await + .write_frame(Frame::close(code, reason.as_bytes())) + .await; + match ret { + Ok(ok) => Ok(ok), + Err(err) => { + let _ = self + .onerror + .call1(&JsValue::null(), &err.to_string().into()); + Err(err.into()) + } + } + } } // https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189 fn verify(response: &Response) -> Result<(), EpoxyError> { - if response.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(EpoxyError::WsInvalidStatusCode); - } + if response.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(EpoxyError::WsInvalidStatusCode); + } - let headers = response.headers(); + let headers = response.headers(); - if !headers - .get(UPGRADE) - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) - { - return Err(EpoxyError::WsInvalidUpgradeHeader); - } + if !headers + .get(UPGRADE) + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + { + return Err(EpoxyError::WsInvalidUpgradeHeader); + } - if !headers - .get(CONNECTION) - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("Upgrade")) - .unwrap_or(false) - { - return Err(EpoxyError::WsInvalidConnectionHeader); - } + if !headers + .get(CONNECTION) + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) + { + return Err(EpoxyError::WsInvalidConnectionHeader); + } - Ok(()) + Ok(()) } diff --git a/client/src/ws_wrapper.rs b/client/src/ws_wrapper.rs index 25d6be6..52aeb0f 100644 --- a/client/src/ws_wrapper.rs +++ b/client/src/ws_wrapper.rs @@ -1,6 +1,6 @@ use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, + atomic::{AtomicBool, Ordering}, + Arc, }; use async_trait::async_trait; @@ -13,214 +13,219 @@ use send_wrapper::SendWrapper; use wasm_bindgen::{closure::Closure, JsCast}; use web_sys::{BinaryType, MessageEvent, WebSocket}; use wisp_mux::{ - ws::{Frame, LockedWebSocketWrite, WebSocketRead, WebSocketWrite}, - WispError, + ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, + WispError, }; use crate::EpoxyError; #[derive(Debug)] pub enum WebSocketError { - Unknown, - SendFailed, - CloseFailed, + Unknown, + SendFailed, + CloseFailed, } impl std::fmt::Display for WebSocketError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - use WebSocketError::*; - match self { - Unknown => write!(f, "Unknown error"), - SendFailed => write!(f, "Send failed"), - CloseFailed => write!(f, "Close failed"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + use WebSocketError::*; + match self { + Unknown => write!(f, "Unknown error"), + SendFailed => write!(f, "Send failed"), + CloseFailed => write!(f, "Close failed"), + } + } } impl std::error::Error for WebSocketError {} impl From for WispError { - fn from(err: WebSocketError) -> Self { - Self::WsImplError(Box::new(err)) - } + fn from(err: WebSocketError) -> Self { + Self::WsImplError(Box::new(err)) + } } pub enum WebSocketMessage { - Closed, - Error, - Message(Vec), + Closed, + Error, + Message(Vec), } pub struct WebSocketWrapper { - inner: SendWrapper, - open_event: Arc, - error_event: Arc, - close_event: Arc, - closed: Arc, + inner: SendWrapper, + open_event: Arc, + error_event: Arc, + close_event: Arc, + closed: Arc, - // used to retain the closures - #[allow(dead_code)] - onopen: SendWrapper>, - #[allow(dead_code)] - onclose: SendWrapper>, - #[allow(dead_code)] - onerror: SendWrapper>, - #[allow(dead_code)] - onmessage: SendWrapper>, + // used to retain the closures + #[allow(dead_code)] + onopen: SendWrapper>, + #[allow(dead_code)] + onclose: SendWrapper>, + #[allow(dead_code)] + onerror: SendWrapper>, + #[allow(dead_code)] + onmessage: SendWrapper>, } pub struct WebSocketReader { - read_rx: Receiver, - closed: Arc, - close_event: Arc, + read_rx: Receiver, + closed: Arc, + close_event: Arc, } #[async_trait] impl WebSocketRead for WebSocketReader { - async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result { - use WebSocketMessage::*; - if self.closed.load(Ordering::Acquire) { - return Err(WispError::WsImplSocketClosed); - } - let res = futures_util::select! { - data = self.read_rx.recv_async() => data.ok(), - _ = self.close_event.listen().fuse() => Some(Closed), - }; - match res.ok_or(WispError::WsImplSocketClosed)? { - Message(bin) => Ok(Frame::binary(BytesMut::from(bin.as_slice()))), - Error => Err(WebSocketError::Unknown.into()), - Closed => Err(WispError::WsImplSocketClosed), - } - } + async fn wisp_read_frame( + &mut self, + _: &LockedWebSocketWrite, + ) -> Result, WispError> { + use WebSocketMessage::*; + if self.closed.load(Ordering::Acquire) { + return Err(WispError::WsImplSocketClosed); + } + let res = futures_util::select! { + data = self.read_rx.recv_async() => data.ok(), + _ = self.close_event.listen().fuse() => Some(Closed), + }; + match res.ok_or(WispError::WsImplSocketClosed)? { + Message(bin) => Ok(Frame::binary(Payload::Bytes(BytesMut::from( + bin.as_slice(), + )))), + Error => Err(WebSocketError::Unknown.into()), + Closed => Err(WispError::WsImplSocketClosed), + } + } } impl WebSocketWrapper { - pub fn connect(url: &str, protocols: &[String]) -> Result<(Self, WebSocketReader), EpoxyError> { - let (read_tx, read_rx) = flume::unbounded(); - let closed = Arc::new(AtomicBool::new(false)); + pub fn connect(url: &str, protocols: &[String]) -> Result<(Self, WebSocketReader), EpoxyError> { + let (read_tx, read_rx) = flume::unbounded(); + let closed = Arc::new(AtomicBool::new(false)); - let open_event = Arc::new(Event::new()); - let close_event = Arc::new(Event::new()); - let error_event = Arc::new(Event::new()); + let open_event = Arc::new(Event::new()); + let close_event = Arc::new(Event::new()); + let error_event = Arc::new(Event::new()); - let onopen_event = open_event.clone(); - let onopen = Closure::wrap( - Box::new(move || while onopen_event.notify(usize::MAX) == 0 {}) as Box, - ); + let onopen_event = open_event.clone(); + let onopen = Closure::wrap( + Box::new(move || while onopen_event.notify(usize::MAX) == 0 {}) as Box, + ); - let onmessage_tx = read_tx.clone(); - let onmessage = Closure::wrap(Box::new(move |evt: MessageEvent| { - if let Ok(arr) = evt.data().dyn_into::() { - let _ = - onmessage_tx.send(WebSocketMessage::Message(Uint8Array::new(&arr).to_vec())); - } - }) as Box); + let onmessage_tx = read_tx.clone(); + let onmessage = Closure::wrap(Box::new(move |evt: MessageEvent| { + if let Ok(arr) = evt.data().dyn_into::() { + let _ = + onmessage_tx.send(WebSocketMessage::Message(Uint8Array::new(&arr).to_vec())); + } + }) as Box); - let onclose_closed = closed.clone(); - let onclose_event = close_event.clone(); - let onclose = Closure::wrap(Box::new(move || { - onclose_closed.store(true, Ordering::Release); - onclose_event.notify(usize::MAX); - }) as Box); + let onclose_closed = closed.clone(); + let onclose_event = close_event.clone(); + let onclose = Closure::wrap(Box::new(move || { + onclose_closed.store(true, Ordering::Release); + onclose_event.notify(usize::MAX); + }) as Box); - let onerror_tx = read_tx.clone(); - let onerror_closed = closed.clone(); - let onerror_close = close_event.clone(); - let onerror_event = error_event.clone(); - let onerror = Closure::wrap(Box::new(move || { - let _ = onerror_tx.send(WebSocketMessage::Error); - onerror_closed.store(true, Ordering::Release); - onerror_close.notify(usize::MAX); - onerror_event.notify(usize::MAX); - }) as Box); + let onerror_tx = read_tx.clone(); + let onerror_closed = closed.clone(); + let onerror_close = close_event.clone(); + let onerror_event = error_event.clone(); + let onerror = Closure::wrap(Box::new(move || { + let _ = onerror_tx.send(WebSocketMessage::Error); + onerror_closed.store(true, Ordering::Release); + onerror_close.notify(usize::MAX); + onerror_event.notify(usize::MAX); + }) as Box); - let ws = if protocols.is_empty() { - WebSocket::new(url) - } else { - WebSocket::new_with_str_sequence( - url, - &protocols - .iter() - .fold(Array::new(), |acc, x| { - acc.push(&x.into()); - acc - }) - .into(), - ) - } - .map_err(|_| EpoxyError::WebSocketConnectFailed)?; - ws.set_binary_type(BinaryType::Arraybuffer); - ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); - ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); - ws.set_onclose(Some(onclose.as_ref().unchecked_ref())); - ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); + let ws = if protocols.is_empty() { + WebSocket::new(url) + } else { + WebSocket::new_with_str_sequence( + url, + &protocols + .iter() + .fold(Array::new(), |acc, x| { + acc.push(&x.into()); + acc + }) + .into(), + ) + } + .map_err(|_| EpoxyError::WebSocketConnectFailed)?; + ws.set_binary_type(BinaryType::Arraybuffer); + ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); + ws.set_onclose(Some(onclose.as_ref().unchecked_ref())); + ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); - Ok(( - Self { - inner: SendWrapper::new(ws), - open_event, - error_event, - close_event: close_event.clone(), - closed: closed.clone(), - onopen: SendWrapper::new(onopen), - onclose: SendWrapper::new(onclose), - onerror: SendWrapper::new(onerror), - onmessage: SendWrapper::new(onmessage), - }, - WebSocketReader { - read_rx, - closed, - close_event, - }, - )) - } + Ok(( + Self { + inner: SendWrapper::new(ws), + open_event, + error_event, + close_event: close_event.clone(), + closed: closed.clone(), + onopen: SendWrapper::new(onopen), + onclose: SendWrapper::new(onclose), + onerror: SendWrapper::new(onerror), + onmessage: SendWrapper::new(onmessage), + }, + WebSocketReader { + read_rx, + closed, + close_event, + }, + )) + } - pub async fn wait_for_open(&self) -> bool { - if self.closed.load(Ordering::Acquire) { - return false; - } - futures_util::select! { - _ = self.open_event.listen().fuse() => true, - _ = self.error_event.listen().fuse() => false, - } - } + pub async fn wait_for_open(&self) -> bool { + if self.closed.load(Ordering::Acquire) { + return false; + } + futures_util::select! { + _ = self.open_event.listen().fuse() => true, + _ = self.error_event.listen().fuse() => false, + } + } } #[async_trait] impl WebSocketWrite for WebSocketWrapper { - async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { - use wisp_mux::ws::OpCode::*; - if self.closed.load(Ordering::Acquire) { - return Err(WispError::WsImplSocketClosed); - } - match frame.opcode { - Binary | Text => self - .inner - .send_with_u8_array(&frame.payload) - .map_err(|_| WebSocketError::SendFailed.into()), - Close => { - let _ = self.inner.close(); - Ok(()) - } - _ => Err(WispError::WsImplNotSupported), - } - } + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> { + use wisp_mux::ws::OpCode::*; + if self.closed.load(Ordering::Acquire) { + return Err(WispError::WsImplSocketClosed); + } + match frame.opcode { + Binary | Text => self + .inner + .send_with_u8_array(&frame.payload) + .map_err(|_| WebSocketError::SendFailed.into()), + Close => { + let _ = self.inner.close(); + Ok(()) + } + _ => Err(WispError::WsImplNotSupported), + } + } - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.inner - .close() - .map_err(|_| WebSocketError::CloseFailed.into()) - } + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.inner + .close() + .map_err(|_| WebSocketError::CloseFailed.into()) + } } impl Drop for WebSocketWrapper { - fn drop(&mut self) { - self.inner.set_onopen(None); - self.inner.set_onclose(None); - self.inner.set_onerror(None); - self.inner.set_onmessage(None); - self.closed.store(true, Ordering::Release); - self.close_event.notify(usize::MAX); - let _ = self.inner.close(); - } + fn drop(&mut self) { + self.inner.set_onopen(None); + self.inner.set_onclose(None); + self.inner.set_onerror(None); + self.inner.set_onmessage(None); + self.closed.store(true, Ordering::Release); + self.close_event.notify(usize::MAX); + let _ = self.inner.close(); + } } diff --git a/server/src/main.rs b/server/src/main.rs index d44da1c..fd6ef4f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -5,36 +5,36 @@ use bytes::Bytes; use cfg_if::cfg_if; use clap::Parser; use fastwebsockets::{ - upgrade::{self, UpgradeFut}, - CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, + upgrade::{self, UpgradeFut}, + CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, }; use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, upgrade::Parts, Request, Response, - StatusCode, + body::Incoming, server::conn::http1, service::service_fn, upgrade::Parts, Request, Response, + StatusCode, }; use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; use tokio::{ - io::{copy, AsyncBufReadExt, AsyncWriteExt}, - net::{lookup_host, TcpListener, TcpStream, UdpSocket}, - select, + io::{copy, AsyncBufReadExt, AsyncWriteExt}, + net::{lookup_host, TcpListener, TcpStream, UdpSocket}, + select, }; #[cfg(unix)] use tokio_util::either::Either; use tokio_util::{ - codec::{BytesCodec, Framed}, - compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}, + codec::{BytesCodec, Framed}, + compat::{FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt}, }; use wisp_mux::{ - extensions::{ - password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, - udp::UdpProtocolExtensionBuilder, - ProtocolExtensionBuilder, - }, - CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError, + extensions::{ + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::UdpProtocolExtensionBuilder, + ProtocolExtensionBuilder, + }, + CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRW, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -43,527 +43,521 @@ type HttpBody = http_body_util::Full; #[derive(Parser)] #[command(version = clap::crate_version!())] struct Cli { - /// URL prefix the server should serve on - #[arg(long)] - prefix: Option, - /// Port the server should bind to - #[arg(long, short, default_value = "4000")] - port: String, - /// Host the server should bind to - #[arg(long = "host", short, value_name = "HOST", default_value = "0.0.0.0")] - bind_host: String, - /// Whether the server should listen on a Unix socket located at the value of the host argument - #[arg(long, short)] - unix_socket: bool, - /// Whether the server should block IP addresses that are not globally reachable - /// - /// See https://doc.rust-lang.org/std/net/struct.Ipv4Addr.html#method.is_global for which IP - /// addresses are blocked - #[arg(long, short = 'B')] - block_local: bool, - /// Whether the server should block UDP - /// - /// This does nothing for wsproxy as that is always TCP - #[arg(long)] - block_udp: bool, - /// Whether the server should block ports other than 80 or 443 - #[arg(long)] - block_non_http: bool, - /// Path to a file containing `user:password` separated by newlines. This is plaintext!!! - /// - /// `user` cannot contain `:`. Whitespace will be trimmed. - #[arg(long)] - auth: Option, - /// Use Wisp V1. - #[arg(long)] - wisp_v1: bool, + /// URL prefix the server should serve on + #[arg(long)] + prefix: Option, + /// Port the server should bind to + #[arg(long, short, default_value = "4000")] + port: String, + /// Host the server should bind to + #[arg(long = "host", short, value_name = "HOST", default_value = "0.0.0.0")] + bind_host: String, + /// Whether the server should listen on a Unix socket located at the value of the host argument + #[arg(long, short)] + unix_socket: bool, + /// Whether the server should block IP addresses that are not globally reachable + /// + /// See https://doc.rust-lang.org/std/net/struct.Ipv4Addr.html#method.is_global for which IP + /// addresses are blocked + #[arg(long, short = 'B')] + block_local: bool, + /// Whether the server should block UDP + /// + /// This does nothing for wsproxy as that is always TCP + #[arg(long)] + block_udp: bool, + /// Whether the server should block ports other than 80 or 443 + #[arg(long)] + block_non_http: bool, + /// Path to a file containing `user:password` separated by newlines. This is plaintext!!! + /// + /// `user` cannot contain `:`. Whitespace will be trimmed. + #[arg(long)] + auth: Option, + /// Use Wisp V1. + #[arg(long)] + wisp_v1: bool, } #[derive(Clone)] struct MuxOptions { - pub block_local: bool, - pub block_udp: bool, - pub block_non_http: bool, - pub enforce_auth: bool, - pub auth: Arc>>, - pub wisp_v1: bool, + pub block_local: bool, + pub block_udp: bool, + pub block_non_http: bool, + pub enforce_auth: bool, + pub auth: Arc>>, + pub wisp_v1: bool, } cfg_if! { - if #[cfg(unix)] { - type ListenerStream = Either; - } else { - type ListenerStream = TcpStream; - } + if #[cfg(unix)] { + type ListenerStream = Either; + } else { + type ListenerStream = TcpStream; + } } enum Listener { - Tcp(TcpListener), - #[cfg(unix)] - Unix(UnixListener), + Tcp(TcpListener), + #[cfg(unix)] + Unix(UnixListener), } impl Listener { - pub async fn accept(&self) -> Result<(ListenerStream, String), std::io::Error> { - Ok(match self { - Listener::Tcp(listener) => { - let (stream, addr) = listener.accept().await?; - cfg_if! { - if #[cfg(unix)] { - (Either::Left(stream), addr.to_string()) - } else { - (stream, addr.to_string()) - } - } - } - #[cfg(unix)] - Listener::Unix(listener) => { - let (stream, addr) = listener.accept().await?; - ( - Either::Right(stream), - addr.as_pathname() - .map(|x| x.to_string_lossy().into()) - .unwrap_or("unknown_unix_socket".into()), - ) - } - }) - } + pub async fn accept(&self) -> Result<(ListenerStream, String), std::io::Error> { + Ok(match self { + Listener::Tcp(listener) => { + let (stream, addr) = listener.accept().await?; + cfg_if! { + if #[cfg(unix)] { + (Either::Left(stream), addr.to_string()) + } else { + (stream, addr.to_string()) + } + } + } + #[cfg(unix)] + Listener::Unix(listener) => { + let (stream, addr) = listener.accept().await?; + ( + Either::Right(stream), + addr.as_pathname() + .map(|x| x.to_string_lossy().into()) + .unwrap_or("unknown_unix_socket".into()), + ) + } + }) + } } async fn bind(addr: &str, unix: bool) -> Result { - cfg_if! { - if #[cfg(unix)] { - if unix { - if std::fs::metadata(addr).is_ok() { - println!("attempting to remove old socket {:?}", addr); - std::fs::remove_file(addr)?; - } - return Ok(Listener::Unix(UnixListener::bind(addr)?)); - } - } else { - if unix { - panic!("Unix sockets are only supported on Unix."); - } - } - } + cfg_if! { + if #[cfg(unix)] { + if unix { + if std::fs::metadata(addr).is_ok() { + println!("attempting to remove old socket {:?}", addr); + std::fs::remove_file(addr)?; + } + return Ok(Listener::Unix(UnixListener::bind(addr)?)); + } + } else { + if unix { + panic!("Unix sockets are only supported on Unix."); + } + } + } - Ok(Listener::Tcp(TcpListener::bind(addr).await?)) + Ok(Listener::Tcp(TcpListener::bind(addr).await?)) } #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Error> { - #[cfg(feature = "tokio-console")] - console_subscriber::init(); - let opt = Cli::parse(); - let addr = if opt.unix_socket { - opt.bind_host - } else { - format!("{}:{}", opt.bind_host, opt.port) - }; + #[cfg(feature = "tokio-console")] + console_subscriber::init(); + let opt = Cli::parse(); + let addr = if opt.unix_socket { + opt.bind_host + } else { + format!("{}:{}", opt.bind_host, opt.port) + }; - let socket = bind(&addr, opt.unix_socket).await?; + let socket = bind(&addr, opt.unix_socket).await?; - let prefix = if let Some(prefix) = opt.prefix { - match (prefix.starts_with('/'), prefix.ends_with('/')) { - (true, true) => prefix, - (true, false) => prefix + "/", - (false, true) => "/".to_string() + &prefix, - (false, false) => "/".to_string() + &prefix + "/", - } - } else { - "/".to_string() - }; + let prefix = if let Some(prefix) = opt.prefix { + match (prefix.starts_with('/'), prefix.ends_with('/')) { + (true, true) => prefix, + (true, false) => prefix + "/", + (false, true) => "/".to_string() + &prefix, + (false, false) => "/".to_string() + &prefix + "/", + } + } else { + "/".to_string() + }; - let mut auth = HashMap::new(); - let enforce_auth = opt.auth.is_some(); - if let Some(file) = opt.auth { - let file = std::fs::read_to_string(file)?; - for entry in file.split('\n').filter_map(|x| { - if x.contains(':') { - Some(x.trim()) - } else { - None - } - }) { - let split: Vec<_> = entry.split(':').collect(); - let username = split[0]; - let password = split[1..].join(":"); - println!( - "adding username {:?} password {:?} to allowed auth", - username, password - ); - auth.insert(username.to_string(), password.to_string()); - } - } - let pw_ext = PasswordProtocolExtensionBuilder::new_server(auth); + let mut auth = HashMap::new(); + let enforce_auth = opt.auth.is_some(); + if let Some(file) = opt.auth { + let file = std::fs::read_to_string(file)?; + for entry in file.split('\n').filter_map(|x| { + if x.contains(':') { + Some(x.trim()) + } else { + None + } + }) { + let split: Vec<_> = entry.split(':').collect(); + let username = split[0]; + let password = split[1..].join(":"); + println!( + "adding username {:?} password {:?} to allowed auth", + username, password + ); + auth.insert(username.to_string(), password.to_string()); + } + } + let pw_ext = PasswordProtocolExtensionBuilder::new_server(auth); - let mux_options = MuxOptions { - block_local: opt.block_local, - block_non_http: opt.block_non_http, - block_udp: opt.block_udp, - auth: Arc::new(vec![ - Box::new(UdpProtocolExtensionBuilder()), - Box::new(pw_ext), - ]), - enforce_auth, - wisp_v1: opt.wisp_v1, - }; + let mux_options = MuxOptions { + block_local: opt.block_local, + block_non_http: opt.block_non_http, + block_udp: opt.block_udp, + auth: Arc::new(vec![ + Box::new(UdpProtocolExtensionBuilder), + Box::new(pw_ext), + ]), + enforce_auth, + wisp_v1: opt.wisp_v1, + }; - println!("listening on `{}` with prefix `{}`", addr, prefix); - while let Ok((stream, addr)) = socket.accept().await { - let prefix = prefix.clone(); - let mux_options = mux_options.clone(); - tokio::spawn(async move { - let service = service_fn(move |res| { - accept_http(res, addr.clone(), prefix.clone(), mux_options.clone()) - }); - let conn = http1::Builder::new() - .serve_connection(TokioIo::new(stream), service) - .with_upgrades(); - if let Err(err) = conn.await { - println!("failed to serve conn: {:?}", err); - } - }); - } + println!("listening on `{}` with prefix `{}`", addr, prefix); + while let Ok((stream, addr)) = socket.accept().await { + let prefix = prefix.clone(); + let mux_options = mux_options.clone(); + tokio::spawn(async move { + let service = service_fn(move |res| { + accept_http(res, addr.clone(), prefix.clone(), mux_options.clone()) + }); + let conn = http1::Builder::new() + .serve_connection(TokioIo::new(stream), service) + .with_upgrades(); + if let Err(err) = conn.await { + println!("failed to serve conn: {:?}", err); + } + }); + } - Ok(()) + Ok(()) } async fn accept_http( - mut req: Request, - addr: String, - prefix: String, - mux_options: MuxOptions, + mut req: Request, + addr: String, + prefix: String, + mux_options: MuxOptions, ) -> Result, WebSocketError> { - let uri = req.uri().path().to_string(); - if upgrade::is_upgrade_request(&req) - && let Some(uri) = uri.strip_prefix(&prefix) - { - let (res, fut) = upgrade::upgrade(&mut req)?; + let uri = req.uri().path().to_string(); + if upgrade::is_upgrade_request(&req) + && let Some(uri) = uri.strip_prefix(&prefix) + { + let (res, fut) = upgrade::upgrade(&mut req)?; - if uri.is_empty() { - tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await }); - } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { - tokio::spawn(async move { - accept_wsproxy( - fut, - uri, - addr.clone(), - mux_options.block_local, - mux_options.block_non_http, - ) - .await - }); - } + if uri.is_empty() { + tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await }); + } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { + tokio::spawn(async move { + accept_wsproxy( + fut, + uri, + addr.clone(), + mux_options.block_local, + mux_options.block_non_http, + ) + .await + }); + } - Ok(Response::from_parts( - res.into_parts().0, - HttpBody::new(Bytes::new()), - )) - } else { - println!("random request to path {:?}", uri); - Ok(Response::builder() - .status(StatusCode::OK) - .body(HttpBody::new(":3".into())) - .unwrap()) - } + Ok(Response::from_parts( + res.into_parts().0, + HttpBody::new(Bytes::new()), + )) + } else { + println!("random request to path {:?}", uri); + Ok(Response::builder() + .status(StatusCode::OK) + .body(HttpBody::new(":3".into())) + .unwrap()) + } } async fn copy_buf(mux: MuxStreamAsyncRW, tcp: TcpStream) -> std::io::Result<()> { - let (muxrx, muxtx) = mux.into_split(); - let mut muxrx = muxrx.compat(); - let mut muxtx = muxtx.compat_write(); + let (muxrx, muxtx) = mux.into_split(); + let mut muxrx = muxrx.compat(); + let mut muxtx = muxtx.compat_write(); - let (mut tcprx, mut tcptx) = tcp.into_split(); + let (mut tcprx, mut tcptx) = tcp.into_split(); - let fast_fut = async { - loop { - let buf = muxrx.fill_buf().await?; - if buf.is_empty() { - tcptx.flush().await?; - return Ok(()); - } + let fast_fut = async { + loop { + let buf = muxrx.fill_buf().await?; + if buf.is_empty() { + tcptx.flush().await?; + return Ok(()); + } - let i = tcptx.write(buf).await?; - if i == 0 { - return Err(std::io::ErrorKind::WriteZero.into()); - } + let i = tcptx.write(buf).await?; + if i == 0 { + return Err(std::io::ErrorKind::WriteZero.into()); + } - muxrx.consume(i); - } - }; + muxrx.consume(i); + } + }; - let slow_fut = copy(&mut tcprx, &mut muxtx); + let slow_fut = copy(&mut tcprx, &mut muxtx); - select! { - x = fast_fut => x, - x = slow_fut => x.map(|_| ()), - } + select! { + x = fast_fut => x, + x = slow_fut => x.map(|_| ()), + } } async fn handle_mux( - packet: ConnectPacket, - stream: MuxStream, + packet: ConnectPacket, + stream: MuxStream, ) -> Result> { - let uri = format!( - "{}:{}", - packet.destination_hostname, packet.destination_port - ); - match packet.stream_type { - StreamType::Tcp => { - let tcp_stream = TcpStream::connect(uri).await?; - let mux = stream.into_io().into_asyncrw(); + let uri = format!( + "{}:{}", + packet.destination_hostname, packet.destination_port + ); + match packet.stream_type { + StreamType::Tcp => { + let tcp_stream = TcpStream::connect(uri).await?; + let mux = stream.into_io().into_asyncrw(); copy_buf(mux, tcp_stream).await?; - } - StreamType::Udp => { - let uri = lookup_host(uri) - .await? - .next() - .ok_or(WispError::InvalidUri)?; - let udp_socket = - UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }).await?; - udp_socket.connect(uri).await?; - let mut data = vec![0u8; 65507]; // udp standard max datagram size - loop { - tokio::select! { - size = udp_socket.recv(&mut data) => { - let size = size?; - stream.write(Bytes::copy_from_slice(&data[..size])).await? - }, - event = stream.read() => { - match event { - Some(event) => { - let _ = udp_socket.send(&event).await?; - } - None => break, - } - } - } - } - } - StreamType::Unknown(_) => { - stream.close(CloseReason::ServerStreamInvalidInfo).await?; - return Ok(false); - } - } - Ok(true) + } + StreamType::Udp => { + let uri = lookup_host(uri) + .await? + .next() + .ok_or(WispError::InvalidUri)?; + let udp_socket = + UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }).await?; + udp_socket.connect(uri).await?; + let mut data = vec![0u8; 65507]; // udp standard max datagram size + loop { + tokio::select! { + size = udp_socket.recv(&mut data) => { + let size = size?; + stream.write(Bytes::copy_from_slice(&data[..size])).await? + }, + event = stream.read() => { + match event { + Some(event) => { + let _ = udp_socket.send(&event).await?; + } + None => break, + } + } + } + } + } + StreamType::Unknown(_) => { + stream.close(CloseReason::ServerStreamInvalidInfo).await?; + return Ok(false); + } + } + Ok(true) } async fn accept_ws( - ws: UpgradeFut, - addr: String, - mux_options: MuxOptions, + ws: UpgradeFut, + addr: String, + mux_options: MuxOptions, ) -> Result<(), Box> { - let mut ws = ws.await?; - // to prevent memory ""leaks"" because users are sending in packets way too fast the message - // size is set to 1M - ws.set_max_message_size(1024 * 1024); - let (rx, tx) = ws.split(|x| { - let Parts { - io, read_buf: buf, .. - } = x - .into_inner() - .downcast::>() - .unwrap(); - assert_eq!(buf.len(), 0); - cfg_if! { - if #[cfg(unix)] { - match io.into_inner() { - Either::Left(x) => { - let (rx, tx) = x.into_split(); - (Either::Left(rx), Either::Left(tx)) - } - Either::Right(x) => { - let (rx, tx) = x.into_split(); - (Either::Right(rx), Either::Right(tx)) - } - } - } else { - io.into_inner().into_split() - } - } - }); - let rx = FragmentCollectorRead::new(rx); + let mut ws = ws.await?; + // to prevent memory ""leaks"" because users are sending in packets way too fast the message + // size is set to 1M + ws.set_max_message_size(1024 * 1024); + let (rx, tx) = ws.split(|x| { + let Parts { + io, read_buf: buf, .. + } = x.into_inner() + .downcast::>() + .unwrap(); + assert_eq!(buf.len(), 0); + cfg_if! { + if #[cfg(unix)] { + match io.into_inner() { + Either::Left(x) => { + let (rx, tx) = x.into_split(); + (Either::Left(rx), Either::Left(tx)) + } + Either::Right(x) => { + let (rx, tx) = x.into_split(); + (Either::Right(rx), Either::Right(tx)) + } + } + } else { + io.into_inner().into_split() + } + } + }); + let rx = FragmentCollectorRead::new(rx); - println!("{:?}: connected", addr); - // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer - // size is set to 512 - let (mux, fut) = if mux_options.wisp_v1 { - ServerMux::create(rx, tx, 512, None) - .await? - .with_no_required_extensions() - } else if mux_options.enforce_auth { - ServerMux::create(rx, tx, 512, Some(mux_options.auth.as_slice())) - .await? - .with_required_extensions(&[PasswordProtocolExtension::ID]) - .await? - } else { - ServerMux::create( - rx, - tx, - 512, - Some(&[Box::new(UdpProtocolExtensionBuilder())]), - ) - .await? - .with_no_required_extensions() - }; + println!("{:?}: connected", addr); + // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer + // size is set to 512 + let (mux, fut) = if mux_options.wisp_v1 { + ServerMux::create(rx, tx, 512, None) + .await? + .with_no_required_extensions() + } else if mux_options.enforce_auth { + ServerMux::create(rx, tx, 512, Some(mux_options.auth.as_slice())) + .await? + .with_required_extensions(&[PasswordProtocolExtension::ID]) + .await? + } else { + ServerMux::create(rx, tx, 512, Some(&[Box::new(UdpProtocolExtensionBuilder)])) + .await? + .with_no_required_extensions() + }; - // this results in one stream ""leaking"" a maximum of ~512M + // this results in one stream ""leaking"" a maximum of ~512M - println!( - "{:?}: downgraded: {} extensions supported: {:?}", - addr, mux.downgraded, mux.supported_extension_ids - ); + println!( + "{:?}: downgraded: {} extensions supported: {:?}", + addr, mux.downgraded, mux.supported_extension_ids + ); - tokio::spawn(async move { - if let Err(e) = fut.await { - println!("err in mux: {:?}", e); - } - }); + tokio::spawn(async move { + if let Err(e) = fut.await { + println!("err in mux: {:?}", e); + } + }); - while let Some((packet, stream)) = mux.server_new_stream().await { - tokio::spawn(async move { - if (mux_options.block_non_http - && !(packet.destination_port == 80 || packet.destination_port == 443)) - || (mux_options.block_udp && packet.stream_type == StreamType::Udp) - { - let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; - return; - } - if mux_options.block_local { - match lookup_host(format!( - "{}:{}", - packet.destination_hostname, packet.destination_port - )) - .await - .ok() - .and_then(|mut x| x.next()) - .map(|x| !x.ip().is_global()) - { - Some(true) => { - let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; - return; - } - Some(false) => {} - None => { - let _ = stream - .close(CloseReason::ServerStreamConnectionRefused) - .await; - return; - } - } - } - let close_err = stream.get_close_handle(); - let close_ok = stream.get_close_handle(); - let _ = handle_mux(packet, stream) - .or_else(|err| async move { - let _ = close_err.close(CloseReason::Unexpected).await; - Err(err) - }) - .and_then(|should_send| async move { - if should_send { - let _ = close_ok.close(CloseReason::Voluntary).await; - } - Ok(()) - }) - .await; - }); - } + while let Some((packet, stream)) = mux.server_new_stream().await { + tokio::spawn(async move { + if (mux_options.block_non_http + && !(packet.destination_port == 80 || packet.destination_port == 443)) + || (mux_options.block_udp && packet.stream_type == StreamType::Udp) + { + let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; + return; + } + if mux_options.block_local { + match lookup_host(format!( + "{}:{}", + packet.destination_hostname, packet.destination_port + )) + .await + .ok() + .and_then(|mut x| x.next()) + .map(|x| !x.ip().is_global()) + { + Some(true) => { + let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; + return; + } + Some(false) => {} + None => { + let _ = stream + .close(CloseReason::ServerStreamConnectionRefused) + .await; + return; + } + } + } + let close_err = stream.get_close_handle(); + let close_ok = stream.get_close_handle(); + let _ = handle_mux(packet, stream) + .or_else(|err| async move { + let _ = close_err.close(CloseReason::Unexpected).await; + Err(err) + }) + .and_then(|should_send| async move { + if should_send { + let _ = close_ok.close(CloseReason::Voluntary).await; + } + Ok(()) + }) + .await; + }); + } - println!("{:?}: disconnected", addr); - Ok(()) + println!("{:?}: disconnected", addr); + Ok(()) } async fn accept_wsproxy( - ws: UpgradeFut, - incoming_uri: String, - addr: String, - block_local: bool, - block_non_http: bool, + ws: UpgradeFut, + incoming_uri: String, + addr: String, + block_local: bool, + block_non_http: bool, ) -> Result<(), Box> { - let mut ws_stream = FragmentCollector::new(ws.await?); + let mut ws_stream = FragmentCollector::new(ws.await?); - println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); + println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri); - let Some(host) = lookup_host(&incoming_uri) - .await - .ok() - .and_then(|mut x| x.next()) - else { - ws_stream - .write_frame(Frame::close( - CloseCode::Error.into(), - b"failed to resolve uri", - )) - .await?; - return Ok(()); - }; + let Some(host) = lookup_host(&incoming_uri) + .await + .ok() + .and_then(|mut x| x.next()) + else { + ws_stream + .write_frame(Frame::close( + CloseCode::Error.into(), + b"failed to resolve uri", + )) + .await?; + return Ok(()); + }; - if block_local && !host.ip().is_global() { - ws_stream - .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) - .await?; - return Ok(()); - } + if block_local && !host.ip().is_global() { + ws_stream + .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) + .await?; + return Ok(()); + } - if block_non_http && !(host.port() == 80 || host.port() == 443) { - ws_stream - .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) - .await?; - return Ok(()); - } + if block_non_http && !(host.port() == 80 || host.port() == 443) { + ws_stream + .write_frame(Frame::close(CloseCode::Error.into(), b"blocked uri")) + .await?; + return Ok(()); + } - let tcp_stream = match TcpStream::connect(incoming_uri).await { - Ok(stream) => stream, - Err(err) => { - ws_stream - .write_frame(Frame::close(CloseCode::Error.into(), b"failed to connect")) - .await?; - return Err(Box::new(err)); - } - }; - let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); + let tcp_stream = match TcpStream::connect(incoming_uri).await { + Ok(stream) => stream, + Err(err) => { + ws_stream + .write_frame(Frame::close(CloseCode::Error.into(), b"failed to connect")) + .await?; + return Err(Box::new(err)); + } + }; + let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new()); - loop { - tokio::select! { - event = ws_stream.read_frame() => { - match event { - Ok(frame) => { - match frame.opcode { - OpCode::Text | OpCode::Binary => { - let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await; - } - OpCode::Close => { - // tokio closes the stream for us - drop(tcp_stream_framed); - break; - } - _ => {} - } - }, - Err(_) => { - // tokio closes the stream for us - drop(tcp_stream_framed); - break; - } - } - }, - event = tcp_stream_framed.next() => { - if let Some(res) = event { - match res { - Ok(buf) => { - let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await; - } - Err(_) => { - let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await; - } - } - } - } - } - } + loop { + tokio::select! { + event = ws_stream.read_frame() => { + match event { + Ok(frame) => { + match frame.opcode { + OpCode::Text | OpCode::Binary => { + let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await; + } + OpCode::Close => { + // tokio closes the stream for us + drop(tcp_stream_framed); + break; + } + _ => {} + } + }, + Err(_) => { + // tokio closes the stream for us + drop(tcp_stream_framed); + break; + } + } + }, + event = tcp_stream_framed.next() => { + if let Some(res) = event { + match res { + Ok(buf) => { + let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await; + } + Err(_) => { + let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await; + } + } + } + } + } + } - println!("{:?}: disconnected (wsproxy)", addr); + println!("{:?}: disconnected (wsproxy)", addr); - Ok(()) + Ok(()) } diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 0452d65..d8bacc9 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -6,50 +6,50 @@ use futures::future::select_all; use http_body_util::Empty; use humantime::format_duration; use hyper::{ - header::{CONNECTION, UPGRADE}, - Request, Uri, + header::{CONNECTION, UPGRADE}, + Request, Uri, }; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ - error::Error, - future::Future, - io::{stdout, IsTerminal, Write}, - net::SocketAddr, - process::exit, - sync::Arc, - time::{Duration, Instant}, + error::Error, + future::Future, + io::{stdout, IsTerminal, Write}, + net::SocketAddr, + process::exit, + sync::Arc, + time::{Duration, Instant}, }; use tokio::{ - net::TcpStream, - select, - signal::unix::{signal, SignalKind}, - time::{interval, sleep}, + net::TcpStream, + select, + signal::unix::{signal, SignalKind}, + time::{interval, sleep}, }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; use wisp_mux::{ - extensions::{ - password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, - udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, - ProtocolExtensionBuilder, - }, - ClientMux, StreamType, WispError, + extensions::{ + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + ProtocolExtensionBuilder, + }, + ClientMux, StreamType, WispError, }; #[derive(Debug)] enum WispClientError { - InvalidUriScheme, - UriHasNoHost, + InvalidUriScheme, + UriHasNoHost, } impl std::fmt::Display for WispClientError { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { - use WispClientError as E; - match self { - E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"), - E::UriHasNoHost => write!(fmt, "URI has no host"), - } - } + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + use WispClientError as E; + match self { + E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"), + E::UriHasNoHost => write!(fmt, "URI has no host"), + } + } } impl Error for WispClientError {} @@ -58,165 +58,166 @@ struct SpawnExecutor; impl hyper::rt::Executor for SpawnExecutor where - Fut: Future + Send + 'static, - Fut::Output: Send + 'static, + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, { - fn execute(&self, fut: Fut) { - tokio::task::spawn(fut); - } + fn execute(&self, fut: Fut) { + tokio::task::spawn(fut); + } } #[derive(Parser)] #[command(version = clap::crate_version!())] struct Cli { - /// Wisp server URL - #[arg(short, long)] - wisp: Uri, - /// TCP server address - #[arg(short, long)] - tcp: SocketAddr, - /// Number of streams - #[arg(short, long, default_value_t = 10)] - streams: usize, - /// Size of packets sent, in KB - #[arg(short, long, default_value_t = 1)] - packet_size: usize, - /// Duration to run the test for - #[arg(short, long)] - duration: Option, - /// Ask for UDP - #[arg(short, long)] - udp: bool, - /// Enable auth: format is `username:password` - /// - /// Usernames and passwords are sent in plaintext!! - #[arg(long)] - auth: Option, - /// Make a Wisp V1 connection - #[arg(long)] - wisp_v1: bool, + /// Wisp server URL + #[arg(short, long)] + wisp: Uri, + /// TCP server address + #[arg(short, long)] + tcp: SocketAddr, + /// Number of streams + #[arg(short, long, default_value_t = 10)] + streams: usize, + /// Size of packets sent, in KB + #[arg(short, long, default_value_t = 1)] + packet_size: usize, + /// Duration to run the test for + #[arg(short, long)] + duration: Option, + /// Ask for UDP + #[arg(short, long)] + udp: bool, + /// Enable auth: format is `username:password` + /// + /// Usernames and passwords are sent in plaintext!! + #[arg(long)] + auth: Option, + /// Make a Wisp V1 connection + #[arg(long)] + wisp_v1: bool, } #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<(), Box> { - #[cfg(feature = "tokio-console")] - console_subscriber::init(); - let opts = Cli::parse(); + #[cfg(feature = "tokio-console")] + console_subscriber::init(); + let opts = Cli::parse(); - let tls = match opts - .wisp - .scheme_str() - .ok_or(WispClientError::InvalidUriScheme)? - { - "wss" => Ok(true), - "ws" => Ok(false), - _ => Err(WispClientError::InvalidUriScheme), - }?; - let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; - let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 }); - let addr_path = opts.wisp.path(); - let addr_dest = opts.tcp.ip().to_string(); - let addr_dest_port = opts.tcp.port(); + let tls = match opts + .wisp + .scheme_str() + .ok_or(WispClientError::InvalidUriScheme)? + { + "wss" => Ok(true), + "ws" => Ok(false), + _ => Err(WispClientError::InvalidUriScheme), + }?; + let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; + let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 }); + let addr_path = opts.wisp.path(); + let addr_dest = opts.tcp.ip().to_string(); + let addr_dest_port = opts.tcp.port(); - let auth = opts.auth.map(|auth| { - let split: Vec<_> = auth.split(':').collect(); - let username = split[0].to_string(); - let password = split[1..].join(":"); - PasswordProtocolExtensionBuilder::new_client(username, password) - }); + let auth = opts.auth.map(|auth| { + let split: Vec<_> = auth.split(':').collect(); + let username = split[0].to_string(); + let password = split[1..].join(":"); + PasswordProtocolExtensionBuilder::new_client(username, password) + }); - println!( - "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", - opts.wisp, opts.packet_size, opts.tcp, opts.streams, - ); + println!( + "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", + opts.wisp, opts.packet_size, opts.tcp, opts.streams, + ); - let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; - let socket = if tls { - let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); - Either::Left(cx.connect(addr, socket).await?) - } else { - Either::Right(socket) - }; - let req = Request::builder() - .method("GET") - .uri(addr_path) - .header("Host", addr) - .header(UPGRADE, "websocket") - .header(CONNECTION, "upgrade") - .header( - "Sec-WebSocket-Key", - fastwebsockets::handshake::generate_key(), - ) - .header("Sec-WebSocket-Version", "13") - .body(Empty::::new())?; + let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; + let socket = if tls { + let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); + Either::Left(cx.connect(addr, socket).await?) + } else { + Either::Right(socket) + }; + let req = Request::builder() + .method("GET") + .uri(addr_path) + .header("Host", addr) + .header(UPGRADE, "websocket") + .header(CONNECTION, "upgrade") + .header( + "Sec-WebSocket-Key", + fastwebsockets::handshake::generate_key(), + ) + .header("Sec-WebSocket-Version", "13") + .body(Empty::::new())?; - let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; + let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; - let (rx, tx) = ws.split(tokio::io::split); - let rx = FragmentCollectorRead::new(rx); + let (rx, tx) = ws.split(tokio::io::split); + let rx = FragmentCollectorRead::new(rx); - let mut extensions: Vec> = Vec::new(); - let mut extension_ids: Vec = Vec::new(); - if opts.udp { - extensions.push(Box::new(UdpProtocolExtensionBuilder())); - extension_ids.push(UdpProtocolExtension::ID); - } - if let Some(auth) = auth { - extensions.push(Box::new(auth)); - extension_ids.push(PasswordProtocolExtension::ID); - } + let mut extensions: Vec> = Vec::new(); + let mut extension_ids: Vec = Vec::new(); + if opts.udp { + extensions.push(Box::new(UdpProtocolExtensionBuilder)); + extension_ids.push(UdpProtocolExtension::ID); + } + if let Some(auth) = auth { + extensions.push(Box::new(auth)); + extension_ids.push(PasswordProtocolExtension::ID); + } - let (mux, fut) = if opts.wisp_v1 { - ClientMux::create(rx, tx, None) - .await? - .with_no_required_extensions() - } else { - ClientMux::create(rx, tx, Some(extensions.as_slice())) - .await? - .with_required_extensions(extension_ids.as_slice()).await? - }; + let (mux, fut) = if opts.wisp_v1 { + ClientMux::create(rx, tx, None) + .await? + .with_no_required_extensions() + } else { + ClientMux::create(rx, tx, Some(extensions.as_slice())) + .await? + .with_required_extensions(extension_ids.as_slice()) + .await? + }; - println!( - "connected and created ClientMux, was downgraded {}, extensions supported {:?}", - mux.downgraded, mux.supported_extension_ids - ); + println!( + "connected and created ClientMux, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); - let mut threads = Vec::with_capacity(opts.streams + 4); - let mut reads = Vec::with_capacity(opts.streams); + let mut threads = Vec::with_capacity(opts.streams + 4); + let mut reads = Vec::with_capacity(opts.streams); - threads.push(tokio::spawn(fut)); + threads.push(tokio::spawn(fut)); - let payload = Bytes::from(vec![0; 1024 * opts.packet_size]); + let payload = Bytes::from(vec![0; 1024 * opts.packet_size]); - let cnt = Arc::new(RelaxedCounter::new(0)); + let cnt = Arc::new(RelaxedCounter::new(0)); - let start_time = Instant::now(); - for _ in 0..opts.streams { - let (cr, cw) = mux - .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) - .await? - .into_split(); - let cnt = cnt.clone(); - let payload = payload.clone(); - threads.push(tokio::spawn(async move { - loop { - cw.write(payload.clone()).await?; - cnt.inc(); - } - #[allow(unreachable_code)] - Ok::<(), WispError>(()) - })); - reads.push(cr); - } + let start_time = Instant::now(); + for _ in 0..opts.streams { + let (cr, cw) = mux + .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) + .await? + .into_split(); + let cnt = cnt.clone(); + let payload = payload.clone(); + threads.push(tokio::spawn(async move { + loop { + cw.write(payload.clone()).await?; + cnt.inc(); + } + #[allow(unreachable_code)] + Ok::<(), WispError>(()) + })); + reads.push(cr); + } - threads.push(tokio::spawn(async move { - loop { - select_all(reads.iter().map(|x| Box::pin(x.read()))).await; - } - })); + threads.push(tokio::spawn(async move { + loop { + select_all(reads.iter().map(|x| Box::pin(x.read()))).await; + } + })); - let cnt_avg = cnt.clone(); - threads.push(tokio::spawn(async move { + let cnt_avg = cnt.clone(); + threads.push(tokio::spawn(async move { let mut interval = interval(Duration::from_millis(100)); let mut avg: SingleSumSMA = SingleSumSMA::new(); let mut last_time = 0; @@ -245,48 +246,48 @@ async fn main() -> Result<(), Box> { } })); - threads.push(tokio::spawn(async move { - let mut interrupt = - signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; - let mut terminate = - signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?; - select! { - _ = interrupt.recv() => (), - _ = terminate.recv() => (), - } - Ok(()) - })); + threads.push(tokio::spawn(async move { + let mut interrupt = + signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; + let mut terminate = + signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?; + select! { + _ = interrupt.recv() => (), + _ = terminate.recv() => (), + } + Ok(()) + })); - if let Some(duration) = opts.duration { - threads.push(tokio::spawn(async move { - sleep(duration.into()).await; - Ok(()) - })); - } + if let Some(duration) = opts.duration { + threads.push(tokio::spawn(async move { + sleep(duration.into()).await; + Ok(()) + })); + } - let out = select_all(threads.into_iter()).await; + let out = select_all(threads.into_iter()).await; - let duration_since = Instant::now().duration_since(start_time); + let duration_since = Instant::now().duration_since(start_time); - if let Err(err) = out.0? { - println!("\n\nerr: {:?}", err); - exit(1); - } + if let Err(err) = out.0? { + println!("\n\nerr: {:?}", err); + exit(1); + } - out.2.into_iter().for_each(|x| x.abort()); + out.2.into_iter().for_each(|x| x.abort()); - mux.close().await?; + mux.close().await?; - if duration_since.as_secs() != 0 { - println!( - "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", - cnt.get(), - opts.packet_size, - cnt.get() * opts.packet_size, - format_duration(duration_since), - (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), - ); - } + if duration_since.as_secs() != 0 { + println!( + "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", + cnt.get(), + opts.packet_size, + cnt.get() * opts.packet_size, + format_duration(duration_since), + (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), + ); + } - Ok(()) + Ok(()) } diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 8c3ec12..7de097f 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -8,8 +8,8 @@ use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - Role, WispError, + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, }; /// Type-erased protocol extension that implements Clone. @@ -17,90 +17,90 @@ use crate::{ pub struct AnyProtocolExtension(Box); impl AnyProtocolExtension { - /// Create a new type-erased protocol extension. - pub fn new(extension: T) -> Self { - Self(Box::new(extension)) - } + /// Create a new type-erased protocol extension. + pub fn new(extension: T) -> Self { + Self(Box::new(extension)) + } } impl Deref for AnyProtocolExtension { - type Target = dyn ProtocolExtension; - fn deref(&self) -> &Self::Target { - self.0.deref() - } + type Target = dyn ProtocolExtension; + fn deref(&self) -> &Self::Target { + self.0.deref() + } } impl DerefMut for AnyProtocolExtension { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.deref_mut() - } + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } } impl Clone for AnyProtocolExtension { - fn clone(&self) -> Self { - Self(self.0.box_clone()) - } + fn clone(&self) -> Self { + Self(self.0.box_clone()) + } } impl From for Bytes { - fn from(value: AnyProtocolExtension) -> Self { - let mut bytes = BytesMut::with_capacity(5); - let payload = value.encode(); - bytes.put_u8(value.get_id()); - bytes.put_u32_le(payload.len() as u32); - bytes.extend(payload); - bytes.freeze() - } + fn from(value: AnyProtocolExtension) -> Self { + let mut bytes = BytesMut::with_capacity(5); + let payload = value.encode(); + bytes.put_u8(value.get_id()); + bytes.put_u32_le(payload.len() as u32); + bytes.extend(payload); + bytes.freeze() + } } /// A Wisp protocol extension. /// /// See [the -/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). +/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions). #[async_trait] pub trait ProtocolExtension: std::fmt::Debug { - /// Get the protocol extension ID. - fn get_id(&self) -> u8; - /// Get the protocol extension's supported packets. - /// - /// Used to decide whether to call the protocol extension's packet handler. - fn get_supported_packets(&self) -> &'static [u8]; + /// Get the protocol extension ID. + fn get_id(&self) -> u8; + /// Get the protocol extension's supported packets. + /// + /// Used to decide whether to call the protocol extension's packet handler. + fn get_supported_packets(&self) -> &'static [u8]; - /// Encode self into Bytes. - fn encode(&self) -> Bytes; + /// Encode self into Bytes. + fn encode(&self) -> Bytes; - /// Handle the handshake part of a Wisp connection. - /// - /// This should be used to send or receive data before any streams are created. - async fn handle_handshake( - &mut self, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, - ) -> Result<(), WispError>; + /// Handle the handshake part of a Wisp connection. + /// + /// This should be used to send or receive data before any streams are created. + async fn handle_handshake( + &mut self, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; - /// Handle receiving a packet. - async fn handle_packet( - &mut self, - packet: Bytes, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, - ) -> Result<(), WispError>; + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + packet: Bytes, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; - /// Clone the protocol extension. - fn box_clone(&self) -> Box; + /// Clone the protocol extension. + fn box_clone(&self) -> Box; } /// Trait to build a Wisp protocol extension from a payload. pub trait ProtocolExtensionBuilder { - /// Get the protocol extension ID. - /// - /// Used to decide whether this builder should be used. - fn get_id(&self) -> u8; + /// Get the protocol extension ID. + /// + /// Used to decide whether this builder should be used. + fn get_id(&self) -> u8; - /// Build a protocol extension from the extension's metadata. - fn build_from_bytes(&self, bytes: Bytes, role: Role) - -> Result; + /// Build a protocol extension from the extension's metadata. + fn build_from_bytes(&self, bytes: Bytes, role: Role) + -> Result; - /// Build a protocol extension to send to the other side. - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; + /// Build a protocol extension to send to the other side. + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; } diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index 3fe15b3..05bd489 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -29,7 +29,7 @@ //! ]) //! ); //! ``` -//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) +//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; @@ -37,8 +37,8 @@ use async_trait::async_trait; use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - Role, WispError, + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, }; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; @@ -50,227 +50,227 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; /// **This extension will panic when encoding if the username's length does not fit within a u8 /// or the password's length does not fit within a u16.** pub struct PasswordProtocolExtension { - /// The username to log in with. - /// - /// This string's length must fit within a u8. - pub username: String, - /// The password to log in with. - /// - /// This string's length must fit within a u16. - pub password: String, - role: Role, + /// The username to log in with. + /// + /// This string's length must fit within a u8. + pub username: String, + /// The password to log in with. + /// + /// This string's length must fit within a u16. + pub password: String, + role: Role, } impl PasswordProtocolExtension { - /// Password protocol extension ID. - pub const ID: u8 = 0x02; + /// Password protocol extension ID. + pub const ID: u8 = 0x02; - /// Create a new password protocol extension for the server. - /// - /// This signifies that the server requires a password. - pub fn new_server() -> Self { - Self { - username: String::new(), - password: String::new(), - role: Role::Server, - } - } + /// Create a new password protocol extension for the server. + /// + /// This signifies that the server requires a password. + pub fn new_server() -> Self { + Self { + username: String::new(), + password: String::new(), + role: Role::Server, + } + } - /// Create a new password protocol extension for the client, with a username and password. - /// - /// The username's length must fit within a u8. The password's length must fit within a - /// u16. - pub fn new_client(username: String, password: String) -> Self { - Self { - username, - password, - role: Role::Client, - } - } + /// Create a new password protocol extension for the client, with a username and password. + /// + /// The username's length must fit within a u8. The password's length must fit within a + /// u16. + pub fn new_client(username: String, password: String) -> Self { + Self { + username, + password, + role: Role::Client, + } + } } #[async_trait] impl ProtocolExtension for PasswordProtocolExtension { - fn get_id(&self) -> u8 { - Self::ID - } + fn get_id(&self) -> u8 { + Self::ID + } - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } - fn encode(&self) -> Bytes { - match self.role { - Role::Server => Bytes::new(), - Role::Client => { - let username = Bytes::from(self.username.clone().into_bytes()); - let password = Bytes::from(self.password.clone().into_bytes()); - let username_len = u8::try_from(username.len()).expect("username was too long"); - let password_len = u16::try_from(password.len()).expect("password was too long"); + fn encode(&self) -> Bytes { + match self.role { + Role::Server => Bytes::new(), + Role::Client => { + let username = Bytes::from(self.username.clone().into_bytes()); + let password = Bytes::from(self.password.clone().into_bytes()); + let username_len = u8::try_from(username.len()).expect("username was too long"); + let password_len = u16::try_from(password.len()).expect("password was too long"); - let mut bytes = - BytesMut::with_capacity(3 + username_len as usize + password_len as usize); - bytes.put_u8(username_len); - bytes.put_u16_le(password_len); - bytes.extend(username); - bytes.extend(password); - bytes.freeze() - } - } - } + let mut bytes = + BytesMut::with_capacity(3 + username_len as usize + password_len as usize); + bytes.put_u8(username_len); + bytes.put_u16_le(password_len); + bytes.extend(username); + bytes.extend(password); + bytes.freeze() + } + } + } - async fn handle_handshake( - &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } - async fn handle_packet( - &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } } #[derive(Debug)] enum PasswordProtocolExtensionError { - Utf8Error(FromUtf8Error), - InvalidUsername, - InvalidPassword, + Utf8Error(FromUtf8Error), + InvalidUsername, + InvalidPassword, } impl Display for PasswordProtocolExtensionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use PasswordProtocolExtensionError as E; - match self { - E::Utf8Error(e) => write!(f, "{}", e), - E::InvalidUsername => write!(f, "Invalid username"), - E::InvalidPassword => write!(f, "Invalid password"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use PasswordProtocolExtensionError as E; + match self { + E::Utf8Error(e) => write!(f, "{}", e), + E::InvalidUsername => write!(f, "Invalid username"), + E::InvalidPassword => write!(f, "Invalid password"), + } + } } impl Error for PasswordProtocolExtensionError {} impl From for WispError { - fn from(value: PasswordProtocolExtensionError) -> Self { - WispError::ExtensionImplError(Box::new(value)) - } + fn from(value: PasswordProtocolExtensionError) -> Self { + WispError::ExtensionImplError(Box::new(value)) + } } impl From for PasswordProtocolExtensionError { - fn from(value: FromUtf8Error) -> Self { - PasswordProtocolExtensionError::Utf8Error(value) - } + fn from(value: FromUtf8Error) -> Self { + PasswordProtocolExtensionError::Utf8Error(value) + } } impl From for AnyProtocolExtension { - fn from(value: PasswordProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } + fn from(value: PasswordProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } } /// Password protocol extension builder. /// /// **Passwords are sent in plain text!!** pub struct PasswordProtocolExtensionBuilder { - /// Map of users and their passwords to allow. Only used on server. - pub users: HashMap, - /// Username to authenticate with. Only used on client. - pub username: String, - /// Password to authenticate with. Only used on client. - pub password: String, + /// Map of users and their passwords to allow. Only used on server. + pub users: HashMap, + /// Username to authenticate with. Only used on client. + pub username: String, + /// Password to authenticate with. Only used on client. + pub password: String, } impl PasswordProtocolExtensionBuilder { - /// Create a new password protocol extension builder for the server, with a map of users - /// and passwords to allow. - pub fn new_server(users: HashMap) -> Self { - Self { - users, - username: String::new(), - password: String::new(), - } - } + /// Create a new password protocol extension builder for the server, with a map of users + /// and passwords to allow. + pub fn new_server(users: HashMap) -> Self { + Self { + users, + username: String::new(), + password: String::new(), + } + } - /// Create a new password protocol extension builder for the client, with a username and - /// password to authenticate with. - pub fn new_client(username: String, password: String) -> Self { - Self { - users: HashMap::new(), - username, - password, - } - } + /// Create a new password protocol extension builder for the client, with a username and + /// password to authenticate with. + pub fn new_client(username: String, password: String) -> Self { + Self { + users: HashMap::new(), + username, + password, + } + } } impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { - fn get_id(&self) -> u8 { - PasswordProtocolExtension::ID - } + fn get_id(&self) -> u8 { + PasswordProtocolExtension::ID + } - fn build_from_bytes( - &self, - mut payload: Bytes, - role: crate::Role, - ) -> Result { - match role { - Role::Server => { - if payload.remaining() < 3 { - return Err(WispError::PacketTooSmall); - } + fn build_from_bytes( + &self, + mut payload: Bytes, + role: crate::Role, + ) -> Result { + match role { + Role::Server => { + if payload.remaining() < 3 { + return Err(WispError::PacketTooSmall); + } - let username_len = payload.get_u8(); - let password_len = payload.get_u16_le(); - if payload.remaining() < (password_len + username_len as u16) as usize { - return Err(WispError::PacketTooSmall); - } + let username_len = payload.get_u8(); + let password_len = payload.get_u16_le(); + if payload.remaining() < (password_len + username_len as u16) as usize { + return Err(WispError::PacketTooSmall); + } - use PasswordProtocolExtensionError as EError; - let username = - String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; - let password = - String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; + use PasswordProtocolExtensionError as EError; + let username = + String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + let password = + String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; - let Some(user) = self.users.iter().find(|x| *x.0 == username) else { - return Err(EError::InvalidUsername.into()); - }; + let Some(user) = self.users.iter().find(|x| *x.0 == username) else { + return Err(EError::InvalidUsername.into()); + }; - if *user.1 != password { - return Err(EError::InvalidPassword.into()); - } + if *user.1 != password { + return Err(EError::InvalidPassword.into()); + } - Ok(PasswordProtocolExtension { - username, - password, - role, - } - .into()) - } - Role::Client => { - Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) - } - } - } + Ok(PasswordProtocolExtension { + username, + password, + role, + } + .into()) + } + Role::Client => { + Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) + } + } + } - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { - match role { - Role::Server => PasswordProtocolExtension::new_server(), - Role::Client => { - PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) - } - } - .into() - } + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { + match role { + Role::Server => PasswordProtocolExtension::new_server(), + Role::Client => { + PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) + } + } + .into() + } } diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index 068b5eb..8510277 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -6,88 +6,88 @@ //! rx, //! tx, //! 128, -//! Some(&[Box::new(UdpProtocolExtensionBuilder())]) +//! Some(&[Box::new(UdpProtocolExtensionBuilder)]) //! ); //! ``` -//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) +//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp) use async_trait::async_trait; use bytes::Bytes; use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - WispError, + ws::{LockedWebSocketWrite, WebSocketRead}, + WispError, }; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; #[derive(Debug)] /// UDP protocol extension. -pub struct UdpProtocolExtension(); +pub struct UdpProtocolExtension; impl UdpProtocolExtension { - /// UDP protocol extension ID. - pub const ID: u8 = 0x01; + /// UDP protocol extension ID. + pub const ID: u8 = 0x01; } #[async_trait] impl ProtocolExtension for UdpProtocolExtension { - fn get_id(&self) -> u8 { - Self::ID - } + fn get_id(&self) -> u8 { + Self::ID + } - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } - fn encode(&self) -> Bytes { - Bytes::new() - } + fn encode(&self) -> Bytes { + Bytes::new() + } - async fn handle_handshake( - &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } - async fn handle_packet( - &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } - fn box_clone(&self) -> Box { - Box::new(Self()) - } + fn box_clone(&self) -> Box { + Box::new(Self) + } } impl From for AnyProtocolExtension { - fn from(value: UdpProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } + fn from(value: UdpProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } } /// UDP protocol extension builder. -pub struct UdpProtocolExtensionBuilder(); +pub struct UdpProtocolExtensionBuilder; impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { - fn get_id(&self) -> u8 { - UdpProtocolExtension::ID - } + fn get_id(&self) -> u8 { + UdpProtocolExtension::ID + } - fn build_from_bytes( - &self, - _: Bytes, - _: crate::Role, - ) -> Result { - Ok(UdpProtocolExtension().into()) - } + fn build_from_bytes( + &self, + _: Bytes, + _: crate::Role, + ) -> Result { + Ok(UdpProtocolExtension.into()) + } - fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { - UdpProtocolExtension().into() - } + fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { + UdpProtocolExtension.into() + } } diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 2525662..8604055 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -3,93 +3,100 @@ use std::ops::Deref; use async_trait::async_trait; use bytes::BytesMut; use fastwebsockets::{ - CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, + CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, }; use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ws::LockedWebSocketWrite, WispError}; -fn match_payload(payload: Payload) -> BytesMut { +fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> { match payload { - Payload::Bytes(x) => x, - Payload::Owned(x) => BytesMut::from(x.deref()), - Payload::BorrowedMut(x) => BytesMut::from(x.deref()), - Payload::Borrowed(x) => BytesMut::from(x), + Payload::Bytes(x) => crate::ws::Payload::Bytes(x), + Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())), + Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x), + Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x), + } +} + +fn match_payload_reverse<'a>(payload: crate::ws::Payload<'a>) -> Payload<'a> { + match payload { + crate::ws::Payload::Bytes(x) => Payload::Bytes(x), + crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x), } } impl From for crate::ws::OpCode { - fn from(opcode: OpCode) -> Self { - use OpCode::*; - match opcode { - Continuation => { - unreachable!("continuation should never be recieved when using a fragmentcollector") - } - Text => Self::Text, - Binary => Self::Binary, - Close => Self::Close, - Ping => Self::Ping, - Pong => Self::Pong, - } - } + fn from(opcode: OpCode) -> Self { + use OpCode::*; + match opcode { + Continuation => { + unreachable!("continuation should never be recieved when using a fragmentcollector") + } + Text => Self::Text, + Binary => Self::Binary, + Close => Self::Close, + Ping => Self::Ping, + Pong => Self::Pong, + } + } } -impl From> for crate::ws::Frame { - fn from(frame: Frame) -> Self { - Self { - finished: frame.fin, - opcode: frame.opcode.into(), - payload: match_payload(frame.payload), - } - } +impl<'a> From> for crate::ws::Frame<'a> { + fn from(frame: Frame<'a>) -> Self { + Self { + finished: frame.fin, + opcode: frame.opcode.into(), + payload: match_payload(frame.payload), + } + } } -impl<'a> From for Frame<'a> { - fn from(frame: crate::ws::Frame) -> Self { - use crate::ws::OpCode::*; - let payload = Payload::Bytes(frame.payload); - match frame.opcode { - Text => Self::text(payload), - Binary => Self::binary(payload), - Close => Self::close_raw(payload), - Ping => Self::new(true, OpCode::Ping, None, payload), - Pong => Self::pong(payload), - } - } +impl<'a> From> for Frame<'a> { + fn from(frame: crate::ws::Frame<'a>) -> Self { + use crate::ws::OpCode::*; + let payload = match_payload_reverse(frame.payload); + match frame.opcode { + Text => Self::text(payload), + Binary => Self::binary(payload), + Close => Self::close_raw(payload), + Ping => Self::new(true, OpCode::Ping, None, payload), + Pong => Self::pong(payload), + } + } } impl From for crate::WispError { - fn from(err: WebSocketError) -> Self { - if let WebSocketError::ConnectionClosed = err { - Self::WsImplSocketClosed - } else { - Self::WsImplError(Box::new(err)) - } - } + fn from(err: WebSocketError) -> Self { + if let WebSocketError::ConnectionClosed = err { + Self::WsImplSocketClosed + } else { + Self::WsImplError(Box::new(err)) + } + } } #[async_trait] impl crate::ws::WebSocketRead for FragmentCollectorRead { - async fn wisp_read_frame( - &mut self, - tx: &LockedWebSocketWrite, - ) -> Result { - Ok(self - .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) - .await? - .into()) - } + async fn wisp_read_frame( + &mut self, + tx: &LockedWebSocketWrite, + ) -> Result, WispError> { + Ok(self + .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) + .await? + .into()) + } } #[async_trait] impl crate::ws::WebSocketWrite for WebSocketWrite { - async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { - self.write_frame(frame.into()).await.map_err(|e| e.into()) - } + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> { + self.write_frame(frame.into()).await.map_err(|e| e.into()) + } - async fn wisp_close(&mut self) -> Result<(), WispError> { - self.write_frame(Frame::close(CloseCode::Normal.into(), b"")) - .await - .map_err(|e| e.into()) - } + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.write_frame(Frame::close(CloseCode::Normal.into(), b"")) + .await + .map_err(|e| e.into()) + } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index ca33c3d..d7b71f8 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -15,7 +15,7 @@ pub mod ws; pub use crate::{packet::*, stream::*}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use dashmap::DashMap; use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; @@ -23,11 +23,11 @@ use flume as mpsc; use futures::{channel::oneshot, select, Future, FutureExt}; use futures_timer::Delay; use std::{ - sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, - }, - time::Duration, + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, + time::Duration, }; use ws::{AppendingWebSocketRead, LockedWebSocketWrite}; @@ -37,453 +37,457 @@ pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; /// The role of the multiplexor. #[derive(Debug, PartialEq, Copy, Clone)] pub enum Role { - /// Client side, can create new channels to proxy. - Client, - /// Server side, can listen for channels to proxy. - Server, + /// Client side, can create new channels to proxy. + Client, + /// Server side, can listen for channels to proxy. + Server, } /// Errors the Wisp implementation can return. #[derive(Debug)] pub enum WispError { - /// The packet received did not have enough data. - PacketTooSmall, - /// The packet received had an invalid type. - InvalidPacketType, - /// The stream had an invalid ID. - InvalidStreamId, - /// The close packet had an invalid reason. - InvalidCloseReason, - /// The URI received was invalid. - InvalidUri, - /// The URI received had no host. - UriHasNoHost, - /// The URI received had no port. - UriHasNoPort, - /// The max stream count was reached. - MaxStreamCountReached, - /// The Wisp protocol version was incompatible. - IncompatibleProtocolVersion, - /// The stream had already been closed. - StreamAlreadyClosed, - /// The websocket frame received had an invalid type. - WsFrameInvalidType, - /// The websocket frame received was not finished. - WsFrameNotFinished, - /// Error specific to the websocket implementation. - WsImplError(Box), - /// The websocket implementation socket closed. - WsImplSocketClosed, - /// The websocket implementation did not support the action. - WsImplNotSupported, - /// Error specific to the protocol extension implementation. - ExtensionImplError(Box), - /// The protocol extension implementation did not support the action. - ExtensionImplNotSupported, - /// The specified protocol extensions are not supported by the server. - ExtensionsNotSupported(Vec), - /// The string was invalid UTF-8. - Utf8Error(std::str::Utf8Error), - /// The integer failed to convert. - TryFromIntError(std::num::TryFromIntError), - /// Other error. - Other(Box), - /// Failed to send message to multiplexor task. - MuxMessageFailedToSend, - /// Failed to receive message from multiplexor task. - MuxMessageFailedToRecv, - /// Multiplexor task ended. - MuxTaskEnded, + /// The packet received did not have enough data. + PacketTooSmall, + /// The packet received had an invalid type. + InvalidPacketType, + /// The stream had an invalid ID. + InvalidStreamId, + /// The close packet had an invalid reason. + InvalidCloseReason, + /// The URI received was invalid. + InvalidUri, + /// The URI received had no host. + UriHasNoHost, + /// The URI received had no port. + UriHasNoPort, + /// The max stream count was reached. + MaxStreamCountReached, + /// The Wisp protocol version was incompatible. + IncompatibleProtocolVersion, + /// The stream had already been closed. + StreamAlreadyClosed, + /// The websocket frame received had an invalid type. + WsFrameInvalidType, + /// The websocket frame received was not finished. + WsFrameNotFinished, + /// Error specific to the websocket implementation. + WsImplError(Box), + /// The websocket implementation socket closed. + WsImplSocketClosed, + /// The websocket implementation did not support the action. + WsImplNotSupported, + /// Error specific to the protocol extension implementation. + ExtensionImplError(Box), + /// The protocol extension implementation did not support the action. + ExtensionImplNotSupported, + /// The specified protocol extensions are not supported by the server. + ExtensionsNotSupported(Vec), + /// The string was invalid UTF-8. + Utf8Error(std::str::Utf8Error), + /// The integer failed to convert. + TryFromIntError(std::num::TryFromIntError), + /// Other error. + Other(Box), + /// Failed to send message to multiplexor task. + MuxMessageFailedToSend, + /// Failed to receive message from multiplexor task. + MuxMessageFailedToRecv, + /// Multiplexor task ended. + MuxTaskEnded, } impl From for WispError { - fn from(err: std::str::Utf8Error) -> Self { - Self::Utf8Error(err) - } + fn from(err: std::str::Utf8Error) -> Self { + Self::Utf8Error(err) + } } impl From for WispError { - fn from(value: std::num::TryFromIntError) -> Self { - Self::TryFromIntError(value) - } + fn from(value: std::num::TryFromIntError) -> Self { + Self::TryFromIntError(value) + } } impl std::fmt::Display for WispError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - match self { - Self::PacketTooSmall => write!(f, "Packet too small"), - Self::InvalidPacketType => write!(f, "Invalid packet type"), - Self::InvalidStreamId => write!(f, "Invalid stream id"), - Self::InvalidCloseReason => write!(f, "Invalid close reason"), - Self::InvalidUri => write!(f, "Invalid URI"), - Self::UriHasNoHost => write!(f, "URI has no host"), - Self::UriHasNoPort => write!(f, "URI has no port"), - Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), - Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"), - Self::StreamAlreadyClosed => write!(f, "Stream already closed"), - Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), - Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), - Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err), - Self::WsImplSocketClosed => { - write!(f, "Websocket implementation error: websocket closed") - } - Self::WsImplNotSupported => { - write!(f, "Websocket implementation error: unsupported feature") - } - Self::ExtensionImplError(err) => { - write!(f, "Protocol extension implementation error: {}", err) - } - Self::ExtensionImplNotSupported => { - write!( - f, - "Protocol extension implementation error: unsupported feature" - ) - } - Self::ExtensionsNotSupported(list) => { - write!(f, "Protocol extensions {:?} not supported", list) - } - Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), - Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err), - Self::Other(err) => write!(f, "Other error: {}", err), - Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), - Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), - Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { + Self::PacketTooSmall => write!(f, "Packet too small"), + Self::InvalidPacketType => write!(f, "Invalid packet type"), + Self::InvalidStreamId => write!(f, "Invalid stream id"), + Self::InvalidCloseReason => write!(f, "Invalid close reason"), + Self::InvalidUri => write!(f, "Invalid URI"), + Self::UriHasNoHost => write!(f, "URI has no host"), + Self::UriHasNoPort => write!(f, "URI has no port"), + Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), + Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"), + Self::StreamAlreadyClosed => write!(f, "Stream already closed"), + Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), + Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), + Self::WsImplError(err) => write!(f, "Websocket implementation error: {}", err), + Self::WsImplSocketClosed => { + write!(f, "Websocket implementation error: websocket closed") + } + Self::WsImplNotSupported => { + write!(f, "Websocket implementation error: unsupported feature") + } + Self::ExtensionImplError(err) => { + write!(f, "Protocol extension implementation error: {}", err) + } + Self::ExtensionImplNotSupported => { + write!( + f, + "Protocol extension implementation error: unsupported feature" + ) + } + Self::ExtensionsNotSupported(list) => { + write!(f, "Protocol extensions {:?} not supported", list) + } + Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), + Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err), + Self::Other(err) => write!(f, "Other error: {}", err), + Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), + Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), + Self::MuxTaskEnded => write!(f, "Multiplexor task ended"), + } + } } impl std::error::Error for WispError {} struct MuxMapValue { - stream: mpsc::Sender, - stream_type: StreamType, - flow_control: Arc, - flow_control_event: Arc, - is_closed: Arc, - is_closed_event: Arc, + stream: mpsc::Sender, + stream_type: StreamType, + flow_control: Arc, + flow_control_event: Arc, + is_closed: Arc, + is_closed_event: Arc, } struct MuxInner { - tx: ws::LockedWebSocketWrite, - stream_map: DashMap, - buffer_size: u32, - fut_exited: Arc + tx: ws::LockedWebSocketWrite, + stream_map: DashMap, + buffer_size: u32, + fut_exited: Arc, } impl MuxInner { - pub async fn server_into_future( - self, - rx: R, - extensions: Vec, - close_rx: mpsc::Receiver, - muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, - close_tx: mpsc::Sender, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - self.as_future( - close_rx, - close_tx.clone(), - self.server_loop(rx, extensions, muxstream_sender, close_tx), - ) - .await - } + pub async fn server_into_future( + self, + rx: R, + extensions: Vec, + close_rx: mpsc::Receiver, + muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, + close_tx: mpsc::Sender, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead + Send, + { + self.as_future( + close_rx, + close_tx.clone(), + self.server_loop(rx, extensions, muxstream_sender, close_tx), + ) + .await + } - pub async fn client_into_future( - self, - rx: R, - extensions: Vec, - close_rx: mpsc::Receiver, - close_tx: mpsc::Sender, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - self.as_future(close_rx, close_tx, self.client_loop(rx, extensions)) - .await - } + pub async fn client_into_future( + self, + rx: R, + extensions: Vec, + close_rx: mpsc::Receiver, + close_tx: mpsc::Sender, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead + Send, + { + self.as_future(close_rx, close_tx, self.client_loop(rx, extensions)) + .await + } - async fn as_future( - &self, - close_rx: mpsc::Receiver, - close_tx: mpsc::Sender, - wisp_fut: impl Future>, - ) -> Result<(), WispError> { - let ret = futures::select! { - _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), - x = wisp_fut.fuse() => x, - }; - self.fut_exited.store(true, Ordering::Release); - for x in self.stream_map.iter_mut() { - x.is_closed.store(true, Ordering::Release); - x.is_closed_event.notify(usize::MAX); - } - self.stream_map.clear(); - let _ = self.tx.close().await; - ret - } + async fn as_future( + &self, + close_rx: mpsc::Receiver, + close_tx: mpsc::Sender, + wisp_fut: impl Future>, + ) -> Result<(), WispError> { + let ret = futures::select! { + _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), + x = wisp_fut.fuse() => x, + }; + self.fut_exited.store(true, Ordering::Release); + for x in self.stream_map.iter_mut() { + x.is_closed.store(true, Ordering::Release); + x.is_closed_event.notify(usize::MAX); + } + self.stream_map.clear(); + let _ = self.tx.close().await; + ret + } - async fn create_new_stream( - &self, - stream_id: u32, - stream_type: StreamType, - role: Role, - stream_tx: mpsc::Sender, - tx: LockedWebSocketWrite, - target_buffer_size: u32, - ) -> Result<(MuxMapValue, MuxStream), WispError> { - let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); + async fn create_new_stream( + &self, + stream_id: u32, + stream_type: StreamType, + role: Role, + stream_tx: mpsc::Sender, + tx: LockedWebSocketWrite, + target_buffer_size: u32, + ) -> Result<(MuxMapValue, MuxStream), WispError> { + let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + let flow_control_event: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - let is_closed: Arc = AtomicBool::new(false).into(); - let is_closed_event: Arc = Event::new().into(); + let is_closed: Arc = AtomicBool::new(false).into(); + let is_closed_event: Arc = Event::new().into(); - Ok(( - MuxMapValue { - stream: ch_tx, - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - is_closed_event: is_closed_event.clone(), - }, - MuxStream::new( - stream_id, - role, - stream_type, - ch_rx, - stream_tx, - tx, - is_closed, - is_closed_event, - flow_control, - flow_control_event, - target_buffer_size, - ), - )) - } + Ok(( + MuxMapValue { + stream: ch_tx, + stream_type, + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + is_closed_event: is_closed_event.clone(), + }, + MuxStream::new( + stream_id, + role, + stream_type, + ch_rx, + stream_tx, + tx, + is_closed, + is_closed_event, + flow_control, + flow_control_event, + target_buffer_size, + ), + )) + } - async fn stream_loop( - &self, - stream_rx: mpsc::Receiver, - stream_tx: mpsc::Sender, - ) { - let mut next_free_stream_id: u32 = 1; - while let Ok(msg) = stream_rx.recv_async().await { - match msg { - WsEvent::CreateStream(stream_type, host, port, channel) => { - let ret: Result = async { - let stream_id = next_free_stream_id; - let next_stream_id = next_free_stream_id - .checked_add(1) - .ok_or(WispError::MaxStreamCountReached)?; + async fn stream_loop( + &self, + stream_rx: mpsc::Receiver, + stream_tx: mpsc::Sender, + ) { + let mut next_free_stream_id: u32 = 1; + while let Ok(msg) = stream_rx.recv_async().await { + match msg { + WsEvent::CreateStream(stream_type, host, port, channel) => { + let ret: Result = async { + let stream_id = next_free_stream_id; + let next_stream_id = next_free_stream_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?; - let (map_value, stream) = self - .create_new_stream( - stream_id, - stream_type, - Role::Client, - stream_tx.clone(), - self.tx.clone(), - 0, - ) - .await?; + let (map_value, stream) = self + .create_new_stream( + stream_id, + stream_type, + Role::Client, + stream_tx.clone(), + self.tx.clone(), + 0, + ) + .await?; - self.tx - .write_frame( - Packet::new_connect(stream_id, stream_type, port, host).into(), - ) - .await?; + self.tx + .write_frame( + Packet::new_connect(stream_id, stream_type, port, host).into(), + ) + .await?; - self.stream_map.insert(stream_id, map_value); + self.stream_map.insert(stream_id, map_value); - next_free_stream_id = next_stream_id; + next_free_stream_id = next_stream_id; - Ok(stream) - } - .await; - let _ = channel.send(ret); - } - WsEvent::Close(packet, channel) => { - if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { - let _ = channel.send(self.tx.write_frame(packet.into()).await); - drop(stream.stream) - } else { - let _ = channel.send(Err(WispError::InvalidStreamId)); - } - } - WsEvent::EndFut(x) => { - if let Some(reason) = x { - let _ = self - .tx - .write_frame(Packet::new_close(0, reason).into()) - .await; - } - break; - } - } - } - } + Ok(stream) + } + .await; + let _ = channel.send(ret); + } + WsEvent::Close(packet, channel) => { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + let _ = channel.send(self.tx.write_frame(packet.into()).await); + drop(stream.stream) + } else { + let _ = channel.send(Err(WispError::InvalidStreamId)); + } + } + WsEvent::EndFut(x) => { + if let Some(reason) = x { + let _ = self + .tx + .write_frame(Packet::new_close(0, reason).into()) + .await; + } + break; + } + } + } + } - fn close_stream(&self, packet: Packet) { - if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { - stream.is_closed.store(true, Ordering::Release); - stream.is_closed_event.notify(usize::MAX); - stream.flow_control.store(u32::MAX, Ordering::Release); - stream.flow_control_event.notify(usize::MAX); - drop(stream.stream) - } - } + fn close_stream(&self, packet: Packet) { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + stream.is_closed.store(true, Ordering::Release); + stream.is_closed_event.notify(usize::MAX); + stream.flow_control.store(u32::MAX, Ordering::Release); + stream.flow_control_event.notify(usize::MAX); + drop(stream.stream) + } + } - async fn server_loop( - &self, - mut rx: R, - mut extensions: Vec, - muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, - stream_tx: mpsc::Sender, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - // will send continues once flow_control is at 10% of max - let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32; + async fn server_loop( + &self, + mut rx: R, + mut extensions: Vec, + muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, + stream_tx: mpsc::Sender, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead + Send, + { + // will send continues once flow_control is at 10% of max + let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32; - loop { - let frame = rx.wisp_read_frame(&self.tx).await?; - if frame.opcode == ws::OpCode::Close { - break Ok(()); - } - if let Some(packet) = - Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? - { - use PacketType::*; - match packet.packet_type { - Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), - Connect(inner_packet) => { - let (map_value, stream) = self - .create_new_stream( - packet.stream_id, - inner_packet.stream_type, - Role::Server, - stream_tx.clone(), - self.tx.clone(), - target_buffer_size, - ) - .await?; - muxstream_sender - .send_async((inner_packet, stream)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - self.stream_map.insert(packet.stream_id, map_value); - } - Data(data) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.try_send(data); - if stream.stream_type == StreamType::Tcp { - stream.flow_control.store( - stream - .flow_control - .load(Ordering::Acquire) - .saturating_sub(1), - Ordering::Release, - ); - } - } - } - Close(_) => { - if packet.stream_id == 0 { - break Ok(()); - } - self.close_stream(packet) - } - } - } - } - } + loop { + let frame = rx.wisp_read_frame(&self.tx).await?; + if frame.opcode == ws::OpCode::Close { + break Ok(()); + } + if let Some(packet) = + Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? + { + use PacketType::*; + match packet.packet_type { + Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), + Connect(inner_packet) => { + let (map_value, stream) = self + .create_new_stream( + packet.stream_id, + inner_packet.stream_type, + Role::Server, + stream_tx.clone(), + self.tx.clone(), + target_buffer_size, + ) + .await?; + muxstream_sender + .send_async((inner_packet, stream)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + self.stream_map.insert(packet.stream_id, map_value); + } + Data(data) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + let _ = stream.stream.try_send(BytesMut::from(data).freeze()); + if stream.stream_type == StreamType::Tcp { + stream.flow_control.store( + stream + .flow_control + .load(Ordering::Acquire) + .saturating_sub(1), + Ordering::Release, + ); + } + } + } + Close(_) => { + if packet.stream_id == 0 { + break Ok(()); + } + self.close_stream(packet) + } + } + } + } + } - async fn client_loop( - &self, - mut rx: R, - mut extensions: Vec, - ) -> Result<(), WispError> - where - R: ws::WebSocketRead + Send, - { - loop { - let frame = rx.wisp_read_frame(&self.tx).await?; - if frame.opcode == ws::OpCode::Close { - break Ok(()); - } - if let Some(packet) = - Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? - { - use PacketType::*; - match packet.packet_type { - Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), - Data(data) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.send_async(data).await; - } - } - Continue(inner_packet) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - if stream.stream_type == StreamType::Tcp { - stream - .flow_control - .store(inner_packet.buffer_remaining, Ordering::Release); - let _ = stream.flow_control_event.notify(u32::MAX); - } - } - } - Close(_) => { - if packet.stream_id == 0 { - break Ok(()); - } - self.close_stream(packet) - } - } - } - } - } + async fn client_loop( + &self, + mut rx: R, + mut extensions: Vec, + ) -> Result<(), WispError> + where + R: ws::WebSocketRead + Send, + { + loop { + let frame = rx.wisp_read_frame(&self.tx).await?; + if frame.opcode == ws::OpCode::Close { + break Ok(()); + } + + if let Some(packet) = + Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? + { + use PacketType::*; + match packet.packet_type { + Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), + Data(data) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + let _ = stream + .stream + .send_async(BytesMut::from(data).freeze()) + .await; + } + } + Continue(inner_packet) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if stream.stream_type == StreamType::Tcp { + stream + .flow_control + .store(inner_packet.buffer_remaining, Ordering::Release); + let _ = stream.flow_control_event.notify(u32::MAX); + } + } + } + Close(_) => { + if packet.stream_id == 0 { + break Ok(()); + } + self.close_stream(packet) + } + } + } + } + } } async fn maybe_wisp_v2( - read: &mut R, - write: &LockedWebSocketWrite, - builders: &[Box], -) -> Result<(Vec, Option, bool), WispError> + read: &mut R, + write: &LockedWebSocketWrite, + builders: &[Box], +) -> Result<(Vec, Option>, bool), WispError> where - R: ws::WebSocketRead + Send, + R: ws::WebSocketRead + Send, { - let mut supported_extensions = Vec::new(); - let mut extra_packet = None; - let mut downgraded = true; + let mut supported_extensions = Vec::new(); + let mut extra_packet: Option> = None; + let mut downgraded = true; - let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); - if let Some(frame) = select! { - x = read.wisp_read_frame(write).fuse() => Some(x?), - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - downgraded = false; - } else { - extra_packet.replace(packet.into()); - } - } + let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.replace(ws::Frame::from(packet).clone()); + } + } - for extension in supported_extensions.iter_mut() { - extension.handle_handshake(read, write).await?; - } - Ok((supported_extensions, extra_packet, downgraded)) + for extension in supported_extensions.iter_mut() { + extension.handle_handshake(read, write).await?; + } + Ok((supported_extensions, extra_packet, downgraded)) } /// Server-side multiplexor. @@ -506,175 +510,175 @@ where /// } /// ``` pub struct ServerMux { - /// Whether the connection was downgraded to Wisp v1. - /// - /// If this variable is true you must assume no extensions are supported. - pub downgraded: bool, - /// Extensions that are supported by both sides. - pub supported_extension_ids: Vec, - close_tx: mpsc::Sender, - muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, - tx: ws::LockedWebSocketWrite, - fut_exited: Arc, + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extension_ids: Vec, + close_tx: mpsc::Sender, + muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, + tx: ws::LockedWebSocketWrite, + fut_exited: Arc, } impl ServerMux { - /// Create a new server-side multiplexor. - /// - /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. - /// **It is not guaranteed that all extensions you specify are available.** You must manually check - /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( - mut read: R, - write: W, - buffer_size: u32, - extension_builders: Option<&[Box]>, - ) -> Result> + Send>, WispError> - where - R: ws::WebSocketRead + Send, - W: ws::WebSocketWrite + Send + 'static, - { - let (close_tx, close_rx) = mpsc::bounded::(256); - let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); - let write = ws::LockedWebSocketWrite::new(Box::new(write)); - let fut_exited = Arc::new(AtomicBool::new(false)); + /// Create a new server-side multiplexor. + /// + /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. + pub async fn create( + mut read: R, + write: W, + buffer_size: u32, + extension_builders: Option<&[Box]>, + ) -> Result> + Send>, WispError> + where + R: ws::WebSocketRead + Send, + W: ws::WebSocketWrite + Send + 'static, + { + let (close_tx, close_rx) = mpsc::bounded::(256); + let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); + let write = ws::LockedWebSocketWrite::new(Box::new(write)); + let fut_exited = Arc::new(AtomicBool::new(false)); - write - .write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; + write + .write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; - let (supported_extensions, extra_packet, downgraded) = - if let Some(builders) = extension_builders { - write - .write_frame( - Packet::new_info( - builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(), - ) - .into(), - ) - .await?; - maybe_wisp_v2(&mut read, &write, builders).await? - } else { - (Vec::new(), None, true) - }; + let (supported_extensions, extra_packet, downgraded) = + if let Some(builders) = extension_builders { + write + .write_frame( + Packet::new_info( + builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(), + ) + .into(), + ) + .await?; + maybe_wisp_v2(&mut read, &write, builders).await? + } else { + (Vec::new(), None, true) + }; - Ok(ServerMuxResult( - Self { - muxstream_recv: rx, - close_tx: close_tx.clone(), - downgraded, - supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(), - tx: write.clone(), - fut_exited: fut_exited.clone(), - }, - MuxInner { - tx: write, - stream_map: DashMap::new(), - buffer_size, - fut_exited - } - .server_into_future( - AppendingWebSocketRead(extra_packet, read), - supported_extensions, - close_rx, - tx, - close_tx, - ), - )) - } + Ok(ServerMuxResult( + Self { + muxstream_recv: rx, + close_tx: close_tx.clone(), + downgraded, + supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(), + tx: write.clone(), + fut_exited: fut_exited.clone(), + }, + MuxInner { + tx: write, + stream_map: DashMap::new(), + buffer_size, + fut_exited, + } + .server_into_future( + AppendingWebSocketRead(extra_packet, read), + supported_extensions, + close_rx, + tx, + close_tx, + ), + )) + } - /// Wait for a stream to be created. - pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { - if self.fut_exited.load(Ordering::Acquire) { - return None; - } - self.muxstream_recv.recv_async().await.ok() - } + /// Wait for a stream to be created. + pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> { + if self.fut_exited.load(Ordering::Acquire) { + return None; + } + self.muxstream_recv.recv_async().await.ok() + } - async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.fut_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - self.close_tx - .send_async(WsEvent::EndFut(reason)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend) - } + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + if self.fut_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + self.close_tx + .send_async(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } - /// Close all streams. - /// - /// Also terminates the multiplexor future. - pub async fn close(&self) -> Result<(), WispError> { - self.close_internal(None).await - } + /// Close all streams. + /// + /// Also terminates the multiplexor future. + pub async fn close(&self) -> Result<(), WispError> { + self.close_internal(None).await + } - /// Close all streams and send an extension incompatibility error to the client. - /// - /// Also terminates the multiplexor future. - pub async fn close_extension_incompat(&self) -> Result<(), WispError> { - self.close_internal(Some(CloseReason::IncompatibleExtensions)) - .await - } + /// Close all streams and send an extension incompatibility error to the client. + /// + /// Also terminates the multiplexor future. + pub async fn close_extension_incompat(&self) -> Result<(), WispError> { + self.close_internal(Some(CloseReason::IncompatibleExtensions)) + .await + } - /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: 0, - tx: self.tx.clone(), - is_closed: self.fut_exited.clone(), - } - } + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.fut_exited.clone(), + } + } } impl Drop for ServerMux { - fn drop(&mut self) { - let _ = self.close_tx.send(WsEvent::EndFut(None)); - } + fn drop(&mut self) { + let _ = self.close_tx.send(WsEvent::EndFut(None)); + } } /// Result of `ServerMux::new`. pub struct ServerMuxResult(ServerMux, F) where - F: Future> + Send; + F: Future> + Send; impl ServerMuxResult where - F: Future> + Send, + F: Future> + Send, { - /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (ServerMux, F) { - (self.0, self.1) - } + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (ServerMux, F) { + (self.0, self.1) + } - /// Require protocol extensions by their ID. Will close the multiplexor connection if - /// extensions are not supported. - pub async fn with_required_extensions( - self, - extensions: &[u8], - ) -> Result<(ServerMux, F), WispError> { - let mut unsupported_extensions = Vec::new(); - for extension in extensions { - if !self.0.supported_extension_ids.contains(extension) { - unsupported_extensions.push(*extension); - } - } - if unsupported_extensions.is_empty() { - Ok((self.0, self.1)) - } else { - self.0.close_extension_incompat().await?; - self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) - } - } + /// Require protocol extensions by their ID. Will close the multiplexor connection if + /// extensions are not supported. + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(ServerMux, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self.0.supported_extension_ids.contains(extension) { + unsupported_extensions.push(*extension); + } + } + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0.close_extension_incompat().await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } - /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]) - .await - } + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await + } } /// Client side multiplexor. @@ -692,195 +696,195 @@ where /// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80); /// ``` pub struct ClientMux { - /// Whether the connection was downgraded to Wisp v1. - /// - /// If this variable is true you must assume no extensions are supported. - pub downgraded: bool, - /// Extensions that are supported by both sides. - pub supported_extension_ids: Vec, - stream_tx: mpsc::Sender, - tx: ws::LockedWebSocketWrite, - fut_exited: Arc, + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extension_ids: Vec, + stream_tx: mpsc::Sender, + tx: ws::LockedWebSocketWrite, + fut_exited: Arc, } impl ClientMux { - /// Create a new client side multiplexor. - /// - /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. - /// **It is not guaranteed that all extensions you specify are available.** You must manually check - /// if the extensions you need are available after the multiplexor has been created. - pub async fn create( - mut read: R, - write: W, - extension_builders: Option<&[Box]>, - ) -> Result> + Send>, WispError> - where - R: ws::WebSocketRead + Send, - W: ws::WebSocketWrite + Send + 'static, - { - let write = ws::LockedWebSocketWrite::new(Box::new(write)); - let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; - let fut_exited = Arc::new(AtomicBool::new(false)); + /// Create a new client side multiplexor. + /// + /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. + pub async fn create( + mut read: R, + write: W, + extension_builders: Option<&[Box]>, + ) -> Result> + Send>, WispError> + where + R: ws::WebSocketRead + Send, + W: ws::WebSocketWrite + Send + 'static, + { + let write = ws::LockedWebSocketWrite::new(Box::new(write)); + let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; + let fut_exited = Arc::new(AtomicBool::new(false)); - if first_packet.stream_id != 0 { - return Err(WispError::InvalidStreamId); - } - if let PacketType::Continue(packet) = first_packet.packet_type { - let (supported_extensions, extra_packet, downgraded) = - if let Some(builders) = extension_builders { - let x = maybe_wisp_v2(&mut read, &write, builders).await?; - write - .write_frame( - Packet::new_info( - builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(), - ) - .into(), - ) - .await?; - x - } else { - (Vec::new(), None, true) - }; + if first_packet.stream_id != 0 { + return Err(WispError::InvalidStreamId); + } + if let PacketType::Continue(packet) = first_packet.packet_type { + let (supported_extensions, extra_packet, downgraded) = + if let Some(builders) = extension_builders { + let x = maybe_wisp_v2(&mut read, &write, builders).await?; + write + .write_frame( + Packet::new_info( + builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(), + ) + .into(), + ) + .await?; + x + } else { + (Vec::new(), None, true) + }; - let (tx, rx) = mpsc::bounded::(256); - Ok(ClientMuxResult( - Self { - stream_tx: tx.clone(), - downgraded, - supported_extension_ids: supported_extensions - .iter() - .map(|x| x.get_id()) - .collect(), - tx: write.clone(), - fut_exited: fut_exited.clone(), - }, - MuxInner { - tx: write, - stream_map: DashMap::new(), - buffer_size: packet.buffer_remaining, - fut_exited - } - .client_into_future( - AppendingWebSocketRead(extra_packet, read), - supported_extensions, - rx, - tx, - ), - )) - } else { - Err(WispError::InvalidPacketType) - } - } + let (tx, rx) = mpsc::bounded::(256); + Ok(ClientMuxResult( + Self { + stream_tx: tx.clone(), + downgraded, + supported_extension_ids: supported_extensions + .iter() + .map(|x| x.get_id()) + .collect(), + tx: write.clone(), + fut_exited: fut_exited.clone(), + }, + MuxInner { + tx: write, + stream_map: DashMap::new(), + buffer_size: packet.buffer_remaining, + fut_exited, + } + .client_into_future( + AppendingWebSocketRead(extra_packet, read), + supported_extensions, + rx, + tx, + ), + )) + } else { + Err(WispError::InvalidPacketType) + } + } - /// Create a new stream, multiplexed through Wisp. - pub async fn client_new_stream( - &self, - stream_type: StreamType, - host: String, - port: u16, - ) -> Result { - if self.fut_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - if stream_type == StreamType::Udp - && !self - .supported_extension_ids - .iter() - .any(|x| *x == UdpProtocolExtension::ID) - { - return Err(WispError::ExtensionsNotSupported(vec![ - UdpProtocolExtension::ID, - ])); - } - let (tx, rx) = oneshot::channel(); - self.stream_tx - .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend)?; - rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? - } + /// Create a new stream, multiplexed through Wisp. + pub async fn client_new_stream( + &self, + stream_type: StreamType, + host: String, + port: u16, + ) -> Result { + if self.fut_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + if stream_type == StreamType::Udp + && !self + .supported_extension_ids + .iter() + .any(|x| *x == UdpProtocolExtension::ID) + { + return Err(WispError::ExtensionsNotSupported(vec![ + UdpProtocolExtension::ID, + ])); + } + let (tx, rx) = oneshot::channel(); + self.stream_tx + .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? + } - async fn close_internal(&self, reason: Option) -> Result<(), WispError> { - if self.fut_exited.load(Ordering::Acquire) { - return Err(WispError::MuxTaskEnded); - } - self.stream_tx - .send_async(WsEvent::EndFut(reason)) - .await - .map_err(|_| WispError::MuxMessageFailedToSend) - } + async fn close_internal(&self, reason: Option) -> Result<(), WispError> { + if self.fut_exited.load(Ordering::Acquire) { + return Err(WispError::MuxTaskEnded); + } + self.stream_tx + .send_async(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } - /// Close all streams. - /// - /// Also terminates the multiplexor future. - pub async fn close(&self) -> Result<(), WispError> { - self.close_internal(None).await - } + /// Close all streams. + /// + /// Also terminates the multiplexor future. + pub async fn close(&self) -> Result<(), WispError> { + self.close_internal(None).await + } - /// Close all streams and send an extension incompatibility error to the client. - /// - /// Also terminates the multiplexor future. - pub async fn close_extension_incompat(&self) -> Result<(), WispError> { - self.close_internal(Some(CloseReason::IncompatibleExtensions)) - .await - } - - /// Get a protocol extension stream for sending packets with stream id 0. - pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { - MuxProtocolExtensionStream { - stream_id: 0, - tx: self.tx.clone(), - is_closed: self.fut_exited.clone(), - } - } + /// Close all streams and send an extension incompatibility error to the client. + /// + /// Also terminates the multiplexor future. + pub async fn close_extension_incompat(&self) -> Result<(), WispError> { + self.close_internal(Some(CloseReason::IncompatibleExtensions)) + .await + } + + /// Get a protocol extension stream for sending packets with stream id 0. + pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream { + MuxProtocolExtensionStream { + stream_id: 0, + tx: self.tx.clone(), + is_closed: self.fut_exited.clone(), + } + } } impl Drop for ClientMux { - fn drop(&mut self) { - let _ = self.stream_tx.send(WsEvent::EndFut(None)); - } + fn drop(&mut self) { + let _ = self.stream_tx.send(WsEvent::EndFut(None)); + } } /// Result of `ClientMux::new`. pub struct ClientMuxResult(ClientMux, F) where - F: Future> + Send; + F: Future> + Send; impl ClientMuxResult where - F: Future> + Send, + F: Future> + Send, { - /// Require no protocol extensions. - pub fn with_no_required_extensions(self) -> (ClientMux, F) { - (self.0, self.1) - } + /// Require no protocol extensions. + pub fn with_no_required_extensions(self) -> (ClientMux, F) { + (self.0, self.1) + } - /// Require protocol extensions by their ID. - pub async fn with_required_extensions( - self, - extensions: &[u8], - ) -> Result<(ClientMux, F), WispError> { - let mut unsupported_extensions = Vec::new(); - for extension in extensions { - if !self.0.supported_extension_ids.contains(extension) { - unsupported_extensions.push(*extension); - } - } - if unsupported_extensions.is_empty() { - Ok((self.0, self.1)) - } else { - self.0.close_extension_incompat().await?; - self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) - } - } + /// Require protocol extensions by their ID. + pub async fn with_required_extensions( + self, + extensions: &[u8], + ) -> Result<(ClientMux, F), WispError> { + let mut unsupported_extensions = Vec::new(); + for extension in extensions { + if !self.0.supported_extension_ids.contains(extension) { + unsupported_extensions.push(*extension); + } + } + if unsupported_extensions.is_empty() { + Ok((self.0, self.1)) + } else { + self.0.close_extension_incompat().await?; + self.1.await?; + Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + } + } - /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` - pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { - self.with_required_extensions(&[UdpProtocolExtension::ID]) - .await - } + /// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])` + pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> { + self.with_required_extensions(&[UdpProtocolExtension::ID]) + .await + } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 138e0c5..85a82e7 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,41 +1,41 @@ use crate::{ - extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, - ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead}, - Role, WispError, WISP_VERSION, + extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, + ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead}, + Role, WispError, WISP_VERSION, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; /// Wisp stream type. #[derive(Debug, PartialEq, Copy, Clone)] pub enum StreamType { - /// TCP Wisp stream. - Tcp, - /// UDP Wisp stream. - Udp, - /// Unknown Wisp stream type used for custom streams by protocol extensions. - Unknown(u8), + /// TCP Wisp stream. + Tcp, + /// UDP Wisp stream. + Udp, + /// Unknown Wisp stream type used for custom streams by protocol extensions. + Unknown(u8), } impl From for StreamType { - fn from(value: u8) -> Self { - use StreamType as S; - match value { - 0x01 => S::Tcp, - 0x02 => S::Udp, - x => S::Unknown(x), - } - } + fn from(value: u8) -> Self { + use StreamType as S; + match value { + 0x01 => S::Tcp, + 0x02 => S::Udp, + x => S::Unknown(x), + } + } } impl From for u8 { - fn from(value: StreamType) -> Self { - use StreamType as S; - match value { - S::Tcp => 0x01, - S::Udp => 0x02, - S::Unknown(x) => x, - } - } + fn from(value: StreamType) -> Self { + use StreamType as S; + match value { + S::Tcp => 0x01, + S::Udp => 0x02, + S::Unknown(x) => x, + } + } } /// Close reason. @@ -44,56 +44,56 @@ impl From for u8 { /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) #[derive(Debug, PartialEq, Copy, Clone)] pub enum CloseReason { - /// Reason unspecified or unknown. - Unknown = 0x01, - /// Voluntary stream closure. - Voluntary = 0x02, - /// Unexpected stream closure due to a network error. - Unexpected = 0x03, - /// Incompatible extensions. Only used during the handshake. - IncompatibleExtensions = 0x04, - /// Stream creation failed due to invalid information. - ServerStreamInvalidInfo = 0x41, - /// Stream creation failed due to an unreachable destination host. - ServerStreamUnreachable = 0x42, - /// Stream creation timed out due to the destination server not responding. - ServerStreamConnectionTimedOut = 0x43, - /// Stream creation failed due to the destination server refusing the connection. - ServerStreamConnectionRefused = 0x44, - /// TCP data transfer timed out. - ServerStreamTimedOut = 0x47, - /// Stream destination address/domain is intentionally blocked by the proxy server. - ServerStreamBlockedAddress = 0x48, - /// Connection throttled by the server. - ServerStreamThrottled = 0x49, - /// The client has encountered an unexpected error. - ClientUnexpected = 0x81, + /// Reason unspecified or unknown. + Unknown = 0x01, + /// Voluntary stream closure. + Voluntary = 0x02, + /// Unexpected stream closure due to a network error. + Unexpected = 0x03, + /// Incompatible extensions. Only used during the handshake. + IncompatibleExtensions = 0x04, + /// Stream creation failed due to invalid information. + ServerStreamInvalidInfo = 0x41, + /// Stream creation failed due to an unreachable destination host. + ServerStreamUnreachable = 0x42, + /// Stream creation timed out due to the destination server not responding. + ServerStreamConnectionTimedOut = 0x43, + /// Stream creation failed due to the destination server refusing the connection. + ServerStreamConnectionRefused = 0x44, + /// TCP data transfer timed out. + ServerStreamTimedOut = 0x47, + /// Stream destination address/domain is intentionally blocked by the proxy server. + ServerStreamBlockedAddress = 0x48, + /// Connection throttled by the server. + ServerStreamThrottled = 0x49, + /// The client has encountered an unexpected error. + ClientUnexpected = 0x81, } impl TryFrom for CloseReason { - type Error = WispError; - fn try_from(close_reason: u8) -> Result { - use CloseReason as R; - match close_reason { - 0x01 => Ok(R::Unknown), - 0x02 => Ok(R::Voluntary), - 0x03 => Ok(R::Unexpected), - 0x04 => Ok(R::IncompatibleExtensions), - 0x41 => Ok(R::ServerStreamInvalidInfo), - 0x42 => Ok(R::ServerStreamUnreachable), - 0x43 => Ok(R::ServerStreamConnectionTimedOut), - 0x44 => Ok(R::ServerStreamConnectionRefused), - 0x47 => Ok(R::ServerStreamTimedOut), - 0x48 => Ok(R::ServerStreamBlockedAddress), - 0x49 => Ok(R::ServerStreamThrottled), - 0x81 => Ok(R::ClientUnexpected), - _ => Err(Self::Error::InvalidCloseReason), - } - } + type Error = WispError; + fn try_from(close_reason: u8) -> Result { + use CloseReason as R; + match close_reason { + 0x01 => Ok(R::Unknown), + 0x02 => Ok(R::Voluntary), + 0x03 => Ok(R::Unexpected), + 0x04 => Ok(R::IncompatibleExtensions), + 0x41 => Ok(R::ServerStreamInvalidInfo), + 0x42 => Ok(R::ServerStreamUnreachable), + 0x43 => Ok(R::ServerStreamConnectionTimedOut), + 0x44 => Ok(R::ServerStreamConnectionRefused), + 0x47 => Ok(R::ServerStreamTimedOut), + 0x48 => Ok(R::ServerStreamBlockedAddress), + 0x49 => Ok(R::ServerStreamThrottled), + 0x81 => Ok(R::ClientUnexpected), + _ => Err(Self::Error::InvalidCloseReason), + } + } } trait Encode { - fn encode(self, bytes: &mut BytesMut); + fn encode(self, bytes: &mut BytesMut); } /// Packet used to create a new stream. @@ -101,49 +101,49 @@ trait Encode { /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). #[derive(Debug, Clone)] pub struct ConnectPacket { - /// Whether the new stream should use a TCP or UDP socket. - pub stream_type: StreamType, - /// Destination TCP/UDP port for the new stream. - pub destination_port: u16, - /// Destination hostname, in a UTF-8 string. - pub destination_hostname: String, + /// Whether the new stream should use a TCP or UDP socket. + pub stream_type: StreamType, + /// Destination TCP/UDP port for the new stream. + pub destination_port: u16, + /// Destination hostname, in a UTF-8 string. + pub destination_hostname: String, } impl ConnectPacket { - /// Create a new connect packet. - pub fn new( - stream_type: StreamType, - destination_port: u16, - destination_hostname: String, - ) -> Self { - Self { - stream_type, - destination_port, - destination_hostname, - } - } + /// Create a new connect packet. + pub fn new( + stream_type: StreamType, + destination_port: u16, + destination_hostname: String, + ) -> Self { + Self { + stream_type, + destination_port, + destination_hostname, + } + } } -impl TryFrom for ConnectPacket { - type Error = WispError; - fn try_from(mut bytes: BytesMut) -> Result { - if bytes.remaining() < (1 + 2) { - return Err(Self::Error::PacketTooSmall); - } - Ok(Self { - stream_type: bytes.get_u8().into(), - destination_port: bytes.get_u16_le(), - destination_hostname: std::str::from_utf8(&bytes)?.to_string(), - }) - } +impl TryFrom> for ConnectPacket { + type Error = WispError; + fn try_from(mut bytes: Payload<'_>) -> Result { + if bytes.remaining() < (1 + 2) { + return Err(Self::Error::PacketTooSmall); + } + Ok(Self { + stream_type: bytes.get_u8().into(), + destination_port: bytes.get_u16_le(), + destination_hostname: std::str::from_utf8(&bytes)?.to_string(), + }) + } } impl Encode for ConnectPacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.stream_type.into()); - bytes.put_u16_le(self.destination_port); - bytes.extend(self.destination_hostname.bytes()); - } + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.stream_type.into()); + bytes.put_u16_le(self.destination_port); + bytes.extend(self.destination_hostname.bytes()); + } } /// Packet used for Wisp TCP stream flow control. @@ -151,33 +151,33 @@ impl Encode for ConnectPacket { /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue). #[derive(Debug, Copy, Clone)] pub struct ContinuePacket { - /// Number of packets that the server can buffer for the current stream. - pub buffer_remaining: u32, + /// Number of packets that the server can buffer for the current stream. + pub buffer_remaining: u32, } impl ContinuePacket { - /// Create a new continue packet. - pub fn new(buffer_remaining: u32) -> Self { - Self { buffer_remaining } - } + /// Create a new continue packet. + pub fn new(buffer_remaining: u32) -> Self { + Self { buffer_remaining } + } } -impl TryFrom for ContinuePacket { - type Error = WispError; - fn try_from(mut bytes: BytesMut) -> Result { - if bytes.remaining() < 4 { - return Err(Self::Error::PacketTooSmall); - } - Ok(Self { - buffer_remaining: bytes.get_u32_le(), - }) - } +impl TryFrom> for ContinuePacket { + type Error = WispError; + fn try_from(mut bytes: Payload<'_>) -> Result { + if bytes.remaining() < 4 { + return Err(Self::Error::PacketTooSmall); + } + Ok(Self { + buffer_remaining: bytes.get_u32_le(), + }) + } } impl Encode for ContinuePacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u32_le(self.buffer_remaining); - } + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u32_le(self.buffer_remaining); + } } /// Packet used to close a stream. @@ -186,42 +186,42 @@ impl Encode for ContinuePacket { /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close). #[derive(Debug, Copy, Clone)] pub struct ClosePacket { - /// The close reason. - pub reason: CloseReason, + /// The close reason. + pub reason: CloseReason, } impl ClosePacket { - /// Create a new close packet. - pub fn new(reason: CloseReason) -> Self { - Self { reason } - } + /// Create a new close packet. + pub fn new(reason: CloseReason) -> Self { + Self { reason } + } } -impl TryFrom for ClosePacket { - type Error = WispError; - fn try_from(mut bytes: BytesMut) -> Result { - if bytes.remaining() < 1 { - return Err(Self::Error::PacketTooSmall); - } - Ok(Self { - reason: bytes.get_u8().try_into()?, - }) - } +impl TryFrom> for ClosePacket { + type Error = WispError; + fn try_from(mut bytes: Payload<'_>) -> Result { + if bytes.remaining() < 1 { + return Err(Self::Error::PacketTooSmall); + } + Ok(Self { + reason: bytes.get_u8().try_into()?, + }) + } } impl Encode for ClosePacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.reason as u8); - } + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.reason as u8); + } } /// Wisp version sent in the handshake. #[derive(Debug, Clone)] pub struct WispVersion { - /// Major Wisp version according to semver. - pub major: u8, - /// Minor Wisp version according to semver. - pub minor: u8, + /// Major Wisp version according to semver. + pub major: u8, + /// Minor Wisp version according to semver. + pub minor: u8, } /// Packet used in the initial handshake. @@ -229,325 +229,327 @@ pub struct WispVersion { /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) #[derive(Debug, Clone)] pub struct InfoPacket { - /// Wisp version sent in the packet. - pub version: WispVersion, - /// List of protocol extensions sent in the packet. - pub extensions: Vec, + /// Wisp version sent in the packet. + pub version: WispVersion, + /// List of protocol extensions sent in the packet. + pub extensions: Vec, } impl Encode for InfoPacket { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.version.major); - bytes.put_u8(self.version.minor); - for extension in self.extensions { - bytes.extend_from_slice(&Bytes::from(extension)); - } - } + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.version.major); + bytes.put_u8(self.version.minor); + for extension in self.extensions { + bytes.extend_from_slice(&Bytes::from(extension)); + } + } } #[derive(Debug, Clone)] /// Type of packet recieved. -pub enum PacketType { - /// Connect packet. - Connect(ConnectPacket), - /// Data packet. - Data(Bytes), - /// Continue packet. - Continue(ContinuePacket), - /// Close packet. - Close(ClosePacket), - /// Info packet. - Info(InfoPacket), +pub enum PacketType<'a> { + /// Connect packet. + Connect(ConnectPacket), + /// Data packet. + Data(Payload<'a>), + /// Continue packet. + Continue(ContinuePacket), + /// Close packet. + Close(ClosePacket), + /// Info packet. + Info(InfoPacket), } -impl PacketType { - /// Get the packet type used in the protocol. - pub fn as_u8(&self) -> u8 { - use PacketType as P; - match self { - P::Connect(_) => 0x01, - P::Data(_) => 0x02, - P::Continue(_) => 0x03, - P::Close(_) => 0x04, - P::Info(_) => 0x05, - } - } +impl PacketType<'_> { + /// Get the packet type used in the protocol. + pub fn as_u8(&self) -> u8 { + use PacketType as P; + match self { + P::Connect(_) => 0x01, + P::Data(_) => 0x02, + P::Continue(_) => 0x03, + P::Close(_) => 0x04, + P::Info(_) => 0x05, + } + } - pub(crate) fn get_packet_size(&self) -> usize { - use PacketType as P; - match self { - P::Connect(p) => 1 + 2 + p.destination_hostname.len(), - P::Data(p) => p.len(), - P::Continue(_) => 4, - P::Close(_) => 1, - P::Info(_) => 2, - } - } + pub(crate) fn get_packet_size(&self) -> usize { + use PacketType as P; + match self { + P::Connect(p) => 1 + 2 + p.destination_hostname.len(), + P::Data(p) => p.len(), + P::Continue(_) => 4, + P::Close(_) => 1, + P::Info(_) => 2, + } + } } -impl Encode for PacketType { - fn encode(self, bytes: &mut BytesMut) { - use PacketType as P; - match self { - P::Connect(x) => x.encode(bytes), - P::Data(x) => bytes.extend_from_slice(&x), - P::Continue(x) => x.encode(bytes), - P::Close(x) => x.encode(bytes), - P::Info(x) => x.encode(bytes), - }; - } +impl Encode for PacketType<'_> { + fn encode(self, bytes: &mut BytesMut) { + use PacketType as P; + match self { + P::Connect(x) => x.encode(bytes), + P::Data(x) => bytes.extend_from_slice(&x), + P::Continue(x) => x.encode(bytes), + P::Close(x) => x.encode(bytes), + P::Info(x) => x.encode(bytes), + }; + } } /// Wisp protocol packet. #[derive(Debug, Clone)] -pub struct Packet { - /// Stream this packet is associated with. - pub stream_id: u32, - /// Packet type recieved. - pub packet_type: PacketType, +pub struct Packet<'a> { + /// Stream this packet is associated with. + pub stream_id: u32, + /// Packet type recieved. + pub packet_type: PacketType<'a>, } -impl Packet { - /// Create a new packet. - /// - /// The helper functions should be used for most use cases. - pub fn new(stream_id: u32, packet: PacketType) -> Self { - Self { - stream_id, - packet_type: packet, - } - } +impl<'a> Packet<'a> { + /// Create a new packet. + /// + /// The helper functions should be used for most use cases. + pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self { + Self { + stream_id, + packet_type: packet, + } + } - /// Create a new connect packet. - pub fn new_connect( - stream_id: u32, - stream_type: StreamType, - destination_port: u16, - destination_hostname: String, - ) -> Self { - Self { - stream_id, - packet_type: PacketType::Connect(ConnectPacket::new( - stream_type, - destination_port, - destination_hostname, - )), - } - } + /// Create a new connect packet. + pub fn new_connect( + stream_id: u32, + stream_type: StreamType, + destination_port: u16, + destination_hostname: String, + ) -> Self { + Self { + stream_id, + packet_type: PacketType::Connect(ConnectPacket::new( + stream_type, + destination_port, + destination_hostname, + )), + } + } - /// Create a new data packet. - pub fn new_data(stream_id: u32, data: Bytes) -> Self { - Self { - stream_id, - packet_type: PacketType::Data(data), - } - } + /// Create a new data packet. + pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self { + Self { + stream_id, + packet_type: PacketType::Data(data), + } + } - /// Create a new continue packet. - pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { - Self { - stream_id, - packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)), - } - } + /// Create a new continue packet. + pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { + Self { + stream_id, + packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)), + } + } - /// Create a new close packet. - pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { - Self { - stream_id, - packet_type: PacketType::Close(ClosePacket::new(reason)), - } - } + /// Create a new close packet. + pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { + Self { + stream_id, + packet_type: PacketType::Close(ClosePacket::new(reason)), + } + } - pub(crate) fn new_info(extensions: Vec) -> Self { - Self { - stream_id: 0, - packet_type: PacketType::Info(InfoPacket { - version: WISP_VERSION, - extensions, - }), - } - } + pub(crate) fn new_info(extensions: Vec) -> Self { + Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version: WISP_VERSION, + extensions, + }), + } + } - fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result { - use PacketType as P; - Ok(Self { - stream_id: bytes.get_u32_le(), - packet_type: match packet_type { - 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), - 0x02 => P::Data(bytes.freeze()), - 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), - 0x04 => P::Close(ClosePacket::try_from(bytes)?), - // 0x05 is handled seperately - _ => return Err(WispError::InvalidPacketType), - }, - }) - } + fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result { + use PacketType as P; + Ok(Self { + stream_id: bytes.get_u32_le(), + packet_type: match packet_type { + 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), + 0x02 => P::Data(bytes), + 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), + 0x04 => P::Close(ClosePacket::try_from(bytes)?), + // 0x05 is handled seperately + _ => return Err(WispError::InvalidPacketType), + }, + }) + } - pub(crate) fn maybe_parse_info( - frame: Frame, - role: Role, - extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], - ) -> Result { - if !frame.finished { - return Err(WispError::WsFrameNotFinished); - } - if frame.opcode != OpCode::Binary { - return Err(WispError::WsFrameInvalidType); - } - let mut bytes = frame.payload; - if bytes.remaining() < 1 { - return Err(WispError::PacketTooSmall); - } - let packet_type = bytes.get_u8(); - if packet_type == 0x05 { - Self::parse_info(bytes, role, extension_builders) - } else { - Self::parse_packet(packet_type, bytes) - } - } + pub(crate) fn maybe_parse_info( + frame: Frame<'a>, + role: Role, + extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], + ) -> Result { + if !frame.finished { + return Err(WispError::WsFrameNotFinished); + } + if frame.opcode != OpCode::Binary { + return Err(WispError::WsFrameInvalidType); + } + let mut bytes = frame.payload; + if bytes.remaining() < 1 { + return Err(WispError::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + if packet_type == 0x05 { + Self::parse_info(bytes, role, extension_builders) + } else { + Self::parse_packet(packet_type, bytes) + } + } - pub(crate) async fn maybe_handle_extension( - frame: Frame, - extensions: &mut [AnyProtocolExtension], - read: &mut (dyn WebSocketRead + Send), - write: &LockedWebSocketWrite, - ) -> Result, WispError> { - if !frame.finished { - return Err(WispError::WsFrameNotFinished); - } - if frame.opcode != OpCode::Binary { - return Err(WispError::WsFrameInvalidType); - } - let mut bytes = frame.payload; - if bytes.remaining() < 5 { - return Err(WispError::PacketTooSmall); - } - let packet_type = bytes.get_u8(); - match packet_type { - 0x01 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Connect(bytes.try_into()?), - })), - 0x02 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Data(bytes.freeze()), - })), - 0x03 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Continue(bytes.try_into()?), - })), - 0x04 => Ok(Some(Self { - stream_id: bytes.get_u32_le(), - packet_type: PacketType::Close(bytes.try_into()?), - })), - 0x05 => Ok(None), - packet_type => { - if let Some(extension) = extensions - .iter_mut() - .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) - { - extension.handle_packet(bytes.freeze(), read, write).await?; - Ok(None) - } else { - Err(WispError::InvalidPacketType) - } - } - } - } + pub(crate) async fn maybe_handle_extension( + frame: Frame<'a>, + extensions: &mut [AnyProtocolExtension], + read: &mut (dyn WebSocketRead + Send), + write: &LockedWebSocketWrite, + ) -> Result, WispError> { + if !frame.finished { + return Err(WispError::WsFrameNotFinished); + } + if frame.opcode != OpCode::Binary { + return Err(WispError::WsFrameInvalidType); + } + let mut bytes = frame.payload; + if bytes.remaining() < 5 { + return Err(WispError::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + match packet_type { + 0x01 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Connect(bytes.try_into()?), + })), + 0x02 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Data(bytes), + })), + 0x03 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Continue(bytes.try_into()?), + })), + 0x04 => Ok(Some(Self { + stream_id: bytes.get_u32_le(), + packet_type: PacketType::Close(bytes.try_into()?), + })), + 0x05 => Ok(None), + packet_type => { + if let Some(extension) = extensions + .iter_mut() + .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) + { + extension + .handle_packet(BytesMut::from(bytes).freeze(), read, write) + .await?; + Ok(None) + } else { + Err(WispError::InvalidPacketType) + } + } + } + } - fn parse_info( - mut bytes: BytesMut, - role: Role, - extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], - ) -> Result { - // packet type is already read by code that calls this - if bytes.remaining() < 4 + 2 { - return Err(WispError::PacketTooSmall); - } - if bytes.get_u32_le() != 0 { - return Err(WispError::InvalidStreamId); - } + fn parse_info( + mut bytes: Payload<'a>, + role: Role, + extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], + ) -> Result { + // packet type is already read by code that calls this + if bytes.remaining() < 4 + 2 { + return Err(WispError::PacketTooSmall); + } + if bytes.get_u32_le() != 0 { + return Err(WispError::InvalidStreamId); + } - let version = WispVersion { - major: bytes.get_u8(), - minor: bytes.get_u8(), - }; + let version = WispVersion { + major: bytes.get_u8(), + minor: bytes.get_u8(), + }; - if version.major != WISP_VERSION.major { - return Err(WispError::IncompatibleProtocolVersion); - } + if version.major != WISP_VERSION.major { + return Err(WispError::IncompatibleProtocolVersion); + } - let mut extensions = Vec::new(); + let mut extensions = Vec::new(); - while bytes.remaining() > 4 { - // We have some extensions - let id = bytes.get_u8(); - let length = usize::try_from(bytes.get_u32_le())?; - if bytes.remaining() < length { - return Err(WispError::PacketTooSmall); - } - if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { - if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { - extensions.push(extension) - } - } else { - bytes.advance(length) - } - } + while bytes.remaining() > 4 { + // We have some extensions + let id = bytes.get_u8(); + let length = usize::try_from(bytes.get_u32_le())?; + if bytes.remaining() < length { + return Err(WispError::PacketTooSmall); + } + if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { + if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { + extensions.push(extension) + } + } else { + bytes.advance(length) + } + } - Ok(Self { - stream_id: 0, - packet_type: PacketType::Info(InfoPacket { - version, - extensions, - }), - }) - } + Ok(Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version, + extensions, + }), + }) + } } -impl Encode for Packet { - fn encode(self, bytes: &mut BytesMut) { - bytes.put_u8(self.packet_type.as_u8()); - bytes.put_u32_le(self.stream_id); - self.packet_type.encode(bytes); - } +impl Encode for Packet<'_> { + fn encode(self, bytes: &mut BytesMut) { + bytes.put_u8(self.packet_type.as_u8()); + bytes.put_u32_le(self.stream_id); + self.packet_type.encode(bytes); + } } -impl TryFrom for Packet { - type Error = WispError; - fn try_from(mut bytes: BytesMut) -> Result { - if bytes.remaining() < 1 { - return Err(Self::Error::PacketTooSmall); - } - let packet_type = bytes.get_u8(); - Self::parse_packet(packet_type, bytes) - } +impl<'a> TryFrom> for Packet<'a> { + type Error = WispError; + fn try_from(mut bytes: Payload<'a>) -> Result { + if bytes.remaining() < 1 { + return Err(Self::Error::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + Self::parse_packet(packet_type, bytes) + } } -impl From for BytesMut { - fn from(packet: Packet) -> Self { - let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size()); - packet.encode(&mut encoded); - encoded - } +impl From> for BytesMut { + fn from(packet: Packet) -> Self { + let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size()); + packet.encode(&mut encoded); + encoded + } } -impl TryFrom for Packet { - type Error = WispError; - fn try_from(frame: ws::Frame) -> Result { - if !frame.finished { - return Err(Self::Error::WsFrameNotFinished); - } - if frame.opcode != ws::OpCode::Binary { - return Err(Self::Error::WsFrameInvalidType); - } - Packet::try_from(frame.payload) - } +impl<'a> TryFrom> for Packet<'a> { + type Error = WispError; + fn try_from(frame: ws::Frame<'a>) -> Result { + if !frame.finished { + return Err(Self::Error::WsFrameNotFinished); + } + if frame.opcode != ws::OpCode::Binary { + return Err(Self::Error::WsFrameInvalidType); + } + Packet::try_from(frame.payload) + } } -impl From for ws::Frame { - fn from(packet: Packet) -> Self { - Self::binary(BytesMut::from(packet)) - } +impl From> for ws::Frame<'static> { + fn from(packet: Packet) -> Self { + Self::binary(Payload::Bytes(BytesMut::from(packet))) + } } diff --git a/wisp/src/sink_unfold.rs b/wisp/src/sink_unfold.rs index dfb170e..852abce 100644 --- a/wisp/src/sink_unfold.rs +++ b/wisp/src/sink_unfold.rs @@ -1,146 +1,146 @@ //! futures sink unfold with a close function use core::{future::Future, pin::Pin}; use futures::{ - ready, - task::{Context, Poll}, - Sink, + ready, + task::{Context, Poll}, + Sink, }; use pin_project_lite::pin_project; pin_project! { - /// UnfoldState used for stream and sink unfolds - #[project = UnfoldStateProj] - #[project_replace = UnfoldStateProjReplace] - #[derive(Debug)] - pub(crate) enum UnfoldState { - Value { - value: T, - }, - Future { - #[pin] - future: Fut, - }, - Empty, - } + /// UnfoldState used for stream and sink unfolds + #[project = UnfoldStateProj] + #[project_replace = UnfoldStateProjReplace] + #[derive(Debug)] + pub(crate) enum UnfoldState { + Value { + value: T, + }, + Future { + #[pin] + future: Fut, + }, + Empty, + } } impl UnfoldState { - pub(crate) fn project_future(self: Pin<&mut Self>) -> Option> { - match self.project() { - UnfoldStateProj::Future { future } => Some(future), - _ => None, - } - } + pub(crate) fn project_future(self: Pin<&mut Self>) -> Option> { + match self.project() { + UnfoldStateProj::Future { future } => Some(future), + _ => None, + } + } - pub(crate) fn take_value(self: Pin<&mut Self>) -> Option { - match &*self { - Self::Value { .. } => match self.project_replace(Self::Empty) { - UnfoldStateProjReplace::Value { value } => Some(value), - _ => unreachable!(), - }, - _ => None, - } - } + pub(crate) fn take_value(self: Pin<&mut Self>) -> Option { + match &*self { + Self::Value { .. } => match self.project_replace(Self::Empty) { + UnfoldStateProjReplace::Value { value } => Some(value), + _ => unreachable!(), + }, + _ => None, + } + } } pin_project! { - /// Sink for the [`unfold`] function. - #[derive(Debug)] - #[must_use = "sinks do nothing unless polled"] - pub struct Unfold { - function: F, - close_function: CF, - #[pin] - state: UnfoldState, - #[pin] - close_state: UnfoldState - } + /// Sink for the [`unfold`] function. + #[derive(Debug)] + #[must_use = "sinks do nothing unless polled"] + pub struct Unfold { + function: F, + close_function: CF, + #[pin] + state: UnfoldState, + #[pin] + close_state: UnfoldState + } } pub(crate) fn unfold( - init: T, - function: F, - close_init: CT, - close_function: CF, + init: T, + function: F, + close_init: CT, + close_function: CF, ) -> Unfold where - F: FnMut(T, Item) -> R, - R: Future>, - CF: FnMut(CT) -> CR, - CR: Future>, + F: FnMut(T, Item) -> R, + R: Future>, + CF: FnMut(CT) -> CR, + CR: Future>, { - Unfold { - function, - close_function, - state: UnfoldState::Value { value: init }, - close_state: UnfoldState::Value { value: close_init }, - } + Unfold { + function, + close_function, + state: UnfoldState::Value { value: init }, + close_state: UnfoldState::Value { value: close_init }, + } } impl Sink for Unfold where - F: FnMut(T, Item) -> R, - R: Future>, - CF: FnMut(CT) -> CR, - CR: Future>, + F: FnMut(T, Item) -> R, + R: Future>, + CF: FnMut(CT) -> CR, + CR: Future>, { - type Error = E; + type Error = E; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush(cx) - } + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } - fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { - let mut this = self.project(); - let future = match this.state.as_mut().take_value() { - Some(value) => (this.function)(value, item), - None => panic!("start_send called without poll_ready being called first"), - }; - this.state.set(UnfoldState::Future { future }); - Ok(()) - } + fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + let mut this = self.project(); + let future = match this.state.as_mut().take_value() { + Some(value) => (this.function)(value, item), + None => panic!("start_send called without poll_ready being called first"), + }; + this.state.set(UnfoldState::Future { future }); + Ok(()) + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - Poll::Ready(if let Some(future) = this.state.as_mut().project_future() { - match ready!(future.poll(cx)) { - Ok(state) => { - this.state.set(UnfoldState::Value { value: state }); - Ok(()) - } - Err(err) => { - this.state.set(UnfoldState::Empty); - Err(err) - } - } - } else { - Ok(()) - }) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + Poll::Ready(if let Some(future) = this.state.as_mut().project_future() { + match ready!(future.poll(cx)) { + Ok(state) => { + this.state.set(UnfoldState::Value { value: state }); + Ok(()) + } + Err(err) => { + this.state.set(UnfoldState::Empty); + Err(err) + } + } + } else { + Ok(()) + }) + } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().poll_flush(cx))?; - let mut this = self.project(); - Poll::Ready( - if let Some(future) = this.close_state.as_mut().project_future() { - match ready!(future.poll(cx)) { - Ok(state) => { - this.close_state.set(UnfoldState::Value { value: state }); - Ok(()) - } - Err(err) => { - this.close_state.set(UnfoldState::Empty); - Err(err) - } - } - } else { - let future = match this.close_state.as_mut().take_value() { - Some(value) => (this.close_function)(value), - None => panic!("start_send called without poll_ready being called first"), - }; - this.close_state.set(UnfoldState::Future { future }); - return Poll::Pending; - }, - ) - } + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.project(); + Poll::Ready( + if let Some(future) = this.close_state.as_mut().project_future() { + match ready!(future.poll(cx)) { + Ok(state) => { + this.close_state.set(UnfoldState::Value { value: state }); + Ok(()) + } + Err(err) => { + this.close_state.set(UnfoldState::Empty); + Err(err) + } + } + } else { + let future = match this.close_state.as_mut().take_value() { + Some(value) => (this.close_function)(value), + None => panic!("start_send called without poll_ready being called first"), + }; + this.close_state.set(UnfoldState::Future { future }); + return Poll::Pending; + }, + ) + } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 5d5c115..c1337b7 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,6 +1,6 @@ use crate::{ sink_unfold, - ws::{Frame, LockedWebSocketWrite}, + ws::{Frame, LockedWebSocketWrite, Payload}, CloseReason, Packet, Role, StreamType, WispError, }; @@ -9,9 +9,10 @@ use event_listener::Event; use flume as mpsc; use futures::{ channel::oneshot, - ready, select, stream::{self, IntoAsyncRead}, + ready, select, + stream::{self, IntoAsyncRead}, task::{noop_waker_ref, Context, Poll}, - AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt, + AsyncBufRead, AsyncRead, AsyncWrite, Future, FutureExt, Sink, Stream, TryStreamExt, }; use pin_project_lite::pin_project; use std::{ @@ -23,7 +24,7 @@ use std::{ }; pub(crate) enum WsEvent { - Close(Packet, oneshot::Sender>), + Close(Packet<'static>, oneshot::Sender>), CreateStream( StreamType, String, @@ -100,8 +101,10 @@ pub struct MuxStreamWrite { } impl MuxStreamWrite { - /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), WispError> { + pub(crate) async fn write_payload_internal( + &self, + frame: Frame<'static>, + ) -> Result<(), WispError> { if self.role == Role::Client && self.stream_type == StreamType::Tcp && self.flow_control.load(Ordering::Acquire) == 0 @@ -112,9 +115,7 @@ impl MuxStreamWrite { return Err(WispError::StreamAlreadyClosed); } - self.tx - .write_frame(Frame::from(Packet::new_data(self.stream_id, data))) - .await?; + self.tx.write_frame(frame).await?; if self.role == Role::Client && self.stream_type == StreamType::Tcp { self.flow_control.store( @@ -125,6 +126,20 @@ impl MuxStreamWrite { Ok(()) } + /// Write a payload to the stream. + pub fn write_payload<'a>( + &'a self, + data: Payload<'_>, + ) -> impl Future> + 'a { + let frame: Frame<'static> = Frame::from(Packet::new_data(self.stream_id, data)); + self.write_payload_internal(frame) + } + + /// Write data to the stream. + pub async fn write>(&self, data: D) -> Result<(), WispError> { + self.write_payload(Payload::Borrowed(data.as_ref())).await + } + /// Get a handle to close the connection. /// /// Useful to close the connection without having access to the stream. @@ -173,16 +188,16 @@ impl MuxStreamWrite { Ok(()) } - pub(crate) fn into_sink(self) -> Pin + Send>> { + pub(crate) fn into_sink(self) -> Pin, Error = WispError> + Send>> { let handle = self.get_close_handle(); Box::pin(sink_unfold::unfold( self, |tx, data| async move { - tx.write(data).await?; + tx.write_payload_internal(data).await?; Ok(tx) }, handle, - move |handle| async { + |handle| async move { handle.close(CloseReason::Unknown).await?; Ok(handle) }, @@ -258,8 +273,13 @@ impl MuxStream { self.rx.read().await } + /// Write a payload to the stream. + pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> { + self.tx.write_payload(data).await + } + /// Write data to the stream. - pub async fn write(&self, data: Bytes) -> Result<(), WispError> { + pub async fn write>(&self, data: D) -> Result<(), WispError> { self.tx.write(data).await } @@ -301,6 +321,7 @@ impl MuxStream { }, tx: MuxStreamIoSink { tx: self.tx.into_sink(), + stream_id: self.stream_id, }, } } @@ -355,7 +376,9 @@ impl MuxProtocolExtensionStream { encoded.put_u8(packet_type); encoded.put_u32_le(self.stream_id); encoded.extend(data); - self.tx.write_frame(Frame::binary(encoded)).await + self.tx + .write_frame(Frame::binary(Payload::Bytes(encoded))) + .await } } @@ -391,12 +414,12 @@ impl Stream for MuxStreamIo { } } -impl Sink for MuxStreamIo { +impl Sink<&[u8]> for MuxStreamIo { type Error = std::io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().tx.poll_ready(cx) } - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> { self.project().tx.start_send(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -433,7 +456,8 @@ pin_project! { /// Write side of a multiplexor stream that implements futures `Sink`. pub struct MuxStreamIoSink { #[pin] - tx: Pin + Send>>, + tx: Pin, Error = WispError> + Send>>, + stream_id: u32, } } @@ -444,7 +468,7 @@ impl MuxStreamIoSink { } } -impl Sink for MuxStreamIoSink { +impl Sink<&[u8]> for MuxStreamIoSink { type Error = std::io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project() @@ -452,10 +476,14 @@ impl Sink for MuxStreamIoSink { .poll_ready(cx) .map_err(std::io::Error::other) } - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> { + let stream_id = self.stream_id; self.project() .tx - .start_send(item) + .start_send(Frame::from(Packet::new_data( + stream_id, + Payload::Borrowed(item), + ))) .map_err(std::io::Error::other) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -564,10 +592,10 @@ impl AsyncRead for MuxStreamAsyncRead { } impl AsyncBufRead for MuxStreamAsyncRead { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_fill_buf(cx) + self.project().rx.poll_fill_buf(cx) } fn consume(self: Pin<&mut Self>, amt: usize) { - self.project().rx.consume(amt) + self.project().rx.consume(amt) } } @@ -582,7 +610,10 @@ pin_project! { impl MuxStreamAsyncWrite { pub(crate) fn new(sink: MuxStreamIoSink) -> Self { - Self { tx: sink, error: None } + Self { + tx: sink, + error: None, + } } } @@ -599,7 +630,7 @@ impl AsyncWrite for MuxStreamAsyncWrite { let mut this = self.as_mut().project(); ready!(this.tx.as_mut().poll_ready(cx))?; - match this.tx.as_mut().start_send(Bytes::copy_from_slice(buf)) { + match this.tx.as_mut().start_send(buf) { Ok(()) => { let mut cx = Context::from_waker(noop_waker_ref()); let cx = &mut cx; diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 06f55ad..27ae03f 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -4,83 +4,168 @@ //! for other WebSocket implementations. //! //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use crate::WispError; use async_trait::async_trait; -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use futures::lock::Mutex; +/// Payload of the websocket frame. +#[derive(Debug)] +pub enum Payload<'a> { + /// Borrowed payload. Currently used when writing data. + Borrowed(&'a [u8]), + /// BytesMut payload. Currently used when reading data. + Bytes(BytesMut), +} + +impl From for Payload<'static> { + fn from(value: BytesMut) -> Self { + Self::Bytes(value) + } +} + +impl<'a> From<&'a [u8]> for Payload<'a> { + fn from(value: &'a [u8]) -> Self { + Self::Borrowed(value) + } +} + +impl Payload<'_> { + /// Turn a Payload<'a> into a Payload<'static> by copying the data. + pub fn into_owned(self) -> Self { + match self { + Self::Bytes(x) => Self::Bytes(x), + Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)), + } + } +} + +impl From> for BytesMut { + fn from(value: Payload<'_>) -> Self { + match value { + Payload::Bytes(x) => x, + Payload::Borrowed(x) => x.into(), + } + } +} + +impl Deref for Payload<'_> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + match self { + Self::Bytes(x) => x.deref(), + Self::Borrowed(x) => x, + } + } +} + +impl Clone for Payload<'_> { + fn clone(&self) -> Self { + match self { + Self::Bytes(x) => Self::Bytes(x.clone()), + Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)), + } + } +} + +impl Buf for Payload<'_> { + fn remaining(&self) -> usize { + match self { + Self::Bytes(x) => x.remaining(), + Self::Borrowed(x) => x.remaining(), + } + } + + fn chunk(&self) -> &[u8] { + match self { + Self::Bytes(x) => x.chunk(), + Self::Borrowed(x) => x.chunk(), + } + } + + fn advance(&mut self, cnt: usize) { + match self { + Self::Bytes(x) => x.advance(cnt), + Self::Borrowed(x) => x.advance(cnt), + } + } +} + /// Opcode of the WebSocket frame. #[derive(Debug, PartialEq, Clone, Copy)] pub enum OpCode { - /// Text frame. - Text, - /// Binary frame. - Binary, - /// Close frame. - Close, - /// Ping frame. - Ping, - /// Pong frame. - Pong, + /// Text frame. + Text, + /// Binary frame. + Binary, + /// Close frame. + Close, + /// Ping frame. + Ping, + /// Pong frame. + Pong, } /// WebSocket frame. #[derive(Debug, Clone)] -pub struct Frame { - /// Whether the frame is finished or not. - pub finished: bool, - /// Opcode of the WebSocket frame. - pub opcode: OpCode, - /// Payload of the WebSocket frame. - pub payload: BytesMut, +pub struct Frame<'a> { + /// Whether the frame is finished or not. + pub finished: bool, + /// Opcode of the WebSocket frame. + pub opcode: OpCode, + /// Payload of the WebSocket frame. + pub payload: Payload<'a>, } -impl Frame { - /// Create a new text frame. - pub fn text(payload: BytesMut) -> Self { - Self { - finished: true, - opcode: OpCode::Text, - payload, - } - } +impl<'a> Frame<'a> { + /// Create a new text frame. + pub fn text(payload: Payload<'a>) -> Self { + Self { + finished: true, + opcode: OpCode::Text, + payload, + } + } - /// Create a new binary frame. - pub fn binary(payload: BytesMut) -> Self { - Self { - finished: true, - opcode: OpCode::Binary, - payload, - } - } + /// Create a new binary frame. + pub fn binary(payload: Payload<'a>) -> Self { + Self { + finished: true, + opcode: OpCode::Binary, + payload, + } + } - /// Create a new close frame. - pub fn close(payload: BytesMut) -> Self { - Self { - finished: true, - opcode: OpCode::Close, - payload, - } - } + /// Create a new close frame. + pub fn close(payload: Payload<'a>) -> Self { + Self { + finished: true, + opcode: OpCode::Close, + payload, + } + } } /// Generic WebSocket read trait. #[async_trait] pub trait WebSocketRead { - /// Read a frame from the socket. - async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result; + /// Read a frame from the socket. + async fn wisp_read_frame( + &mut self, + tx: &LockedWebSocketWrite, + ) -> Result, WispError>; } /// Generic WebSocket write trait. #[async_trait] pub trait WebSocketWrite { - /// Write a frame to the socket. - async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; + /// Write a frame to the socket. + async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>; - /// Close the socket. - async fn wisp_close(&mut self) -> Result<(), WispError>; + /// Close the socket. + async fn wisp_close(&mut self) -> Result<(), WispError>; } /// Locked WebSocket. @@ -88,35 +173,38 @@ pub trait WebSocketWrite { pub struct LockedWebSocketWrite(Arc>>); impl LockedWebSocketWrite { - /// Create a new locked websocket. - pub fn new(ws: Box) -> Self { - Self(Mutex::new(ws).into()) - } + /// Create a new locked websocket. + pub fn new(ws: Box) -> Self { + Self(Mutex::new(ws).into()) + } - /// Write a frame to the websocket. - pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> { - self.0.lock().await.wisp_write_frame(frame).await - } + /// Write a frame to the websocket. + pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> { + self.0.lock().await.wisp_write_frame(frame).await + } - /// Close the websocket. - pub async fn close(&self) -> Result<(), WispError> { - self.0.lock().await.wisp_close().await - } + /// Close the websocket. + pub async fn close(&self) -> Result<(), WispError> { + self.0.lock().await.wisp_close().await + } } -pub(crate) struct AppendingWebSocketRead(pub Option, pub R) +pub(crate) struct AppendingWebSocketRead(pub Option>, pub R) where - R: WebSocketRead + Send; + R: WebSocketRead + Send; #[async_trait] impl WebSocketRead for AppendingWebSocketRead where - R: WebSocketRead + Send, + R: WebSocketRead + Send, { - async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result { - if let Some(x) = self.0.take() { - return Ok(x); - } - return self.1.wisp_read_frame(tx).await; - } + async fn wisp_read_frame( + &mut self, + tx: &LockedWebSocketWrite, + ) -> Result, WispError> { + if let Some(x) = self.0.take() { + return Ok(x); + } + return self.1.wisp_read_frame(tx).await; + } }