mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
move the wisp logic into wisp lib
This commit is contained in:
parent
379e07d643
commit
2a5684192a
8 changed files with 314 additions and 198 deletions
|
@ -1,30 +1,30 @@
|
|||
use bytes::Bytes;
|
||||
use fastwebsockets::{Payload, Frame, OpCode};
|
||||
use fastwebsockets::{
|
||||
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
impl TryFrom<OpCode> for crate::ws::OpCode {
|
||||
type Error = crate::WispError;
|
||||
fn try_from(opcode: OpCode) -> Result<Self, Self::Error> {
|
||||
impl From<OpCode> for crate::ws::OpCode {
|
||||
fn from(opcode: OpCode) -> Self {
|
||||
use OpCode::*;
|
||||
match opcode {
|
||||
Continuation => Err(Self::Error::WsImplNotSupported),
|
||||
Text => Ok(Self::Text),
|
||||
Binary => Ok(Self::Binary),
|
||||
Close => Ok(Self::Close),
|
||||
Ping => Err(Self::Error::WsImplNotSupported),
|
||||
Pong => Err(Self::Error::WsImplNotSupported),
|
||||
Continuation => unreachable!(),
|
||||
Text => Self::Text,
|
||||
Binary => Self::Binary,
|
||||
Close => Self::Close,
|
||||
Ping => Self::Ping,
|
||||
Pong => Self::Pong,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Frame<'_>> for crate::ws::Frame {
|
||||
type Error = crate::WispError;
|
||||
fn try_from(mut frame: Frame) -> Result<Self, Self::Error> {
|
||||
let opcode = frame.opcode.try_into()?;
|
||||
Ok(Self {
|
||||
impl From<Frame<'_>> for crate::ws::Frame {
|
||||
fn from(mut frame: Frame) -> Self {
|
||||
Self {
|
||||
finished: frame.fin,
|
||||
opcode,
|
||||
opcode: frame.opcode.into(),
|
||||
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,38 @@ impl From<crate::ws::Frame> for Frame<'_> {
|
|||
match frame.opcode {
|
||||
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
|
||||
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
|
||||
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec()))
|
||||
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
|
||||
Ping => Self::new(
|
||||
true,
|
||||
OpCode::Ping,
|
||||
None,
|
||||
Payload::Owned(frame.payload.to_vec()),
|
||||
),
|
||||
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WebSocketError> for crate::WispError {
|
||||
fn from(err: WebSocketError) -> Self {
|
||||
Self::WsImplError(Box::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
||||
async fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
||||
) -> Result<crate::ws::Frame, crate::WispError> {
|
||||
Ok(self
|
||||
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||
.await?
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
|
||||
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
|
108
wisp/src/lib.rs
108
wisp/src/lib.rs
|
@ -5,6 +5,11 @@ pub mod ws;
|
|||
|
||||
pub use crate::packet::*;
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use futures::{channel::mpsc, StreamExt};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Role {
|
||||
Client,
|
||||
|
@ -15,11 +20,13 @@ pub enum Role {
|
|||
pub enum WispError {
|
||||
PacketTooSmall,
|
||||
InvalidPacketType,
|
||||
InvalidStreamType,
|
||||
WsFrameInvalidType,
|
||||
WsFrameNotFinished,
|
||||
WsImplError(Box<dyn std::error::Error>),
|
||||
WsImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||
WsImplNotSupported,
|
||||
Utf8Error(std::str::Utf8Error),
|
||||
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for WispError {
|
||||
|
@ -34,13 +41,112 @@ impl std::fmt::Display for WispError {
|
|||
match self {
|
||||
PacketTooSmall => write!(f, "Packet too small"),
|
||||
InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
InvalidStreamType => write!(f, "Invalid stream type"),
|
||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
|
||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
|
||||
Other(err) => write!(f, "Other error: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WispError {}
|
||||
|
||||
pub enum WsEvent {
|
||||
Send(Bytes),
|
||||
Close(ClosePacket),
|
||||
}
|
||||
|
||||
pub struct MuxStream<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
pub stream_id: u32,
|
||||
rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> MuxStream<W> {
|
||||
pub async fn read(&mut self) -> Option<WsEvent> {
|
||||
self.rx.next().await
|
||||
}
|
||||
|
||||
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
|
||||
self.tx
|
||||
.write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data)))
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn get_write_half(&self) -> ws::LockedWebSocketWrite<W> {
|
||||
self.tx.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerMux<R, W>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
rx: R,
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
|
||||
}
|
||||
|
||||
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
|
||||
pub fn new(read: R, write: W) -> Self {
|
||||
Self {
|
||||
rx: read,
|
||||
tx: ws::LockedWebSocketWrite::new(write),
|
||||
stream_map: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn server_loop<FR>(
|
||||
&mut self,
|
||||
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
FR: std::future::Future<Output = Result<(), crate::WispError>>,
|
||||
{
|
||||
self.tx
|
||||
.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)))
|
||||
.await?;
|
||||
|
||||
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
|
||||
if let Ok(packet) = Packet::try_from(frame) {
|
||||
use PacketType::*;
|
||||
match packet.packet {
|
||||
Connect(inner_packet) => {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
self.stream_map.clone().insert(packet.stream_id, ch_tx);
|
||||
let _ = handler_fn(
|
||||
inner_packet,
|
||||
MuxStream {
|
||||
stream_id: packet.stream_id,
|
||||
rx: ch_rx,
|
||||
tx: self.tx.clone(),
|
||||
},
|
||||
).await;
|
||||
}
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
||||
let _ = stream.unbounded_send(WsEvent::Send(data));
|
||||
self.tx
|
||||
.write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX)))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Continue(_) => unreachable!(),
|
||||
Close(inner_packet) => {
|
||||
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
||||
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,15 +2,33 @@ use crate::ws;
|
|||
use crate::WispError;
|
||||
use bytes::{Buf, BufMut, Bytes};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StreamType {
|
||||
Tcp = 0x01,
|
||||
Udp = 0x02,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for StreamType {
|
||||
type Error = WispError;
|
||||
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
|
||||
use StreamType::*;
|
||||
match stream_type {
|
||||
0x01 => Ok(Tcp),
|
||||
0x02 => Ok(Udp),
|
||||
_ => Err(Self::Error::InvalidStreamType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectPacket {
|
||||
pub stream_type: u8,
|
||||
pub stream_type: StreamType,
|
||||
pub destination_port: u16,
|
||||
pub destination_hostname: String,
|
||||
}
|
||||
|
||||
impl ConnectPacket {
|
||||
pub fn new(stream_type: u8, destination_port: u16, destination_hostname: String) -> Self {
|
||||
pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self {
|
||||
Self {
|
||||
stream_type,
|
||||
destination_port,
|
||||
|
@ -26,7 +44,7 @@ impl TryFrom<Bytes> for ConnectPacket {
|
|||
return Err(Self::Error::PacketTooSmall);
|
||||
}
|
||||
Ok(Self {
|
||||
stream_type: bytes.get_u8(),
|
||||
stream_type: bytes.get_u8().try_into()?,
|
||||
destination_port: bytes.get_u16_le(),
|
||||
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
||||
})
|
||||
|
@ -36,7 +54,7 @@ impl TryFrom<Bytes> for ConnectPacket {
|
|||
impl From<ConnectPacket> for Vec<u8> {
|
||||
fn from(packet: ConnectPacket) -> Self {
|
||||
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||
encoded.put_u8(packet.stream_type);
|
||||
encoded.put_u8(packet.stream_type as u8);
|
||||
encoded.put_u16_le(packet.destination_port);
|
||||
encoded.extend(packet.destination_hostname.bytes());
|
||||
encoded
|
||||
|
@ -108,7 +126,7 @@ impl From<ClosePacket> for Vec<u8> {
|
|||
#[derive(Debug)]
|
||||
pub enum PacketType {
|
||||
Connect(ConnectPacket),
|
||||
Data(Vec<u8>),
|
||||
Data(Bytes),
|
||||
Continue(ContinuePacket),
|
||||
Close(ClosePacket),
|
||||
}
|
||||
|
@ -130,7 +148,7 @@ impl From<PacketType> for Vec<u8> {
|
|||
use PacketType::*;
|
||||
match packet {
|
||||
Connect(x) => x.into(),
|
||||
Data(x) => x,
|
||||
Data(x) => x.to_vec(),
|
||||
Continue(x) => x.into(),
|
||||
Close(x) => x.into(),
|
||||
}
|
||||
|
@ -150,7 +168,7 @@ impl Packet {
|
|||
|
||||
pub fn new_connect(
|
||||
stream_id: u32,
|
||||
stream_type: u8,
|
||||
stream_type: StreamType,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
) -> Self {
|
||||
|
@ -164,7 +182,7 @@ impl Packet {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn new_data(stream_id: u32, data: Vec<u8>) -> Self {
|
||||
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
|
||||
Self {
|
||||
stream_id,
|
||||
packet: PacketType::Data(data),
|
||||
|
@ -198,7 +216,7 @@ impl TryFrom<Bytes> for Packet {
|
|||
stream_id: bytes.get_u32_le(),
|
||||
packet: match packet_type {
|
||||
0x01 => Connect(ConnectPacket::try_from(bytes)?),
|
||||
0x02 => Data(bytes.to_vec()),
|
||||
0x02 => Data(bytes),
|
||||
0x03 => Continue(ContinuePacket::try_from(bytes)?),
|
||||
0x04 => Close(ClosePacket::try_from(bytes)?),
|
||||
_ => return Err(Self::Error::InvalidPacketType),
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
use bytes::Bytes;
|
||||
use futures::lock::Mutex;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum OpCode {
|
||||
Text,
|
||||
Binary,
|
||||
Close,
|
||||
Ping,
|
||||
Pong,
|
||||
}
|
||||
|
||||
pub struct Frame {
|
||||
|
@ -38,3 +42,37 @@ impl Frame {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WebSocketRead {
|
||||
fn wisp_read_frame(
|
||||
&mut self,
|
||||
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
||||
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
|
||||
}
|
||||
|
||||
pub trait WebSocketWrite {
|
||||
fn wisp_write_frame(
|
||||
&mut self,
|
||||
frame: Frame,
|
||||
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
|
||||
}
|
||||
|
||||
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>)
|
||||
where
|
||||
S: WebSocketWrite;
|
||||
|
||||
impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
|
||||
pub fn new(ws: S) -> Self {
|
||||
Self(Arc::new(Mutex::new(ws)))
|
||||
}
|
||||
|
||||
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
|
||||
self.0.lock().await.wisp_write_frame(frame).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue