mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
add a new Payload struct to allow for one-copy writes and cargo fmt
This commit is contained in:
parent
314c1bfa75
commit
d6353bd5a9
18 changed files with 3533 additions and 3395 deletions
|
@ -1,10 +1,9 @@
|
||||||
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut};
|
use bytes::{buf::UninitSlice, BufMut, BytesMut};
|
||||||
use futures_util::{
|
use futures_util::{io::WriteHalf, lock::Mutex, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
|
||||||
io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt,
|
|
||||||
};
|
|
||||||
use js_sys::{Function, Uint8Array};
|
use js_sys::{Function, Uint8Array};
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_bindgen_futures::spawn_local;
|
use wasm_bindgen_futures::spawn_local;
|
||||||
|
use wisp_mux::MuxStreamIoSink;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream},
|
stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream},
|
||||||
|
@ -105,15 +104,14 @@ impl EpoxyIoStream {
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub struct EpoxyUdpStream {
|
pub struct EpoxyUdpStream {
|
||||||
tx: Mutex<SplitSink<ProviderUnencryptedStream, Bytes>>,
|
tx: Mutex<MuxStreamIoSink>,
|
||||||
onerror: Function,
|
onerror: Function,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
impl EpoxyUdpStream {
|
impl EpoxyUdpStream {
|
||||||
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
|
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
|
||||||
let (tx, mut rx) = stream.split();
|
let (mut rx, tx) = stream.into_split();
|
||||||
let tx = Mutex::new(tx);
|
|
||||||
|
|
||||||
let EpoxyHandlers {
|
let EpoxyHandlers {
|
||||||
onopen,
|
onopen,
|
||||||
|
@ -142,7 +140,7 @@ impl EpoxyUdpStream {
|
||||||
let _ = onopen.call0(&JsValue::null());
|
let _ = onopen.call0(&JsValue::null());
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
tx,
|
tx: tx.into(),
|
||||||
onerror: onerror_cloned,
|
onerror: onerror_cloned,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -154,7 +152,7 @@ impl EpoxyUdpStream {
|
||||||
.map_err(|_| EpoxyError::InvalidPayload)?
|
.map_err(|_| EpoxyError::InvalidPayload)?
|
||||||
.0
|
.0
|
||||||
.to_vec();
|
.to_vec();
|
||||||
Ok(self.tx.lock().await.send(payload.into()).await?)
|
Ok(self.tx.lock().await.send(payload.as_ref()).await?)
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,9 @@ use std::{str::FromStr, sync::Arc};
|
||||||
use async_compression::futures::bufread as async_comp;
|
use async_compression::futures::bufread as async_comp;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use cfg_if::cfg_if;
|
use cfg_if::cfg_if;
|
||||||
use futures_util::TryStreamExt;
|
|
||||||
#[cfg(feature = "full")]
|
#[cfg(feature = "full")]
|
||||||
use futures_util::future::Either;
|
use futures_util::future::Either;
|
||||||
|
use futures_util::TryStreamExt;
|
||||||
use http::{
|
use http::{
|
||||||
header::{InvalidHeaderName, InvalidHeaderValue},
|
header::{InvalidHeaderName, InvalidHeaderValue},
|
||||||
method::InvalidMethod,
|
method::InvalidMethod,
|
||||||
|
@ -22,7 +22,8 @@ use js_sys::{Array, Function, Object, Reflect};
|
||||||
use stream_provider::{StreamProvider, StreamProviderService};
|
use stream_provider::{StreamProvider, StreamProviderService};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use utils::{
|
use utils::{
|
||||||
asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, object_get, object_set, IncomingBody, UriExt, WasmExecutor
|
asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect,
|
||||||
|
object_get, object_set, IncomingBody, UriExt, WasmExecutor,
|
||||||
};
|
};
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_streams::ReadableStream;
|
use wasm_streams::ReadableStream;
|
||||||
|
|
|
@ -5,7 +5,9 @@ use futures_rustls::{
|
||||||
TlsConnector, TlsStream,
|
TlsConnector, TlsStream,
|
||||||
};
|
};
|
||||||
use futures_util::{
|
use futures_util::{
|
||||||
future::Either, lock::{Mutex, MutexGuard}, AsyncRead, AsyncWrite, Future
|
future::Either,
|
||||||
|
lock::{Mutex, MutexGuard},
|
||||||
|
AsyncRead, AsyncWrite, Future,
|
||||||
};
|
};
|
||||||
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
|
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
|
||||||
use js_sys::{Array, Reflect, Uint8Array};
|
use js_sys::{Array, Reflect, Uint8Array};
|
||||||
|
@ -81,7 +83,7 @@ impl StreamProvider {
|
||||||
mut locked: MutexGuard<'_, Option<ClientMux>>,
|
mut locked: MutexGuard<'_, Option<ClientMux>>,
|
||||||
) -> Result<(), EpoxyError> {
|
) -> Result<(), EpoxyError> {
|
||||||
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
|
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
|
||||||
vec![Box::new(UdpProtocolExtensionBuilder())];
|
vec![Box::new(UdpProtocolExtensionBuilder)];
|
||||||
let extensions = if self.wisp_v2 {
|
let extensions = if self.wisp_v2 {
|
||||||
Some(extensions_vec.as_slice())
|
Some(extensions_vec.as_slice())
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -9,7 +9,8 @@ use futures_util::lock::Mutex;
|
||||||
use getrandom::getrandom;
|
use getrandom::getrandom;
|
||||||
use http::{
|
use http::{
|
||||||
header::{
|
header::{
|
||||||
CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT,
|
CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION,
|
||||||
|
UPGRADE, USER_AGENT,
|
||||||
},
|
},
|
||||||
Method, Request, Response, StatusCode, Uri,
|
Method, Request, Response, StatusCode, Uri,
|
||||||
};
|
};
|
||||||
|
@ -22,7 +23,9 @@ use tokio::io::WriteHalf;
|
||||||
use wasm_bindgen::{prelude::*, JsError, JsValue};
|
use wasm_bindgen::{prelude::*, JsError, JsValue};
|
||||||
use wasm_bindgen_futures::spawn_local;
|
use wasm_bindgen_futures::spawn_local;
|
||||||
|
|
||||||
use crate::{tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody};
|
use crate::{
|
||||||
|
tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody,
|
||||||
|
};
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub struct EpoxyWebSocket {
|
pub struct EpoxyWebSocket {
|
||||||
|
@ -69,7 +72,9 @@ impl EpoxyWebSocket {
|
||||||
request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
|
request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
|
||||||
}
|
}
|
||||||
|
|
||||||
if web_sys::Headers::instanceof(&headers) && let Ok(entries) = Object::from_entries(&headers) {
|
if web_sys::Headers::instanceof(&headers)
|
||||||
|
&& let Ok(entries) = Object::from_entries(&headers)
|
||||||
|
{
|
||||||
for header in entries_of_object(&entries) {
|
for header in entries_of_object(&entries) {
|
||||||
request = request.header(&header[0], &header[1]);
|
request = request.header(&header[0], &header[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ use send_wrapper::SendWrapper;
|
||||||
use wasm_bindgen::{closure::Closure, JsCast};
|
use wasm_bindgen::{closure::Closure, JsCast};
|
||||||
use web_sys::{BinaryType, MessageEvent, WebSocket};
|
use web_sys::{BinaryType, MessageEvent, WebSocket};
|
||||||
use wisp_mux::{
|
use wisp_mux::{
|
||||||
ws::{Frame, LockedWebSocketWrite, WebSocketRead, WebSocketWrite},
|
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
WispError,
|
WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -77,7 +77,10 @@ pub struct WebSocketReader {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WebSocketRead for WebSocketReader {
|
impl WebSocketRead for WebSocketReader {
|
||||||
async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
async fn wisp_read_frame(
|
||||||
|
&mut self,
|
||||||
|
_: &LockedWebSocketWrite,
|
||||||
|
) -> Result<Frame<'static>, WispError> {
|
||||||
use WebSocketMessage::*;
|
use WebSocketMessage::*;
|
||||||
if self.closed.load(Ordering::Acquire) {
|
if self.closed.load(Ordering::Acquire) {
|
||||||
return Err(WispError::WsImplSocketClosed);
|
return Err(WispError::WsImplSocketClosed);
|
||||||
|
@ -87,7 +90,9 @@ impl WebSocketRead for WebSocketReader {
|
||||||
_ = self.close_event.listen().fuse() => Some(Closed),
|
_ = self.close_event.listen().fuse() => Some(Closed),
|
||||||
};
|
};
|
||||||
match res.ok_or(WispError::WsImplSocketClosed)? {
|
match res.ok_or(WispError::WsImplSocketClosed)? {
|
||||||
Message(bin) => Ok(Frame::binary(BytesMut::from(bin.as_slice()))),
|
Message(bin) => Ok(Frame::binary(Payload::Bytes(BytesMut::from(
|
||||||
|
bin.as_slice(),
|
||||||
|
)))),
|
||||||
Error => Err(WebSocketError::Unknown.into()),
|
Error => Err(WebSocketError::Unknown.into()),
|
||||||
Closed => Err(WispError::WsImplSocketClosed),
|
Closed => Err(WispError::WsImplSocketClosed),
|
||||||
}
|
}
|
||||||
|
@ -188,7 +193,7 @@ impl WebSocketWrapper {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WebSocketWrite for WebSocketWrapper {
|
impl WebSocketWrite for WebSocketWrapper {
|
||||||
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> {
|
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||||
use wisp_mux::ws::OpCode::*;
|
use wisp_mux::ws::OpCode::*;
|
||||||
if self.closed.load(Ordering::Acquire) {
|
if self.closed.load(Ordering::Acquire) {
|
||||||
return Err(WispError::WsImplSocketClosed);
|
return Err(WispError::WsImplSocketClosed);
|
||||||
|
|
|
@ -202,7 +202,7 @@ async fn main() -> Result<(), Error> {
|
||||||
block_non_http: opt.block_non_http,
|
block_non_http: opt.block_non_http,
|
||||||
block_udp: opt.block_udp,
|
block_udp: opt.block_udp,
|
||||||
auth: Arc::new(vec![
|
auth: Arc::new(vec![
|
||||||
Box::new(UdpProtocolExtensionBuilder()),
|
Box::new(UdpProtocolExtensionBuilder),
|
||||||
Box::new(pw_ext),
|
Box::new(pw_ext),
|
||||||
]),
|
]),
|
||||||
enforce_auth,
|
enforce_auth,
|
||||||
|
@ -361,8 +361,7 @@ async fn accept_ws(
|
||||||
let (rx, tx) = ws.split(|x| {
|
let (rx, tx) = ws.split(|x| {
|
||||||
let Parts {
|
let Parts {
|
||||||
io, read_buf: buf, ..
|
io, read_buf: buf, ..
|
||||||
} = x
|
} = x.into_inner()
|
||||||
.into_inner()
|
|
||||||
.downcast::<TokioIo<ListenerStream>>()
|
.downcast::<TokioIo<ListenerStream>>()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(buf.len(), 0);
|
assert_eq!(buf.len(), 0);
|
||||||
|
@ -398,12 +397,7 @@ async fn accept_ws(
|
||||||
.with_required_extensions(&[PasswordProtocolExtension::ID])
|
.with_required_extensions(&[PasswordProtocolExtension::ID])
|
||||||
.await?
|
.await?
|
||||||
} else {
|
} else {
|
||||||
ServerMux::create(
|
ServerMux::create(rx, tx, 512, Some(&[Box::new(UdpProtocolExtensionBuilder)]))
|
||||||
rx,
|
|
||||||
tx,
|
|
||||||
512,
|
|
||||||
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
|
|
||||||
)
|
|
||||||
.await?
|
.await?
|
||||||
.with_no_required_extensions()
|
.with_no_required_extensions()
|
||||||
};
|
};
|
||||||
|
|
|
@ -158,7 +158,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
|
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
|
||||||
let mut extension_ids: Vec<u8> = Vec::new();
|
let mut extension_ids: Vec<u8> = Vec::new();
|
||||||
if opts.udp {
|
if opts.udp {
|
||||||
extensions.push(Box::new(UdpProtocolExtensionBuilder()));
|
extensions.push(Box::new(UdpProtocolExtensionBuilder));
|
||||||
extension_ids.push(UdpProtocolExtension::ID);
|
extension_ids.push(UdpProtocolExtension::ID);
|
||||||
}
|
}
|
||||||
if let Some(auth) = auth {
|
if let Some(auth) = auth {
|
||||||
|
@ -173,7 +173,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
} else {
|
} else {
|
||||||
ClientMux::create(rx, tx, Some(extensions.as_slice()))
|
ClientMux::create(rx, tx, Some(extensions.as_slice()))
|
||||||
.await?
|
.await?
|
||||||
.with_required_extensions(extension_ids.as_slice()).await?
|
.with_required_extensions(extension_ids.as_slice())
|
||||||
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
|
|
|
@ -56,7 +56,7 @@ impl From<AnyProtocolExtension> for Bytes {
|
||||||
/// A Wisp protocol extension.
|
/// A Wisp protocol extension.
|
||||||
///
|
///
|
||||||
/// See [the
|
/// See [the
|
||||||
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions).
|
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ProtocolExtension: std::fmt::Debug {
|
pub trait ProtocolExtension: std::fmt::Debug {
|
||||||
/// Get the protocol extension ID.
|
/// Get the protocol extension ID.
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
//! ])
|
//! ])
|
||||||
//! );
|
//! );
|
||||||
//! ```
|
//! ```
|
||||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication)
|
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
|
||||||
|
|
||||||
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
|
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
//! rx,
|
//! rx,
|
||||||
//! tx,
|
//! tx,
|
||||||
//! 128,
|
//! 128,
|
||||||
//! Some(&[Box::new(UdpProtocolExtensionBuilder())])
|
//! Some(&[Box::new(UdpProtocolExtensionBuilder)])
|
||||||
//! );
|
//! );
|
||||||
//! ```
|
//! ```
|
||||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp)
|
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp)
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// UDP protocol extension.
|
/// UDP protocol extension.
|
||||||
pub struct UdpProtocolExtension();
|
pub struct UdpProtocolExtension;
|
||||||
|
|
||||||
impl UdpProtocolExtension {
|
impl UdpProtocolExtension {
|
||||||
/// UDP protocol extension ID.
|
/// UDP protocol extension ID.
|
||||||
|
@ -61,7 +61,7 @@ impl ProtocolExtension for UdpProtocolExtension {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||||
Box::new(Self())
|
Box::new(Self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ impl From<UdpProtocolExtension> for AnyProtocolExtension {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// UDP protocol extension builder.
|
/// UDP protocol extension builder.
|
||||||
pub struct UdpProtocolExtensionBuilder();
|
pub struct UdpProtocolExtensionBuilder;
|
||||||
|
|
||||||
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
||||||
fn get_id(&self) -> u8 {
|
fn get_id(&self) -> u8 {
|
||||||
|
@ -84,10 +84,10 @@ impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
||||||
_: Bytes,
|
_: Bytes,
|
||||||
_: crate::Role,
|
_: crate::Role,
|
||||||
) -> Result<AnyProtocolExtension, WispError> {
|
) -> Result<AnyProtocolExtension, WispError> {
|
||||||
Ok(UdpProtocolExtension().into())
|
Ok(UdpProtocolExtension.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
||||||
UdpProtocolExtension().into()
|
UdpProtocolExtension.into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,12 +9,19 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
use crate::{ws::LockedWebSocketWrite, WispError};
|
use crate::{ws::LockedWebSocketWrite, WispError};
|
||||||
|
|
||||||
fn match_payload(payload: Payload) -> BytesMut {
|
fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> {
|
||||||
match payload {
|
match payload {
|
||||||
Payload::Bytes(x) => x,
|
Payload::Bytes(x) => crate::ws::Payload::Bytes(x),
|
||||||
Payload::Owned(x) => BytesMut::from(x.deref()),
|
Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())),
|
||||||
Payload::BorrowedMut(x) => BytesMut::from(x.deref()),
|
Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x),
|
||||||
Payload::Borrowed(x) => BytesMut::from(x),
|
Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn match_payload_reverse<'a>(payload: crate::ws::Payload<'a>) -> Payload<'a> {
|
||||||
|
match payload {
|
||||||
|
crate::ws::Payload::Bytes(x) => Payload::Bytes(x),
|
||||||
|
crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,8 +41,8 @@ impl From<OpCode> for crate::ws::OpCode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Frame<'_>> for crate::ws::Frame {
|
impl<'a> From<Frame<'a>> for crate::ws::Frame<'a> {
|
||||||
fn from(frame: Frame) -> Self {
|
fn from(frame: Frame<'a>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: frame.fin,
|
finished: frame.fin,
|
||||||
opcode: frame.opcode.into(),
|
opcode: frame.opcode.into(),
|
||||||
|
@ -44,10 +51,10 @@ impl From<Frame<'_>> for crate::ws::Frame {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> From<crate::ws::Frame> for Frame<'a> {
|
impl<'a> From<crate::ws::Frame<'a>> for Frame<'a> {
|
||||||
fn from(frame: crate::ws::Frame) -> Self {
|
fn from(frame: crate::ws::Frame<'a>) -> Self {
|
||||||
use crate::ws::OpCode::*;
|
use crate::ws::OpCode::*;
|
||||||
let payload = Payload::Bytes(frame.payload);
|
let payload = match_payload_reverse(frame.payload);
|
||||||
match frame.opcode {
|
match frame.opcode {
|
||||||
Text => Self::text(payload),
|
Text => Self::text(payload),
|
||||||
Binary => Self::binary(payload),
|
Binary => Self::binary(payload),
|
||||||
|
@ -73,7 +80,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
|
||||||
async fn wisp_read_frame(
|
async fn wisp_read_frame(
|
||||||
&mut self,
|
&mut self,
|
||||||
tx: &LockedWebSocketWrite,
|
tx: &LockedWebSocketWrite,
|
||||||
) -> Result<crate::ws::Frame, WispError> {
|
) -> Result<crate::ws::Frame<'static>, WispError> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||||
.await?
|
.await?
|
||||||
|
@ -83,7 +90,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> {
|
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
|
||||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ pub mod ws;
|
||||||
|
|
||||||
pub use crate::{packet::*, stream::*};
|
pub use crate::{packet::*, stream::*};
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::{Bytes, BytesMut};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use event_listener::Event;
|
use event_listener::Event;
|
||||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
||||||
|
@ -167,7 +167,7 @@ struct MuxInner {
|
||||||
tx: ws::LockedWebSocketWrite,
|
tx: ws::LockedWebSocketWrite,
|
||||||
stream_map: DashMap<u32, MuxMapValue>,
|
stream_map: DashMap<u32, MuxMapValue>,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
fut_exited: Arc<AtomicBool>
|
fut_exited: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MuxInner {
|
impl MuxInner {
|
||||||
|
@ -381,7 +381,7 @@ impl MuxInner {
|
||||||
}
|
}
|
||||||
Data(data) => {
|
Data(data) => {
|
||||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
let _ = stream.stream.try_send(data);
|
let _ = stream.stream.try_send(BytesMut::from(data).freeze());
|
||||||
if stream.stream_type == StreamType::Tcp {
|
if stream.stream_type == StreamType::Tcp {
|
||||||
stream.flow_control.store(
|
stream.flow_control.store(
|
||||||
stream
|
stream
|
||||||
|
@ -417,6 +417,7 @@ impl MuxInner {
|
||||||
if frame.opcode == ws::OpCode::Close {
|
if frame.opcode == ws::OpCode::Close {
|
||||||
break Ok(());
|
break Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(packet) =
|
if let Some(packet) =
|
||||||
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
|
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
|
||||||
{
|
{
|
||||||
|
@ -425,7 +426,10 @@ impl MuxInner {
|
||||||
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||||
Data(data) => {
|
Data(data) => {
|
||||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||||
let _ = stream.stream.send_async(data).await;
|
let _ = stream
|
||||||
|
.stream
|
||||||
|
.send_async(BytesMut::from(data).freeze())
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Continue(inner_packet) => {
|
Continue(inner_packet) => {
|
||||||
|
@ -454,12 +458,12 @@ async fn maybe_wisp_v2<R>(
|
||||||
read: &mut R,
|
read: &mut R,
|
||||||
write: &LockedWebSocketWrite,
|
write: &LockedWebSocketWrite,
|
||||||
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||||
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame>, bool), WispError>
|
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
|
||||||
where
|
where
|
||||||
R: ws::WebSocketRead + Send,
|
R: ws::WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
let mut supported_extensions = Vec::new();
|
let mut supported_extensions = Vec::new();
|
||||||
let mut extra_packet = None;
|
let mut extra_packet: Option<ws::Frame<'static>> = None;
|
||||||
let mut downgraded = true;
|
let mut downgraded = true;
|
||||||
|
|
||||||
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
||||||
|
@ -476,7 +480,7 @@ where
|
||||||
.collect();
|
.collect();
|
||||||
downgraded = false;
|
downgraded = false;
|
||||||
} else {
|
} else {
|
||||||
extra_packet.replace(packet.into());
|
extra_packet.replace(ws::Frame::from(packet).clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -574,7 +578,7 @@ impl ServerMux {
|
||||||
tx: write,
|
tx: write,
|
||||||
stream_map: DashMap::new(),
|
stream_map: DashMap::new(),
|
||||||
buffer_size,
|
buffer_size,
|
||||||
fut_exited
|
fut_exited,
|
||||||
}
|
}
|
||||||
.server_into_future(
|
.server_into_future(
|
||||||
AppendingWebSocketRead(extra_packet, read),
|
AppendingWebSocketRead(extra_packet, read),
|
||||||
|
@ -761,7 +765,7 @@ impl ClientMux {
|
||||||
tx: write,
|
tx: write,
|
||||||
stream_map: DashMap::new(),
|
stream_map: DashMap::new(),
|
||||||
buffer_size: packet.buffer_remaining,
|
buffer_size: packet.buffer_remaining,
|
||||||
fut_exited
|
fut_exited,
|
||||||
}
|
}
|
||||||
.client_into_future(
|
.client_into_future(
|
||||||
AppendingWebSocketRead(extra_packet, read),
|
AppendingWebSocketRead(extra_packet, read),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
|
||||||
ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead},
|
ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
|
||||||
Role, WispError, WISP_VERSION,
|
Role, WispError, WISP_VERSION,
|
||||||
};
|
};
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
|
@ -124,9 +124,9 @@ impl ConnectPacket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<BytesMut> for ConnectPacket {
|
impl TryFrom<Payload<'_>> for ConnectPacket {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
|
||||||
if bytes.remaining() < (1 + 2) {
|
if bytes.remaining() < (1 + 2) {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
|
@ -162,9 +162,9 @@ impl ContinuePacket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<BytesMut> for ContinuePacket {
|
impl TryFrom<Payload<'_>> for ContinuePacket {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
|
||||||
if bytes.remaining() < 4 {
|
if bytes.remaining() < 4 {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
|
@ -197,9 +197,9 @@ impl ClosePacket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<BytesMut> for ClosePacket {
|
impl TryFrom<Payload<'_>> for ClosePacket {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
|
||||||
if bytes.remaining() < 1 {
|
if bytes.remaining() < 1 {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
|
@ -247,11 +247,11 @@ impl Encode for InfoPacket {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Type of packet recieved.
|
/// Type of packet recieved.
|
||||||
pub enum PacketType {
|
pub enum PacketType<'a> {
|
||||||
/// Connect packet.
|
/// Connect packet.
|
||||||
Connect(ConnectPacket),
|
Connect(ConnectPacket),
|
||||||
/// Data packet.
|
/// Data packet.
|
||||||
Data(Bytes),
|
Data(Payload<'a>),
|
||||||
/// Continue packet.
|
/// Continue packet.
|
||||||
Continue(ContinuePacket),
|
Continue(ContinuePacket),
|
||||||
/// Close packet.
|
/// Close packet.
|
||||||
|
@ -260,7 +260,7 @@ pub enum PacketType {
|
||||||
Info(InfoPacket),
|
Info(InfoPacket),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PacketType {
|
impl PacketType<'_> {
|
||||||
/// Get the packet type used in the protocol.
|
/// Get the packet type used in the protocol.
|
||||||
pub fn as_u8(&self) -> u8 {
|
pub fn as_u8(&self) -> u8 {
|
||||||
use PacketType as P;
|
use PacketType as P;
|
||||||
|
@ -285,7 +285,7 @@ impl PacketType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encode for PacketType {
|
impl Encode for PacketType<'_> {
|
||||||
fn encode(self, bytes: &mut BytesMut) {
|
fn encode(self, bytes: &mut BytesMut) {
|
||||||
use PacketType as P;
|
use PacketType as P;
|
||||||
match self {
|
match self {
|
||||||
|
@ -300,18 +300,18 @@ impl Encode for PacketType {
|
||||||
|
|
||||||
/// Wisp protocol packet.
|
/// Wisp protocol packet.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Packet {
|
pub struct Packet<'a> {
|
||||||
/// Stream this packet is associated with.
|
/// Stream this packet is associated with.
|
||||||
pub stream_id: u32,
|
pub stream_id: u32,
|
||||||
/// Packet type recieved.
|
/// Packet type recieved.
|
||||||
pub packet_type: PacketType,
|
pub packet_type: PacketType<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Packet {
|
impl<'a> Packet<'a> {
|
||||||
/// Create a new packet.
|
/// Create a new packet.
|
||||||
///
|
///
|
||||||
/// The helper functions should be used for most use cases.
|
/// The helper functions should be used for most use cases.
|
||||||
pub fn new(stream_id: u32, packet: PacketType) -> Self {
|
pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream_id,
|
stream_id,
|
||||||
packet_type: packet,
|
packet_type: packet,
|
||||||
|
@ -336,7 +336,7 @@ impl Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new data packet.
|
/// Create a new data packet.
|
||||||
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
|
pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream_id,
|
stream_id,
|
||||||
packet_type: PacketType::Data(data),
|
packet_type: PacketType::Data(data),
|
||||||
|
@ -369,13 +369,13 @@ impl Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> {
|
fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
|
||||||
use PacketType as P;
|
use PacketType as P;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
stream_id: bytes.get_u32_le(),
|
stream_id: bytes.get_u32_le(),
|
||||||
packet_type: match packet_type {
|
packet_type: match packet_type {
|
||||||
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
|
0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
|
||||||
0x02 => P::Data(bytes.freeze()),
|
0x02 => P::Data(bytes),
|
||||||
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
|
0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
|
||||||
0x04 => P::Close(ClosePacket::try_from(bytes)?),
|
0x04 => P::Close(ClosePacket::try_from(bytes)?),
|
||||||
// 0x05 is handled seperately
|
// 0x05 is handled seperately
|
||||||
|
@ -385,7 +385,7 @@ impl Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn maybe_parse_info(
|
pub(crate) fn maybe_parse_info(
|
||||||
frame: Frame,
|
frame: Frame<'a>,
|
||||||
role: Role,
|
role: Role,
|
||||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||||
) -> Result<Self, WispError> {
|
) -> Result<Self, WispError> {
|
||||||
|
@ -408,7 +408,7 @@ impl Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn maybe_handle_extension(
|
pub(crate) async fn maybe_handle_extension(
|
||||||
frame: Frame,
|
frame: Frame<'a>,
|
||||||
extensions: &mut [AnyProtocolExtension],
|
extensions: &mut [AnyProtocolExtension],
|
||||||
read: &mut (dyn WebSocketRead + Send),
|
read: &mut (dyn WebSocketRead + Send),
|
||||||
write: &LockedWebSocketWrite,
|
write: &LockedWebSocketWrite,
|
||||||
|
@ -431,7 +431,7 @@ impl Packet {
|
||||||
})),
|
})),
|
||||||
0x02 => Ok(Some(Self {
|
0x02 => Ok(Some(Self {
|
||||||
stream_id: bytes.get_u32_le(),
|
stream_id: bytes.get_u32_le(),
|
||||||
packet_type: PacketType::Data(bytes.freeze()),
|
packet_type: PacketType::Data(bytes),
|
||||||
})),
|
})),
|
||||||
0x03 => Ok(Some(Self {
|
0x03 => Ok(Some(Self {
|
||||||
stream_id: bytes.get_u32_le(),
|
stream_id: bytes.get_u32_le(),
|
||||||
|
@ -447,7 +447,9 @@ impl Packet {
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||||
{
|
{
|
||||||
extension.handle_packet(bytes.freeze(), read, write).await?;
|
extension
|
||||||
|
.handle_packet(BytesMut::from(bytes).freeze(), read, write)
|
||||||
|
.await?;
|
||||||
Ok(None)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
Err(WispError::InvalidPacketType)
|
Err(WispError::InvalidPacketType)
|
||||||
|
@ -457,7 +459,7 @@ impl Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_info(
|
fn parse_info(
|
||||||
mut bytes: BytesMut,
|
mut bytes: Payload<'a>,
|
||||||
role: Role,
|
role: Role,
|
||||||
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
|
||||||
) -> Result<Self, WispError> {
|
) -> Result<Self, WispError> {
|
||||||
|
@ -506,7 +508,7 @@ impl Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encode for Packet {
|
impl Encode for Packet<'_> {
|
||||||
fn encode(self, bytes: &mut BytesMut) {
|
fn encode(self, bytes: &mut BytesMut) {
|
||||||
bytes.put_u8(self.packet_type.as_u8());
|
bytes.put_u8(self.packet_type.as_u8());
|
||||||
bytes.put_u32_le(self.stream_id);
|
bytes.put_u32_le(self.stream_id);
|
||||||
|
@ -514,9 +516,9 @@ impl Encode for Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<BytesMut> for Packet {
|
impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> {
|
fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
|
||||||
if bytes.remaining() < 1 {
|
if bytes.remaining() < 1 {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
|
@ -525,7 +527,7 @@ impl TryFrom<BytesMut> for Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Packet> for BytesMut {
|
impl From<Packet<'_>> for BytesMut {
|
||||||
fn from(packet: Packet) -> Self {
|
fn from(packet: Packet) -> Self {
|
||||||
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
|
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
|
||||||
packet.encode(&mut encoded);
|
packet.encode(&mut encoded);
|
||||||
|
@ -533,9 +535,9 @@ impl From<Packet> for BytesMut {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<ws::Frame> for Packet {
|
impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
|
||||||
type Error = WispError;
|
type Error = WispError;
|
||||||
fn try_from(frame: ws::Frame) -> Result<Self, Self::Error> {
|
fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
|
||||||
if !frame.finished {
|
if !frame.finished {
|
||||||
return Err(Self::Error::WsFrameNotFinished);
|
return Err(Self::Error::WsFrameNotFinished);
|
||||||
}
|
}
|
||||||
|
@ -546,8 +548,8 @@ impl TryFrom<ws::Frame> for Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Packet> for ws::Frame {
|
impl From<Packet<'_>> for ws::Frame<'static> {
|
||||||
fn from(packet: Packet) -> Self {
|
fn from(packet: Packet) -> Self {
|
||||||
Self::binary(BytesMut::from(packet))
|
Self::binary(Payload::Bytes(BytesMut::from(packet)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
sink_unfold,
|
sink_unfold,
|
||||||
ws::{Frame, LockedWebSocketWrite},
|
ws::{Frame, LockedWebSocketWrite, Payload},
|
||||||
CloseReason, Packet, Role, StreamType, WispError,
|
CloseReason, Packet, Role, StreamType, WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -9,9 +9,10 @@ use event_listener::Event;
|
||||||
use flume as mpsc;
|
use flume as mpsc;
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::oneshot,
|
channel::oneshot,
|
||||||
ready, select, stream::{self, IntoAsyncRead},
|
ready, select,
|
||||||
|
stream::{self, IntoAsyncRead},
|
||||||
task::{noop_waker_ref, Context, Poll},
|
task::{noop_waker_ref, Context, Poll},
|
||||||
AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt,
|
AsyncBufRead, AsyncRead, AsyncWrite, Future, FutureExt, Sink, Stream, TryStreamExt,
|
||||||
};
|
};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -23,7 +24,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) enum WsEvent {
|
pub(crate) enum WsEvent {
|
||||||
Close(Packet, oneshot::Sender<Result<(), WispError>>),
|
Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
|
||||||
CreateStream(
|
CreateStream(
|
||||||
StreamType,
|
StreamType,
|
||||||
String,
|
String,
|
||||||
|
@ -100,8 +101,10 @@ pub struct MuxStreamWrite {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MuxStreamWrite {
|
impl MuxStreamWrite {
|
||||||
/// Write data to the stream.
|
pub(crate) async fn write_payload_internal(
|
||||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
&self,
|
||||||
|
frame: Frame<'static>,
|
||||||
|
) -> Result<(), WispError> {
|
||||||
if self.role == Role::Client
|
if self.role == Role::Client
|
||||||
&& self.stream_type == StreamType::Tcp
|
&& self.stream_type == StreamType::Tcp
|
||||||
&& self.flow_control.load(Ordering::Acquire) == 0
|
&& self.flow_control.load(Ordering::Acquire) == 0
|
||||||
|
@ -112,9 +115,7 @@ impl MuxStreamWrite {
|
||||||
return Err(WispError::StreamAlreadyClosed);
|
return Err(WispError::StreamAlreadyClosed);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.tx
|
self.tx.write_frame(frame).await?;
|
||||||
.write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||||
self.flow_control.store(
|
self.flow_control.store(
|
||||||
|
@ -125,6 +126,20 @@ impl MuxStreamWrite {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Write a payload to the stream.
|
||||||
|
pub fn write_payload<'a>(
|
||||||
|
&'a self,
|
||||||
|
data: Payload<'_>,
|
||||||
|
) -> impl Future<Output = Result<(), WispError>> + 'a {
|
||||||
|
let frame: Frame<'static> = Frame::from(Packet::new_data(self.stream_id, data));
|
||||||
|
self.write_payload_internal(frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write data to the stream.
|
||||||
|
pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
|
||||||
|
self.write_payload(Payload::Borrowed(data.as_ref())).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a handle to close the connection.
|
/// Get a handle to close the connection.
|
||||||
///
|
///
|
||||||
/// Useful to close the connection without having access to the stream.
|
/// Useful to close the connection without having access to the stream.
|
||||||
|
@ -173,16 +188,16 @@ impl MuxStreamWrite {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>> {
|
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>> {
|
||||||
let handle = self.get_close_handle();
|
let handle = self.get_close_handle();
|
||||||
Box::pin(sink_unfold::unfold(
|
Box::pin(sink_unfold::unfold(
|
||||||
self,
|
self,
|
||||||
|tx, data| async move {
|
|tx, data| async move {
|
||||||
tx.write(data).await?;
|
tx.write_payload_internal(data).await?;
|
||||||
Ok(tx)
|
Ok(tx)
|
||||||
},
|
},
|
||||||
handle,
|
handle,
|
||||||
move |handle| async {
|
|handle| async move {
|
||||||
handle.close(CloseReason::Unknown).await?;
|
handle.close(CloseReason::Unknown).await?;
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
},
|
},
|
||||||
|
@ -258,8 +273,13 @@ impl MuxStream {
|
||||||
self.rx.read().await
|
self.rx.read().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Write a payload to the stream.
|
||||||
|
pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
|
||||||
|
self.tx.write_payload(data).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Write data to the stream.
|
/// Write data to the stream.
|
||||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
|
||||||
self.tx.write(data).await
|
self.tx.write(data).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -301,6 +321,7 @@ impl MuxStream {
|
||||||
},
|
},
|
||||||
tx: MuxStreamIoSink {
|
tx: MuxStreamIoSink {
|
||||||
tx: self.tx.into_sink(),
|
tx: self.tx.into_sink(),
|
||||||
|
stream_id: self.stream_id,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -355,7 +376,9 @@ impl MuxProtocolExtensionStream {
|
||||||
encoded.put_u8(packet_type);
|
encoded.put_u8(packet_type);
|
||||||
encoded.put_u32_le(self.stream_id);
|
encoded.put_u32_le(self.stream_id);
|
||||||
encoded.extend(data);
|
encoded.extend(data);
|
||||||
self.tx.write_frame(Frame::binary(encoded)).await
|
self.tx
|
||||||
|
.write_frame(Frame::binary(Payload::Bytes(encoded)))
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -391,12 +414,12 @@ impl Stream for MuxStreamIo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Sink<Bytes> for MuxStreamIo {
|
impl Sink<&[u8]> for MuxStreamIo {
|
||||||
type Error = std::io::Error;
|
type Error = std::io::Error;
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
self.project().tx.poll_ready(cx)
|
self.project().tx.poll_ready(cx)
|
||||||
}
|
}
|
||||||
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
|
||||||
self.project().tx.start_send(item)
|
self.project().tx.start_send(item)
|
||||||
}
|
}
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
@ -433,7 +456,8 @@ pin_project! {
|
||||||
/// Write side of a multiplexor stream that implements futures `Sink`.
|
/// Write side of a multiplexor stream that implements futures `Sink`.
|
||||||
pub struct MuxStreamIoSink {
|
pub struct MuxStreamIoSink {
|
||||||
#[pin]
|
#[pin]
|
||||||
tx: Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>>,
|
tx: Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>>,
|
||||||
|
stream_id: u32,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -444,7 +468,7 @@ impl MuxStreamIoSink {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Sink<Bytes> for MuxStreamIoSink {
|
impl Sink<&[u8]> for MuxStreamIoSink {
|
||||||
type Error = std::io::Error;
|
type Error = std::io::Error;
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
self.project()
|
self.project()
|
||||||
|
@ -452,10 +476,14 @@ impl Sink<Bytes> for MuxStreamIoSink {
|
||||||
.poll_ready(cx)
|
.poll_ready(cx)
|
||||||
.map_err(std::io::Error::other)
|
.map_err(std::io::Error::other)
|
||||||
}
|
}
|
||||||
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
|
||||||
|
let stream_id = self.stream_id;
|
||||||
self.project()
|
self.project()
|
||||||
.tx
|
.tx
|
||||||
.start_send(item)
|
.start_send(Frame::from(Packet::new_data(
|
||||||
|
stream_id,
|
||||||
|
Payload::Borrowed(item),
|
||||||
|
)))
|
||||||
.map_err(std::io::Error::other)
|
.map_err(std::io::Error::other)
|
||||||
}
|
}
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
@ -582,7 +610,10 @@ pin_project! {
|
||||||
|
|
||||||
impl MuxStreamAsyncWrite {
|
impl MuxStreamAsyncWrite {
|
||||||
pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
|
pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
|
||||||
Self { tx: sink, error: None }
|
Self {
|
||||||
|
tx: sink,
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -599,7 +630,7 @@ impl AsyncWrite for MuxStreamAsyncWrite {
|
||||||
let mut this = self.as_mut().project();
|
let mut this = self.as_mut().project();
|
||||||
|
|
||||||
ready!(this.tx.as_mut().poll_ready(cx))?;
|
ready!(this.tx.as_mut().poll_ready(cx))?;
|
||||||
match this.tx.as_mut().start_send(Bytes::copy_from_slice(buf)) {
|
match this.tx.as_mut().start_send(buf) {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
let mut cx = Context::from_waker(noop_waker_ref());
|
let mut cx = Context::from_waker(noop_waker_ref());
|
||||||
let cx = &mut cx;
|
let cx = &mut cx;
|
||||||
|
|
114
wisp/src/ws.rs
114
wisp/src/ws.rs
|
@ -4,13 +4,95 @@
|
||||||
//! for other WebSocket implementations.
|
//! for other WebSocket implementations.
|
||||||
//!
|
//!
|
||||||
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
||||||
use std::sync::Arc;
|
use std::{ops::Deref, sync::Arc};
|
||||||
|
|
||||||
use crate::WispError;
|
use crate::WispError;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::BytesMut;
|
use bytes::{Buf, BytesMut};
|
||||||
use futures::lock::Mutex;
|
use futures::lock::Mutex;
|
||||||
|
|
||||||
|
/// Payload of the websocket frame.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Payload<'a> {
|
||||||
|
/// Borrowed payload. Currently used when writing data.
|
||||||
|
Borrowed(&'a [u8]),
|
||||||
|
/// BytesMut payload. Currently used when reading data.
|
||||||
|
Bytes(BytesMut),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<BytesMut> for Payload<'static> {
|
||||||
|
fn from(value: BytesMut) -> Self {
|
||||||
|
Self::Bytes(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<&'a [u8]> for Payload<'a> {
|
||||||
|
fn from(value: &'a [u8]) -> Self {
|
||||||
|
Self::Borrowed(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Payload<'_> {
|
||||||
|
/// Turn a Payload<'a> into a Payload<'static> by copying the data.
|
||||||
|
pub fn into_owned(self) -> Self {
|
||||||
|
match self {
|
||||||
|
Self::Bytes(x) => Self::Bytes(x),
|
||||||
|
Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Payload<'_>> for BytesMut {
|
||||||
|
fn from(value: Payload<'_>) -> Self {
|
||||||
|
match value {
|
||||||
|
Payload::Bytes(x) => x,
|
||||||
|
Payload::Borrowed(x) => x.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for Payload<'_> {
|
||||||
|
type Target = [u8];
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
match self {
|
||||||
|
Self::Bytes(x) => x.deref(),
|
||||||
|
Self::Borrowed(x) => x,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for Payload<'_> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
match self {
|
||||||
|
Self::Bytes(x) => Self::Bytes(x.clone()),
|
||||||
|
Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Buf for Payload<'_> {
|
||||||
|
fn remaining(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Bytes(x) => x.remaining(),
|
||||||
|
Self::Borrowed(x) => x.remaining(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn chunk(&self) -> &[u8] {
|
||||||
|
match self {
|
||||||
|
Self::Bytes(x) => x.chunk(),
|
||||||
|
Self::Borrowed(x) => x.chunk(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn advance(&mut self, cnt: usize) {
|
||||||
|
match self {
|
||||||
|
Self::Bytes(x) => x.advance(cnt),
|
||||||
|
Self::Borrowed(x) => x.advance(cnt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Opcode of the WebSocket frame.
|
/// Opcode of the WebSocket frame.
|
||||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||||
pub enum OpCode {
|
pub enum OpCode {
|
||||||
|
@ -28,18 +110,18 @@ pub enum OpCode {
|
||||||
|
|
||||||
/// WebSocket frame.
|
/// WebSocket frame.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Frame {
|
pub struct Frame<'a> {
|
||||||
/// Whether the frame is finished or not.
|
/// Whether the frame is finished or not.
|
||||||
pub finished: bool,
|
pub finished: bool,
|
||||||
/// Opcode of the WebSocket frame.
|
/// Opcode of the WebSocket frame.
|
||||||
pub opcode: OpCode,
|
pub opcode: OpCode,
|
||||||
/// Payload of the WebSocket frame.
|
/// Payload of the WebSocket frame.
|
||||||
pub payload: BytesMut,
|
pub payload: Payload<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Frame {
|
impl<'a> Frame<'a> {
|
||||||
/// Create a new text frame.
|
/// Create a new text frame.
|
||||||
pub fn text(payload: BytesMut) -> Self {
|
pub fn text(payload: Payload<'a>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: true,
|
finished: true,
|
||||||
opcode: OpCode::Text,
|
opcode: OpCode::Text,
|
||||||
|
@ -48,7 +130,7 @@ impl Frame {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new binary frame.
|
/// Create a new binary frame.
|
||||||
pub fn binary(payload: BytesMut) -> Self {
|
pub fn binary(payload: Payload<'a>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: true,
|
finished: true,
|
||||||
opcode: OpCode::Binary,
|
opcode: OpCode::Binary,
|
||||||
|
@ -57,7 +139,7 @@ impl Frame {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new close frame.
|
/// Create a new close frame.
|
||||||
pub fn close(payload: BytesMut) -> Self {
|
pub fn close(payload: Payload<'a>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: true,
|
finished: true,
|
||||||
opcode: OpCode::Close,
|
opcode: OpCode::Close,
|
||||||
|
@ -70,14 +152,17 @@ impl Frame {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait WebSocketRead {
|
pub trait WebSocketRead {
|
||||||
/// Read a frame from the socket.
|
/// Read a frame from the socket.
|
||||||
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError>;
|
async fn wisp_read_frame(
|
||||||
|
&mut self,
|
||||||
|
tx: &LockedWebSocketWrite,
|
||||||
|
) -> Result<Frame<'static>, WispError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic WebSocket write trait.
|
/// Generic WebSocket write trait.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait WebSocketWrite {
|
pub trait WebSocketWrite {
|
||||||
/// Write a frame to the socket.
|
/// Write a frame to the socket.
|
||||||
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>;
|
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
|
||||||
|
|
||||||
/// Close the socket.
|
/// Close the socket.
|
||||||
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
async fn wisp_close(&mut self) -> Result<(), WispError>;
|
||||||
|
@ -94,7 +179,7 @@ impl LockedWebSocketWrite {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write a frame to the websocket.
|
/// Write a frame to the websocket.
|
||||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> {
|
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||||
self.0.lock().await.wisp_write_frame(frame).await
|
self.0.lock().await.wisp_write_frame(frame).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +189,7 @@ impl LockedWebSocketWrite {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame>, pub R)
|
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
|
||||||
where
|
where
|
||||||
R: WebSocketRead + Send;
|
R: WebSocketRead + Send;
|
||||||
|
|
||||||
|
@ -113,7 +198,10 @@ impl<R> WebSocketRead for AppendingWebSocketRead<R>
|
||||||
where
|
where
|
||||||
R: WebSocketRead + Send,
|
R: WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
async fn wisp_read_frame(
|
||||||
|
&mut self,
|
||||||
|
tx: &LockedWebSocketWrite,
|
||||||
|
) -> Result<Frame<'static>, WispError> {
|
||||||
if let Some(x) = self.0.take() {
|
if let Some(x) = self.0.take() {
|
||||||
return Ok(x);
|
return Ok(x);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue