move the wisp logic into wisp lib

This commit is contained in:
Toshit Chawda 2024-01-27 18:57:04 -08:00
parent 379e07d643
commit 2a5684192a
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
8 changed files with 314 additions and 198 deletions

2
Cargo.lock generated
View file

@ -1530,9 +1530,11 @@ name = "wisp-mux"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bytes", "bytes",
"dashmap",
"fastwebsockets", "fastwebsockets",
"futures", "futures",
"futures-util", "futures-util",
"tokio",
] ]
[[package]] [[package]]

View file

@ -1,20 +0,0 @@
use fastwebsockets::{WebSocketWrite, Frame, WebSocketError};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::{io::WriteHalf, sync::Mutex};
type Ws = WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>;
#[derive(Clone)]
pub struct LockedWebSocketWrite(Arc<Mutex<Ws>>);
impl LockedWebSocketWrite {
pub fn new(ws: Ws) -> Self {
Self(Arc::new(Mutex::new(ws)))
}
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> {
self.0.lock().await.write_frame(frame).await
}
}

View file

@ -1,25 +1,22 @@
mod lockedws; #![feature(let_chains)]
use std::io::Error;
use std::{io::Error, sync::Arc};
use bytes::Bytes; use bytes::Bytes;
use fastwebsockets::{ use fastwebsockets::{
upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError, upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
WebSocketError,
}; };
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt, TryFutureExt};
use hyper::{ use hyper::{
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request, body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
Response, StatusCode, Response, StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::{ use tokio::net::{TcpListener, TcpStream};
net::{TcpListener, TcpStream},
sync::mpsc,
};
use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
use wisp_mux::{ws, Packet, PacketType}; use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent};
type HttpBody = http_body_util::Empty<hyper::body::Bytes>; type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
@ -73,37 +70,26 @@ async fn accept_http(
addr: String, addr: String,
prefix: String, prefix: String,
) -> Result<Response<HttpBody>, WebSocketError> { ) -> Result<Response<HttpBody>, WebSocketError> {
if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) { if upgrade::is_upgrade_request(&req)
&& req.uri().path().to_string().starts_with(&prefix)
&& let Some(protocol) = req.headers().get("Sec-Websocket-Protocol")
&& protocol == "wisp-v1"
{
let uri = req.uri().clone(); let uri = req.uri().clone();
let (mut res, fut) = upgrade::upgrade(&mut req)?; let (mut res, fut) = upgrade::upgrade(&mut req)?;
tokio::spawn(async move {
if *uri.path() != prefix { if *uri.path() != prefix {
if let Err(e) = tokio::spawn(async move {
accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()) accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await
.await
{
println!("{:?}: error in ws handling: {:?}", addr, e);
}
} else if let Err(e) = accept_ws(fut, addr.clone()).await {
println!("{:?}: error in ws handling: {:?}", addr, e);
}
}); });
} else {
tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
}
if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") {
let first_protocol = protocol
.to_str()
.expect("failed to get protocol")
.split(',')
.next()
.expect("failed to get first protocol")
.trim();
res.headers_mut().insert( res.headers_mut().insert(
"Sec-Websocket-Protocol", "Sec-Websocket-Protocol",
HeaderValue::from_str(first_protocol).unwrap(), HeaderValue::from_str("wisp-v1").unwrap(),
); );
}
Ok(res) Ok(res)
} else { } else {
Ok(Response::builder() Ok(Response::builder()
@ -113,127 +99,80 @@ async fn accept_http(
} }
} }
enum WsEvent { async fn handle_mux(
Send(Bytes), packet: ConnectPacket,
Close, mut stream: MuxStream<impl ws::WebSocketWrite>,
) -> Result<(), WispError> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
);
match packet.stream_type {
StreamType::Tcp => {
let tcp_stream = TcpStream::connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
loop {
tokio::select! {
event = stream.read() => {
match event {
Some(event) => match event {
WsEvent::Send(data) => {
tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?;
}
WsEvent::Close(_) => break,
},
None => break
}
},
event = tcp_stream_framed.next() => {
match event.and_then(|x| x.ok()) {
Some(event) => stream.write(event.into()).await?,
None => break
}
}
}
}
}
StreamType::Udp => todo!(),
}
Ok(())
} }
async fn accept_ws( async fn accept_ws(
fut: upgrade::UpgradeFut, fut: upgrade::UpgradeFut,
addr: String, addr: String,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let (mut rx, tx) = fut.await?.split(tokio::io::split); let (rx, tx) = fut.await?.split(tokio::io::split);
let tx = lockedws::LockedWebSocketWrite::new(tx); let rx = FragmentCollectorRead::new(rx);
let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new());
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into()) let mut mux = ServerMux::new(rx, tx);
.await?;
while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await { mux.server_loop(&mut |packet, stream| async move {
if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) { let tx_cloned = stream.get_write_half();
use PacketType::*; let stream_id = stream.stream_id;
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::<WsEvent>();
stream_map.clone().insert(packet.stream_id, ch_tx);
let tx_cloned = tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let tcp_stream = match TcpStream::connect(format!( let _ = handle_mux(packet, stream)
"{}:{}", .or_else(|err| async {
inner_packet.destination_hostname, inner_packet.destination_port let _ = tx_cloned
)) .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03)))
.await .await;
{ Err(err)
Ok(stream) => stream, })
Err(err) => { .and_then(|_| async {
tx_cloned tx_cloned
.write_frame( .write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02)))
ws::Frame::from(Packet::new_close(packet.stream_id, 0x03))
.into(),
)
.await .await
.map_err(std::io::Error::other)?; })
return Err(Box::new(err)); .await;
}
};
let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new());
loop {
tokio::select! {
event = tcp_stream.next() => {
if let Some(res) = event {
match res {
Ok(buf) => {
tx_cloned.write_frame(
ws::Frame::from(
Packet::new_data(
packet.stream_id,
buf.to_vec()
)
).into()
).await.map_err(std::io::Error::other)?;
}
Err(err) => {
tx_cloned
.write_frame(
ws::Frame::from(Packet::new_close(
packet.stream_id,
0x03,
))
.into(),
)
.await
.map_err(std::io::Error::other)?;
return Err(Box::new(err));
}
}
}
}
event = ch_rx.recv() => {
if let Some(event) = event {
match event {
WsEvent::Send(buf) => {
tcp_stream.send(buf).await?;
tx_cloned
.write_frame(
ws::Frame::from(
Packet::new_continue(
packet.stream_id,
u32::MAX
)
).into()
).await.map_err(std::io::Error::other)?;
}
WsEvent::Close => {
break;
}
}
} else {
break;
}
}
};
}
Ok(())
}); });
} Ok(())
Data(inner_packet) => { })
println!("recieved data for {:?}", packet.stream_id); .await?;
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
let _ = stream.send(WsEvent::Send(inner_packet.into()));
}
}
Continue(_) => unreachable!(),
Close(_) => {
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
let _ = stream.send(WsEvent::Close);
}
}
}
}
}
println!("{:?}: disconnected", addr); println!("{:?}: disconnected", addr);
Ok(()) Ok(())
@ -243,7 +182,7 @@ async fn accept_wsproxy(
fut: upgrade::UpgradeFut, fut: upgrade::UpgradeFut,
incoming_uri: &str, incoming_uri: &str,
addr: String, addr: String,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
let mut ws_stream = FragmentCollector::new(fut.await?); let mut ws_stream = FragmentCollector::new(fut.await?);
println!("{:?}: connected (wsproxy)", addr); println!("{:?}: connected (wsproxy)", addr);

View file

@ -5,9 +5,11 @@ edition = "2021"
[dependencies] [dependencies]
bytes = "1.5.0" bytes = "1.5.0"
fastwebsockets = { version = "0.6.0", optional = true } dashmap = "5.5.3"
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
futures = "0.3.30" futures = "0.3.30"
futures-util = "0.3.30" futures-util = "0.3.30"
tokio = { version = "1.35.1", optional = true }
[features] [features]
fastwebsockets = ["dep:fastwebsockets"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]

View file

@ -1,30 +1,30 @@
use bytes::Bytes; use bytes::Bytes;
use fastwebsockets::{Payload, Frame, OpCode}; use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
};
use tokio::io::{AsyncRead, AsyncWrite};
impl TryFrom<OpCode> for crate::ws::OpCode { impl From<OpCode> for crate::ws::OpCode {
type Error = crate::WispError; fn from(opcode: OpCode) -> Self {
fn try_from(opcode: OpCode) -> Result<Self, Self::Error> {
use OpCode::*; use OpCode::*;
match opcode { match opcode {
Continuation => Err(Self::Error::WsImplNotSupported), Continuation => unreachable!(),
Text => Ok(Self::Text), Text => Self::Text,
Binary => Ok(Self::Binary), Binary => Self::Binary,
Close => Ok(Self::Close), Close => Self::Close,
Ping => Err(Self::Error::WsImplNotSupported), Ping => Self::Ping,
Pong => Err(Self::Error::WsImplNotSupported), Pong => Self::Pong,
} }
} }
} }
impl TryFrom<Frame<'_>> for crate::ws::Frame { impl From<Frame<'_>> for crate::ws::Frame {
type Error = crate::WispError; fn from(mut frame: Frame) -> Self {
fn try_from(mut frame: Frame) -> Result<Self, Self::Error> { Self {
let opcode = frame.opcode.try_into()?;
Ok(Self {
finished: frame.fin, finished: frame.fin,
opcode, opcode: frame.opcode.into(),
payload: Bytes::copy_from_slice(frame.payload.to_mut()), payload: Bytes::copy_from_slice(frame.payload.to_mut()),
}) }
} }
} }
@ -34,7 +34,38 @@ impl From<crate::ws::Frame> for Frame<'_> {
match frame.opcode { match frame.opcode {
Text => Self::text(Payload::Owned(frame.payload.to_vec())), Text => Self::text(Payload::Owned(frame.payload.to_vec())),
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())) Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
Ping => Self::new(
true,
OpCode::Ping,
None,
Payload::Owned(frame.payload.to_vec()),
),
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
} }
} }
} }
impl From<WebSocketError> for crate::WispError {
fn from(err: WebSocketError) -> Self {
Self::WsImplError(Box::new(err))
}
}
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
async fn wisp_read_frame(
&mut self,
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> Result<crate::ws::Frame, crate::WispError> {
Ok(self
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
.await?
.into())
}
}
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
self.write_frame(frame.into()).await.map_err(|e| e.into())
}
}

