add a new Payload struct to allow for one-copy writes and cargo fmt

This commit is contained in:
Toshit Chawda 2024-07-17 16:23:58 -07:00
parent 314c1bfa75
commit d6353bd5a9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
18 changed files with 3533 additions and 3395 deletions

View file

@ -1,10 +1,9 @@
use bytes::{buf::UninitSlice, BufMut, Bytes, BytesMut}; use bytes::{buf::UninitSlice, BufMut, BytesMut};
use futures_util::{ use futures_util::{io::WriteHalf, lock::Mutex, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
io::WriteHalf, lock::Mutex, stream::SplitSink, AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt,
};
use js_sys::{Function, Uint8Array}; use js_sys::{Function, Uint8Array};
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use wisp_mux::MuxStreamIoSink;
use crate::{ use crate::{
stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream}, stream_provider::{ProviderAsyncRW, ProviderUnencryptedStream},
@ -105,15 +104,14 @@ impl EpoxyIoStream {
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyUdpStream { pub struct EpoxyUdpStream {
tx: Mutex<SplitSink<ProviderUnencryptedStream, Bytes>>, tx: Mutex<MuxStreamIoSink>,
onerror: Function, onerror: Function,
} }
#[wasm_bindgen] #[wasm_bindgen]
impl EpoxyUdpStream { impl EpoxyUdpStream {
pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self { pub(crate) fn connect(stream: ProviderUnencryptedStream, handlers: EpoxyHandlers) -> Self {
let (tx, mut rx) = stream.split(); let (mut rx, tx) = stream.into_split();
let tx = Mutex::new(tx);
let EpoxyHandlers { let EpoxyHandlers {
onopen, onopen,
@ -142,7 +140,7 @@ impl EpoxyUdpStream {
let _ = onopen.call0(&JsValue::null()); let _ = onopen.call0(&JsValue::null());
Self { Self {
tx, tx: tx.into(),
onerror: onerror_cloned, onerror: onerror_cloned,
} }
} }
@ -154,7 +152,7 @@ impl EpoxyUdpStream {
.map_err(|_| EpoxyError::InvalidPayload)? .map_err(|_| EpoxyError::InvalidPayload)?
.0 .0
.to_vec(); .to_vec();
Ok(self.tx.lock().await.send(payload.into()).await?) Ok(self.tx.lock().await.send(payload.as_ref()).await?)
} }
.await; .await;

View file

@ -5,9 +5,9 @@ use std::{str::FromStr, sync::Arc};
use async_compression::futures::bufread as async_comp; use async_compression::futures::bufread as async_comp;
use bytes::Bytes; use bytes::Bytes;
use cfg_if::cfg_if; use cfg_if::cfg_if;
use futures_util::TryStreamExt;
#[cfg(feature = "full")] #[cfg(feature = "full")]
use futures_util::future::Either; use futures_util::future::Either;
use futures_util::TryStreamExt;
use http::{ use http::{
header::{InvalidHeaderName, InvalidHeaderValue}, header::{InvalidHeaderName, InvalidHeaderValue},
method::InvalidMethod, method::InvalidMethod,
@ -22,7 +22,8 @@ use js_sys::{Array, Function, Object, Reflect};
use stream_provider::{StreamProvider, StreamProviderService}; use stream_provider::{StreamProvider, StreamProviderService};
use thiserror::Error; use thiserror::Error;
use utils::{ use utils::{
asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect, object_get, object_set, IncomingBody, UriExt, WasmExecutor asyncread_to_readablestream_stream, convert_body, entries_of_object, is_null_body, is_redirect,
object_get, object_set, IncomingBody, UriExt, WasmExecutor,
}; };
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_streams::ReadableStream; use wasm_streams::ReadableStream;

View file

@ -5,7 +5,9 @@ use futures_rustls::{
TlsConnector, TlsStream, TlsConnector, TlsStream,
}; };
use futures_util::{ use futures_util::{
future::Either, lock::{Mutex, MutexGuard}, AsyncRead, AsyncWrite, Future future::Either,
lock::{Mutex, MutexGuard},
AsyncRead, AsyncWrite, Future,
}; };
use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection}; use hyper_util_wasm::client::legacy::connect::{ConnectSvc, Connected, Connection};
use js_sys::{Array, Reflect, Uint8Array}; use js_sys::{Array, Reflect, Uint8Array};
@ -81,7 +83,7 @@ impl StreamProvider {
mut locked: MutexGuard<'_, Option<ClientMux>>, mut locked: MutexGuard<'_, Option<ClientMux>>,
) -> Result<(), EpoxyError> { ) -> Result<(), EpoxyError> {
let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> = let extensions_vec: Vec<Box<dyn ProtocolExtensionBuilder + Send + Sync>> =
vec![Box::new(UdpProtocolExtensionBuilder())]; vec![Box::new(UdpProtocolExtensionBuilder)];
let extensions = if self.wisp_v2 { let extensions = if self.wisp_v2 {
Some(extensions_vec.as_slice()) Some(extensions_vec.as_slice())
} else { } else {

View file

@ -9,7 +9,8 @@ use futures_util::lock::Mutex;
use getrandom::getrandom; use getrandom::getrandom;
use http::{ use http::{
header::{ header::{
CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT, CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION,
UPGRADE, USER_AGENT,
}, },
Method, Request, Response, StatusCode, Uri, Method, Request, Response, StatusCode, Uri,
}; };
@ -22,7 +23,9 @@ use tokio::io::WriteHalf;
use wasm_bindgen::{prelude::*, JsError, JsValue}; use wasm_bindgen::{prelude::*, JsError, JsValue};
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
use crate::{tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody}; use crate::{
tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody,
};
#[wasm_bindgen] #[wasm_bindgen]
pub struct EpoxyWebSocket { pub struct EpoxyWebSocket {
@ -69,7 +72,9 @@ impl EpoxyWebSocket {
request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(",")); request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
} }
if web_sys::Headers::instanceof(&headers) && let Ok(entries) = Object::from_entries(&headers) { if web_sys::Headers::instanceof(&headers)
&& let Ok(entries) = Object::from_entries(&headers)
{
for header in entries_of_object(&entries) { for header in entries_of_object(&entries) {
request = request.header(&header[0], &header[1]); request = request.header(&header[0], &header[1]);
} }

View file

@ -13,7 +13,7 @@ use send_wrapper::SendWrapper;
use wasm_bindgen::{closure::Closure, JsCast}; use wasm_bindgen::{closure::Closure, JsCast};
use web_sys::{BinaryType, MessageEvent, WebSocket}; use web_sys::{BinaryType, MessageEvent, WebSocket};
use wisp_mux::{ use wisp_mux::{
ws::{Frame, LockedWebSocketWrite, WebSocketRead, WebSocketWrite}, ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
WispError, WispError,
}; };
@ -77,7 +77,10 @@ pub struct WebSocketReader {
#[async_trait] #[async_trait]
impl WebSocketRead for WebSocketReader { impl WebSocketRead for WebSocketReader {
async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result<Frame, WispError> { async fn wisp_read_frame(
&mut self,
_: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
use WebSocketMessage::*; use WebSocketMessage::*;
if self.closed.load(Ordering::Acquire) { if self.closed.load(Ordering::Acquire) {
return Err(WispError::WsImplSocketClosed); return Err(WispError::WsImplSocketClosed);
@ -87,7 +90,9 @@ impl WebSocketRead for WebSocketReader {
_ = self.close_event.listen().fuse() => Some(Closed), _ = self.close_event.listen().fuse() => Some(Closed),
}; };
match res.ok_or(WispError::WsImplSocketClosed)? { match res.ok_or(WispError::WsImplSocketClosed)? {
Message(bin) => Ok(Frame::binary(BytesMut::from(bin.as_slice()))), Message(bin) => Ok(Frame::binary(Payload::Bytes(BytesMut::from(
bin.as_slice(),
)))),
Error => Err(WebSocketError::Unknown.into()), Error => Err(WebSocketError::Unknown.into()),
Closed => Err(WispError::WsImplSocketClosed), Closed => Err(WispError::WsImplSocketClosed),
} }
@ -188,7 +193,7 @@ impl WebSocketWrapper {
#[async_trait] #[async_trait]
impl WebSocketWrite for WebSocketWrapper { impl WebSocketWrite for WebSocketWrapper {
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
use wisp_mux::ws::OpCode::*; use wisp_mux::ws::OpCode::*;
if self.closed.load(Ordering::Acquire) { if self.closed.load(Ordering::Acquire) {
return Err(WispError::WsImplSocketClosed); return Err(WispError::WsImplSocketClosed);

View file

@ -202,7 +202,7 @@ async fn main() -> Result<(), Error> {
block_non_http: opt.block_non_http, block_non_http: opt.block_non_http,
block_udp: opt.block_udp, block_udp: opt.block_udp,
auth: Arc::new(vec![ auth: Arc::new(vec![
Box::new(UdpProtocolExtensionBuilder()), Box::new(UdpProtocolExtensionBuilder),
Box::new(pw_ext), Box::new(pw_ext),
]), ]),
enforce_auth, enforce_auth,
@ -361,8 +361,7 @@ async fn accept_ws(
let (rx, tx) = ws.split(|x| { let (rx, tx) = ws.split(|x| {
let Parts { let Parts {
io, read_buf: buf, .. io, read_buf: buf, ..
} = x } = x.into_inner()
.into_inner()
.downcast::<TokioIo<ListenerStream>>() .downcast::<TokioIo<ListenerStream>>()
.unwrap(); .unwrap();
assert_eq!(buf.len(), 0); assert_eq!(buf.len(), 0);
@ -398,12 +397,7 @@ async fn accept_ws(
.with_required_extensions(&[PasswordProtocolExtension::ID]) .with_required_extensions(&[PasswordProtocolExtension::ID])
.await? .await?
} else { } else {
ServerMux::create( ServerMux::create(rx, tx, 512, Some(&[Box::new(UdpProtocolExtensionBuilder)]))
rx,
tx,
512,
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
)
.await? .await?
.with_no_required_extensions() .with_no_required_extensions()
}; };

View file

@ -158,7 +158,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new(); let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new(); let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp { if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder())); extensions.push(Box::new(UdpProtocolExtensionBuilder));
extension_ids.push(UdpProtocolExtension::ID); extension_ids.push(UdpProtocolExtension::ID);
} }
if let Some(auth) = auth { if let Some(auth) = auth {
@ -173,7 +173,8 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
} else { } else {
ClientMux::create(rx, tx, Some(extensions.as_slice())) ClientMux::create(rx, tx, Some(extensions.as_slice()))
.await? .await?
.with_required_extensions(extension_ids.as_slice()).await? .with_required_extensions(extension_ids.as_slice())
.await?
}; };
println!( println!(

View file

@ -56,7 +56,7 @@ impl From<AnyProtocolExtension> for Bytes {
/// A Wisp protocol extension. /// A Wisp protocol extension.
/// ///
/// See [the /// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). /// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
#[async_trait] #[async_trait]
pub trait ProtocolExtension: std::fmt::Debug { pub trait ProtocolExtension: std::fmt::Debug {
/// Get the protocol extension ID. /// Get the protocol extension ID.

View file

@ -29,7 +29,7 @@
//! ]) //! ])
//! ); //! );
//! ``` //! ```
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};

View file

@ -6,10 +6,10 @@
//! rx, //! rx,
//! tx, //! tx,
//! 128, //! 128,
//! Some(&[Box::new(UdpProtocolExtensionBuilder())]) //! Some(&[Box::new(UdpProtocolExtensionBuilder)])
//! ); //! );
//! ``` //! ```
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp)
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
@ -22,7 +22,7 @@ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
#[derive(Debug)] #[derive(Debug)]
/// UDP protocol extension. /// UDP protocol extension.
pub struct UdpProtocolExtension(); pub struct UdpProtocolExtension;
impl UdpProtocolExtension { impl UdpProtocolExtension {
/// UDP protocol extension ID. /// UDP protocol extension ID.
@ -61,7 +61,7 @@ impl ProtocolExtension for UdpProtocolExtension {
} }
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> { fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(Self()) Box::new(Self)
} }
} }
@ -72,7 +72,7 @@ impl From<UdpProtocolExtension> for AnyProtocolExtension {
} }
/// UDP protocol extension builder. /// UDP protocol extension builder.
pub struct UdpProtocolExtensionBuilder(); pub struct UdpProtocolExtensionBuilder;
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
@ -84,10 +84,10 @@ impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
_: Bytes, _: Bytes,
_: crate::Role, _: crate::Role,
) -> Result<AnyProtocolExtension, WispError> { ) -> Result<AnyProtocolExtension, WispError> {
Ok(UdpProtocolExtension().into()) Ok(UdpProtocolExtension.into())
} }
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
UdpProtocolExtension().into() UdpProtocolExtension.into()
} }
} }

View file

@ -9,12 +9,19 @@ use tokio::io::{AsyncRead, AsyncWrite};
use crate::{ws::LockedWebSocketWrite, WispError}; use crate::{ws::LockedWebSocketWrite, WispError};
fn match_payload(payload: Payload) -> BytesMut { fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> {
match payload { match payload {
Payload::Bytes(x) => x, Payload::Bytes(x) => crate::ws::Payload::Bytes(x),
Payload::Owned(x) => BytesMut::from(x.deref()), Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())),
Payload::BorrowedMut(x) => BytesMut::from(x.deref()), Payload::BorrowedMut(x) => crate::ws::Payload::Borrowed(&*x),
Payload::Borrowed(x) => BytesMut::from(x), Payload::Borrowed(x) => crate::ws::Payload::Borrowed(x),
}
}
fn match_payload_reverse<'a>(payload: crate::ws::Payload<'a>) -> Payload<'a> {
match payload {
crate::ws::Payload::Bytes(x) => Payload::Bytes(x),
crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x),
} }
} }
@ -34,8 +41,8 @@ impl From<OpCode> for crate::ws::OpCode {
} }
} }
impl From<Frame<'_>> for crate::ws::Frame { impl<'a> From<Frame<'a>> for crate::ws::Frame<'a> {
fn from(frame: Frame) -> Self { fn from(frame: Frame<'a>) -> Self {
Self { Self {
finished: frame.fin, finished: frame.fin,
opcode: frame.opcode.into(), opcode: frame.opcode.into(),
@ -44,10 +51,10 @@ impl From<Frame<'_>> for crate::ws::Frame {
} }
} }
impl<'a> From<crate::ws::Frame> for Frame<'a> { impl<'a> From<crate::ws::Frame<'a>> for Frame<'a> {
fn from(frame: crate::ws::Frame) -> Self { fn from(frame: crate::ws::Frame<'a>) -> Self {
use crate::ws::OpCode::*; use crate::ws::OpCode::*;
let payload = Payload::Bytes(frame.payload); let payload = match_payload_reverse(frame.payload);
match frame.opcode { match frame.opcode {
Text => Self::text(payload), Text => Self::text(payload),
Binary => Self::binary(payload), Binary => Self::binary(payload),
@ -73,7 +80,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
async fn wisp_read_frame( async fn wisp_read_frame(
&mut self, &mut self,
tx: &LockedWebSocketWrite, tx: &LockedWebSocketWrite,
) -> Result<crate::ws::Frame, WispError> { ) -> Result<crate::ws::Frame<'static>, WispError> {
Ok(self Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await? .await?
@ -83,7 +90,7 @@ impl<S: AsyncRead + Unpin + Send> crate::ws::WebSocketRead for FragmentCollector
#[async_trait] #[async_trait]
impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> { impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { async fn wisp_write_frame(&mut self, frame: crate::ws::Frame<'_>) -> Result<(), WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into()) self.write_frame(frame.into()).await.map_err(|e| e.into())
} }

View file

@ -15,7 +15,7 @@ pub mod ws;
pub use crate::{packet::*, stream::*}; pub use crate::{packet::*, stream::*};
use bytes::Bytes; use bytes::{Bytes, BytesMut};
use dashmap::DashMap; use dashmap::DashMap;
use event_listener::Event; use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
@ -167,7 +167,7 @@ struct MuxInner {
tx: ws::LockedWebSocketWrite, tx: ws::LockedWebSocketWrite,
stream_map: DashMap<u32, MuxMapValue>, stream_map: DashMap<u32, MuxMapValue>,
buffer_size: u32, buffer_size: u32,
fut_exited: Arc<AtomicBool> fut_exited: Arc<AtomicBool>,
} }
impl MuxInner { impl MuxInner {
@ -381,7 +381,7 @@ impl MuxInner {
} }
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.try_send(data); let _ = stream.stream.try_send(BytesMut::from(data).freeze());
if stream.stream_type == StreamType::Tcp { if stream.stream_type == StreamType::Tcp {
stream.flow_control.store( stream.flow_control.store(
stream stream
@ -417,6 +417,7 @@ impl MuxInner {
if frame.opcode == ws::OpCode::Close { if frame.opcode == ws::OpCode::Close {
break Ok(()); break Ok(());
} }
if let Some(packet) = if let Some(packet) =
Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await?
{ {
@ -425,7 +426,10 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => { Data(data) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get(&packet.stream_id) {
let _ = stream.stream.send_async(data).await; let _ = stream
.stream
.send_async(BytesMut::from(data).freeze())
.await;
} }
} }
Continue(inner_packet) => { Continue(inner_packet) => {
@ -454,12 +458,12 @@ async fn maybe_wisp_v2<R>(
read: &mut R, read: &mut R,
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>], builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame>, bool), WispError> ) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
{ {
let mut supported_extensions = Vec::new(); let mut supported_extensions = Vec::new();
let mut extra_packet = None; let mut extra_packet: Option<ws::Frame<'static>> = None;
let mut downgraded = true; let mut downgraded = true;
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
@ -476,7 +480,7 @@ where
.collect(); .collect();
downgraded = false; downgraded = false;
} else { } else {
extra_packet.replace(packet.into()); extra_packet.replace(ws::Frame::from(packet).clone());
} }
} }
@ -574,7 +578,7 @@ impl ServerMux {
tx: write, tx: write,
stream_map: DashMap::new(), stream_map: DashMap::new(),
buffer_size, buffer_size,
fut_exited fut_exited,
} }
.server_into_future( .server_into_future(
AppendingWebSocketRead(extra_packet, read), AppendingWebSocketRead(extra_packet, read),
@ -761,7 +765,7 @@ impl ClientMux {
tx: write, tx: write,
stream_map: DashMap::new(), stream_map: DashMap::new(),
buffer_size: packet.buffer_remaining, buffer_size: packet.buffer_remaining,
fut_exited fut_exited,
} }
.client_into_future( .client_into_future(
AppendingWebSocketRead(extra_packet, read), AppendingWebSocketRead(extra_packet, read),

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, extensions::{AnyProtocolExtension, ProtocolExtensionBuilder},
ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead}, ws::{self, Frame, LockedWebSocketWrite, OpCode, Payload, WebSocketRead},
Role, WispError, WISP_VERSION, Role, WispError, WISP_VERSION,
}; };
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
@ -124,9 +124,9 @@ impl ConnectPacket {
} }
} }
impl TryFrom<BytesMut> for ConnectPacket { impl TryFrom<Payload<'_>> for ConnectPacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < (1 + 2) { if bytes.remaining() < (1 + 2) {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -162,9 +162,9 @@ impl ContinuePacket {
} }
} }
impl TryFrom<BytesMut> for ContinuePacket { impl TryFrom<Payload<'_>> for ContinuePacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < 4 { if bytes.remaining() < 4 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -197,9 +197,9 @@ impl ClosePacket {
} }
} }
impl TryFrom<BytesMut> for ClosePacket { impl TryFrom<Payload<'_>> for ClosePacket {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'_>) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -247,11 +247,11 @@ impl Encode for InfoPacket {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Type of packet recieved. /// Type of packet recieved.
pub enum PacketType { pub enum PacketType<'a> {
/// Connect packet. /// Connect packet.
Connect(ConnectPacket), Connect(ConnectPacket),
/// Data packet. /// Data packet.
Data(Bytes), Data(Payload<'a>),
/// Continue packet. /// Continue packet.
Continue(ContinuePacket), Continue(ContinuePacket),
/// Close packet. /// Close packet.
@ -260,7 +260,7 @@ pub enum PacketType {
Info(InfoPacket), Info(InfoPacket),
} }
impl PacketType { impl PacketType<'_> {
/// Get the packet type used in the protocol. /// Get the packet type used in the protocol.
pub fn as_u8(&self) -> u8 { pub fn as_u8(&self) -> u8 {
use PacketType as P; use PacketType as P;
@ -285,7 +285,7 @@ impl PacketType {
} }
} }
impl Encode for PacketType { impl Encode for PacketType<'_> {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
use PacketType as P; use PacketType as P;
match self { match self {
@ -300,18 +300,18 @@ impl Encode for PacketType {
/// Wisp protocol packet. /// Wisp protocol packet.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Packet { pub struct Packet<'a> {
/// Stream this packet is associated with. /// Stream this packet is associated with.
pub stream_id: u32, pub stream_id: u32,
/// Packet type recieved. /// Packet type recieved.
pub packet_type: PacketType, pub packet_type: PacketType<'a>,
} }
impl Packet { impl<'a> Packet<'a> {
/// Create a new packet. /// Create a new packet.
/// ///
/// The helper functions should be used for most use cases. /// The helper functions should be used for most use cases.
pub fn new(stream_id: u32, packet: PacketType) -> Self { pub fn new(stream_id: u32, packet: PacketType<'a>) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: packet, packet_type: packet,
@ -336,7 +336,7 @@ impl Packet {
} }
/// Create a new data packet. /// Create a new data packet.
pub fn new_data(stream_id: u32, data: Bytes) -> Self { pub fn new_data(stream_id: u32, data: Payload<'a>) -> Self {
Self { Self {
stream_id, stream_id,
packet_type: PacketType::Data(data), packet_type: PacketType::Data(data),
@ -369,13 +369,13 @@ impl Packet {
} }
} }
fn parse_packet(packet_type: u8, mut bytes: BytesMut) -> Result<Self, WispError> { fn parse_packet(packet_type: u8, mut bytes: Payload<'a>) -> Result<Self, WispError> {
use PacketType as P; use PacketType as P;
Ok(Self { Ok(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: match packet_type { packet_type: match packet_type {
0x01 => P::Connect(ConnectPacket::try_from(bytes)?), 0x01 => P::Connect(ConnectPacket::try_from(bytes)?),
0x02 => P::Data(bytes.freeze()), 0x02 => P::Data(bytes),
0x03 => P::Continue(ContinuePacket::try_from(bytes)?), 0x03 => P::Continue(ContinuePacket::try_from(bytes)?),
0x04 => P::Close(ClosePacket::try_from(bytes)?), 0x04 => P::Close(ClosePacket::try_from(bytes)?),
// 0x05 is handled seperately // 0x05 is handled seperately
@ -385,7 +385,7 @@ impl Packet {
} }
pub(crate) fn maybe_parse_info( pub(crate) fn maybe_parse_info(
frame: Frame, frame: Frame<'a>,
role: Role, role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
@ -408,7 +408,7 @@ impl Packet {
} }
pub(crate) async fn maybe_handle_extension( pub(crate) async fn maybe_handle_extension(
frame: Frame, frame: Frame<'a>,
extensions: &mut [AnyProtocolExtension], extensions: &mut [AnyProtocolExtension],
read: &mut (dyn WebSocketRead + Send), read: &mut (dyn WebSocketRead + Send),
write: &LockedWebSocketWrite, write: &LockedWebSocketWrite,
@ -431,7 +431,7 @@ impl Packet {
})), })),
0x02 => Ok(Some(Self { 0x02 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet_type: PacketType::Data(bytes.freeze()), packet_type: PacketType::Data(bytes),
})), })),
0x03 => Ok(Some(Self { 0x03 => Ok(Some(Self {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
@ -447,7 +447,9 @@ impl Packet {
.iter_mut() .iter_mut()
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
{ {
extension.handle_packet(bytes.freeze(), read, write).await?; extension
.handle_packet(BytesMut::from(bytes).freeze(), read, write)
.await?;
Ok(None) Ok(None)
} else { } else {
Err(WispError::InvalidPacketType) Err(WispError::InvalidPacketType)
@ -457,7 +459,7 @@ impl Packet {
} }
fn parse_info( fn parse_info(
mut bytes: BytesMut, mut bytes: Payload<'a>,
role: Role, role: Role,
extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>],
) -> Result<Self, WispError> { ) -> Result<Self, WispError> {
@ -506,7 +508,7 @@ impl Packet {
} }
} }
impl Encode for Packet { impl Encode for Packet<'_> {
fn encode(self, bytes: &mut BytesMut) { fn encode(self, bytes: &mut BytesMut) {
bytes.put_u8(self.packet_type.as_u8()); bytes.put_u8(self.packet_type.as_u8());
bytes.put_u32_le(self.stream_id); bytes.put_u32_le(self.stream_id);
@ -514,9 +516,9 @@ impl Encode for Packet {
} }
} }
impl TryFrom<BytesMut> for Packet { impl<'a> TryFrom<Payload<'a>> for Packet<'a> {
type Error = WispError; type Error = WispError;
fn try_from(mut bytes: BytesMut) -> Result<Self, Self::Error> { fn try_from(mut bytes: Payload<'a>) -> Result<Self, Self::Error> {
if bytes.remaining() < 1 { if bytes.remaining() < 1 {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
@ -525,7 +527,7 @@ impl TryFrom<BytesMut> for Packet {
} }
} }
impl From<Packet> for BytesMut { impl From<Packet<'_>> for BytesMut {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size()); let mut encoded = BytesMut::with_capacity(1 + 4 + packet.packet_type.get_packet_size());
packet.encode(&mut encoded); packet.encode(&mut encoded);
@ -533,9 +535,9 @@ impl From<Packet> for BytesMut {
} }
} }
impl TryFrom<ws::Frame> for Packet { impl<'a> TryFrom<ws::Frame<'a>> for Packet<'a> {
type Error = WispError; type Error = WispError;
fn try_from(frame: ws::Frame) -> Result<Self, Self::Error> { fn try_from(frame: ws::Frame<'a>) -> Result<Self, Self::Error> {
if !frame.finished { if !frame.finished {
return Err(Self::Error::WsFrameNotFinished); return Err(Self::Error::WsFrameNotFinished);
} }
@ -546,8 +548,8 @@ impl TryFrom<ws::Frame> for Packet {
} }
} }
impl From<Packet> for ws::Frame { impl From<Packet<'_>> for ws::Frame<'static> {
fn from(packet: Packet) -> Self { fn from(packet: Packet) -> Self {
Self::binary(BytesMut::from(packet)) Self::binary(Payload::Bytes(BytesMut::from(packet)))
} }
} }

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
sink_unfold, sink_unfold,
ws::{Frame, LockedWebSocketWrite}, ws::{Frame, LockedWebSocketWrite, Payload},
CloseReason, Packet, Role, StreamType, WispError, CloseReason, Packet, Role, StreamType, WispError,
}; };
@ -9,9 +9,10 @@ use event_listener::Event;
use flume as mpsc; use flume as mpsc;
use futures::{ use futures::{
channel::oneshot, channel::oneshot,
ready, select, stream::{self, IntoAsyncRead}, ready, select,
stream::{self, IntoAsyncRead},
task::{noop_waker_ref, Context, Poll}, task::{noop_waker_ref, Context, Poll},
AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt, AsyncBufRead, AsyncRead, AsyncWrite, Future, FutureExt, Sink, Stream, TryStreamExt,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -23,7 +24,7 @@ use std::{
}; };
pub(crate) enum WsEvent { pub(crate) enum WsEvent {
Close(Packet, oneshot::Sender<Result<(), WispError>>), Close(Packet<'static>, oneshot::Sender<Result<(), WispError>>),
CreateStream( CreateStream(
StreamType, StreamType,
String, String,
@ -100,8 +101,10 @@ pub struct MuxStreamWrite {
} }
impl MuxStreamWrite { impl MuxStreamWrite {
/// Write data to the stream. pub(crate) async fn write_payload_internal(
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { &self,
frame: Frame<'static>,
) -> Result<(), WispError> {
if self.role == Role::Client if self.role == Role::Client
&& self.stream_type == StreamType::Tcp && self.stream_type == StreamType::Tcp
&& self.flow_control.load(Ordering::Acquire) == 0 && self.flow_control.load(Ordering::Acquire) == 0
@ -112,9 +115,7 @@ impl MuxStreamWrite {
return Err(WispError::StreamAlreadyClosed); return Err(WispError::StreamAlreadyClosed);
} }
self.tx self.tx.write_frame(frame).await?;
.write_frame(Frame::from(Packet::new_data(self.stream_id, data)))
.await?;
if self.role == Role::Client && self.stream_type == StreamType::Tcp { if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store( self.flow_control.store(
@ -125,6 +126,20 @@ impl MuxStreamWrite {
Ok(()) Ok(())
} }
/// Write a payload to the stream.
pub fn write_payload<'a>(
&'a self,
data: Payload<'_>,
) -> impl Future<Output = Result<(), WispError>> + 'a {
let frame: Frame<'static> = Frame::from(Packet::new_data(self.stream_id, data));
self.write_payload_internal(frame)
}
/// Write data to the stream.
pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
self.write_payload(Payload::Borrowed(data.as_ref())).await
}
/// Get a handle to close the connection. /// Get a handle to close the connection.
/// ///
/// Useful to close the connection without having access to the stream. /// Useful to close the connection without having access to the stream.
@ -173,16 +188,16 @@ impl MuxStreamWrite {
Ok(()) Ok(())
} }
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>> { pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>> {
let handle = self.get_close_handle(); let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold( Box::pin(sink_unfold::unfold(
self, self,
|tx, data| async move { |tx, data| async move {
tx.write(data).await?; tx.write_payload_internal(data).await?;
Ok(tx) Ok(tx)
}, },
handle, handle,
move |handle| async { |handle| async move {
handle.close(CloseReason::Unknown).await?; handle.close(CloseReason::Unknown).await?;
Ok(handle) Ok(handle)
}, },
@ -258,8 +273,13 @@ impl MuxStream {
self.rx.read().await self.rx.read().await
} }
/// Write a payload to the stream.
pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
self.tx.write_payload(data).await
}
/// Write data to the stream. /// Write data to the stream.
pub async fn write(&self, data: Bytes) -> Result<(), WispError> { pub async fn write<D: AsRef<[u8]>>(&self, data: D) -> Result<(), WispError> {
self.tx.write(data).await self.tx.write(data).await
} }
@ -301,6 +321,7 @@ impl MuxStream {
}, },
tx: MuxStreamIoSink { tx: MuxStreamIoSink {
tx: self.tx.into_sink(), tx: self.tx.into_sink(),
stream_id: self.stream_id,
}, },
} }
} }
@ -355,7 +376,9 @@ impl MuxProtocolExtensionStream {
encoded.put_u8(packet_type); encoded.put_u8(packet_type);
encoded.put_u32_le(self.stream_id); encoded.put_u32_le(self.stream_id);
encoded.extend(data); encoded.extend(data);
self.tx.write_frame(Frame::binary(encoded)).await self.tx
.write_frame(Frame::binary(Payload::Bytes(encoded)))
.await
} }
} }
@ -391,12 +414,12 @@ impl Stream for MuxStreamIo {
} }
} }
impl Sink<Bytes> for MuxStreamIo { impl Sink<&[u8]> for MuxStreamIo {
type Error = std::io::Error; type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().tx.poll_ready(cx) self.project().tx.poll_ready(cx)
} }
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
self.project().tx.start_send(item) self.project().tx.start_send(item)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -433,7 +456,8 @@ pin_project! {
/// Write side of a multiplexor stream that implements futures `Sink`. /// Write side of a multiplexor stream that implements futures `Sink`.
pub struct MuxStreamIoSink { pub struct MuxStreamIoSink {
#[pin] #[pin]
tx: Pin<Box<dyn Sink<Bytes, Error = WispError> + Send>>, tx: Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>>,
stream_id: u32,
} }
} }
@ -444,7 +468,7 @@ impl MuxStreamIoSink {
} }
} }
impl Sink<Bytes> for MuxStreamIoSink { impl Sink<&[u8]> for MuxStreamIoSink {
type Error = std::io::Error; type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project() self.project()
@ -452,10 +476,14 @@ impl Sink<Bytes> for MuxStreamIoSink {
.poll_ready(cx) .poll_ready(cx)
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
let stream_id = self.stream_id;
self.project() self.project()
.tx .tx
.start_send(item) .start_send(Frame::from(Packet::new_data(
stream_id,
Payload::Borrowed(item),
)))
.map_err(std::io::Error::other) .map_err(std::io::Error::other)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -582,7 +610,10 @@ pin_project! {
impl MuxStreamAsyncWrite { impl MuxStreamAsyncWrite {
pub(crate) fn new(sink: MuxStreamIoSink) -> Self { pub(crate) fn new(sink: MuxStreamIoSink) -> Self {
Self { tx: sink, error: None } Self {
tx: sink,
error: None,
}
} }
} }
@ -599,7 +630,7 @@ impl AsyncWrite for MuxStreamAsyncWrite {
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
ready!(this.tx.as_mut().poll_ready(cx))?; ready!(this.tx.as_mut().poll_ready(cx))?;
match this.tx.as_mut().start_send(Bytes::copy_from_slice(buf)) { match this.tx.as_mut().start_send(buf) {
Ok(()) => { Ok(()) => {
let mut cx = Context::from_waker(noop_waker_ref()); let mut cx = Context::from_waker(noop_waker_ref());
let cx = &mut cx; let cx = &mut cx;

View file

@ -4,13 +4,95 @@
//! for other WebSocket implementations. //! for other WebSocket implementations.
//! //!
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
use std::sync::Arc; use std::{ops::Deref, sync::Arc};
use crate::WispError; use crate::WispError;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::BytesMut; use bytes::{Buf, BytesMut};
use futures::lock::Mutex; use futures::lock::Mutex;
/// Payload of the websocket frame.
#[derive(Debug)]
pub enum Payload<'a> {
/// Borrowed payload. Currently used when writing data.
Borrowed(&'a [u8]),
/// BytesMut payload. Currently used when reading data.
Bytes(BytesMut),
}
impl From<BytesMut> for Payload<'static> {
fn from(value: BytesMut) -> Self {
Self::Bytes(value)
}
}
impl<'a> From<&'a [u8]> for Payload<'a> {
fn from(value: &'a [u8]) -> Self {
Self::Borrowed(value)
}
}
impl Payload<'_> {
/// Turn a Payload<'a> into a Payload<'static> by copying the data.
pub fn into_owned(self) -> Self {
match self {
Self::Bytes(x) => Self::Bytes(x),
Self::Borrowed(x) => Self::Bytes(BytesMut::from(x)),
}
}
}
impl From<Payload<'_>> for BytesMut {
fn from(value: Payload<'_>) -> Self {
match value {
Payload::Bytes(x) => x,
Payload::Borrowed(x) => x.into(),
}
}
}
impl Deref for Payload<'_> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
Self::Bytes(x) => x.deref(),
Self::Borrowed(x) => x,
}
}
}
impl Clone for Payload<'_> {
fn clone(&self) -> Self {
match self {
Self::Bytes(x) => Self::Bytes(x.clone()),
Self::Borrowed(x) => Self::Bytes(BytesMut::from(*x)),
}
}
}
impl Buf for Payload<'_> {
fn remaining(&self) -> usize {
match self {
Self::Bytes(x) => x.remaining(),
Self::Borrowed(x) => x.remaining(),
}
}
fn chunk(&self) -> &[u8] {
match self {
Self::Bytes(x) => x.chunk(),
Self::Borrowed(x) => x.chunk(),
}
}
fn advance(&mut self, cnt: usize) {
match self {
Self::Bytes(x) => x.advance(cnt),
Self::Borrowed(x) => x.advance(cnt),
}
}
}
/// Opcode of the WebSocket frame. /// Opcode of the WebSocket frame.
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
pub enum OpCode { pub enum OpCode {
@ -28,18 +110,18 @@ pub enum OpCode {
/// WebSocket frame. /// WebSocket frame.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Frame { pub struct Frame<'a> {
/// Whether the frame is finished or not. /// Whether the frame is finished or not.
pub finished: bool, pub finished: bool,
/// Opcode of the WebSocket frame. /// Opcode of the WebSocket frame.
pub opcode: OpCode, pub opcode: OpCode,
/// Payload of the WebSocket frame. /// Payload of the WebSocket frame.
pub payload: BytesMut, pub payload: Payload<'a>,
} }
impl Frame { impl<'a> Frame<'a> {
/// Create a new text frame. /// Create a new text frame.
pub fn text(payload: BytesMut) -> Self { pub fn text(payload: Payload<'a>) -> Self {
Self { Self {
finished: true, finished: true,
opcode: OpCode::Text, opcode: OpCode::Text,
@ -48,7 +130,7 @@ impl Frame {
} }
/// Create a new binary frame. /// Create a new binary frame.
pub fn binary(payload: BytesMut) -> Self { pub fn binary(payload: Payload<'a>) -> Self {
Self { Self {
finished: true, finished: true,
opcode: OpCode::Binary, opcode: OpCode::Binary,
@ -57,7 +139,7 @@ impl Frame {
} }
/// Create a new close frame. /// Create a new close frame.
pub fn close(payload: BytesMut) -> Self { pub fn close(payload: Payload<'a>) -> Self {
Self { Self {
finished: true, finished: true,
opcode: OpCode::Close, opcode: OpCode::Close,
@ -70,14 +152,17 @@ impl Frame {
#[async_trait] #[async_trait]
pub trait WebSocketRead { pub trait WebSocketRead {
/// Read a frame from the socket. /// Read a frame from the socket.
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError>; async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError>;
} }
/// Generic WebSocket write trait. /// Generic WebSocket write trait.
#[async_trait] #[async_trait]
pub trait WebSocketWrite { pub trait WebSocketWrite {
/// Write a frame to the socket. /// Write a frame to the socket.
async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError>;
/// Close the socket. /// Close the socket.
async fn wisp_close(&mut self) -> Result<(), WispError>; async fn wisp_close(&mut self) -> Result<(), WispError>;
@ -94,7 +179,7 @@ impl LockedWebSocketWrite {
} }
/// Write a frame to the websocket. /// Write a frame to the websocket.
pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> { pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WispError> {
self.0.lock().await.wisp_write_frame(frame).await self.0.lock().await.wisp_write_frame(frame).await
} }
@ -104,7 +189,7 @@ impl LockedWebSocketWrite {
} }
} }
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame>, pub R) pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame<'static>>, pub R)
where where
R: WebSocketRead + Send; R: WebSocketRead + Send;
@ -113,7 +198,10 @@ impl<R> WebSocketRead for AppendingWebSocketRead<R>
where where
R: WebSocketRead + Send, R: WebSocketRead + Send,
{ {
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> { async fn wisp_read_frame(
&mut self,
tx: &LockedWebSocketWrite,
) -> Result<Frame<'static>, WispError> {
if let Some(x) = self.0.take() { if let Some(x) = self.0.take() {
return Ok(x); return Ok(x);
} }