add a new Payload struct to allow for one-copy writes and cargo fmt

This commit is contained in:
Toshit Chawda 2024-07-17 16:23:58 -07:00
parent 314c1bfa75
commit d6353bd5a9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
18 changed files with 3533 additions and 3395 deletions

View file

@ -3,15 +3,15 @@ use std::fmt::Write;
use rustls_pki_types::TrustAnchor; use rustls_pki_types::TrustAnchor;
fn main() { fn main() {
let mut code = String::with_capacity(256 * 1_024); let mut code = String::with_capacity(256 * 1_024);
code.push_str("const ROOTS = ["); code.push_str("const ROOTS = [");
for anchor in webpki_roots::TLS_SERVER_ROOTS { for anchor in webpki_roots::TLS_SERVER_ROOTS {
let TrustAnchor { let TrustAnchor {
subject, subject,
subject_public_key_info, subject_public_key_info,
name_constraints, name_constraints,
} = anchor; } = anchor;
code.write_fmt(format_args!( code.write_fmt(format_args!(
"{{subject:new Uint8Array([{}]),subject_public_key_info:new Uint8Array([{}]),name_constraints:{}}},", "{{subject:new Uint8Array([{}]),subject_public_key_info:new Uint8Array([{}]),name_constraints:{}}},",
subject subject
.as_ref() .as_ref()
@ -34,8 +34,8 @@ fn main() {
} }
)) ))
.unwrap(); .unwrap();
} }
code.pop(); code.pop();
code.push_str("];"); code.push_str("];");
println!("{}", code); println!("{}", code);
} }

View file

@ -1,183 +1,181 @@
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use bytes::{buf::UninitSlice, BufMut, BytesMut};
use futures_util::{ use futures_util::{io::WriteHalf, lock::Mutex, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt,
};
use js_sys::{Function, Uint8Array}; use js_sys::{Function, Uint8Array};
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use wisp_mux::MuxStreamIoSink;
use crate::{ use crate::{
stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream}, stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream},
utils::convert_body, utils::convert_body,
EpoxyError, EpoxyHandlers, EpoxyError, EpoxyHandlers,
}; };
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyIoStream { pub struct EpoxyIoStream {
tx: Mutex<WriteHalf<ProviderAsyncRW>>, tx: Mutex<WriteHalf<ProviderAsyncRW>>,
onerror: Function, onerror: Function,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyIoStream { impl EpoxyIoStream {
pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self { pub(crate) fn connect(stream: ProviderAsyncRW, handlers: EpoxyHandlers) -> Self {
let (mut rx, tx) = stream.split(); let (mut rx, tx) = stream.split();
let tx = Mutex::new(tx); let tx = Mutex::new(tx);
let EpoxyHandlers { let EpoxyHandlers {
onopen, onopen,
onclose, onclose,
onerror, onerror,
onmessage, onmessage,
} = handlers; } = handlers;
let onerror_cloned = onerror.clone(); let onerror_cloned = onerror.clone();
// similar to tokio_util::io::ReaderStream // similar to tokio_util::io::ReaderStream
spawn_local(async move { spawn_local(async move {
let mut buf = BytesMut::with_capacity(4096); let mut buf = BytesMut::with_capacity(4096);
loop { loop {
match rx match rx
.read(unsafe { .read(unsafe {
std::mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut()) std::mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut())
}) })
.await .await
{ {
Ok(cnt) => { Ok(cnt) => {
if cnt > 0 { if cnt > 0 {
unsafe { buf.advance_mut(cnt) }; unsafe { buf.advance_mut(cnt) };
let _ = onmessage let _ = onmessage
.call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref())); .call1(&JsValue::null(), &Uint8Array::from(buf.split().as_ref()));
} }
} }
Err(err) => { Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break; break;
} }
} }
} }
let _ = onclose.call0(&JsValue::null()); let _ = onclose.call0(&JsValue::null());
}); });
let _ = onopen.call0(&JsValue::null()); let _ = onopen.call0(&JsValue::null());
Self { Self {
tx, tx,
onerror: onerror_cloned, onerror: onerror_cloned,
} }
} }
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret: Result<(), EpoxyError> = async move { let ret: Result<(), EpoxyError> = async move {
let payload = convert_body(payload) let payload = convert_body(payload)
.await .await
.map_err(|_| EpoxyError::InvalidPayload)? .map_err(|_| EpoxyError::InvalidPayload)?
.0 .0
.to_vec(); .to_vec();
Ok(self.tx.lock().await.write_all(&payload).await?) Ok(self.tx.lock().await.write_all(&payload).await?)
} }
.await; .await;
match ret { match ret {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = self let _ = self
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
pub async fn close(&self) -> Result<(), EpoxyError> { pub async fn close(&self) -> Result<(), EpoxyError> {
match self.tx.lock().await.close().await { match self.tx.lock().await.close().await {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = self let _ = self
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err.into()) Err(err.into())
} }
} }
} }
} }
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyUdpStream { pub struct EpoxyUdpStream {
tx: Mutex<SplitSink<ProviderUnencryptedStream, Bytes>>, tx: Mutex<MuxStreamIoSink>,
onerror: Function, onerror: Function,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyUdpStream { impl EpoxyUdpStream {
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
let (tx, mut rx) = stream.split(); let (mut rx, tx) = stream.into_split();
let tx = Mutex::new(tx);
let EpoxyHandlers { let EpoxyHandlers {
onopen, onopen,
onclose, onclose,
onerror, onerror,
onmessage, onmessage,
} = handlers; } = handlers;
let onerror_cloned = onerror.clone(); let onerror_cloned = onerror.clone();
spawn_local(async move { spawn_local(async move {
while let Some(packet) = rx.next().await { while let Some(packet) = rx.next().await {
match packet { match packet {
Ok(buf) => { Ok(buf) => {
let _ = onmessage.call1(&JsValue::null(), &Uint8Array::from(buf.as_ref())); let _ = onmessage.call1(&JsValue::null(), &Uint8Array::from(buf.as_ref()));
} }
Err(err) => { Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break; break;
} }
} }
} }
let _ = onclose.call0(&JsValue::null()); let _ = onclose.call0(&JsValue::null());
}); });
let _ = onopen.call0(&JsValue::null()); let _ = onopen.call0(&JsValue::null());
Self { Self {
tx, tx: tx.into(),
onerror: onerror_cloned, onerror: onerror_cloned,
} }
} }
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret: Result<(), EpoxyError> = async move { let ret: Result<(), EpoxyError> = async move {
let payload = convert_body(payload) let payload = convert_body(payload)
.await .await
.map_err(|_| EpoxyError::InvalidPayload)? .map_err(|_| EpoxyError::InvalidPayload)?
.0 .0
.to_vec(); .to_vec();
Ok(self.tx.lock().await.send(payload.into()).await?) Ok(self.tx.lock().await.send(payload.as_ref()).await?)
} }
.await; .await;
match ret { match ret {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = self let _ = self
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
pub async fn close(&self) -> Result<(), EpoxyError> { pub async fn close(&self) -> Result<(), EpoxyError> {
match self.tx.lock().await.close().await { match self.tx.lock().await.close().await {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = self let _ = self
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err.into()) Err(err.into())
} }
} }
} }
} }

View file

@ -5,14 +5,14 @@ use std::{str::FromStr, sync::Arc};
use async_compression::futures::bufread as async_comp; use async_compression::futures::bufread as async_comp;
use bytes::Bytes; use bytes::Bytes;
use cfg_if::cfg_if; use cfg_if::cfg_if;
use futures_util::TryStreamExt;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use futures_util::future::Either; use futures_util::future::Either;
use futures_util::TryStreamExt;
use http::{ use http::{
header::{InvalidHeaderName, InvalidHeaderValue}, header::{InvalidHeaderName, InvalidHeaderValue},
method::InvalidMethod, method::InvalidMethod,
uri::{InvalidUri, InvalidUriParts}, uri::{InvalidUri, InvalidUriParts},
HeaderName, HeaderValue, Method, Request, Response, HeaderName, HeaderValue, Method, Request, Response,
}; };
use hyper::{body::Incoming, Uri}; use hyper::{body::Incoming, Uri};
use hyper_util_wasm::client::legacy::Client; use hyper_util_wasm::client::legacy::Client;
@ -22,7 +22,8 @@ use js_sys::{Array, Function, Object, Reflect};
use stream_provider::{StreamProvider, StreamProviderService}; use stream_provider::{StreamProvider, StreamProviderService};
use thiserror::Error; use thiserror::Error;
use utils::{ use utils::{
asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, object_get, object_set, IncomingBody, UriExt, WasmExecutor asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect,
object_get, object_set, IncomingBody, UriExt, WasmExecutor,
}; };
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_streams::ReadableStream; use wasm_streams::ReadableStream;
@ -45,409 +46,409 @@ type HttpBody = http_body_util::Full<Bytes>;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum EpoxyError { pub enum EpoxyError {
#[error("Invalid DNS name: {0:?}")] #[error("Invalid DNS name: {0:?}")]
InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError), InvalidDnsName(#[from] futures_rustls::rustls::pki_types::InvalidDnsNameError),
#[error("Wisp: {0:?}")] #[error("Wisp: {0:?}")]
Wisp(#[from] wisp_mux::WispError), Wisp(#[from] wisp_mux::WispError),
#[error("IO: {0:?}")] #[error("IO: {0:?}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
#[error("HTTP: {0:?}")] #[error("HTTP: {0:?}")]
Http(#[from] http::Error), Http(#[from] http::Error),
#[error("Hyper client: {0:?}")] #[error("Hyper client: {0:?}")]
HyperClient(#[from] hyper_util_wasm::client::legacy::Error), HyperClient(#[from] hyper_util_wasm::client::legacy::Error),
#[error("Hyper: {0:?}")] #[error("Hyper: {0:?}")]
Hyper(#[from] hyper::Error), Hyper(#[from] hyper::Error),
#[error("HTTP ToStr: {0:?}")] #[error("HTTP ToStr: {0:?}")]
ToStr(#[from] http::header::ToStrError), ToStr(#[from] http::header::ToStrError),
#[cfg(feature = "full")] #[cfg(feature = "full")]
#[error("Getrandom: {0:?}")] #[error("Getrandom: {0:?}")]
GetRandom(#[from] getrandom::Error), GetRandom(#[from] getrandom::Error),
#[cfg(feature = "full")] #[cfg(feature = "full")]
#[error("Fastwebsockets: {0:?}")] #[error("Fastwebsockets: {0:?}")]
FastWebSockets(#[from] fastwebsockets::WebSocketError), FastWebSockets(#[from] fastwebsockets::WebSocketError),
#[error("Invalid URL scheme")] #[error("Invalid URL scheme")]
InvalidUrlScheme, InvalidUrlScheme,
#[error("No URL host found")] #[error("No URL host found")]
NoUrlHost, NoUrlHost,
#[error("No URL port found")] #[error("No URL port found")]
NoUrlPort, NoUrlPort,
#[error("Invalid request body")] #[error("Invalid request body")]
InvalidRequestBody, InvalidRequestBody,
#[error("Invalid request")] #[error("Invalid request")]
InvalidRequest, InvalidRequest,
#[error("Invalid websocket response status code")] #[error("Invalid websocket response status code")]
WsInvalidStatusCode, WsInvalidStatusCode,
#[error("Invalid websocket upgrade header")] #[error("Invalid websocket upgrade header")]
WsInvalidUpgradeHeader, WsInvalidUpgradeHeader,
#[error("Invalid websocket connection header")] #[error("Invalid websocket connection header")]
WsInvalidConnectionHeader, WsInvalidConnectionHeader,
#[error("Invalid websocket payload")] #[error("Invalid websocket payload")]
WsInvalidPayload, WsInvalidPayload,
#[error("Invalid payload")] #[error("Invalid payload")]
InvalidPayload, InvalidPayload,
#[error("Invalid certificate store")] #[error("Invalid certificate store")]
InvalidCertStore, InvalidCertStore,
#[error("WebSocket failed to connect")] #[error("WebSocket failed to connect")]
WebSocketConnectFailed, WebSocketConnectFailed,
#[error("Failed to construct response headers object")] #[error("Failed to construct response headers object")]
ResponseHeadersFromEntriesFailed, ResponseHeadersFromEntriesFailed,
#[error("Failed to construct response object")] #[error("Failed to construct response object")]
ResponseNewFailed, ResponseNewFailed,
#[error("Failed to construct define_property object")] #[error("Failed to construct define_property object")]
DefinePropertyObjFailed, DefinePropertyObjFailed,
#[error("Failed to set raw header item")] #[error("Failed to set raw header item")]
RawHeaderSetFailed, RawHeaderSetFailed,
} }
impl From<EpoxyError> for JsValue { impl From<EpoxyError> for JsValue {
fn from(value: EpoxyError) -> Self { fn from(value: EpoxyError) -> Self {
JsError::from(value).into() JsError::from(value).into()
} }
} }
impl From<InvalidUri> for EpoxyError { impl From<InvalidUri> for EpoxyError {
fn from(value: InvalidUri) -> Self { fn from(value: InvalidUri) -> Self {
http::Error::from(value).into() http::Error::from(value).into()
} }
} }
impl From<InvalidUriParts> for EpoxyError { impl From<InvalidUriParts> for EpoxyError {
fn from(value: InvalidUriParts) -> Self { fn from(value: InvalidUriParts) -> Self {
http::Error::from(value).into() http::Error::from(value).into()
} }
} }
impl From<InvalidHeaderName> for EpoxyError { impl From<InvalidHeaderName> for EpoxyError {
fn from(value: InvalidHeaderName) -> Self { fn from(value: InvalidHeaderName) -> Self {
http::Error::from(value).into() http::Error::from(value).into()
} }
} }
impl From<InvalidHeaderValue> for EpoxyError { impl From<InvalidHeaderValue> for EpoxyError {
fn from(value: InvalidHeaderValue) -> Self { fn from(value: InvalidHeaderValue) -> Self {
http::Error::from(value).into() http::Error::from(value).into()
} }
} }
impl From<InvalidMethod> for EpoxyError { impl From<InvalidMethod> for EpoxyError {
fn from(value: InvalidMethod) -> Self { fn from(value: InvalidMethod) -> Self {
http::Error::from(value).into() http::Error::from(value).into()
} }
} }
#[derive(Debug)] #[derive(Debug)]
enum EpoxyResponse { enum EpoxyResponse {
Success(Response<Incoming>), Success(Response<Incoming>),
Redirect((Response<Incoming>, http::Request<HttpBody>)), Redirect((Response<Incoming>, http::Request<HttpBody>)),
} }
#[cfg(feature = "full")] #[cfg(feature = "full")]
enum EpoxyCompression { enum EpoxyCompression {
Brotli, Brotli,
Gzip, Gzip,
} }
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyClientOptions { pub struct EpoxyClientOptions {
pub wisp_v2: bool, pub wisp_v2: bool,
pub udp_extension_required: bool, pub udp_extension_required: bool,
#[wasm_bindgen(getter_with_clone)] #[wasm_bindgen(getter_with_clone)]
pub websocket_protocols: Vec<String>, pub websocket_protocols: Vec<String>,
pub redirect_limit: usize, pub redirect_limit: usize,
#[wasm_bindgen(getter_with_clone)] #[wasm_bindgen(getter_with_clone)]
pub user_agent: String, pub user_agent: String,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyClientOptions { impl EpoxyClientOptions {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new_default() -> Self { pub fn new_default() -> Self {
Self::default() Self::default()
} }
} }
impl Default for EpoxyClientOptions { impl Default for EpoxyClientOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
wisp_v2: true, wisp_v2: true,
udp_extension_required: true, udp_extension_required: true,
websocket_protocols: Vec::new(), websocket_protocols: Vec::new(),
redirect_limit: 10, redirect_limit: 10,
user_agent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36".to_string(), user_agent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36".to_string(),
} }
} }
} }
#[wasm_bindgen(getter_with_clone)] #[wasm_bindgen(getter_with_clone)]
pub struct EpoxyHandlers { pub struct EpoxyHandlers {
pub onopen: Function, pub onopen: Function,
pub onclose: Function, pub onclose: Function,
pub onerror: Function, pub onerror: Function,
pub onmessage: Function, pub onmessage: Function,
} }
#[cfg(feature = "full")] #[cfg(feature = "full")]
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyHandlers { impl EpoxyHandlers {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new( pub fn new(
onopen: Function, onopen: Function,
onclose: Function, onclose: Function,
onerror: Function, onerror: Function,
onmessage: Function, onmessage: Function,
) -> Self { ) -> Self {
Self { Self {
onopen, onopen,
onclose, onclose,
onerror, onerror,
onmessage, onmessage,
} }
} }
} }
#[wasm_bindgen(inspectable)] #[wasm_bindgen(inspectable)]
pub struct EpoxyClient { pub struct EpoxyClient {
stream_provider: Arc<StreamProvider>, stream_provider: Arc<StreamProvider>,
client: Client<StreamProviderService, HttpBody>, client: Client<StreamProviderService, HttpBody>,
pub redirect_limit: usize, pub redirect_limit: usize,
#[wasm_bindgen(getter_with_clone)] #[wasm_bindgen(getter_with_clone)]
pub user_agent: String, pub user_agent: String,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyClient { impl EpoxyClient {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
pub fn new( pub fn new(
wisp_url: String, wisp_url: String,
certs: Array, certs: Array,
options: EpoxyClientOptions, options: EpoxyClientOptions,
) -> Result<EpoxyClient, EpoxyError> { ) -> Result<EpoxyClient, EpoxyError> {
let wisp_url: Uri = wisp_url.try_into()?; let wisp_url: Uri = wisp_url.try_into()?;
if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") { if wisp_url.scheme_str() != Some("wss") && wisp_url.scheme_str() != Some("ws") {
return Err(EpoxyError::InvalidUrlScheme); return Err(EpoxyError::InvalidUrlScheme);
} }
let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), certs, &options)?); let stream_provider = Arc::new(StreamProvider::new(wisp_url.to_string(), certs, &options)?);
let service = StreamProviderService(stream_provider.clone()); let service = StreamProviderService(stream_provider.clone());
let client = Client::builder(WasmExecutor) let client = Client::builder(WasmExecutor)
.http09_responses(true) .http09_responses(true)
.http1_title_case_headers(true) .http1_title_case_headers(true)
.http1_preserve_header_case(true) .http1_preserve_header_case(true)
.build(service); .build(service);
Ok(Self { Ok(Self {
stream_provider, stream_provider,
client, client,
redirect_limit: options.redirect_limit, redirect_limit: options.redirect_limit,
user_agent: options.user_agent, user_agent: options.user_agent,
}) })
} }
pub async fn replace_stream_provider(&self) -> Result<(), EpoxyError> { pub async fn replace_stream_provider(&self) -> Result<(), EpoxyError> {
self.stream_provider.replace_client().await self.stream_provider.replace_client().await
} }
#[cfg(feature = "full")] #[cfg(feature = "full")]
pub async fn connect_websocket( pub async fn connect_websocket(
&self, &self,
handlers: EpoxyHandlers, handlers: EpoxyHandlers,
url: String, url: String,
protocols: Vec<String>, protocols: Vec<String>,
headers: JsValue, headers: JsValue,
) -> Result<EpoxyWebSocket, EpoxyError> { ) -> Result<EpoxyWebSocket, EpoxyError> {
EpoxyWebSocket::connect(self, handlers, url, protocols, headers, &self.user_agent).await EpoxyWebSocket::connect(self, handlers, url, protocols, headers, &self.user_agent).await
} }
#[cfg(feature = "full")] #[cfg(feature = "full")]
pub async fn connect_tcp( pub async fn connect_tcp(
&self, &self,
handlers: EpoxyHandlers, handlers: EpoxyHandlers,
url: String, url: String,
) -> Result<EpoxyIoStream, EpoxyError> { ) -> Result<EpoxyIoStream, EpoxyError> {
let url: Uri = url.try_into()?; let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?; let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
match self match self
.stream_provider .stream_provider
.get_asyncread(StreamType::Tcp, host.to_string(), port) .get_asyncread(StreamType::Tcp, host.to_string(), port)
.await .await
{ {
Ok(stream) => Ok(EpoxyIoStream::connect(Either::Right(stream), handlers)), Ok(stream) => Ok(EpoxyIoStream::connect(Either::Right(stream), handlers)),
Err(err) => { Err(err) => {
let _ = handlers let _ = handlers
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
#[cfg(feature = "full")] #[cfg(feature = "full")]
pub async fn connect_tls( pub async fn connect_tls(
&self, &self,
handlers: EpoxyHandlers, handlers: EpoxyHandlers,
url: String, url: String,
) -> Result<EpoxyIoStream, EpoxyError> { ) -> Result<EpoxyIoStream, EpoxyError> {
let url: Uri = url.try_into()?; let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?; let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
match self match self
.stream_provider .stream_provider
.get_tls_stream(host.to_string(), port) .get_tls_stream(host.to_string(), port)
.await .await
{ {
Ok(stream) => Ok(EpoxyIoStream::connect(Either::Left(stream), handlers)), Ok(stream) => Ok(EpoxyIoStream::connect(Either::Left(stream), handlers)),
Err(err) => { Err(err) => {
let _ = handlers let _ = handlers
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
#[cfg(feature = "full")] #[cfg(feature = "full")]
pub async fn connect_udp( pub async fn connect_udp(
&self, &self,
handlers: EpoxyHandlers, handlers: EpoxyHandlers,
url: String, url: String,
) -> Result<EpoxyUdpStream, EpoxyError> { ) -> Result<EpoxyUdpStream, EpoxyError> {
let url: Uri = url.try_into()?; let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?; let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?; let port = url.port_u16().ok_or(EpoxyError::NoUrlPort)?;
match self match self
.stream_provider .stream_provider
.get_stream(StreamType::Udp, host.to_string(), port) .get_stream(StreamType::Udp, host.to_string(), port)
.await .await
{ {
Ok(stream) => Ok(EpoxyUdpStream::connect(stream, handlers)), Ok(stream) => Ok(EpoxyUdpStream::connect(stream, handlers)),
Err(err) => { Err(err) => {
let _ = handlers let _ = handlers
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
async fn send_req_inner( async fn send_req_inner(
&self, &self,
req: http::Request<HttpBody>, req: http::Request<HttpBody>,
should_redirect: bool, should_redirect: bool,
) -> Result<EpoxyResponse, EpoxyError> { ) -> Result<EpoxyResponse, EpoxyError> {
let new_req = if should_redirect { let new_req = if should_redirect {
Some(req.clone()) Some(req.clone())
} else { } else {
None None
}; };
let res = self.client.request(req).await; let res = self.client.request(req).await;
match res { match res {
Ok(res) => { Ok(res) => {
if is_redirect(res.status().as_u16()) if is_redirect(res.status().as_u16())
&& let Some(mut new_req) = new_req && let Some(mut new_req) = new_req
&& let Some(location) = res.headers().get("Location") && let Some(location) = res.headers().get("Location")
&& let Ok(redirect_url) = new_req.uri().get_redirect(location) && let Ok(redirect_url) = new_req.uri().get_redirect(location)
&& let Some(redirect_url_authority) = redirect_url.clone().authority() && let Some(redirect_url_authority) = redirect_url.clone().authority()
{ {
*new_req.uri_mut() = redirect_url; *new_req.uri_mut() = redirect_url;
new_req.headers_mut().insert( new_req.headers_mut().insert(
"Host", "Host",
HeaderValue::from_str(redirect_url_authority.as_str())?, HeaderValue::from_str(redirect_url_authority.as_str())?,
); );
Ok(EpoxyResponse::Redirect((res, new_req))) Ok(EpoxyResponse::Redirect((res, new_req)))
} else { } else {
Ok(EpoxyResponse::Success(res)) Ok(EpoxyResponse::Success(res))
} }
} }
Err(err) => Err(err.into()), Err(err) => Err(err.into()),
} }
} }
async fn send_req( async fn send_req(
&self, &self,
req: http::Request<HttpBody>, req: http::Request<HttpBody>,
should_redirect: bool, should_redirect: bool,
) -> Result<(hyper::Response<Incoming>, Uri, bool), EpoxyError> { ) -> Result<(hyper::Response<Incoming>, Uri, bool), EpoxyError> {
let mut redirected = false; let mut redirected = false;
let mut current_url = req.uri().clone(); let mut current_url = req.uri().clone();
let mut current_resp: EpoxyResponse = self.send_req_inner(req, should_redirect).await?; let mut current_resp: EpoxyResponse = self.send_req_inner(req, should_redirect).await?;
for _ in 0..self.redirect_limit { for _ in 0..self.redirect_limit {
match current_resp { match current_resp {
EpoxyResponse::Success(_) => break, EpoxyResponse::Success(_) => break,
EpoxyResponse::Redirect((_, req)) => { EpoxyResponse::Redirect((_, req)) => {
redirected = true; redirected = true;
current_url = req.uri().clone(); current_url = req.uri().clone();
current_resp = self.send_req_inner(req, should_redirect).await? current_resp = self.send_req_inner(req, should_redirect).await?
} }
} }
} }
match current_resp { match current_resp {
EpoxyResponse::Success(resp) => Ok((resp, current_url, redirected)), EpoxyResponse::Success(resp) => Ok((resp, current_url, redirected)),
EpoxyResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)), EpoxyResponse::Redirect((resp, _)) => Ok((resp, current_url, redirected)),
} }
} }
pub async fn fetch( pub async fn fetch(
&self, &self,
url: String, url: String,
options: Object, options: Object,
) -> Result<web_sys::Response, EpoxyError> { ) -> Result<web_sys::Response, EpoxyError> {
let url: Uri = url.try_into()?; let url: Uri = url.try_into()?;
// only valid `Scheme`s are HTTP and HTTPS, which are the ones we support // only valid `Scheme`s are HTTP and HTTPS, which are the ones we support
url.scheme().ok_or(EpoxyError::InvalidUrlScheme)?; url.scheme().ok_or(EpoxyError::InvalidUrlScheme)?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?; let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let request_method = object_get(&options, "method") let request_method = object_get(&options, "method")
.and_then(|x| x.as_string()) .and_then(|x| x.as_string())
.unwrap_or_else(|| "GET".to_string()); .unwrap_or_else(|| "GET".to_string());
let request_method: Method = Method::from_str(&request_method)?; let request_method: Method = Method::from_str(&request_method)?;
let request_redirect = object_get(&options, "redirect") let request_redirect = object_get(&options, "redirect")
.map(|x| { .map(|x| {
!matches!( !matches!(
x.as_string().unwrap_or_default().as_str(), x.as_string().unwrap_or_default().as_str(),
"error" | "manual" "error" | "manual"
) )
}) })
.unwrap_or(true); .unwrap_or(true);
let mut body_content_type: Option<String> = None; let mut body_content_type: Option<String> = None;
let body = match object_get(&options, "body") { let body = match object_get(&options, "body") {
Some(buf) => { Some(buf) => {
let (body, req) = convert_body(buf) let (body, req) = convert_body(buf)
.await .await
.map_err(|_| EpoxyError::InvalidRequestBody)?; .map_err(|_| EpoxyError::InvalidRequestBody)?;
body_content_type = req.headers().get("Content-Type").ok().flatten(); body_content_type = req.headers().get("Content-Type").ok().flatten();
Bytes::from(body.to_vec()) Bytes::from(body.to_vec())
} }
None => Bytes::new(), None => Bytes::new(),
}; };
let headers = object_get(&options, "headers").and_then(|val| { let headers = object_get(&options, "headers").and_then(|val| {
if web_sys::Headers::instanceof(&val) { if web_sys::Headers::instanceof(&val) {
Some(entries_of_object(&Object::from_entries(&val).ok()?)) Some(entries_of_object(&Object::from_entries(&val).ok()?))
} else if val.is_truthy() { } else if val.is_truthy() {
Some(entries_of_object(&Object::from(val))) Some(entries_of_object(&Object::from(val)))
} else { } else {
None None
} }
}); });
let mut request_builder = Request::builder().uri(url.clone()).method(request_method); let mut request_builder = Request::builder().uri(url.clone()).method(request_method);
// Generic InvalidRequest because this only returns None if the builder has some error // Generic InvalidRequest because this only returns None if the builder has some error
// which we don't know // which we don't know
let headers_map = request_builder let headers_map = request_builder
.headers_mut() .headers_mut()
.ok_or(EpoxyError::InvalidRequest)?; .ok_or(EpoxyError::InvalidRequest)?;
cfg_if! { cfg_if! {
if #[cfg(feature = "full")] { if #[cfg(feature = "full")] {
@ -456,54 +457,54 @@ impl EpoxyClient {
headers_map.insert("Accept-Encoding", HeaderValue::from_static("identity")); headers_map.insert("Accept-Encoding", HeaderValue::from_static("identity"));
} }
} }
headers_map.insert("Connection", HeaderValue::from_static("keep-alive")); headers_map.insert("Connection", HeaderValue::from_static("keep-alive"));
headers_map.insert("User-Agent", HeaderValue::from_str(&self.user_agent)?); headers_map.insert("User-Agent", HeaderValue::from_str(&self.user_agent)?);
headers_map.insert("Host", HeaderValue::from_str(host)?); headers_map.insert("Host", HeaderValue::from_str(host)?);
if body.is_empty() { if body.is_empty() {
headers_map.insert("Content-Length", HeaderValue::from_static("0")); headers_map.insert("Content-Length", HeaderValue::from_static("0"));
} }
if let Some(content_type) = body_content_type { if let Some(content_type) = body_content_type {
headers_map.insert("Content-Type", HeaderValue::from_str(&content_type)?); headers_map.insert("Content-Type", HeaderValue::from_str(&content_type)?);
} }
if let Some(headers) = headers { if let Some(headers) = headers {
for hdr in headers { for hdr in headers {
headers_map.insert( headers_map.insert(
HeaderName::from_str(&hdr[0])?, HeaderName::from_str(&hdr[0])?,
HeaderValue::from_str(&hdr[1])?, HeaderValue::from_str(&hdr[1])?,
); );
} }
} }
let (response, response_uri, redirected) = self let (response, response_uri, redirected) = self
.send_req(request_builder.body(HttpBody::new(body))?, request_redirect) .send_req(request_builder.body(HttpBody::new(body))?, request_redirect)
.await?; .await?;
let response_headers: Array = response let response_headers: Array = response
.headers() .headers()
.iter() .iter()
.filter_map(|val| { .filter_map(|val| {
Some(Array::of2( Some(Array::of2(
&val.0.as_str().into(), &val.0.as_str().into(),
&val.1.to_str().ok()?.into(), &val.1.to_str().ok()?.into(),
)) ))
}) })
.collect(); .collect();
let response_headers = Object::from_entries(&response_headers) let response_headers = Object::from_entries(&response_headers)
.map_err(|_| EpoxyError::ResponseHeadersFromEntriesFailed)?; .map_err(|_| EpoxyError::ResponseHeadersFromEntriesFailed)?;
let response_headers_raw = response.headers().clone(); let response_headers_raw = response.headers().clone();
let mut response_builder = ResponseInit::new(); let mut response_builder = ResponseInit::new();
response_builder response_builder
.headers(&response_headers) .headers(&response_headers)
.status(response.status().as_u16()) .status(response.status().as_u16())
.status_text(response.status().canonical_reason().unwrap_or_default()); .status_text(response.status().canonical_reason().unwrap_or_default());
cfg_if! { cfg_if! {
if #[cfg(feature = "full")] { if #[cfg(feature = "full")] {
let response_stream = if !is_null_body(response.status().as_u16()) { let response_stream = if !is_null_body(response.status().as_u16()) {
let compression = match response let compression = match response
.headers() .headers()
@ -532,59 +533,59 @@ impl EpoxyClient {
} else { } else {
None None
}; };
} else { } else {
let response_stream = if !is_null_body(response.status().as_u16()) { let response_stream = if !is_null_body(response.status().as_u16()) {
let response_body = IncomingBody::new(response.into_body()).into_async_read(); let response_body = IncomingBody::new(response.into_body()).into_async_read();
Some(ReadableStream::from_stream(asyncread_to_readablestream_stream(response_body)).into_raw()) Some(ReadableStream::from_stream(asyncread_to_readablestream_stream(response_body)).into_raw())
} else { } else {
None None
}; };
} }
} }
let resp = web_sys::Response::new_with_opt_readable_stream_and_init( let resp = web_sys::Response::new_with_opt_readable_stream_and_init(
response_stream.as_ref(), response_stream.as_ref(),
&response_builder, &response_builder,
) )
.map_err(|_| EpoxyError::ResponseNewFailed)?; .map_err(|_| EpoxyError::ResponseNewFailed)?;
Object::define_property( Object::define_property(
&resp, &resp,
&"url".into(), &"url".into(),
&utils::define_property_obj(response_uri.to_string().into(), false) &utils::define_property_obj(response_uri.to_string().into(), false)
.map_err(|_| EpoxyError::DefinePropertyObjFailed)?, .map_err(|_| EpoxyError::DefinePropertyObjFailed)?,
); );
Object::define_property( Object::define_property(
&resp, &resp,
&"redirected".into(), &"redirected".into(),
&utils::define_property_obj(redirected.into(), false) &utils::define_property_obj(redirected.into(), false)
.map_err(|_| EpoxyError::DefinePropertyObjFailed)?, .map_err(|_| EpoxyError::DefinePropertyObjFailed)?,
); );
let raw_headers = Object::new(); let raw_headers = Object::new();
for (k, v) in response_headers_raw.iter() { for (k, v) in response_headers_raw.iter() {
let k: JsValue = k.to_string().into(); let k: JsValue = k.to_string().into();
let v: JsValue = v.to_str()?.to_string().into(); let v: JsValue = v.to_str()?.to_string().into();
if let Ok(jv) = Reflect::get(&raw_headers, &k) { if let Ok(jv) = Reflect::get(&raw_headers, &k) {
if jv.is_array() { if jv.is_array() {
let arr = Array::from(&jv); let arr = Array::from(&jv);
arr.push(&v); arr.push(&v);
object_set(&raw_headers, &k, &arr)?; object_set(&raw_headers, &k, &arr)?;
} else if jv.is_truthy() { } else if jv.is_truthy() {
object_set(&raw_headers, &k, &Array::of2(&jv, &v))?; object_set(&raw_headers, &k, &Array::of2(&jv, &v))?;
} else { } else {
object_set(&raw_headers, &k, &v)?; object_set(&raw_headers, &k, &v)?;
} }
} }
} }
Object::define_property( Object::define_property(
&resp, &resp,
&"rawHeaders".into(), &"rawHeaders".into(),
&utils::define_property_obj(raw_headers.into(), false) &utils::define_property_obj(raw_headers.into(), false)
.map_err(|_| EpoxyError::DefinePropertyObjFailed)?, .map_err(|_| EpoxyError::DefinePropertyObjFailed)?,
); );
Ok(resp) Ok(resp)
} }
} }

View file

@ -5,7 +5,9 @@ use futures_rustls::{
TlsConnector, TlsStream, TlsConnector, TlsStream,
}; };
use futures_util::{ use futures_util::{
future::Either, lock::{Mutex, MutexGuard}, AsyncRead, AsyncWrite, Future future::Either,
lock::{Mutex, MutexGuard},
AsyncRead, AsyncWrite, Future,
}; };
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection}; use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
use js_sys::{Array, Reflect, Uint8Array}; use js_sys::{Array, Reflect, Uint8Array};
@ -81,7 +83,7 @@ impl StreamProvider {
mut locked: MutexGuard<'_, Option<ClientMux>>, mut locked: MutexGuard<'_, Option<ClientMux>>,
) -> Result<(), EpoxyError> { ) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> = let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
vec![Box::new(UdpProtocolExtensionBuilder())]; vec![Box::new(UdpProtocolExtensionBuilder)];
let extensions = if self.wisp_v2 { let extensions = if self.wisp_v2 {
Some(extensions_vec.as_slice()) Some(extensions_vec.as_slice())
} else { } else {

View file

@ -2,168 +2,168 @@
//! hyper_util::rt::tokio::TokioIo //! hyper_util::rt::tokio::TokioIo
use std::{ use std::{
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
pin_project! { pin_project! {
/// A wrapping implementing hyper IO traits for a type that /// A wrapping implementing hyper IO traits for a type that
/// implements Tokio's IO traits. /// implements Tokio's IO traits.
#[derive(Debug)] #[derive(Debug)]
pub struct TokioIo<T> { pub struct TokioIo<T> {
#[pin] #[pin]
inner: T, inner: T,
} }
} }
impl<T> TokioIo<T> { impl<T> TokioIo<T> {
/// Wrap a type implementing Tokio's IO traits. /// Wrap a type implementing Tokio's IO traits.
pub fn new(inner: T) -> Self { pub fn new(inner: T) -> Self {
Self { inner } Self { inner }
} }
/// Borrow the inner type. /// Borrow the inner type.
pub fn inner(&self) -> &T { pub fn inner(&self) -> &T {
&self.inner &self.inner
} }
/// Mut borrow the inner type. /// Mut borrow the inner type.
pub fn inner_mut(&mut self) -> &mut T { pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner &mut self.inner
} }
/// Consume this wrapper and get the inner type. /// Consume this wrapper and get the inner type.
pub fn into_inner(self) -> T { pub fn into_inner(self) -> T {
self.inner self.inner
} }
} }
impl<T> hyper::rt::Read for TokioIo<T> impl<T> hyper::rt::Read for TokioIo<T>
where where
T: tokio::io::AsyncRead, T: tokio::io::AsyncRead,
{ {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>, mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
let n = unsafe { let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(), Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other, other => return other,
} }
}; };
unsafe { unsafe {
buf.advance(n); buf.advance(n);
} }
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
impl<T> hyper::rt::Write for TokioIo<T> impl<T> hyper::rt::Write for TokioIo<T>
where where
T: tokio::io::AsyncWrite, T: tokio::io::AsyncWrite,
{ {
fn poll_write( fn poll_write(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
} }
fn poll_shutdown( fn poll_shutdown(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
} }
fn is_write_vectored(&self) -> bool { fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner) tokio::io::AsyncWrite::is_write_vectored(&self.inner)
} }
fn poll_write_vectored( fn poll_write_vectored(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>], bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
} }
} }
impl<T> tokio::io::AsyncRead for TokioIo<T> impl<T> tokio::io::AsyncRead for TokioIo<T>
where where
T: hyper::rt::Read, T: hyper::rt::Read,
{ {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>, tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
//let init = tbuf.initialized().len(); //let init = tbuf.initialized().len();
let filled = tbuf.filled().len(); let filled = tbuf.filled().len();
let sub_filled = unsafe { let sub_filled = unsafe {
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
Poll::Ready(Ok(())) => buf.filled().len(), Poll::Ready(Ok(())) => buf.filled().len(),
other => return other, other => return other,
} }
}; };
let n_filled = filled + sub_filled; let n_filled = filled + sub_filled;
// At least sub_filled bytes had to have been initialized. // At least sub_filled bytes had to have been initialized.
let n_init = sub_filled; let n_init = sub_filled;
unsafe { unsafe {
tbuf.assume_init(n_init); tbuf.assume_init(n_init);
tbuf.set_filled(n_filled); tbuf.set_filled(n_filled);
} }
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
impl<T> tokio::io::AsyncWrite for TokioIo<T> impl<T> tokio::io::AsyncWrite for TokioIo<T>
where where
T: hyper::rt::Write, T: hyper::rt::Write,
{ {
fn poll_write( fn poll_write(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write(self.project().inner, cx, buf) hyper::rt::Write::poll_write(self.project().inner, cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_flush(self.project().inner, cx) hyper::rt::Write::poll_flush(self.project().inner, cx)
} }
fn poll_shutdown( fn poll_shutdown(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_shutdown(self.project().inner, cx) hyper::rt::Write::poll_shutdown(self.project().inner, cx)
} }
fn is_write_vectored(&self) -> bool { fn is_write_vectored(&self) -> bool {
hyper::rt::Write::is_write_vectored(&self.inner) hyper::rt::Write::is_write_vectored(&self.inner)
} }
fn poll_write_vectored( fn poll_write_vectored(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>], bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
} }
} }

View file

@ -3,73 +3,78 @@ use std::{str::from_utf8, sync::Arc};
use base64::{prelude::BASE64_STANDARD, Engine}; use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::Bytes; use bytes::Bytes;
use fastwebsockets::{ use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite, FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
}; };
use futures_util::lock::Mutex; use futures_util::lock::Mutex;
use getrandom::getrandom; use getrandom::getrandom;
use http::{ use http::{
header::{ header::{
CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION,
}, UPGRADE, USER_AGENT,
Method, Request, Response, StatusCode, Uri, },
Method, Request, Response, StatusCode, Uri,
}; };
use hyper::{ use hyper::{
body::Incoming, body::Incoming,
upgrade::{self, Upgraded}, upgrade::{self, Upgraded},
}; };
use js_sys::{ArrayBuffer, Function, Object, Uint8Array}; use js_sys::{ArrayBuffer, Function, Object, Uint8Array};
use tokio::io::WriteHalf; use tokio::io::WriteHalf;
use wasm_bindgen::{prelude::*, JsError, JsValue}; use wasm_bindgen::{prelude::*, JsError, JsValue};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use crate::{tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody}; use crate::{
tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody,
};
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyWebSocket { pub struct EpoxyWebSocket {
tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>, tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
onerror: Function, onerror: Function,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyWebSocket { impl EpoxyWebSocket {
pub(crate) async fn connect( pub(crate) async fn connect(
client: &EpoxyClient, client: &EpoxyClient,
handlers: EpoxyHandlers, handlers: EpoxyHandlers,
url: String, url: String,
protocols: Vec<String>, protocols: Vec<String>,
headers: JsValue, headers: JsValue,
user_agent: &str, user_agent: &str,
) -> Result<Self, EpoxyError> { ) -> Result<Self, EpoxyError> {
let EpoxyHandlers { let EpoxyHandlers {
onopen, onopen,
onclose, onclose,
onerror, onerror,
onmessage, onmessage,
} = handlers; } = handlers;
let onerror_cloned = onerror.clone(); let onerror_cloned = onerror.clone();
let ret: Result<EpoxyWebSocket, EpoxyError> = async move { let ret: Result<EpoxyWebSocket, EpoxyError> = async move {
let url: Uri = url.try_into()?; let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?; let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let mut rand = [0u8; 16]; let mut rand = [0u8; 16];
getrandom(&mut rand)?; getrandom(&mut rand)?;
let key = BASE64_STANDARD.encode(rand); let key = BASE64_STANDARD.encode(rand);
let mut request = Request::builder() let mut request = Request::builder()
.method(Method::GET) .method(Method::GET)
.uri(url.clone()) .uri(url.clone())
.header(HOST, host) .header(HOST, host)
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket") .header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_KEY, key) .header(SEC_WEBSOCKET_KEY, key)
.header(SEC_WEBSOCKET_VERSION, "13") .header(SEC_WEBSOCKET_VERSION, "13")
.header(USER_AGENT, user_agent); .header(USER_AGENT, user_agent);
if !protocols.is_empty() { if !protocols.is_empty() {
request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(",")); request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
} }
if web_sys::Headers::instanceof(&headers) && let Ok(entries) = Object::from_entries(&headers) { if web_sys::Headers::instanceof(&headers)
&& let Ok(entries) = Object::from_entries(&headers)
{
for header in entries_of_object(&entries) { for header in entries_of_object(&entries) {
request = request.header(&header[0], &header[1]); request = request.header(&header[0], &header[1]);
} }
@ -79,153 +84,153 @@ impl EpoxyWebSocket {
} }
} }
let request = request.body(HttpBody::new(Bytes::new()))?; let request = request.body(HttpBody::new(Bytes::new()))?;
let mut response = client.client.request(request).await?; let mut response = client.client.request(request).await?;
verify(&response)?; verify(&response)?;
let websocket = WebSocket::after_handshake( let websocket = WebSocket::after_handshake(
TokioIo::new(upgrade::on(&mut response).await?), TokioIo::new(upgrade::on(&mut response).await?),
Role::Client, Role::Client,
); );
let (rx, tx) = websocket.split(tokio::io::split); let (rx, tx) = websocket.split(tokio::io::split);
let mut rx = FragmentCollectorRead::new(rx); let mut rx = FragmentCollectorRead::new(rx);
let tx = Arc::new(Mutex::new(tx)); let tx = Arc::new(Mutex::new(tx));
let read_tx = tx.clone(); let read_tx = tx.clone();
let onerror_cloned = onerror.clone(); let onerror_cloned = onerror.clone();
spawn_local(async move { spawn_local(async move {
loop { loop {
match rx match rx
.read_frame(&mut |arg| async { .read_frame(&mut |arg| async {
read_tx.lock().await.write_frame(arg).await read_tx.lock().await.write_frame(arg).await
}) })
.await .await
{ {
Ok(frame) => match frame.opcode { Ok(frame) => match frame.opcode {
OpCode::Text => { OpCode::Text => {
if let Ok(str) = from_utf8(&frame.payload) { if let Ok(str) = from_utf8(&frame.payload) {
let _ = onmessage.call1(&JsValue::null(), &str.into()); let _ = onmessage.call1(&JsValue::null(), &str.into());
} }
} }
OpCode::Binary => { OpCode::Binary => {
let _ = onmessage.call1( let _ = onmessage.call1(
&JsValue::null(), &JsValue::null(),
&Uint8Array::from(frame.payload.to_vec().as_slice()).into(), &Uint8Array::from(frame.payload.to_vec().as_slice()).into(),
); );
} }
OpCode::Close => { OpCode::Close => {
break; break;
} }
// ping/pong/continue // ping/pong/continue
_ => {} _ => {}
}, },
Err(err) => { Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into()); let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break; break;
} }
} }
} }
let _ = onclose.call0(&JsValue::null()); let _ = onclose.call0(&JsValue::null());
}); });
let _ = onopen.call0(&JsValue::null()); let _ = onopen.call0(&JsValue::null());
Ok(Self { Ok(Self {
tx, tx,
onerror: onerror_cloned, onerror: onerror_cloned,
}) })
} }
.await; .await;
match ret { match ret {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into()); let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> { pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret = if let Some(str) = payload.as_string() { let ret = if let Some(str) = payload.as_string() {
self.tx self.tx
.lock() .lock()
.await .await
.write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec()))) .write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec())))
.await .await
.map_err(EpoxyError::from) .map_err(EpoxyError::from)
} else if let Ok(binary) = payload.dyn_into::<ArrayBuffer>() { } else if let Ok(binary) = payload.dyn_into::<ArrayBuffer>() {
self.tx self.tx
.lock() .lock()
.await .await
.write_frame(Frame::binary(Payload::Owned( .write_frame(Frame::binary(Payload::Owned(
Uint8Array::new(&binary).to_vec(), Uint8Array::new(&binary).to_vec(),
))) )))
.await .await
.map_err(EpoxyError::from) .map_err(EpoxyError::from)
} else { } else {
Err(EpoxyError::WsInvalidPayload) Err(EpoxyError::WsInvalidPayload)
}; };
match ret { match ret {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = self let _ = self
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err) Err(err)
} }
} }
} }
pub async fn close(&self, code: u16, reason: String) -> Result<(), EpoxyError> { pub async fn close(&self, code: u16, reason: String) -> Result<(), EpoxyError> {
let ret = self let ret = self
.tx .tx
.lock() .lock()
.await .await
.write_frame(Frame::close(code, reason.as_bytes())) .write_frame(Frame::close(code, reason.as_bytes()))
.await; .await;
match ret { match ret {
Ok(ok) => Ok(ok), Ok(ok) => Ok(ok),
Err(err) => { Err(err) => {
let _ = self let _ = self
.onerror .onerror
.call1(&JsValue::null(), &err.to_string().into()); .call1(&JsValue::null(), &err.to_string().into());
Err(err.into()) Err(err.into())
} }
} }
} }
} }
// https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189 // https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189
fn verify(response: &Response<Incoming>) -> Result<(), EpoxyError> { fn verify(response: &Response<Incoming>) -> Result<(), EpoxyError> {
if response.status() != StatusCode::SWITCHING_PROTOCOLS { if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(EpoxyError::WsInvalidStatusCode); return Err(EpoxyError::WsInvalidStatusCode);
} }
let headers = response.headers(); let headers = response.headers();
if !headers if !headers
.get(UPGRADE) .get(UPGRADE)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket")) .map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(EpoxyError::WsInvalidUpgradeHeader); return Err(EpoxyError::WsInvalidUpgradeHeader);
} }
if !headers if !headers
.get(CONNECTION) .get(CONNECTION)
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade")) .map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false) .unwrap_or(false)
{ {
return Err(EpoxyError::WsInvalidConnectionHeader); return Err(EpoxyError::WsInvalidConnectionHeader);
} }
Ok(()) Ok(())
} }

View file

@ -1,6 +1,6 @@
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -13,214 +13,219 @@ use send_wrapper::SendWrapper;
use wasm_bindgen::{closure::Closure, JsCast}; use wasm_bindgen::{closure::Closure, JsCast};
use web_sys::{BinaryType, MessageEvent, WebSocket}; use web_sys::{BinaryType, MessageEvent, WebSocket};
use wisp_mux::{ use wisp_mux::{
ws::{Frame, LockedWebSocketWrite, WebSocketRead, WebSocketWrite}, ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
WispError, WispError,
}; };
use crate::EpoxyError; use crate::EpoxyError;
#[derive(Debug)] #[derive(Debug)]
pub enum WebSocketError { pub enum WebSocketError {
Unknown, Unknown,
SendFailed, SendFailed,
CloseFailed, CloseFailed,
} }
impl std::fmt::Display for WebSocketError { impl std::fmt::Display for WebSocketError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use WebSocketError::*; use WebSocketError::*;
match self { match self {
Unknown => write!(f, "Unknown error"), Unknown => write!(f, "Unknown error"),
SendFailed => write!(f, "Send failed"), SendFailed => write!(f, "Send failed"),
CloseFailed => write!(f, "Close failed"), CloseFailed => write!(f, "Close failed"),
} }
} }
} }
impl std::error::Error for WebSocketError {} impl std::error::Error for WebSocketError {}
impl From<WebSocketError> for WispError { impl From<WebSocketError> for WispError {
fn from(err: WebSocketError) -> Self { fn from(err: WebSocketError) -> Self {
Self::WsImplError(Box::new(err)) Self::WsImplError(Box::new(err))
} }
} }
pub enum WebSocketMessage { pub enum WebSocketMessage {
Closed, Closed,
Error, Error,
Message(Vec<u8>), Message(Vec<u8>),
} }
pub struct WebSocketWrapper { pub struct WebSocketWrapper {
inner: SendWrapper<WebSocket>, inner: SendWrapper<WebSocket>,
open_event: Arc<Event>, open_event: Arc<Event>,
error_event: Arc<Event>, error_event: Arc<Event>,
close_event: Arc<Event>, close_event: Arc<Event>,
closed: Arc<AtomicBool>, closed: Arc<AtomicBool>,
// used to retain the closures // used to retain the closures
#[allow(dead_code)] #[allow(dead_code)]
onopen: SendWrapper<Closure<dyn Fn()>>, onopen: SendWrapper<Closure<dyn Fn()>>,
#[allow(dead_code)] #[allow(dead_code)]
onclose: SendWrapper<Closure<dyn Fn()>>, onclose: SendWrapper<Closure<dyn Fn()>>,
#[allow(dead_code)] #[allow(dead_code)]
onerror: SendWrapper<Closure<dyn Fn()>>, onerror: SendWrapper<Closure<dyn Fn()>>,
#[allow(dead_code)] #[allow(dead_code)]
onmessage: SendWrapper<Closure<dyn Fn(MessageEvent)>>, onmessage: SendWrapper<Closure<dyn Fn(MessageEvent)>>,
} }
pub struct WebSocketReader { pub struct WebSocketReader {
read_rx: Receiver<WebSocketMessage>, read_rx: Receiver<WebSocketMessage>,
closed: Arc<AtomicBool>, closed: Arc<AtomicBool>,
close_event: Arc<Event>, close_event: Arc<Event>,
} }
#[async_trait] #[async_trait]
impl WebSocketRead for WebSocketReader { impl WebSocketRead for WebSocketReader {
async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> { async fn wisp_read_frame(
use WebSocketMessage::*; &mut self,
if self.closed.load(Ordering::Acquire) { _: &LockedWebSocketWrite,
return Err(WispError::WsImplSocketClosed); ) -> Result<Frame<'static>, WispError> {
} use WebSocketMessage::*;
let res = futures_util::select! { if self.closed.load(Ordering::Acquire) {
data = self.read_rx.recv_async() => data.ok(), return Err(WispError::WsImplSocketClosed);
_ = self.close_event.listen().fuse() => Some(Closed), }
}; let res = futures_util::select! {
match res.ok_or(WispError::WsImplSocketClosed)? { data = self.read_rx.recv_async() => data.ok(),
Message(bin) => Ok(Frame::binary(BytesMut::from(bin.as_slice()))), _ = self.close_event.listen().fuse() => Some(Closed),
Error => Err(WebSocketError::Unknown.into()), };
Closed => Err(WispError::WsImplSocketClosed), match res.ok_or(WispError::WsImplSocketClosed)? {
} Message(bin) => Ok(Frame::binary(Payload::Bytes(BytesMut::from(
} bin.as_slice(),
)))),
Error => Err(WebSocketError::Unknown.into()),
Closed => Err(WispError::WsImplSocketClosed),
}
}
} }
impl WebSocketWrapper { impl WebSocketWrapper {
pub fn connect(url: &str, protocols: &[String]) -> Result<(Self, WebSocketReader), EpoxyError> { pub fn connect(url: &str, protocols: &[String]) -> Result<(Self, WebSocketReader), EpoxyError> {
let (read_tx, read_rx) = flume::unbounded(); let (read_tx, read_rx) = flume::unbounded();
let closed = Arc::new(AtomicBool::new(false)); let closed = Arc::new(AtomicBool::new(false));
let open_event = Arc::new(Event::new()); let open_event = Arc::new(Event::new());
let close_event = Arc::new(Event::new()); let close_event = Arc::new(Event::new());
let error_event = Arc::new(Event::new()); let error_event = Arc::new(Event::new());
let onopen_event = open_event.clone(); let onopen_event = open_event.clone();
let onopen = Closure::wrap( let onopen = Closure::wrap(
Box::new(move || while onopen_event.notify(usize::MAX) == 0 {}) as Box<dyn Fn()>, Box::new(move || while onopen_event.notify(usize::MAX) == 0 {}) as Box<dyn Fn()>,
); );
let onmessage_tx = read_tx.clone(); let onmessage_tx = read_tx.clone();
let onmessage = Closure::wrap(Box::new(move |evt: MessageEvent| { let onmessage = Closure::wrap(Box::new(move |evt: MessageEvent| {
if let Ok(arr) = evt.data().dyn_into::<ArrayBuffer>() { if let Ok(arr) = evt.data().dyn_into::<ArrayBuffer>() {
let _ = let _ =
onmessage_tx.send(WebSocketMessage::Message(Uint8Array::new(&arr).to_vec())); onmessage_tx.send(WebSocketMessage::Message(Uint8Array::new(&arr).to_vec()));
} }
}) as Box<dyn Fn(MessageEvent)>); }) as Box<dyn Fn(MessageEvent)>);
let onclose_closed = closed.clone(); let onclose_closed = closed.clone();
let onclose_event = close_event.clone(); let onclose_event = close_event.clone();
let onclose = Closure::wrap(Box::new(move || { let onclose = Closure::wrap(Box::new(move || {
onclose_closed.store(true, Ordering::Release); onclose_closed.store(true, Ordering::Release);
onclose_event.notify(usize::MAX); onclose_event.notify(usize::MAX);
}) as Box<dyn Fn()>); }) as Box<dyn Fn()>);
let onerror_tx = read_tx.clone(); let onerror_tx = read_tx.clone();
let onerror_closed = closed.clone(); let onerror_closed = closed.clone();
let onerror_close = close_event.clone(); let onerror_close = close_event.clone();
let onerror_event = error_event.clone(); let onerror_event = error_event.clone();
let onerror = Closure::wrap(Box::new(move || { let onerror = Closure::wrap(Box::new(move || {
let _ = onerror_tx.send(WebSocketMessage::Error); let _ = onerror_tx.send(WebSocketMessage::Error);
onerror_closed.store(true, Ordering::Release); onerror_closed.store(true, Ordering::Release);
onerror_close.notify(usize::MAX); onerror_close.notify(usize::MAX);
onerror_event.notify(usize::MAX); onerror_event.notify(usize::MAX);
}) as Box<dyn Fn()>); }) as Box<dyn Fn()>);
let ws = if protocols.is_empty() { let ws = if protocols.is_empty() {
WebSocket::new(url) WebSocket::new(url)
} else { } else {
WebSocket::new_with_str_sequence( WebSocket::new_with_str_sequence(
url, url,
&protocols &protocols
.iter() .iter()
.fold(Array::new(), |acc, x| { .fold(Array::new(), |acc, x| {
acc.push(&x.into()); acc.push(&x.into());
acc acc
}) })
.into(), .into(),
) )
} }
.map_err(|_| EpoxyError::WebSocketConnectFailed)?; .map_err(|_| EpoxyError::WebSocketConnectFailed)?;
ws.set_binary_type(BinaryType::Arraybuffer); ws.set_binary_type(BinaryType::Arraybuffer);
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
ws.set_onclose(Some(onclose.as_ref().unchecked_ref())); ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
Ok(( Ok((
Self { Self {
inner: SendWrapper::new(ws), inner: SendWrapper::new(ws),
open_event, open_event,
error_event, error_event,
close_event: close_event.clone(), close_event: close_event.clone(),
closed: closed.clone(), closed: closed.clone(),
onopen: SendWrapper::new(onopen), onopen: SendWrapper::new(onopen),
onclose: SendWrapper::new(onclose), onclose: SendWrapper::new(onclose),
onerror: SendWrapper::new(onerror), onerror: SendWrapper::new(onerror),
onmessage: SendWrapper::new(onmessage), onmessage: SendWrapper::new(onmessage),
}, },
WebSocketReader { WebSocketReader {
read_rx, read_rx,
closed, closed,
close_event, close_event,
}, },
)) ))
} }
pub async fn wait_for_open(&self) -> bool { pub async fn wait_for_open(&self) -> bool {
if self.closed.load(Ordering::Acquire) { if self.closed.load(Ordering::Acquire) {
return false; return false;
} }
futures_util::select! { futures_util::select! {
_ = self.open_event.listen().fuse() => true, _ = self.open_event.listen().fuse() => true,
_ = self.error_event.listen().fuse() => false, _ = self.error_event.listen().fuse() => false,
} }
} }
} }
#[async_trait] #[async_trait]
impl WebSocketWrite for WebSocketWrapper { impl WebSocketWrite for WebSocketWrapper {
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
use wisp_mux::ws::OpCode::*; use wisp_mux::ws::OpCode::*;
if self.closed.load(Ordering::Acquire) { if self.closed.load(Ordering::Acquire) {
return Err(WispError::WsImplSocketClosed); return Err(WispError::WsImplSocketClosed);
} }
match frame.opcode { match frame.opcode {
Binary | Text => self Binary | Text => self
.inner .inner
.send_with_u8_array(&frame.payload) .send_with_u8_array(&frame.payload)
.map_err(|_| WebSocketError::SendFailed.into()), .map_err(|_| WebSocketError::SendFailed.into()),
Close => { Close => {
let _ = self.inner.close(); let _ = self.inner.close();
Ok(()) Ok(())
} }
_ => Err(WispError::WsImplNotSupported), _ => Err(WispError::WsImplNotSupported),
} }
} }
async fn wisp_close(&mut self) -> Result<(), WispError> { async fn wisp_close(&mut self) -> Result<(), WispError> {
self.inner self.inner
.close() .close()
.map_err(|_| WebSocketError::CloseFailed.into()) .map_err(|_| WebSocketError::CloseFailed.into())
} }
} }
impl Drop for WebSocketWrapper { impl Drop for WebSocketWrapper {
fn drop(&mut self) { fn drop(&mut self) {
self.inner.set_onopen(None); self.inner.set_onopen(None);
self.inner.set_onclose(None); self.inner.set_onclose(None);
self.inner.set_onerror(None); self.inner.set_onerror(None);
self.inner.set_onmessage(None); self.inner.set_onmessage(None);
self.closed.store(true, Ordering::Release); self.closed.store(true, Ordering::Release);
self.close_event.notify(usize::MAX); self.close_event.notify(usize::MAX);
let _ = self.inner.close(); let _ = self.inner.close();
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -6,50 +6,50 @@ use futures::future::select_all;
use http_body_util::Empty; use http_body_util::Empty;
use humantime::format_duration; use humantime::format_duration;
use hyper::{ use hyper::{
header::{CONNECTION, UPGRADE}, header::{CONNECTION, UPGRADE},
Request, Uri, Request, Uri,
}; };
use simple_moving_average::{SingleSumSMA, SMA}; use simple_moving_average::{SingleSumSMA, SMA};
use std::{ use std::{
error::Error, error::Error,
future::Future, future::Future,
io::{stdout, IsTerminal, Write}, io::{stdout, IsTerminal, Write},
net::SocketAddr, net::SocketAddr,
process::exit, process::exit,
sync::Arc, sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use tokio::{ use tokio::{
net::TcpStream, net::TcpStream,
select, select,
signal::unix::{signal, SignalKind}, signal::unix::{signal, SignalKind},
time::{interval, sleep}, time::{interval, sleep},
}; };
use tokio_native_tls::{native_tls, TlsConnector}; use tokio_native_tls::{native_tls, TlsConnector};
use tokio_util::either::Either; use tokio_util::either::Either;
use wisp_mux::{ use wisp_mux::{
extensions::{ extensions::{
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
ProtocolExtensionBuilder, ProtocolExtensionBuilder,
}, },
ClientMux, StreamType, WispError, ClientMux, StreamType, WispError,
}; };
#[derive(Debug)] #[derive(Debug)]
enum WispClientError { enum WispClientError {
InvalidUriScheme, InvalidUriScheme,
UriHasNoHost, UriHasNoHost,
} }
impl std::fmt::Display for WispClientError { impl std::fmt::Display for WispClientError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
use WispClientError as E; use WispClientError as E;
match self { match self {
E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"), E::InvalidUriScheme => write!(fmt, "Invalid URI scheme"),
E::UriHasNoHost => write!(fmt, "URI has no host"), E::UriHasNoHost => write!(fmt, "URI has no host"),
} }
} }
} }
impl Error for WispClientError {} impl Error for WispClientError {}
@ -58,165 +58,166 @@ struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where where
Fut: Future + Send + 'static, Fut: Future + Send + 'static,
Fut::Output: Send + 'static, Fut::Output: Send + 'static,
{ {
fn execute(&self, fut: Fut) { fn execute(&self, fut: Fut) {
tokio::task::spawn(fut); tokio::task::spawn(fut);
} }
} }
#[derive(Parser)] #[derive(Parser)]
#[command(version = clap::crate_version!())] #[command(version = clap::crate_version!())]
struct Cli { struct Cli {
/// Wisp server URL /// Wisp server URL
#[arg(short, long)] #[arg(short, long)]
wisp: Uri, wisp: Uri,
/// TCP server address /// TCP server address
#[arg(short, long)] #[arg(short, long)]
tcp: SocketAddr, tcp: SocketAddr,
/// Number of streams /// Number of streams
#[arg(short, long, default_value_t = 10)] #[arg(short, long, default_value_t = 10)]
streams: usize, streams: usize,
/// Size of packets sent, in KB /// Size of packets sent, in KB
#[arg(short, long, default_value_t = 1)] #[arg(short, long, default_value_t = 1)]
packet_size: usize, packet_size: usize,
/// Duration to run the test for /// Duration to run the test for
#[arg(short, long)] #[arg(short, long)]
duration: Option<humantime::Duration>, duration: Option<humantime::Duration>,
/// Ask for UDP /// Ask for UDP
#[arg(short, long)] #[arg(short, long)]
udp: bool, udp: bool,
/// Enable auth: format is `username:password` /// Enable auth: format is `username:password`
/// ///
/// Usernames and passwords are sent in plaintext!! /// Usernames and passwords are sent in plaintext!!
#[arg(long)] #[arg(long)]
auth: Option<String>, auth: Option<String>,
/// Make a Wisp V1 connection /// Make a Wisp V1 connection
#[arg(long)] #[arg(long)]
wisp_v1: bool, wisp_v1: bool,
} }
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> { async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[cfg(feature = "tokio-console")] #[cfg(feature = "tokio-console")]
console_subscriber::init(); console_subscriber::init();
let opts = Cli::parse(); let opts = Cli::parse();
let tls = match opts let tls = match opts
.wisp .wisp
.scheme_str() .scheme_str()
.ok_or(WispClientError::InvalidUriScheme)? .ok_or(WispClientError::InvalidUriScheme)?
{ {
"wss" => Ok(true), "wss" => Ok(true),
"ws" => Ok(false), "ws" => Ok(false),
_ => Err(WispClientError::InvalidUriScheme), _ => Err(WispClientError::InvalidUriScheme),
}?; }?;
let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?; let addr = opts.wisp.host().ok_or(WispClientError::UriHasNoHost)?;
let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 }); let addr_port = opts.wisp.port_u16().unwrap_or(if tls { 443 } else { 80 });
let addr_path = opts.wisp.path(); let addr_path = opts.wisp.path();
let addr_dest = opts.tcp.ip().to_string(); let addr_dest = opts.tcp.ip().to_string();
let addr_dest_port = opts.tcp.port(); let addr_dest_port = opts.tcp.port();
let auth = opts.auth.map(|auth| { let auth = opts.auth.map(|auth| {
let split: Vec<_> = auth.split(':').collect(); let split: Vec<_> = auth.split(':').collect();
let username = split[0].to_string(); let username = split[0].to_string();
let password = split[1..].join(":"); let password = split[1..].join(":");
PasswordProtocolExtensionBuilder::new_client(username, password) PasswordProtocolExtensionBuilder::new_client(username, password)
}); });
println!( println!(
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
opts.wisp, opts.packet_size, opts.tcp, opts.streams, opts.wisp, opts.packet_size, opts.tcp, opts.streams,
); );
let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?; let socket = TcpStream::connect(format!("{}:{}", &addr, addr_port)).await?;
let socket = if tls { let socket = if tls {
let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?); let cx = TlsConnector::from(native_tls::TlsConnector::builder().build()?);
Either::Left(cx.connect(addr, socket).await?) Either::Left(cx.connect(addr, socket).await?)
} else { } else {
Either::Right(socket) Either::Right(socket)
}; };
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
.uri(addr_path) .uri(addr_path)
.header("Host", addr) .header("Host", addr)
.header(UPGRADE, "websocket") .header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header( .header(
"Sec-WebSocket-Key", "Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(), fastwebsockets::handshake::generate_key(),
) )
.header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Version", "13")
.body(Empty::<Bytes>::new())?; .body(Empty::<Bytes>::new())?;
let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?;
let (rx, tx) = ws.split(tokio::io::split); let (rx, tx) = ws.split(tokio::io::split);
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new(); let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new(); let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp { if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder())); extensions.push(Box::new(UdpProtocolExtensionBuilder));
extension_ids.push(UdpProtocolExtension::ID); extension_ids.push(UdpProtocolExtension::ID);
} }
if let Some(auth) = auth { if let Some(auth) = auth {
extensions.push(Box::new(auth)); extensions.push(Box::new(auth));
extension_ids.push(PasswordProtocolExtension::ID); extension_ids.push(PasswordProtocolExtension::ID);
} }
let (mux, fut) = if opts.wisp_v1 { let (mux, fut) = if opts.wisp_v1 {
ClientMux::create(rx, tx, None) ClientMux::create(rx, tx, None)
.await? .await?
.with_no_required_extensions() .with_no_required_extensions()
} else { } else {
ClientMux::create(rx, tx, Some(extensions.as_slice())) ClientMux::create(rx, tx, Some(extensions.as_slice()))
.await? .await?
.with_required_extensions(extension_ids.as_slice()).await? .with_required_extensions(extension_ids.as_slice())
}; .await?
};
println!( println!(
"connected and created ClientMux, was downgraded {}, extensions supported {:?}", "connected and created ClientMux, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids mux.downgraded, mux.supported_extension_ids
); );
let mut threads = Vec::with_capacity(opts.streams + 4); let mut threads = Vec::with_capacity(opts.streams + 4);
let mut reads = Vec::with_capacity(opts.streams); let mut reads = Vec::with_capacity(opts.streams);
threads.push(tokio::spawn(fut)); threads.push(tokio::spawn(fut));
let payload = Bytes::from(vec![0; 1024 * opts.packet_size]); let payload = Bytes::from(vec![0; 1024 * opts.packet_size]);
let cnt = Arc::new(RelaxedCounter::new(0)); let cnt = Arc::new(RelaxedCounter::new(0));
let start_time = Instant::now(); let start_time = Instant::now();
for _ in 0..opts.streams { for _ in 0..opts.streams {
let (cr, cw) = mux let (cr, cw) = mux
.client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port) .client_new_stream(StreamType::Tcp, addr_dest.clone(), addr_dest_port)
.await? .await?
.into_split(); .into_split();
let cnt = cnt.clone(); let cnt = cnt.clone();
let payload = payload.clone(); let payload = payload.clone();
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
loop { loop {
cw.write(payload.clone()).await?; cw.write(payload.clone()).await?;
cnt.inc(); cnt.inc();
} }
#[allow(unreachable_code)] #[allow(unreachable_code)]
Ok::<(), WispError>(()) Ok::<(), WispError>(())
})); }));
reads.push(cr); reads.push(cr);
} }
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
loop { loop {
select_all(reads.iter().map(|x| Box::pin(x.read()))).await; select_all(reads.iter().map(|x| Box::pin(x.read()))).await;
} }
})); }));
let cnt_avg = cnt.clone(); let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100)); let mut interval = interval(Duration::from_millis(100));
let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new(); let mut avg: SingleSumSMA<usize, usize, 100> = SingleSumSMA::new();
let mut last_time = 0; let mut last_time = 0;
@ -245,48 +246,48 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
} }
})); }));
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
let mut interrupt = let mut interrupt =
signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?; signal(SignalKind::interrupt()).map_err(|x| WispError::Other(Box::new(x)))?;
let mut terminate = let mut terminate =
signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?; signal(SignalKind::terminate()).map_err(|x| WispError::Other(Box::new(x)))?;
select! { select! {
_ = interrupt.recv() => (), _ = interrupt.recv() => (),
_ = terminate.recv() => (), _ = terminate.recv() => (),
} }
Ok(()) Ok(())
})); }));
if let Some(duration) = opts.duration { if let Some(duration) = opts.duration {
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
sleep(duration.into()).await; sleep(duration.into()).await;
Ok(()) Ok(())
})); }));
} }
let out = select_all(threads.into_iter()).await; let out = select_all(threads.into_iter()).await;
let duration_since = Instant::now().duration_since(start_time); let duration_since = Instant::now().duration_since(start_time);
if let Err(err) = out.0? { if let Err(err) = out.0? {
println!("\n\nerr: {:?}", err); println!("\n\nerr: {:?}", err);
exit(1); exit(1);
} }
out.2.into_iter().for_each(|x| x.abort()); out.2.into_iter().for_each(|x| x.abort());
mux.close().await?; mux.close().await?;
if duration_since.as_secs() != 0 { if duration_since.as_secs() != 0 {
println!( println!(
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
cnt.get(), cnt.get(),
opts.packet_size, opts.packet_size,
cnt.get() * opts.packet_size, cnt.get() * opts.packet_size,
format_duration(duration_since), format_duration(duration_since),
(cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(),
); );
} }
Ok(()) Ok(())
} }

View file

@ -8,8 +8,8 @@ use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use crate::{ use crate::{
ws::{LockedWebSocketWrite, WebSocketRead}, ws::{LockedWebSocketWrite, WebSocketRead},
Role, WispError, Role, WispError,
}; };
/// Type-erased protocol extension that implements Clone. /// Type-erased protocol extension that implements Clone.
@ -17,90 +17,90 @@ use crate::{
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>); pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>);
impl AnyProtocolExtension { impl AnyProtocolExtension {
/// Create a new type-erased protocol extension. /// Create a new type-erased protocol extension.
pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self { pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
Self(Box::new(extension)) Self(Box::new(extension))
} }
} }
impl Deref for AnyProtocolExtension { impl Deref for AnyProtocolExtension {
type Target = dyn ProtocolExtension; type Target = dyn ProtocolExtension;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.0.deref() self.0.deref()
} }
} }
impl DerefMut for AnyProtocolExtension { impl DerefMut for AnyProtocolExtension {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut() self.0.deref_mut()
} }
} }
impl Clone for AnyProtocolExtension { impl Clone for AnyProtocolExtension {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self(self.0.box_clone()) Self(self.0.box_clone())
} }
} }
impl From<AnyProtocolExtension> for Bytes { impl From<AnyProtocolExtension> for Bytes {
fn from(value: AnyProtocolExtension) -> Self { fn from(value: AnyProtocolExtension) -> Self {
let mut bytes = BytesMut::with_capacity(5); let mut bytes = BytesMut::with_capacity(5);
let payload = value.encode(); let payload = value.encode();
bytes.put_u8(value.get_id()); bytes.put_u8(value.get_id());
bytes.put_u32_le(payload.len() as u32); bytes.put_u32_le(payload.len() as u32);
bytes.extend(payload); bytes.extend(payload);
bytes.freeze() bytes.freeze()
} }
} }
/// A Wisp protocol extension. /// A Wisp protocol extension.
/// ///
/// See [the /// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
#[async_trait] #[async_trait]
pub trait ProtocolExtension: std::fmt::Debug { pub trait ProtocolExtension: std::fmt::Debug {
/// Get the protocol extension ID. /// Get the protocol extension ID.
fn get_id(&self) -> u8; fn get_id(&self) -> u8;
/// Get the protocol extension's supported packets. /// Get the protocol extension's supported packets.
/// ///
/// Used to decide whether to call the protocol extension's packet handler. /// Used to decide whether to call the protocol extension's packet handler.
fn get_supported_packets(&self) -> &'static [u8]; fn get_supported_packets(&self) -> &'static [u8];
/// Encode self into Bytes. /// Encode self into Bytes.
fn encode(&self) -> Bytes; fn encode(&self) -> Bytes;
/// Handle the handshake part of a Wisp connection. /// Handle the handshake part of a Wisp connection.
/// ///
/// This should be used to send or receive data before any streams are created. /// This should be used to send or receive data before any streams are created.
async fn handle_handshake( async fn handle_handshake(
&mut self, &mut self,
read: &mut dyn WebSocketRead, read: &mut dyn WebSocketRead,
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
) -> Result<(), WispError>; ) -> Result<(), WispError>;
/// Handle receiving a packet. /// Handle receiving a packet.
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
packet: Bytes, packet: Bytes,
read: &mut dyn WebSocketRead, read: &mut dyn WebSocketRead,
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
) -> Result<(), WispError>; ) -> Result<(), WispError>;
/// Clone the protocol extension. /// Clone the protocol extension.
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>; fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
} }
/// Trait to build a Wisp protocol extension from a payload. /// Trait to build a Wisp protocol extension from a payload.
pub trait ProtocolExtensionBuilder { pub trait ProtocolExtensionBuilder {
/// Get the protocol extension ID. /// Get the protocol extension ID.
/// ///
/// Used to decide whether this builder should be used. /// Used to decide whether this builder should be used.
fn get_id(&self) -> u8; fn get_id(&self) -> u8;
/// Build a protocol extension from the extension's metadata. /// Build a protocol extension from the extension's metadata.
fn build_from_bytes(&self, bytes: Bytes, role: Role) fn build_from_bytes(&self, bytes: Bytes, role: Role)
-> Result<AnyProtocolExtension, WispError>; -> Result<AnyProtocolExtension, WispError>;
/// Build a protocol extension to send to the other side. /// Build a protocol extension to send to the other side.
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
} }

View file

@ -29,7 +29,7 @@
//! ]) //! ])
//! ); //! );
//! ``` //! ```
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
@ -37,8 +37,8 @@ use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::{ use crate::{
ws::{LockedWebSocketWrite, WebSocketRead}, ws::{LockedWebSocketWrite, WebSocketRead},
Role, WispError, Role, WispError,
}; };
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
@ -50,227 +50,227 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
/// **This extension will panic when encoding if the username's length does not fit within a u8 /// **This extension will panic when encoding if the username's length does not fit within a u8
/// or the password's length does not fit within a u16.** /// or the password's length does not fit within a u16.**
pub struct PasswordProtocolExtension { pub struct PasswordProtocolExtension {
/// The username to log in with. /// The username to log in with.
/// ///
/// This string's length must fit within a u8. /// This string's length must fit within a u8.
pub username: String, pub username: String,
/// The password to log in with. /// The password to log in with.
/// ///
/// This string's length must fit within a u16. /// This string's length must fit within a u16.
pub password: String, pub password: String,
role: Role, role: Role,
} }
impl PasswordProtocolExtension { impl PasswordProtocolExtension {
/// Password protocol extension ID. /// Password protocol extension ID.
pub const ID: u8 = 0x02; pub const ID: u8 = 0x02;
/// Create a new password protocol extension for the server. /// Create a new password protocol extension for the server.
/// ///
/// This signifies that the server requires a password. /// This signifies that the server requires a password.
pub fn new_server() -> Self { pub fn new_server() -> Self {
Self { Self {
username: String::new(), username: String::new(),
password: String::new(), password: String::new(),
role: Role::Server, role: Role::Server,
} }
} }
/// Create a new password protocol extension for the client, with a username and password. /// Create a new password protocol extension for the client, with a username and password.
/// ///
/// The username's length must fit within a u8. The password's length must fit within a /// The username's length must fit within a u8. The password's length must fit within a
/// u16. /// u16.
pub fn new_client(username: String, password: String) -> Self { pub fn new_client(username: String, password: String) -> Self {
Self { Self {
username, username,
password, password,
role: Role::Client, role: Role::Client,
} }
} }
} }
#[async_trait] #[async_trait]
impl ProtocolExtension for PasswordProtocolExtension { impl ProtocolExtension for PasswordProtocolExtension {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
Self::ID Self::ID
} }
fn get_supported_packets(&self) -> &'static [u8] { fn get_supported_packets(&self) -> &'static [u8] {
&[] &[]
} }
fn encode(&self) -> Bytes { fn encode(&self) -> Bytes {
match self.role { match self.role {
Role::Server => Bytes::new(), Role::Server => Bytes::new(),
Role::Client => { Role::Client => {
let username = Bytes::from(self.username.clone().into_bytes()); let username = Bytes::from(self.username.clone().into_bytes());
let password = Bytes::from(self.password.clone().into_bytes()); let password = Bytes::from(self.password.clone().into_bytes());
let username_len = u8::try_from(username.len()).expect("username was too long"); let username_len = u8::try_from(username.len()).expect("username was too long");
let password_len = u16::try_from(password.len()).expect("password was too long"); let password_len = u16::try_from(password.len()).expect("password was too long");
let mut bytes = let mut bytes =
BytesMut::with_capacity(3 + username_len as usize + password_len as usize); BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
bytes.put_u8(username_len); bytes.put_u8(username_len);
bytes.put_u16_le(password_len); bytes.put_u16_le(password_len);
bytes.extend(username); bytes.extend(username);
bytes.extend(password); bytes.extend(password);
bytes.freeze() bytes.freeze()
} }
} }
} }
async fn handle_handshake( async fn handle_handshake(
&mut self, &mut self,
_: &mut dyn WebSocketRead, _: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite, _: &LockedWebSocketWrite,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
Ok(()) Ok(())
} }
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: Bytes, _: Bytes,
_: &mut dyn WebSocketRead, _: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite, _: &LockedWebSocketWrite,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
Ok(()) Ok(())
} }
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> { fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(self.clone()) Box::new(self.clone())
} }
} }
#[derive(Debug)] #[derive(Debug)]
enum PasswordProtocolExtensionError { enum PasswordProtocolExtensionError {
Utf8Error(FromUtf8Error), Utf8Error(FromUtf8Error),
InvalidUsername, InvalidUsername,
InvalidPassword, InvalidPassword,
} }
impl Display for PasswordProtocolExtensionError { impl Display for PasswordProtocolExtensionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use PasswordProtocolExtensionError as E; use PasswordProtocolExtensionError as E;
match self { match self {
E::Utf8Error(e) => write!(f, "{}", e), E::Utf8Error(e) => write!(f, "{}", e),
E::InvalidUsername => write!(f, "Invalid username"), E::InvalidUsername => write!(f, "Invalid username"),
E::InvalidPassword => write!(f, "Invalid password"), E::InvalidPassword => write!(f, "Invalid password"),
} }
} }
} }
impl Error for PasswordProtocolExtensionError {} impl Error for PasswordProtocolExtensionError {}
impl From<PasswordProtocolExtensionError> for WispError { impl From<PasswordProtocolExtensionError> for WispError {
fn from(value: PasswordProtocolExtensionError) -> Self { fn from(value: PasswordProtocolExtensionError) -> Self {
WispError::ExtensionImplError(Box::new(value)) WispError::ExtensionImplError(Box::new(value))
} }
} }
impl From<FromUtf8Error> for PasswordProtocolExtensionError { impl From<FromUtf8Error> for PasswordProtocolExtensionError {
fn from(value: FromUtf8Error) -> Self { fn from(value: FromUtf8Error) -> Self {
PasswordProtocolExtensionError::Utf8Error(value) PasswordProtocolExtensionError::Utf8Error(value)
} }
} }
impl From<PasswordProtocolExtension> for AnyProtocolExtension { impl From<PasswordProtocolExtension> for AnyProtocolExtension {
fn from(value: PasswordProtocolExtension) -> Self { fn from(value: PasswordProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value)) AnyProtocolExtension(Box::new(value))
} }
} }
/// Password protocol extension builder. /// Password protocol extension builder.
/// ///
/// **Passwords are sent in plain text!!** /// **Passwords are sent in plain text!!**
pub struct PasswordProtocolExtensionBuilder { pub struct PasswordProtocolExtensionBuilder {
/// Map of users and their passwords to allow. Only used on server. /// Map of users and their passwords to allow. Only used on server.
pub users: HashMap<String, String>, pub users: HashMap<String, String>,
/// Username to authenticate with. Only used on client. /// Username to authenticate with. Only used on client.
pub username: String, pub username: String,
/// Password to authenticate with. Only used on client. /// Password to authenticate with. Only used on client.
pub password: String, pub password: String,
} }
impl PasswordProtocolExtensionBuilder { impl PasswordProtocolExtensionBuilder {
/// Create a new password protocol extension builder for the server, with a map of users /// Create a new password protocol extension builder for the server, with a map of users
/// and passwords to allow. /// and passwords to allow.
pub fn new_server(users: HashMap<String, String>) -> Self { pub fn new_server(users: HashMap<String, String>) -> Self {
Self { Self {
users, users,
username: String::new(), username: String::new(),
password: String::new(), password: String::new(),
} }
} }
/// Create a new password protocol extension builder for the client, with a username and /// Create a new password protocol extension builder for the client, with a username and
/// password to authenticate with. /// password to authenticate with.
pub fn new_client(username: String, password: String) -> Self { pub fn new_client(username: String, password: String) -> Self {
Self { Self {
users: HashMap::new(), users: HashMap::new(),
username, username,
password, password,
} }
} }
} }
impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
PasswordProtocolExtension::ID PasswordProtocolExtension::ID
} }
fn build_from_bytes( fn build_from_bytes(
&self, &self,
mut payload: Bytes, mut payload: Bytes,
role: crate::Role, role: crate::Role,
) -> Result<AnyProtocolExtension, WispError> { ) -> Result<AnyProtocolExtension, WispError> {
match role { match role {
Role::Server => { Role::Server => {
if payload.remaining() < 3 { if payload.remaining() < 3 {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
let username_len = payload.get_u8(); let username_len = payload.get_u8();
let password_len = payload.get_u16_le(); let password_len = payload.get_u16_le();
if payload.remaining() < (password_len + username_len as u16) as usize { if payload.remaining() < (password_len + username_len as u16) as usize {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
use PasswordProtocolExtensionError as EError; use PasswordProtocolExtensionError as EError;
let username = let username =
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
.map_err(|x| WispError::from(EError::from(x)))?; .map_err(|x| WispError::from(EError::from(x)))?;
let password = let password =
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
.map_err(|x| WispError::from(EError::from(x)))?; .map_err(|x| WispError::from(EError::from(x)))?;
let Some(user) = self.users.iter().find(|x| *x.0 == username) else { let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
return Err(EError::InvalidUsername.into()); return Err(EError::InvalidUsername.into());
}; };
if *user.1 != password { if *user.1 != password {
return Err(EError::InvalidPassword.into()); return Err(EError::InvalidPassword.into());
} }
Ok(PasswordProtocolExtension { Ok(PasswordProtocolExtension {
username, username,
password, password,
role, role,
} }
.into()) .into())
} }
Role::Client => { Role::Client => {
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
} }
} }
} }
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
match role { match role {
Role::Server => PasswordProtocolExtension::new_server(), Role::Server => PasswordProtocolExtension::new_server(),
Role::Client => { Role::Client => {
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
} }
} }
.into() .into()
} }
} }

View file

@ -6,88 +6,88 @@
//! rx, //! rx,
//! tx, //! tx,
//! 128, //! 128,
//! Some(&[Box::new(UdpProtocolExtensionBuilder())]) //! Some(&[Box::new(UdpProtocolExtensionBuilder)])
//! ); //! );
//! ``` //! ```
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp)
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use crate::{ use crate::{
ws::{LockedWebSocketWrite, WebSocketRead}, ws::{LockedWebSocketWrite, WebSocketRead},
WispError, WispError,
}; };
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
#[derive(Debug)] #[derive(Debug)]
/// UDP protocol extension. /// UDP protocol extension.
pub struct UdpProtocolExtension(); pub struct UdpProtocolExtension;
impl UdpProtocolExtension { impl UdpProtocolExtension {
/// UDP protocol extension ID. /// UDP protocol extension ID.
pub const ID: u8 = 0x01; pub const ID: u8 = 0x01;
} }
#[async_trait] #[async_trait]
impl ProtocolExtension for UdpProtocolExtension { impl ProtocolExtension for UdpProtocolExtension {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
Self::ID Self::ID
} }
fn get_supported_packets(&self) -> &'static [u8] { fn get_supported_packets(&self) -> &'static [u8] {
&[] &[]
} }
fn encode(&self) -> Bytes { fn encode(&self) -> Bytes {
Bytes::new() Bytes::new()
} }
async fn handle_handshake( async fn handle_handshake(
&mut self, &mut self,
_: &mut dyn WebSocketRead, _: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite, _: &LockedWebSocketWrite,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
Ok(()) Ok(())
} }
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: Bytes, _: Bytes,
_: &mut dyn WebSocketRead, _: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite, _: &LockedWebSocketWrite,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
Ok(()) Ok(())
} }
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> { fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(Self()) Box::new(Self)
} }
} }
impl From<UdpProtocolExtension> for AnyProtocolExtension { impl From<UdpProtocolExtension> for AnyProtocolExtension {
fn from(value: UdpProtocolExtension) -> Self { fn from(value: UdpProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value)) AnyProtocolExtension(Box::new(value))
} }
} }
/// UDP protocol extension builder. /// UDP protocol extension builder.
pub struct UdpProtocolExtensionBuilder(); pub struct UdpProtocolExtensionBuilder;
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
UdpProtocolExtension::ID UdpProtocolExtension::ID
} }
fn build_from_bytes( fn build_from_bytes(
&self, &self,
_: Bytes, _: Bytes,
_: crate::Role, _: crate::Role,
) -> Result<AnyProtocolExtension, WispError> { ) -> Result<AnyProtocolExtension, WispError> {
Ok(UdpProtocolExtension().into()) Ok(UdpProtocolExtension.into())
} }
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
UdpProtocolExtension().into() UdpProtocolExtension.into()
} }
} }

View file

@ -3,93 +3,100 @@ use std::ops::Deref;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::BytesMut; use bytes::BytesMut;
use fastwebsockets::{ use fastwebsockets::{
CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
}; };
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use crate::{ws::LockedWebSocketWrite, WispError}; use crate::{ws::LockedWebSocketWrite, WispError};
fn match_payload(payload: Payload) -> BytesMut { fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> {
match payload { match payload {
Payload::Bytes(x) => x, Payload::Bytes(x) => crate::ws::Payload::Bytes(x),
Payload::Owned(x) => BytesMut::from(x.deref()), Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())),
Payload::BorrowedMut(x) => BytesMut::from(x.deref()), Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x),
Payload::Borrowed(x) => BytesMut::from(x), Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x),
}
}
fn match_payload_reverse<'a>(payload: crate::ws::Payload<'a>) -> Payload<'a> {
match payload {
crate::ws::Payload::Bytes(x) => Payload::Bytes(x),
crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x),
} }
} }
impl From<OpCode> for crate::ws::OpCode { impl From<OpCode> for crate::ws::OpCode {
fn from(opcode: OpCode) -> Self { fn from(opcode: OpCode) -> Self {
use OpCode::*; use OpCode::*;
match opcode { match opcode {
Continuation => { Continuation => {
unreachable!("continuation should never be recieved when using a fragmentcollector") unreachable!("continuation should never be recieved when using a fragmentcollector")
} }
Text => Self::Text, Text => Self::Text,
Binary => Self::Binary, Binary => Self::Binary,
Close => Self::Close, Close => Self::Close,
Ping => Self::Ping, Ping => Self::Ping,
Pong => Self::Pong, Pong => Self::Pong,
} }
} }
} }
impl From<Frame<'_>> for crate::ws::Frame { impl<'a> From<Frame<'a>> for crate::ws::Frame<'a> {
fn from(frame: Frame) -> Self { fn from(frame: Frame<'a>) -> Self {
Self { Self {
finished: frame.fin, finished: frame.fin,
opcode: frame.opcode.into(), opcode: frame.opcode.into(),
payload: match_payload(frame.payload), payload: match_payload(frame.payload),
} }
} }
} }
impl<'a> From<crate::ws::Frame> for Frame<'a> { impl<'a> From<crate::ws::Frame<'a>> for Frame<'a> {
fn from(frame: crate::ws::Frame) -> Self { fn from(frame: crate::ws::Frame<'a>) -> Self {
use crate::ws::OpCode::*; use crate::ws::OpCode::*;
let payload = Payload::Bytes(frame.payload); let payload = match_payload_reverse(frame.payload);
match frame.opcode { match frame.opcode {
Text => Self::text(payload), Text => Self::text(payload),
Binary => Self::binary(payload), Binary => Self::binary(payload),
Close => Self::close_raw(payload), Close => Self::close_raw(payload),
Ping => Self::new(true, OpCode::Ping, None, payload), Ping => Self::new(true, OpCode::Ping, None, payload),
Pong => Self::pong(payload), Pong => Self::pong(payload),
} }
} }
} }
impl From<WebSocketError> for crate::WispError { impl From<WebSocketError> for crate::WispError {
fn from(err: WebSocketError) -> Self { fn from(err: WebSocketError) -> Self {
if let WebSocketError::ConnectionClosed = err { if let WebSocketError::ConnectionClosed = err {
Self::WsImplSocketClosed Self::WsImplSocketClosed
} else { } else {
Self::WsImplError(Box::new(err)) Self::WsImplError(Box::new(err))
} }
} }
} }
#[async_trait] #[async_trait]
impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> { impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame( async fn wisp_read_frame(
&mut self, &mut self,
tx: &LockedWebSocketWrite, tx: &LockedWebSocketWrite,
) -> Result<crate::ws::Frame, WispError> { ) -> Result<crate::ws::Frame<'static>, WispError> {
Ok(self Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await? .await?
.into()) .into())
} }
} }
#[async_trait] #[async_trait]
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> { impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into()) self.write_frame(frame.into()).await.map_err(|e| e.into())
} }
async fn wisp_close(&mut self) -> Result<(), WispError> { async fn wisp_close(&mut self) -> Result<(), WispError> {
self.write_frame(Frame::close(CloseCode::Normal.into(), b"")) self.write_frame(Frame::close(CloseCode::Normal.into(), b""))
.await .await
.map_err(|e| e.into()) .map_err(|e| e.into())
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -1,41 +1,41 @@
use crate::{ use crate::{
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead}, ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
Role, WispError, WISP_VERSION, Role, WispError, WISP_VERSION,
}; };
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
/// Wisp stream type. /// Wisp stream type.
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum StreamType { pub enum StreamType {
/// TCP Wisp stream. /// TCP Wisp stream.
Tcp, Tcp,
/// UDP Wisp stream. /// UDP Wisp stream.
Udp, Udp,
/// Unknown Wisp stream type used for custom streams by protocol extensions. /// Unknown Wisp stream type used for custom streams by protocol extensions.
Unknown(u8), Unknown(u8),
} }
impl From<u8> for StreamType { impl From<u8> for StreamType {
fn from(value: u8) -> Self { fn from(value: u8) -> Self {
use StreamType as S; use StreamType as S;
match value { match value {
0x01 => S::Tcp, 0x01 => S::Tcp,
0x02 => S::Udp, 0x02 => S::Udp,
x => S::Unknown(x), x => S::Unknown(x),
} }
} }
} }
impl From<StreamType> for u8 { impl From<StreamType> for u8 {
fn from(value: StreamType) -> Self { fn from(value: StreamType) -> Self {
use StreamType as S; use StreamType as S;
match value { match value {
S::Tcp => 0x01, S::Tcp => 0x01,
S::Udp => 0x02, S::Udp => 0x02,
S::Unknown(x) => x, S::Unknown(x) => x,
} }
} }
} }
/// Close reason. /// Close reason.
@ -44,56 +44,56 @@ impl From<StreamType> for u8 {
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons) /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#clientserver-close-reasons)
#[derive(Debug, PartialEq, Copy, Clone)] #[derive(Debug, PartialEq, Copy, Clone)]
pub enum CloseReason { pub enum CloseReason {
/// Reason unspecified or unknown. /// Reason unspecified or unknown.
Unknown = 0x01, Unknown = 0x01,
/// Voluntary stream closure. /// Voluntary stream closure.
Voluntary = 0x02, Voluntary = 0x02,
/// Unexpected stream closure due to a network error. /// Unexpected stream closure due to a network error.
Unexpected = 0x03, Unexpected = 0x03,
/// Incompatible extensions. Only used during the handshake. /// Incompatible extensions. Only used during the handshake.
IncompatibleExtensions = 0x04, IncompatibleExtensions = 0x04,
/// Stream creation failed due to invalid information. /// Stream creation failed due to invalid information.
ServerStreamInvalidInfo = 0x41, ServerStreamInvalidInfo = 0x41,
/// Stream creation failed due to an unreachable destination host. /// Stream creation failed due to an unreachable destination host.
ServerStreamUnreachable = 0x42, ServerStreamUnreachable = 0x42,
/// Stream creation timed out due to the destination server not responding. /// Stream creation timed out due to the destination server not responding.
ServerStreamConnectionTimedOut = 0x43, ServerStreamConnectionTimedOut = 0x43,
/// Stream creation failed due to the destination server refusing the connection. /// Stream creation failed due to the destination server refusing the connection.
ServerStreamConnectionRefused = 0x44, ServerStreamConnectionRefused = 0x44,
/// TCP data transfer timed out. /// TCP data transfer timed out.
ServerStreamTimedOut = 0x47, ServerStreamTimedOut = 0x47,
/// Stream destination address/domain is intentionally blocked by the proxy server. /// Stream destination address/domain is intentionally blocked by the proxy server.
ServerStreamBlockedAddress = 0x48, ServerStreamBlockedAddress = 0x48,
/// Connection throttled by the server. /// Connection throttled by the server.
ServerStreamThrottled = 0x49, ServerStreamThrottled = 0x49,
/// The client has encountered an unexpected error. /// The client has encountered an unexpected error.
ClientUnexpected = 0x81, ClientUnexpected = 0x81,
} }
impl TryFrom<u8> for CloseReason { impl TryFrom<u8> for CloseReason {
type Error = WispError; type Error = WispError;
fn try_from(close_reason: u8) -> Result<Self, Self::Error> { fn try_from(close_reason: u8) -> Result<Self, Self::Error> {
use CloseReason as R; use CloseReason as R;
match close_reason { match close_reason {
0x01 => Ok(R::Unknown), 0x01 => Ok(R::Unknown),
0x02 => Ok(R::Voluntary), 0x02 => Ok(R::Voluntary),
0x03 => Ok(R::Unexpected), 0x03 => Ok(R::Unexpected),
0x04 => Ok(R::IncompatibleExtensions), 0x04 => Ok(R::IncompatibleExtensions),
0x41 => Ok(R::ServerStreamInvalidInfo), 0x41 => Ok(R::ServerStreamInvalidInfo),
0x42 => Ok(R::ServerStreamUnreachable), 0x42 => Ok(R::ServerStreamUnreachable),
0x43 => Ok(R::ServerStreamConnectionTimedOut), 0x43 => Ok(R::ServerStreamConnectionTimedOut),
0x44 => Ok(R::ServerStreamConnectionRefused), 0x44 => Ok(R::ServerStreamConnectionRefused),
0x47 => Ok(R::ServerStreamTimedOut), 0x47 => Ok(R::ServerStreamTimedOut),
0x48 => Ok(R::ServerStreamBlockedAddress), 0x48 => Ok(R::ServerStreamBlockedAddress),
0x49 => Ok(R::ServerStreamThrottled), 0x49 => Ok(R::ServerStreamThrottled),
0x81 => Ok(R::ClientUnexpected), 0x81 => Ok(R::ClientUnexpected),
_ => Err(Self::Error::InvalidCloseReason), _ => Err(Self::Error::InvalidCloseReason),
} }
} }
} }
trait Encode { trait Encode {
fn encode(self, bytes: &mut BytesMut); fn encode(self, bytes: &mut BytesMut);
} }
/// Packet used to create a new stream. /// Packet used to create a new stream.
@ -101,49 +101,49 @@ trait Encode {
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect). /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---connect).
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ConnectPacket { pub struct ConnectPacket {
/// Whether the new stream should use a TCP or UDP socket. /// Whether the new stream should use a TCP or UDP socket.
pub stream_type: StreamType, pub stream_type: StreamType,
/// Destination TCP/UDP port for the new stream. /// Destination TCP/UDP port for the new stream.
pub destination_port: u16, pub destination_port: u16,
/// Destination hostname, in a UTF-8 string. /// Destination hostname, in a UTF-8 string.
pub destination_hostname: String, pub destination_hostname: String,
} }
impl ConnectPacket { impl ConnectPacket {
/// Create a new connect packet. /// Create a new connect packet.
pub fn new( pub fn new(
stream_type: StreamType, stream_type: StreamType,
destination_port: u16, destination_port: u16,
destination_hostname: String, destination_hostname: String,
) -> Self { ) -> Self {
Self { Self {
stream_type, stream_type,
destination_port, destination_port,
destination_hostname, destination_hostname,
} }
} }
} }
impl TryFrom<BytesMut> for ConnectPacket { impl TryFrom<Payload<'_>> for ConnectPacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < (1 + 2) { if bytes.remaining() < (1 + 2) {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
Ok(Self { Ok(Self {
stream_type: bytes.get_u8().into(), stream_type: bytes.get_u8().into(),
destination_port: bytes.get_u16_le(), destination_port: bytes.get_u16_le(),
destination_hostname: std::str::from_utf8(&bytes)?.to_string(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
}) })
} }
} }
impl Encode for ConnectPacket { impl Encode for ConnectPacket {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.stream_type.into()); bytes.put_u8(self.stream_type.into());
bytes.put_u16_le(self.destination_port); bytes.put_u16_le(self.destination_port);
bytes.extend(self.destination_hostname.bytes()); bytes.extend(self.destination_hostname.bytes());
} }
} }
/// Packet used for Wisp TCP stream flow control. /// Packet used for Wisp TCP stream flow control.
@ -151,33 +151,33 @@ impl Encode for ConnectPacket {
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue). /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x03---continue).
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub struct ContinuePacket { pub struct ContinuePacket {
/// Number of packets that the server can buffer for the current stream. /// Number of packets that the server can buffer for the current stream.
pub buffer_remaining: u32, pub buffer_remaining: u32,
} }
impl ContinuePacket { impl ContinuePacket {
/// Create a new continue packet. /// Create a new continue packet.
pub fn new(buffer_remaining: u32) -> Self { pub fn new(buffer_remaining: u32) -> Self {
Self { buffer_remaining } Self { buffer_remaining }
} }
} }
impl TryFrom<BytesMut> for ContinuePacket { impl TryFrom<Payload<'_>> for ContinuePacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < 4 { if bytes.remaining() < 4 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
Ok(Self { Ok(Self {
buffer_remaining: bytes.get_u32_le(), buffer_remaining: bytes.get_u32_le(),
}) })
} }
} }
impl Encode for ContinuePacket { impl Encode for ContinuePacket {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
bytes.put_u32_le(self.buffer_remaining); bytes.put_u32_le(self.buffer_remaining);
} }
} }
/// Packet used to close a stream. /// Packet used to close a stream.
@ -186,42 +186,42 @@ impl Encode for ContinuePacket {
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close). /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x04---close).
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub struct ClosePacket { pub struct ClosePacket {
/// The close reason. /// The close reason.
pub reason: CloseReason, pub reason: CloseReason,
} }
impl ClosePacket { impl ClosePacket {
/// Create a new close packet. /// Create a new close packet.
pub fn new(reason: CloseReason) -> Self { pub fn new(reason: CloseReason) -> Self {
Self { reason } Self { reason }
} }
} }
impl TryFrom<BytesMut> for ClosePacket { impl TryFrom<Payload<'_>> for ClosePacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
Ok(Self { Ok(Self {
reason: bytes.get_u8().try_into()?, reason: bytes.get_u8().try_into()?,
}) })
} }
} }
impl Encode for ClosePacket { impl Encode for ClosePacket {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.reason as u8); bytes.put_u8(self.reason as u8);
} }
} }
/// Wisp version sent in the handshake. /// Wisp version sent in the handshake.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct WispVersion { pub struct WispVersion {
/// Major Wisp version according to semver. /// Major Wisp version according to semver.
pub major: u8, pub major: u8,
/// Minor Wisp version according to semver. /// Minor Wisp version according to semver.
pub minor: u8, pub minor: u8,
} }
/// Packet used in the initial handshake. /// Packet used in the initial handshake.
@ -229,325 +229,327 @@ pub struct WispVersion {
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info)
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct InfoPacket { pub struct InfoPacket {
/// Wisp version sent in the packet. /// Wisp version sent in the packet.
pub version: WispVersion, pub version: WispVersion,
/// List of protocol extensions sent in the packet. /// List of protocol extensions sent in the packet.
pub extensions: Vec<AnyProtocolExtension>, pub extensions: Vec<AnyProtocolExtension>,
} }
impl Encode for InfoPacket { impl Encode for InfoPacket {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.version.major); bytes.put_u8(self.version.major);
bytes.put_u8(self.version.minor); bytes.put_u8(self.version.minor);
for extension in self.extensions { for extension in self.extensions {
bytes.extend_from_slice(&Bytes::from(extension)); bytes.extend_from_slice(&Bytes::from(extension));
} }
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Type of packet recieved. /// Type of packet recieved.
pub enum PacketType { pub enum PacketType<'a> {
/// Connect packet. /// Connect packet.
Connect(ConnectPacket), Connect(ConnectPacket),
/// Data packet. /// Data packet.
Data(Bytes), Data(Payload<'a>),
/// Continue packet. /// Continue packet.
Continue(ContinuePacket), Continue(ContinuePacket),
/// Close packet. /// Close packet.
Close(ClosePacket), Close(ClosePacket),
/// Info packet. /// Info packet.
Info(InfoPacket), Info(InfoPacket),
} }
impl PacketType { impl PacketType<'_> {
/// Get the packet type used in the protocol. /// Get the packet type used in the protocol.
pub fn as_u8(&self) -> u8 { pub fn as_u8(&self) -> u8 {
use PacketType as P; use PacketType as P;
match self { match self {
P::Connect(_) => 0x01, P::Connect(_) => 0x01,
P::Data(_) => 0x02, P::Data(_) => 0x02,
P::Continue(_) => 0x03, P::Continue(_) => 0x03,
P::Close(_) => 0x04, P::Close(_) => 0x04,
P::Info(_) => 0x05, P::Info(_) => 0x05,
} }
} }
pub(crate) fn get_packet_size(&self) -> usize { pub(crate) fn get_packet_size(&self) -> usize {
use PacketType as P; use PacketType as P;
match self { match self {
P::Connect(p) => 1 + 2 + p.destination_hostname.len(), P::Connect(p) => 1 + 2 + p.destination_hostname.len(),
P::Data(p) => p.len(), P::Data(p) => p.len(),
P::Continue(_) => 4, P::Continue(_) => 4,
P::Close(_) => 1, P::Close(_) => 1,
P::Info(_) => 2, P::Info(_) => 2,
} }
} }
} }
impl Encode for PacketType { impl Encode for PacketType<'_> {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
use PacketType as P; use PacketType as P;
match self { match self {
P::Connect(x) => x.encode(bytes), P::Connect(x) => x.encode(bytes),
P::Data(x) => bytes.extend_from_slice(&x), P::Data(x) => bytes.extend_from_slice(&x),
P::Continue(x) => x.encode(bytes), P::Continue(x) => x.encode(bytes),
P::Close(x) => x.encode(bytes), P::Close(x) => x.encode(bytes),
P::Info(x) => x.encode(bytes), P::Info(x) => x.encode(bytes),
}; };
} }
} }
/// Wisp protocol packet. /// Wisp protocol packet.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Packet { pub struct Packet<'a> {
/// Stream this packet is associated with. /// Stream this packet is associated with.
pub stream_id: u32, pub stream_id: u32,
/// Packet type recieved. /// Packet type recieved.
pub packet_type: PacketType, pub packet_type: PacketType<'a>,
} }
impl Packet { impl<'a> Packet<'a> {
/// Create a new packet. /// Create a new packet.
/// ///
/// The helper functions should be used for most use cases. /// The helper functions should be used for most use cases.
pub fn new(stream_id: u32, packet: PacketType) -> Self { pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: packet, packet_type: packet,
} }
} }
/// Create a new connect packet. /// Create a new connect packet.
pub fn new_connect( pub fn new_connect(
stream_id: u32, stream_id: u32,
stream_type: StreamType, stream_type: StreamType,
destination_port: u16, destination_port: u16,
destination_hostname: String, destination_hostname: String,
) -> Self { ) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: PacketType::Connect(ConnectPacket::new( packet_type: PacketType::Connect(ConnectPacket::new(
stream_type, stream_type,
destination_port, destination_port,
destination_hostname, destination_hostname,
)), )),
} }
} }
/// Create a new data packet. /// Create a new data packet.
pub fn new_data(stream_id: u32, data: Bytes) -> Self { pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: PacketType::Data(data), packet_type: PacketType::Data(data),
} }
} }
/// Create a new continue packet. /// Create a new continue packet.
pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self { pub fn new_continue(stream_id: u32, buffer_remaining: u32) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)), packet_type: PacketType::Continue(ContinuePacket::new(buffer_remaining)),
} }
} }
/// Create a new close packet. /// Create a new close packet.
pub fn new_close(stream_id: u32, reason: CloseReason) -> Self { pub fn new_close(stream_id: u32, reason: CloseReason) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: PacketType::Close(ClosePacket::new(reason)), packet_type: PacketType::Close(ClosePacket::new(reason)),
} }
} }
pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self { pub(crate) fn new_info(extensions: Vec<AnyProtocolExtension>) -> Self {
Self { Self {
stream_id: 0, stream_id: 0,
packet_type: PacketType::Info(InfoPacket { packet_type: PacketType::Info(InfoPacket {
version: WISP_VERSION, version: WISP_VERSION,
extensions, extensions,
}), }),
} }
} }
fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> { fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
use PacketType as P; use PacketType as P;
Ok(Self { Ok(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: match packet_type { packet_type: match packet_type {
0x01 => P::Connect(ConnectPacket::try_from(bytes)?), 0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
0x02 => P::Data(bytes.freeze()), 0x02 => P::Data(bytes),
0x03 => P::Continue(ContinuePacket::try_from(bytes)?), 0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
0x04 => P::Close(ClosePacket::try_from(bytes)?), 0x04 => P::Close(ClosePacket::try_from(bytes)?),
// 0x05 is handled seperately // 0x05 is handled seperately
_ => return Err(WispError::InvalidPacketType), _ => return Err(WispError::InvalidPacketType),
}, },
}) })
} }
pub(crate) fn maybe_parse_info( pub(crate) fn maybe_parse_info(
frame: Frame, frame: Frame<'a>,
role: Role, role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
if !frame.finished { if !frame.finished {
return Err(WispError::WsFrameNotFinished); return Err(WispError::WsFrameNotFinished);
} }
if frame.opcode != OpCode::Binary { if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType); return Err(WispError::WsFrameInvalidType);
} }
let mut bytes = frame.payload; let mut bytes = frame.payload;
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
let packet_type = bytes.get_u8(); let packet_type = bytes.get_u8();
if packet_type == 0x05 { if packet_type == 0x05 {
Self::parse_info(bytes, role, extension_builders) Self::parse_info(bytes, role, extension_builders)
} else { } else {
Self::parse_packet(packet_type, bytes) Self::parse_packet(packet_type, bytes)
} }
} }
pub(crate) async fn maybe_handle_extension( pub(crate) async fn maybe_handle_extension(
frame: Frame, frame: Frame<'a>,
extensions: &mut [AnyProtocolExtension], extensions: &mut [AnyProtocolExtension],
read: &mut (dyn WebSocketRead + Send), read: &mut (dyn WebSocketRead + Send),
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
) -> Result<Option<Self>, WispError> { ) -> Result<Option<Self>, WispError> {
if !frame.finished { if !frame.finished {
return Err(WispError::WsFrameNotFinished); return Err(WispError::WsFrameNotFinished);
} }
if frame.opcode != OpCode::Binary { if frame.opcode != OpCode::Binary {
return Err(WispError::WsFrameInvalidType); return Err(WispError::WsFrameInvalidType);
} }
let mut bytes = frame.payload; let mut bytes = frame.payload;
if bytes.remaining() < 5 { if bytes.remaining() < 5 {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
let packet_type = bytes.get_u8(); let packet_type = bytes.get_u8();
match packet_type { match packet_type {
0x01 => Ok(Some(Self { 0x01 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: PacketType::Connect(bytes.try_into()?), packet_type: PacketType::Connect(bytes.try_into()?),
})), })),
0x02 => Ok(Some(Self { 0x02 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: PacketType::Data(bytes.freeze()), packet_type: PacketType::Data(bytes),
})), })),
0x03 => Ok(Some(Self { 0x03 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: PacketType::Continue(bytes.try_into()?), packet_type: PacketType::Continue(bytes.try_into()?),
})), })),
0x04 => Ok(Some(Self { 0x04 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: PacketType::Close(bytes.try_into()?), packet_type: PacketType::Close(bytes.try_into()?),
})), })),
0x05 => Ok(None), 0x05 => Ok(None),
packet_type => { packet_type => {
if let Some(extension) = extensions if let Some(extension) = extensions
.iter_mut() .iter_mut()
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
{ {
extension.handle_packet(bytes.freeze(), read, write).await?; extension
Ok(None) .handle_packet(BytesMut::from(bytes).freeze(), read, write)
} else { .await?;
Err(WispError::InvalidPacketType) Ok(None)
} } else {
} Err(WispError::InvalidPacketType)
} }
} }
}
}
fn parse_info( fn parse_info(
mut bytes: BytesMut, mut bytes: Payload<'a>,
role: Role, role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
// packet type is already read by code that calls this // packet type is already read by code that calls this
if bytes.remaining() < 4 + 2 { if bytes.remaining() < 4 + 2 {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
if bytes.get_u32_le() != 0 { if bytes.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId); return Err(WispError::InvalidStreamId);
} }
let version = WispVersion { let version = WispVersion {
major: bytes.get_u8(), major: bytes.get_u8(),
minor: bytes.get_u8(), minor: bytes.get_u8(),
}; };
if version.major != WISP_VERSION.major { if version.major != WISP_VERSION.major {
return Err(WispError::IncompatibleProtocolVersion); return Err(WispError::IncompatibleProtocolVersion);
} }
let mut extensions = Vec::new(); let mut extensions = Vec::new();
while bytes.remaining() > 4 { while bytes.remaining() > 4 {
// We have some extensions // We have some extensions
let id = bytes.get_u8(); let id = bytes.get_u8();
let length = usize::try_from(bytes.get_u32_le())?; let length = usize::try_from(bytes.get_u32_le())?;
if bytes.remaining() < length { if bytes.remaining() < length {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
extensions.push(extension) extensions.push(extension)
} }
} else { } else {
bytes.advance(length) bytes.advance(length)
} }
} }
Ok(Self { Ok(Self {
stream_id: 0, stream_id: 0,
packet_type: PacketType::Info(InfoPacket { packet_type: PacketType::Info(InfoPacket {
version, version,
extensions, extensions,
}), }),
}) })
} }
} }
impl Encode for Packet { impl Encode for Packet<'_> {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.packet_type.as_u8()); bytes.put_u8(self.packet_type.as_u8());
bytes.put_u32_le(self.stream_id); bytes.put_u32_le(self.stream_id);
self.packet_type.encode(bytes); self.packet_type.encode(bytes);
} }
} }
impl TryFrom<BytesMut> for Packet { impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
let packet_type = bytes.get_u8(); let packet_type = bytes.get_u8();
Self::parse_packet(packet_type, bytes) Self::parse_packet(packet_type, bytes)
} }
} }
impl From<Packet> for BytesMut { impl From<Packet<'_>> for BytesMut {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size()); let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
packet.encode(&mut encoded); packet.encode(&mut encoded);
encoded encoded
} }
} }
impl TryFrom<ws::Frame> for Packet { impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
type Error = WispError; type Error = WispError;
fn try_from(frame: ws::Frame) -> Result<Self, Self::Error> { fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
if !frame.finished { if !frame.finished {
return Err(Self::Error::WsFrameNotFinished); return Err(Self::Error::WsFrameNotFinished);
} }
if frame.opcode != ws::OpCode::Binary { if frame.opcode != ws::OpCode::Binary {
return Err(Self::Error::WsFrameInvalidType); return Err(Self::Error::WsFrameInvalidType);
} }
Packet::try_from(frame.payload) Packet::try_from(frame.payload)
} }
} }
impl From<Packet> for ws::Frame { impl From<Packet<'_>> for ws::Frame<'static> {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
Self::binary(BytesMut::from(packet)) Self::binary(Payload::Bytes(BytesMut::from(packet)))
} }
} }

View file

@ -1,146 +1,146 @@
//! futures sink unfold with a close function //! futures sink unfold with a close function
use core::{future::Future, pin::Pin}; use core::{future::Future, pin::Pin};
use futures::{ use futures::{
ready, ready,
task::{Context, Poll}, task::{Context, Poll},
Sink, Sink,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
pin_project! { pin_project! {
/// UnfoldState used for stream and sink unfolds /// UnfoldState used for stream and sink unfolds
#[project = UnfoldStateProj] #[project = UnfoldStateProj]
#[project_replace = UnfoldStateProjReplace] #[project_replace = UnfoldStateProjReplace]
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum UnfoldState<T, Fut> { pub(crate) enum UnfoldState<T, Fut> {
Value { Value {
value: T, value: T,
}, },
Future { Future {
#[pin] #[pin]
future: Fut, future: Fut,
}, },
Empty, Empty,
} }
} }
impl<T, Fut> UnfoldState<T, Fut> { impl<T, Fut> UnfoldState<T, Fut> {
pub(crate) fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut Fut>> { pub(crate) fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut Fut>> {
match self.project() { match self.project() {
UnfoldStateProj::Future { future } => Some(future), UnfoldStateProj::Future { future } => Some(future),
_ => None, _ => None,
} }
} }
pub(crate) fn take_value(self: Pin<&mut Self>) -> Option<T> { pub(crate) fn take_value(self: Pin<&mut Self>) -> Option<T> {
match &*self { match &*self {
Self::Value { .. } => match self.project_replace(Self::Empty) { Self::Value { .. } => match self.project_replace(Self::Empty) {
UnfoldStateProjReplace::Value { value } => Some(value), UnfoldStateProjReplace::Value { value } => Some(value),
_ => unreachable!(), _ => unreachable!(),
}, },
_ => None, _ => None,
} }
} }
} }
pin_project! { pin_project! {
/// Sink for the [`unfold`] function. /// Sink for the [`unfold`] function.
#[derive(Debug)] #[derive(Debug)]
#[must_use = "sinks do nothing unless polled"] #[must_use = "sinks do nothing unless polled"]
pub struct Unfold<T, F, R, CT, CF, CR> { pub struct Unfold<T, F, R, CT, CF, CR> {
function: F, function: F,
close_function: CF, close_function: CF,
#[pin] #[pin]
state: UnfoldState<T, R>, state: UnfoldState<T, R>,
#[pin] #[pin]
close_state: UnfoldState<CT, CR> close_state: UnfoldState<CT, CR>
} }
} }
pub(crate) fn unfold<T, F, R, CT, CF, CR, Item, E>( pub(crate) fn unfold<T, F, R, CT, CF, CR, Item, E>(
init: T, init: T,
function: F, function: F,
close_init: CT, close_init: CT,
close_function: CF, close_function: CF,
) -> Unfold<T, F, R, CT, CF, CR> ) -> Unfold<T, F, R, CT, CF, CR>
where where
F: FnMut(T, Item) -> R, F: FnMut(T, Item) -> R,
R: Future<Output = Result<T, E>>, R: Future<Output = Result<T, E>>,
CF: FnMut(CT) -> CR, CF: FnMut(CT) -> CR,
CR: Future<Output = Result<CT, E>>, CR: Future<Output = Result<CT, E>>,
{ {
Unfold { Unfold {
function, function,
close_function, close_function,
state: UnfoldState::Value { value: init }, state: UnfoldState::Value { value: init },
close_state: UnfoldState::Value { value: close_init }, close_state: UnfoldState::Value { value: close_init },
} }
} }
impl<T, F, R, CT, CF, CR, Item, E> Sink<Item> for Unfold<T, F, R, CT, CF, CR> impl<T, F, R, CT, CF, CR, Item, E> Sink<Item> for Unfold<T, F, R, CT, CF, CR>
where where
F: FnMut(T, Item) -> R, F: FnMut(T, Item) -> R,
R: Future<Output = Result<T, E>>, R: Future<Output = Result<T, E>>,
CF: FnMut(CT) -> CR, CF: FnMut(CT) -> CR,
CR: Future<Output = Result<CT, E>>, CR: Future<Output = Result<CT, E>>,
{ {
type Error = E; type Error = E;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx) self.poll_flush(cx)
} }
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
let mut this = self.project(); let mut this = self.project();
let future = match this.state.as_mut().take_value() { let future = match this.state.as_mut().take_value() {
Some(value) => (this.function)(value, item), Some(value) => (this.function)(value, item),
None => panic!("start_send called without poll_ready being called first"), None => panic!("start_send called without poll_ready being called first"),
}; };
this.state.set(UnfoldState::Future { future }); this.state.set(UnfoldState::Future { future });
Ok(()) Ok(())
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project(); let mut this = self.project();
Poll::Ready(if let Some(future) = this.state.as_mut().project_future() { Poll::Ready(if let Some(future) = this.state.as_mut().project_future() {
match ready!(future.poll(cx)) { match ready!(future.poll(cx)) {
Ok(state) => { Ok(state) => {
this.state.set(UnfoldState::Value { value: state }); this.state.set(UnfoldState::Value { value: state });
Ok(()) Ok(())
} }
Err(err) => { Err(err) => {
this.state.set(UnfoldState::Empty); this.state.set(UnfoldState::Empty);
Err(err) Err(err)
} }
} }
} else { } else {
Ok(()) Ok(())
}) })
} }
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?; ready!(self.as_mut().poll_flush(cx))?;
let mut this = self.project(); let mut this = self.project();
Poll::Ready( Poll::Ready(
if let Some(future) = this.close_state.as_mut().project_future() { if let Some(future) = this.close_state.as_mut().project_future() {
match ready!(future.poll(cx)) { match ready!(future.poll(cx)) {
Ok(state) => { Ok(state) => {
this.close_state.set(UnfoldState::Value { value: state }); this.close_state.set(UnfoldState::Value { value: state });
Ok(()) Ok(())
} }
Err(err) => { Err(err) => {
this.close_state.set(UnfoldState::Empty); this.close_state.set(UnfoldState::Empty);
Err(err) Err(err)
} }
} }
} else { } else {
let future = match this.close_state.as_mut().take_value() { let future = match this.close_state.as_mut().take_value() {
Some(value) => (this.close_function)(value), Some(value) => (this.close_function)(value),
None => panic!("start_send called without poll_ready being called first"), None => panic!("start_send called without poll_ready being called first"),
}; };
this.close_state.set(UnfoldState::Future { future }); this.close_state.set(UnfoldState::Future { future });
return Poll::Pending; return Poll::Pending;
}, },
) )
} }
} }

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
sink_unfold, sink_unfold,
ws::{Frame, LockedWebSocketWrite}, ws::{Frame, LockedWebSocketWrite, Payload},
CloseReason, Packet, Role, StreamType, WispError, CloseReason, Packet, Role, StreamType, WispError,
}; };
@ -9,9 +9,10 @@ use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{ use futures::{
channel::oneshot, channel::oneshot,
ready, select, stream::{self, IntoAsyncRead}, ready, select,
stream::{self, IntoAsyncRead},
task::{noop_waker_ref, Context, Poll}, task::{noop_waker_ref, Context, Poll},
AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt, AsyncBufRead, AsyncRead, AsyncWrite, Future, FutureExt, Sink, Stream, TryStreamExt,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -23,7 +24,7 @@ use std::{
}; };
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
Close(Packet, oneshot::Sender<Result<(), WispError>>), Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
CreateStream( CreateStream(
StreamType, StreamType,
String, String,
@ -100,8 +101,10 @@ pub struct MuxStreamWrite {
} }
impl MuxStreamWrite { impl MuxStreamWrite {
/// Write data to the stream. pub(crate) async fn write_payload_internal(
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { &self,
frame: Frame<'static>,
) -> Result<(), WispError> {
if self.role == Role::Client if self.role == Role::Client
&& self.stream_type == StreamType::Tcp && self.stream_type == StreamType::Tcp
&& self.flow_control.load(Ordering::Acquire) == 0 && self.flow_control.load(Ordering::Acquire) == 0
@ -112,9 +115,7 @@ impl MuxStreamWrite {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
self.tx self.tx.write_frame(frame).await?;
.write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
.await?;
if self.role == Role::Client && self.stream_type == StreamType::Tcp { if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store( self.flow_control.store(
@ -125,6 +126,20 @@ impl MuxStreamWrite {
Ok(()) Ok(())
} }
/// Write a payload to the stream.
pub fn write_payload<'a>(
&'a self,
data: Payload<'_>,
) -> impl Future<Output = Result<(), WispError>> + 'a {
let frame: Frame<'static> = Frame::from(Packet::new_data(self.stream_id, data));
self.write_payload_internal(frame)
}
/// Write data to the stream.
pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
self.write_payload(Payload::Borrowed(data.as_ref())).await
}
/// Get a handle to close the connection. /// Get a handle to close the connection.
/// ///
/// Useful to close the connection without having access to the stream. /// Useful to close the connection without having access to the stream.
@ -173,16 +188,16 @@ impl MuxStreamWrite {
Ok(()) Ok(())
} }
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>> { pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>> {
let handle = self.get_close_handle(); let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold( Box::pin(sink_unfold::unfold(
self, self,
|tx, data| async move { |tx, data| async move {
tx.write(data).await?; tx.write_payload_internal(data).await?;
Ok(tx) Ok(tx)
}, },
handle, handle,
move |handle| async { |handle| async move {
handle.close(CloseReason::Unknown).await?; handle.close(CloseReason::Unknown).await?;
Ok(handle) Ok(handle)
}, },
@ -258,8 +273,13 @@ impl MuxStream {
self.rx.read().await self.rx.read().await
} }
/// Write a payload to the stream.
pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
self.tx.write_payload(data).await
}
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
self.tx.write(data).await self.tx.write(data).await
} }
@ -301,6 +321,7 @@ impl MuxStream {
}, },
tx: MuxStreamIoSink { tx: MuxStreamIoSink {
tx: self.tx.into_sink(), tx: self.tx.into_sink(),
stream_id: self.stream_id,
}, },
} }
} }
@ -355,7 +376,9 @@ impl MuxProtocolExtensionStream {
encoded.put_u8(packet_type); encoded.put_u8(packet_type);
encoded.put_u32_le(self.stream_id); encoded.put_u32_le(self.stream_id);
encoded.extend(data); encoded.extend(data);
self.tx.write_frame(Frame::binary(encoded)).await self.tx
.write_frame(Frame::binary(Payload::Bytes(encoded)))
.await
} }
} }
@ -391,12 +414,12 @@ impl Stream for MuxStreamIo {
} }
} }
impl Sink<Bytes> for MuxStreamIo { impl Sink<&[u8]> for MuxStreamIo {
type Error = std::io::Error; type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_ready(cx) self.project().tx.poll_ready(cx)
} }
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
self.project().tx.start_send(item) self.project().tx.start_send(item)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -433,7 +456,8 @@ pin_project! {
/// Write side of a multiplexor stream that implements futures `Sink`. /// Write side of a multiplexor stream that implements futures `Sink`.
pub struct MuxStreamIoSink { pub struct MuxStreamIoSink {
#[pin] #[pin]
tx: Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>>, tx: Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>>,
stream_id: u32,
} }
} }
@ -444,7 +468,7 @@ impl MuxStreamIoSink {
} }
} }
impl Sink<Bytes> for MuxStreamIoSink { impl Sink<&[u8]> for MuxStreamIoSink {
type Error = std::io::Error; type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project()
@ -452,10 +476,14 @@ impl Sink<Bytes> for MuxStreamIoSink {
.poll_ready(cx) .poll_ready(cx)
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
let stream_id = self.stream_id;
self.project() self.project()
.tx .tx
.start_send(item) .start_send(Frame::from(Packet::new_data(
stream_id,
Payload::Borrowed(item),
)))
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -564,10 +592,10 @@ impl AsyncRead for MuxStreamAsyncRead {
} }
impl AsyncBufRead for MuxStreamAsyncRead { impl AsyncBufRead for MuxStreamAsyncRead {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
self.project().rx.poll_fill_buf(cx) self.project().rx.poll_fill_buf(cx)
} }
fn consume(self: Pin<&mut Self>, amt: usize) { fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt) self.project().rx.consume(amt)
} }
} }
@ -582,7 +610,10 @@ pin_project! {
impl MuxStreamAsyncWrite { impl MuxStreamAsyncWrite {
pub(crate) fn new(sink: MuxStreamIoSink) -> Self { pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
Self { tx: sink, error: None } Self {
tx: sink,
error: None,
}
} }
} }
@ -599,7 +630,7 @@ impl AsyncWrite for MuxStreamAsyncWrite {
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
ready!(this.tx.as_mut().poll_ready(cx))?; ready!(this.tx.as_mut().poll_ready(cx))?;
match this.tx.as_mut().start_send(Bytes::copy_from_slice(buf)) { match this.tx.as_mut().start_send(buf) {
Ok(()) => { Ok(()) => {
let mut cx = Context::from_waker(noop_waker_ref()); let mut cx = Context::from_waker(noop_waker_ref());
let cx = &mut cx; let cx = &mut cx;

View file

@ -4,83 +4,168 @@
//! for other WebSocket implementations. //! for other WebSocket implementations.
//! //!
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
use std::sync::Arc; use std::{ops::Deref, sync::Arc};
use crate::WispError; use crate::WispError;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::BytesMut; use bytes::{Buf, BytesMut};
use futures::lock::Mutex; use futures::lock::Mutex;
/// Payload of the websocket frame.
#[derive(Debug)]
pub enum Payload<'a> {
/// Borrowed payload. Currently used when writing data.
Borrowed(&'a [u8]),
/// BytesMut payload. Currently used when reading data.
Bytes(BytesMut),
}
impl From<BytesMut> for Payload<'static> {
fn from(value: BytesMut) -> Self {
Self::Bytes(value)
}
}
impl<'a> From<&'a [u8]> for Payload<'a> {
fn from(value: &'a [u8]) -> Self {
Self::Borrowed(value)
}
}
impl Payload<'_> {
/// Turn a Payload<'a> into a Payload<'static> by copying the data.
pub fn into_owned(self) -> Self {
match self {
Self::Bytes(x) => Self::Bytes(x),
Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)),
}
}
}
impl From<Payload<'_>> for BytesMut {
fn from(value: Payload<'_>) -> Self {
match value {
Payload::Bytes(x) => x,
Payload::Borrowed(x) => x.into(),
}
}
}
impl Deref for Payload<'_> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
Self::Bytes(x) => x.deref(),
Self::Borrowed(x) => x,
}
}
}
impl Clone for Payload<'_> {
fn clone(&self) -> Self {
match self {
Self::Bytes(x) => Self::Bytes(x.clone()),
Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)),
}
}
}
impl Buf for Payload<'_> {
fn remaining(&self) -> usize {
match self {
Self::Bytes(x) => x.remaining(),
Self::Borrowed(x) => x.remaining(),
}
}
fn chunk(&self) -> &[u8] {
match self {
Self::Bytes(x) => x.chunk(),
Self::Borrowed(x) => x.chunk(),
}
}
fn advance(&mut self, cnt: usize) {
match self {
Self::Bytes(x) => x.advance(cnt),
Self::Borrowed(x) => x.advance(cnt),
}
}
}
/// Opcode of the WebSocket frame. /// Opcode of the WebSocket frame.
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
pub enum OpCode { pub enum OpCode {
/// Text frame. /// Text frame.
Text, Text,
/// Binary frame. /// Binary frame.
Binary, Binary,
/// Close frame. /// Close frame.
Close, Close,
/// Ping frame. /// Ping frame.
Ping, Ping,
/// Pong frame. /// Pong frame.
Pong, Pong,
} }
/// WebSocket frame. /// WebSocket frame.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Frame { pub struct Frame<'a> {
/// Whether the frame is finished or not. /// Whether the frame is finished or not.
pub finished: bool, pub finished: bool,
/// Opcode of the WebSocket frame. /// Opcode of the WebSocket frame.
pub opcode: OpCode, pub opcode: OpCode,
/// Payload of the WebSocket frame. /// Payload of the WebSocket frame.
pub payload: BytesMut, pub payload: Payload<'a>,
} }
impl Frame { impl<'a> Frame<'a> {
/// Create a new text frame. /// Create a new text frame.
pub fn text(payload: BytesMut) -> Self { pub fn text(payload: Payload<'a>) -> Self {
Self { Self {
finished: true, finished: true,
opcode: OpCode::Text, opcode: OpCode::Text,
payload, payload,
} }
} }
/// Create a new binary frame. /// Create a new binary frame.
pub fn binary(payload: BytesMut) -> Self { pub fn binary(payload: Payload<'a>) -> Self {
Self { Self {
finished: true, finished: true,
opcode: OpCode::Binary, opcode: OpCode::Binary,
payload, payload,
} }
} }
/// Create a new close frame. /// Create a new close frame.
pub fn close(payload: BytesMut) -> Self { pub fn close(payload: Payload<'a>) -> Self {
Self { Self {
finished: true, finished: true,
opcode: OpCode::Close, opcode: OpCode::Close,
payload, payload,
} }
} }
} }
/// Generic WebSocket read trait. /// Generic WebSocket read trait.
#[async_trait] #[async_trait]
pub trait WebSocketRead { pub trait WebSocketRead {
/// Read a frame from the socket. /// Read a frame from the socket.
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError>; async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError>;
} }
/// Generic WebSocket write trait. /// Generic WebSocket write trait.
#[async_trait] #[async_trait]
pub trait WebSocketWrite { pub trait WebSocketWrite {
/// Write a frame to the socket. /// Write a frame to the socket.
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
/// Close the socket. /// Close the socket.
async fn wisp_close(&mut self) -> Result<(), WispError>; async fn wisp_close(&mut self) -> Result<(), WispError>;
} }
/// Locked WebSocket. /// Locked WebSocket.
@ -88,35 +173,38 @@ pub trait WebSocketWrite {
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>); pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
impl LockedWebSocketWrite { impl LockedWebSocketWrite {
/// Create a new locked websocket. /// Create a new locked websocket.
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self { pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
Self(Mutex::new(ws).into()) Self(Mutex::new(ws).into())
} }
/// Write a frame to the websocket. /// Write a frame to the websocket.
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> { pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
self.0.lock().await.wisp_write_frame(frame).await self.0.lock().await.wisp_write_frame(frame).await
} }
/// Close the websocket. /// Close the websocket.
pub async fn close(&self) -> Result<(), WispError> { pub async fn close(&self) -> Result<(), WispError> {
self.0.lock().await.wisp_close().await self.0.lock().await.wisp_close().await
} }
} }
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame>, pub R) pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
where where
R: WebSocketRead + Send; R: WebSocketRead + Send;
#[async_trait] #[async_trait]
impl<R> WebSocketRead for AppendingWebSocketRead<R> impl<R> WebSocketRead for AppendingWebSocketRead<R>
where where
R: WebSocketRead + Send, R: WebSocketRead + Send,
{ {
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> { async fn wisp_read_frame(
if let Some(x) = self.0.take() { &mut self,
return Ok(x); tx: &LockedWebSocketWrite,
} ) -> Result<Frame<'static>, WispError> {
return self.1.wisp_read_frame(tx).await; if let Some(x) = self.0.take() {
} return Ok(x);
}
return self.1.wisp_read_frame(tx).await;
}
} }