View file

@ -5,6 +5,11 @@ pub mod ws;
pub use crate::packet::*; pub use crate::packet::*;
use bytes::Bytes;
use dashmap::DashMap;
use futures::{channel::mpsc, StreamExt};
use std::sync::Arc;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Role { pub enum Role {
Client, Client,
@ -15,11 +20,13 @@ pub enum Role {
pub enum WispError { pub enum WispError {
PacketTooSmall, PacketTooSmall,
InvalidPacketType, InvalidPacketType,
InvalidStreamType,
WsFrameInvalidType, WsFrameInvalidType,
WsFrameNotFinished, WsFrameNotFinished,
WsImplError(Box<dyn std::error::Error>), WsImplError(Box<dyn std::error::Error + Sync + Send>),
WsImplNotSupported, WsImplNotSupported,
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
Other(Box<dyn std::error::Error + Sync + Send>),
} }
impl From<std::str::Utf8Error> for WispError { impl From<std::str::Utf8Error> for WispError {
@ -34,13 +41,112 @@ impl std::fmt::Display for WispError {
match self { match self {
PacketTooSmall => write!(f, "Packet too small"), PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"), InvalidPacketType => write!(f, "Invalid packet type"),
InvalidStreamType => write!(f, "Invalid stream type"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"), WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"), WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err), WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"), WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err), Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
Other(err) => write!(f, "Other error: {:?}", err),
} }
} }
} }
impl std::error::Error for WispError {} impl std::error::Error for WispError {}
pub enum WsEvent {
Send(Bytes),
Close(ClosePacket),
}
pub struct MuxStream<W>
where
W: ws::WebSocketWrite,
{
pub stream_id: u32,
rx: mpsc::UnboundedReceiver<WsEvent>,
tx: ws::LockedWebSocketWrite<W>,
}
impl<W: ws::WebSocketWrite> MuxStream<W> {
pub async fn read(&mut self) -> Option<WsEvent> {
self.rx.next().await
}
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
self.tx
.write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data)))
.await
}
pub fn get_write_half(&self) -> ws::LockedWebSocketWrite<W> {
self.tx.clone()
}
}
pub struct ServerMux<R, W>
where
R: ws::WebSocketRead,
W: ws::WebSocketWrite,
{
rx: R,
tx: ws::LockedWebSocketWrite<W>,
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
}
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
pub fn new(read: R, write: W) -> Self {
Self {
rx: read,
tx: ws::LockedWebSocketWrite::new(write),
stream_map: Arc::new(DashMap::new()),
}
}
pub async fn server_loop<FR>(
&mut self,
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
) -> Result<(), WispError>
where
FR: std::future::Future<Output = Result<(), crate::WispError>>,
{
self.tx
.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)))
.await?;
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded();
self.stream_map.clone().insert(packet.stream_id, ch_tx);
let _ = handler_fn(
inner_packet,
MuxStream {
stream_id: packet.stream_id,
rx: ch_rx,
tx: self.tx.clone(),
},
).await;
}
Data(data) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Send(data));
self.tx
.write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX)))
.await?;
}
}
Continue(_) => unreachable!(),
Close(inner_packet) => {
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
}
}
}
}
}
Ok(())
}
}

View file

@ -2,15 +2,33 @@ use crate::ws;
use crate::WispError; use crate::WispError;
use bytes::{Buf, BufMut, Bytes}; use bytes::{Buf, BufMut, Bytes};
#[derive(Debug)]
pub enum StreamType {
Tcp = 0x01,
Udp = 0x02,
}
impl TryFrom<u8> for StreamType {
type Error = WispError;
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
use StreamType::*;
match stream_type {
0x01 => Ok(Tcp),
0x02 => Ok(Udp),
_ => Err(Self::Error::InvalidStreamType),
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectPacket { pub struct ConnectPacket {
pub stream_type: u8, pub stream_type: StreamType,
pub destination_port: u16, pub destination_port: u16,
pub destination_hostname: String, pub destination_hostname: String,
} }
impl ConnectPacket { impl ConnectPacket {
pub fn new(stream_type: u8, destination_port: u16, destination_hostname: String) -> Self { pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self {
Self { Self {
stream_type, stream_type,
destination_port, destination_port,
@ -26,7 +44,7 @@ impl TryFrom<Bytes> for ConnectPacket {
return Err(Self::Error::PacketTooSmall); return Err(Self::Error::PacketTooSmall);
} }
Ok(Self { Ok(Self {
stream_type: bytes.get_u8(), stream_type: bytes.get_u8().try_into()?,
destination_port: bytes.get_u16_le(), destination_port: bytes.get_u16_le(),
destination_hostname: std::str::from_utf8(&bytes)?.to_string(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
}) })
@ -36,7 +54,7 @@ impl TryFrom<Bytes> for ConnectPacket {
impl From<ConnectPacket> for Vec<u8> { impl From<ConnectPacket> for Vec<u8> {
fn from(packet: ConnectPacket) -> Self { fn from(packet: ConnectPacket) -> Self {
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len()); let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
encoded.put_u8(packet.stream_type); encoded.put_u8(packet.stream_type as u8);
encoded.put_u16_le(packet.destination_port); encoded.put_u16_le(packet.destination_port);
encoded.extend(packet.destination_hostname.bytes()); encoded.extend(packet.destination_hostname.bytes());
encoded encoded
@ -108,7 +126,7 @@ impl From<ClosePacket> for Vec<u8> {
#[derive(Debug)] #[derive(Debug)]
pub enum PacketType { pub enum PacketType {
Connect(ConnectPacket), Connect(ConnectPacket),
Data(Vec<u8>), Data(Bytes),
Continue(ContinuePacket), Continue(ContinuePacket),
Close(ClosePacket), Close(ClosePacket),
} }
@ -130,7 +148,7 @@ impl From<PacketType> for Vec<u8> {
use PacketType::*; use PacketType::*;
match packet { match packet {
Connect(x) => x.into(), Connect(x) => x.into(),
Data(x) => x, Data(x) => x.to_vec(),
Continue(x) => x.into(), Continue(x) => x.into(),
Close(x) => x.into(), Close(x) => x.into(),
} }
@ -150,7 +168,7 @@ impl Packet {
pub fn new_connect( pub fn new_connect(
stream_id: u32, stream_id: u32,
stream_type: u8, stream_type: StreamType,
destination_port: u16, destination_port: u16,
destination_hostname: String, destination_hostname: String,
) -> Self { ) -> Self {
@ -164,7 +182,7 @@ impl Packet {
} }
} }
pub fn new_data(stream_id: u32, data: Vec<u8>) -> Self { pub fn new_data(stream_id: u32, data: Bytes) -> Self {
Self { Self {
stream_id, stream_id,
packet: PacketType::Data(data), packet: PacketType::Data(data),
@ -198,7 +216,7 @@ impl TryFrom<Bytes> for Packet {
stream_id: bytes.get_u32_le(), stream_id: bytes.get_u32_le(),
packet: match packet_type { packet: match packet_type {
0x01 => Connect(ConnectPacket::try_from(bytes)?), 0x01 => Connect(ConnectPacket::try_from(bytes)?),
0x02 => Data(bytes.to_vec()), 0x02 => Data(bytes),
0x03 => Continue(ContinuePacket::try_from(bytes)?), 0x03 => Continue(ContinuePacket::try_from(bytes)?),
0x04 => Close(ClosePacket::try_from(bytes)?), 0x04 => Close(ClosePacket::try_from(bytes)?),
_ => return Err(Self::Error::InvalidPacketType), _ => return Err(Self::Error::InvalidPacketType),

View file

@ -1,10 +1,14 @@
use bytes::Bytes; use bytes::Bytes;
use futures::lock::Mutex;
use std::sync::Arc;
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
pub enum OpCode { pub enum OpCode {
Text, Text,
Binary, Binary,
Close, Close,
Ping,
Pong,
} }
pub struct Frame { pub struct Frame {
@ -38,3 +42,37 @@ impl Frame {
} }
} }
} }
pub trait WebSocketRead {
fn wisp_read_frame(
&mut self,
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
}
pub trait WebSocketWrite {
fn wisp_write_frame(
&mut self,
frame: Frame,
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
}
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>)
where
S: WebSocketWrite;
impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
pub fn new(ws: S) -> Self {
Self(Arc::new(Mutex::new(ws)))
}
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
self.0.lock().await.wisp_write_frame(frame).await
}
}
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}