mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
move the wisp logic into wisp lib
This commit is contained in:
parent
379e07d643
commit
2a5684192a
8 changed files with 314 additions and 198 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -1530,9 +1530,11 @@ name = "wisp-mux"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"dashmap",
|
||||||
"fastwebsockets",
|
"fastwebsockets",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
use fastwebsockets::{WebSocketWrite, Frame, WebSocketError};
|
|
||||||
use hyper::upgrade::Upgraded;
|
|
||||||
use hyper_util::rt::TokioIo;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::{io::WriteHalf, sync::Mutex};
|
|
||||||
|
|
||||||
type Ws = WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct LockedWebSocketWrite(Arc<Mutex<Ws>>);
|
|
||||||
|
|
||||||
impl LockedWebSocketWrite {
|
|
||||||
pub fn new(ws: Ws) -> Self {
|
|
||||||
Self(Arc::new(Mutex::new(ws)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> {
|
|
||||||
self.0.lock().await.write_frame(frame).await
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +1,22 @@
|
||||||
mod lockedws;
|
#![feature(let_chains)]
|
||||||
|
use std::io::Error;
|
||||||
use std::{io::Error, sync::Arc};
|
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use fastwebsockets::{
|
use fastwebsockets::{
|
||||||
upgrade, CloseCode, FragmentCollector, Frame, OpCode, Payload, WebSocketError,
|
upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload,
|
||||||
|
WebSocketError,
|
||||||
};
|
};
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt, TryFutureExt};
|
||||||
use hyper::{
|
use hyper::{
|
||||||
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
|
body::Incoming, header::HeaderValue, server::conn::http1, service::service_fn, Request,
|
||||||
Response, StatusCode,
|
Response, StatusCode,
|
||||||
};
|
};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use tokio::{
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
net::{TcpListener, TcpStream},
|
|
||||||
sync::mpsc,
|
|
||||||
};
|
|
||||||
use tokio_native_tls::{native_tls, TlsAcceptor};
|
use tokio_native_tls::{native_tls, TlsAcceptor};
|
||||||
use tokio_util::codec::{BytesCodec, Framed};
|
use tokio_util::codec::{BytesCodec, Framed};
|
||||||
|
|
||||||
use wisp_mux::{ws, Packet, PacketType};
|
use wisp_mux::{ws, ConnectPacket, MuxStream, Packet, ServerMux, StreamType, WispError, WsEvent};
|
||||||
|
|
||||||
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
|
type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
|
||||||
|
|
||||||
|
@ -73,37 +70,26 @@ async fn accept_http(
|
||||||
addr: String,
|
addr: String,
|
||||||
prefix: String,
|
prefix: String,
|
||||||
) -> Result<Response<HttpBody>, WebSocketError> {
|
) -> Result<Response<HttpBody>, WebSocketError> {
|
||||||
if upgrade::is_upgrade_request(&req) && req.uri().path().to_string().starts_with(&prefix) {
|
if upgrade::is_upgrade_request(&req)
|
||||||
|
&& req.uri().path().to_string().starts_with(&prefix)
|
||||||
|
&& let Some(protocol) = req.headers().get("Sec-Websocket-Protocol")
|
||||||
|
&& protocol == "wisp-v1"
|
||||||
|
{
|
||||||
let uri = req.uri().clone();
|
let uri = req.uri().clone();
|
||||||
let (mut res, fut) = upgrade::upgrade(&mut req)?;
|
let (mut res, fut) = upgrade::upgrade(&mut req)?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
if *uri.path() != prefix {
|
||||||
if *uri.path() != prefix {
|
tokio::spawn(async move {
|
||||||
if let Err(e) =
|
accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone()).await
|
||||||
accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone())
|
});
|
||||||
.await
|
} else {
|
||||||
{
|
tokio::spawn(async move { accept_ws(fut, addr.clone()).await });
|
||||||
println!("{:?}: error in ws handling: {:?}", addr, e);
|
|
||||||
}
|
|
||||||
} else if let Err(e) = accept_ws(fut, addr.clone()).await {
|
|
||||||
println!("{:?}: error in ws handling: {:?}", addr, e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(protocol) = req.headers().get("Sec-Websocket-Protocol") {
|
|
||||||
let first_protocol = protocol
|
|
||||||
.to_str()
|
|
||||||
.expect("failed to get protocol")
|
|
||||||
.split(',')
|
|
||||||
.next()
|
|
||||||
.expect("failed to get first protocol")
|
|
||||||
.trim();
|
|
||||||
res.headers_mut().insert(
|
|
||||||
"Sec-Websocket-Protocol",
|
|
||||||
HeaderValue::from_str(first_protocol).unwrap(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res.headers_mut().insert(
|
||||||
|
"Sec-Websocket-Protocol",
|
||||||
|
HeaderValue::from_str("wisp-v1").unwrap(),
|
||||||
|
);
|
||||||
Ok(res)
|
Ok(res)
|
||||||
} else {
|
} else {
|
||||||
Ok(Response::builder()
|
Ok(Response::builder()
|
||||||
|
@ -113,127 +99,80 @@ async fn accept_http(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum WsEvent {
|
async fn handle_mux(
|
||||||
Send(Bytes),
|
packet: ConnectPacket,
|
||||||
Close,
|
mut stream: MuxStream<impl ws::WebSocketWrite>,
|
||||||
|
) -> Result<(), WispError> {
|
||||||
|
let uri = format!(
|
||||||
|
"{}:{}",
|
||||||
|
packet.destination_hostname, packet.destination_port
|
||||||
|
);
|
||||||
|
match packet.stream_type {
|
||||||
|
StreamType::Tcp => {
|
||||||
|
let tcp_stream = TcpStream::connect(uri)
|
||||||
|
.await
|
||||||
|
.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||||
|
let mut tcp_stream_framed = Framed::new(tcp_stream, BytesCodec::new());
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
event = stream.read() => {
|
||||||
|
match event {
|
||||||
|
Some(event) => match event {
|
||||||
|
WsEvent::Send(data) => {
|
||||||
|
tcp_stream_framed.send(data).await.map_err(|x| WispError::Other(Box::new(x)))?;
|
||||||
|
}
|
||||||
|
WsEvent::Close(_) => break,
|
||||||
|
},
|
||||||
|
None => break
|
||||||
|
}
|
||||||
|
},
|
||||||
|
event = tcp_stream_framed.next() => {
|
||||||
|
match event.and_then(|x| x.ok()) {
|
||||||
|
Some(event) => stream.write(event.into()).await?,
|
||||||
|
None => break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StreamType::Udp => todo!(),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn accept_ws(
|
async fn accept_ws(
|
||||||
fut: upgrade::UpgradeFut,
|
fut: upgrade::UpgradeFut,
|
||||||
addr: String,
|
addr: String,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||||
let (mut rx, tx) = fut.await?.split(tokio::io::split);
|
let (rx, tx) = fut.await?.split(tokio::io::split);
|
||||||
let tx = lockedws::LockedWebSocketWrite::new(tx);
|
let rx = FragmentCollectorRead::new(rx);
|
||||||
|
|
||||||
let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new());
|
|
||||||
|
|
||||||
println!("{:?}: connected", addr);
|
println!("{:?}: connected", addr);
|
||||||
|
|
||||||
tx.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into())
|
let mut mux = ServerMux::new(rx, tx);
|
||||||
.await?;
|
|
||||||
|
|
||||||
while let Ok(frame) = rx.read_frame(&mut |x| tx.write_frame(x)).await {
|
mux.server_loop(&mut |packet, stream| async move {
|
||||||
if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) {
|
let tx_cloned = stream.get_write_half();
|
||||||
use PacketType::*;
|
let stream_id = stream.stream_id;
|
||||||
match packet.packet {
|
tokio::spawn(async move {
|
||||||
Connect(inner_packet) => {
|
let _ = handle_mux(packet, stream)
|
||||||
let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::<WsEvent>();
|
.or_else(|err| async {
|
||||||
stream_map.clone().insert(packet.stream_id, ch_tx);
|
let _ = tx_cloned
|
||||||
let tx_cloned = tx.clone();
|
.write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x03)))
|
||||||
tokio::spawn(async move {
|
.await;
|
||||||
let tcp_stream = match TcpStream::connect(format!(
|
Err(err)
|
||||||
"{}:{}",
|
})
|
||||||
inner_packet.destination_hostname, inner_packet.destination_port
|
.and_then(|_| async {
|
||||||
))
|
tx_cloned
|
||||||
|
.write_frame(ws::Frame::from(Packet::new_close(stream_id, 0x02)))
|
||||||
.await
|
.await
|
||||||
{
|
})
|
||||||
Ok(stream) => stream,
|
.await;
|
||||||
Err(err) => {
|
});
|
||||||
tx_cloned
|
Ok(())
|
||||||
.write_frame(
|
})
|
||||||
ws::Frame::from(Packet::new_close(packet.stream_id, 0x03))
|
.await?;
|
||||||
.into(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(std::io::Error::other)?;
|
|
||||||
return Err(Box::new(err));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new());
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
event = tcp_stream.next() => {
|
|
||||||
if let Some(res) = event {
|
|
||||||
match res {
|
|
||||||
Ok(buf) => {
|
|
||||||
tx_cloned.write_frame(
|
|
||||||
ws::Frame::from(
|
|
||||||
Packet::new_data(
|
|
||||||
packet.stream_id,
|
|
||||||
buf.to_vec()
|
|
||||||
)
|
|
||||||
).into()
|
|
||||||
).await.map_err(std::io::Error::other)?;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
tx_cloned
|
|
||||||
.write_frame(
|
|
||||||
ws::Frame::from(Packet::new_close(
|
|
||||||
packet.stream_id,
|
|
||||||
0x03,
|
|
||||||
))
|
|
||||||
.into(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(std::io::Error::other)?;
|
|
||||||
return Err(Box::new(err));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
event = ch_rx.recv() => {
|
|
||||||
if let Some(event) = event {
|
|
||||||
match event {
|
|
||||||
WsEvent::Send(buf) => {
|
|
||||||
tcp_stream.send(buf).await?;
|
|
||||||
tx_cloned
|
|
||||||
.write_frame(
|
|
||||||
ws::Frame::from(
|
|
||||||
Packet::new_continue(
|
|
||||||
packet.stream_id,
|
|
||||||
u32::MAX
|
|
||||||
)
|
|
||||||
).into()
|
|
||||||
).await.map_err(std::io::Error::other)?;
|
|
||||||
}
|
|
||||||
WsEvent::Close => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Data(inner_packet) => {
|
|
||||||
println!("recieved data for {:?}", packet.stream_id);
|
|
||||||
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
|
|
||||||
let _ = stream.send(WsEvent::Send(inner_packet.into()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Continue(_) => unreachable!(),
|
|
||||||
Close(_) => {
|
|
||||||
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
|
|
||||||
let _ = stream.send(WsEvent::Close);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("{:?}: disconnected", addr);
|
println!("{:?}: disconnected", addr);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -243,7 +182,7 @@ async fn accept_wsproxy(
|
||||||
fut: upgrade::UpgradeFut,
|
fut: upgrade::UpgradeFut,
|
||||||
incoming_uri: &str,
|
incoming_uri: &str,
|
||||||
addr: String,
|
addr: String,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||||
let mut ws_stream = FragmentCollector::new(fut.await?);
|
let mut ws_stream = FragmentCollector::new(fut.await?);
|
||||||
|
|
||||||
println!("{:?}: connected (wsproxy)", addr);
|
println!("{:?}: connected (wsproxy)", addr);
|
||||||
|
|
|
@ -5,9 +5,11 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytes = "1.5.0"
|
bytes = "1.5.0"
|
||||||
fastwebsockets = { version = "0.6.0", optional = true }
|
dashmap = "5.5.3"
|
||||||
|
fastwebsockets = { version = "0.6.0", features = ["unstable-split"], optional = true }
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
|
tokio = { version = "1.35.1", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
fastwebsockets = ["dep:fastwebsockets"]
|
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
||||||
|
|
|
@ -1,30 +1,30 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use fastwebsockets::{Payload, Frame, OpCode};
|
use fastwebsockets::{
|
||||||
|
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite,
|
||||||
|
};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
impl TryFrom<OpCode> for crate::ws::OpCode {
|
impl From<OpCode> for crate::ws::OpCode {
|
||||||
type Error = crate::WispError;
|
fn from(opcode: OpCode) -> Self {
|
||||||
fn try_from(opcode: OpCode) -> Result<Self, Self::Error> {
|
|
||||||
use OpCode::*;
|
use OpCode::*;
|
||||||
match opcode {
|
match opcode {
|
||||||
Continuation => Err(Self::Error::WsImplNotSupported),
|
Continuation => unreachable!(),
|
||||||
Text => Ok(Self::Text),
|
Text => Self::Text,
|
||||||
Binary => Ok(Self::Binary),
|
Binary => Self::Binary,
|
||||||
Close => Ok(Self::Close),
|
Close => Self::Close,
|
||||||
Ping => Err(Self::Error::WsImplNotSupported),
|
Ping => Self::Ping,
|
||||||
Pong => Err(Self::Error::WsImplNotSupported),
|
Pong => Self::Pong,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<Frame<'_>> for crate::ws::Frame {
|
impl From<Frame<'_>> for crate::ws::Frame {
|
||||||
type Error = crate::WispError;
|
fn from(mut frame: Frame) -> Self {
|
||||||
fn try_from(mut frame: Frame) -> Result<Self, Self::Error> {
|
Self {
|
||||||
let opcode = frame.opcode.try_into()?;
|
|
||||||
Ok(Self {
|
|
||||||
finished: frame.fin,
|
finished: frame.fin,
|
||||||
opcode,
|
opcode: frame.opcode.into(),
|
||||||
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
|
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,38 @@ impl From<crate::ws::Frame> for Frame<'_> {
|
||||||
match frame.opcode {
|
match frame.opcode {
|
||||||
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
|
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
|
||||||
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
|
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
|
||||||
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec()))
|
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())),
|
||||||
|
Ping => Self::new(
|
||||||
|
true,
|
||||||
|
OpCode::Ping,
|
||||||
|
None,
|
||||||
|
Payload::Owned(frame.payload.to_vec()),
|
||||||
|
),
|
||||||
|
Pong => Self::pong(Payload::Owned(frame.payload.to_vec())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<WebSocketError> for crate::WispError {
|
||||||
|
fn from(err: WebSocketError) -> Self {
|
||||||
|
Self::WsImplError(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + Unpin> crate::ws::WebSocketRead for FragmentCollectorRead<S> {
|
||||||
|
async fn wisp_read_frame(
|
||||||
|
&mut self,
|
||||||
|
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
||||||
|
) -> Result<crate::ws::Frame, crate::WispError> {
|
||||||
|
Ok(self
|
||||||
|
.read_frame(&mut |frame| async { tx.write_frame(frame.into()).await })
|
||||||
|
.await?
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncWrite + Unpin> crate::ws::WebSocketWrite for WebSocketWrite<S> {
|
||||||
|
async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> {
|
||||||
|
self.write_frame(frame.into()).await.map_err(|e| e.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
108
wisp/src/lib.rs
108
wisp/src/lib.rs
|
@ -5,6 +5,11 @@ pub mod ws;
|
||||||
|
|
||||||
pub use crate::packet::*;
|
pub use crate::packet::*;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use futures::{channel::mpsc, StreamExt};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
Client,
|
Client,
|
||||||
|
@ -15,11 +20,13 @@ pub enum Role {
|
||||||
pub enum WispError {
|
pub enum WispError {
|
||||||
PacketTooSmall,
|
PacketTooSmall,
|
||||||
InvalidPacketType,
|
InvalidPacketType,
|
||||||
|
InvalidStreamType,
|
||||||
WsFrameInvalidType,
|
WsFrameInvalidType,
|
||||||
WsFrameNotFinished,
|
WsFrameNotFinished,
|
||||||
WsImplError(Box<dyn std::error::Error>),
|
WsImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||||
WsImplNotSupported,
|
WsImplNotSupported,
|
||||||
Utf8Error(std::str::Utf8Error),
|
Utf8Error(std::str::Utf8Error),
|
||||||
|
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<std::str::Utf8Error> for WispError {
|
impl From<std::str::Utf8Error> for WispError {
|
||||||
|
@ -34,13 +41,112 @@ impl std::fmt::Display for WispError {
|
||||||
match self {
|
match self {
|
||||||
PacketTooSmall => write!(f, "Packet too small"),
|
PacketTooSmall => write!(f, "Packet too small"),
|
||||||
InvalidPacketType => write!(f, "Invalid packet type"),
|
InvalidPacketType => write!(f, "Invalid packet type"),
|
||||||
|
InvalidStreamType => write!(f, "Invalid stream type"),
|
||||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||||
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
|
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
|
||||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||||
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
|
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
|
||||||
|
Other(err) => write!(f, "Other error: {:?}", err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for WispError {}
|
impl std::error::Error for WispError {}
|
||||||
|
|
||||||
|
pub enum WsEvent {
|
||||||
|
Send(Bytes),
|
||||||
|
Close(ClosePacket),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MuxStream<W>
|
||||||
|
where
|
||||||
|
W: ws::WebSocketWrite,
|
||||||
|
{
|
||||||
|
pub stream_id: u32,
|
||||||
|
rx: mpsc::UnboundedReceiver<WsEvent>,
|
||||||
|
tx: ws::LockedWebSocketWrite<W>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<W: ws::WebSocketWrite> MuxStream<W> {
|
||||||
|
pub async fn read(&mut self) -> Option<WsEvent> {
|
||||||
|
self.rx.next().await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write(&mut self, data: Bytes) -> Result<(), WispError> {
|
||||||
|
self.tx
|
||||||
|
.write_frame(ws::Frame::from(Packet::new_data(self.stream_id, data)))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_write_half(&self) -> ws::LockedWebSocketWrite<W> {
|
||||||
|
self.tx.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ServerMux<R, W>
|
||||||
|
where
|
||||||
|
R: ws::WebSocketRead,
|
||||||
|
W: ws::WebSocketWrite,
|
||||||
|
{
|
||||||
|
rx: R,
|
||||||
|
tx: ws::LockedWebSocketWrite<W>,
|
||||||
|
stream_map: Arc<DashMap<u32, mpsc::UnboundedSender<WsEvent>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: ws::WebSocketRead, W: ws::WebSocketWrite> ServerMux<R, W> {
|
||||||
|
pub fn new(read: R, write: W) -> Self {
|
||||||
|
Self {
|
||||||
|
rx: read,
|
||||||
|
tx: ws::LockedWebSocketWrite::new(write),
|
||||||
|
stream_map: Arc::new(DashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn server_loop<FR>(
|
||||||
|
&mut self,
|
||||||
|
handler_fn: &mut impl Fn(ConnectPacket, MuxStream<W>) -> FR,
|
||||||
|
) -> Result<(), WispError>
|
||||||
|
where
|
||||||
|
FR: std::future::Future<Output = Result<(), crate::WispError>>,
|
||||||
|
{
|
||||||
|
self.tx
|
||||||
|
.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
while let Ok(frame) = self.rx.wisp_read_frame(&mut self.tx).await {
|
||||||
|
if let Ok(packet) = Packet::try_from(frame) {
|
||||||
|
use PacketType::*;
|
||||||
|
match packet.packet {
|
||||||
|
Connect(inner_packet) => {
|
||||||
|
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||||
|
self.stream_map.clone().insert(packet.stream_id, ch_tx);
|
||||||
|
let _ = handler_fn(
|
||||||
|
inner_packet,
|
||||||
|
MuxStream {
|
||||||
|
stream_id: packet.stream_id,
|
||||||
|
rx: ch_rx,
|
||||||
|
tx: self.tx.clone(),
|
||||||
|
},
|
||||||
|
).await;
|
||||||
|
}
|
||||||
|
Data(data) => {
|
||||||
|
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
||||||
|
let _ = stream.unbounded_send(WsEvent::Send(data));
|
||||||
|
self.tx
|
||||||
|
.write_frame(ws::Frame::from(Packet::new_continue(packet.stream_id, u32::MAX)))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Continue(_) => unreachable!(),
|
||||||
|
Close(inner_packet) => {
|
||||||
|
if let Some(stream) = self.stream_map.clone().get(&packet.stream_id) {
|
||||||
|
let _ = stream.unbounded_send(WsEvent::Close(inner_packet));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,15 +2,33 @@ use crate::ws;
|
||||||
use crate::WispError;
|
use crate::WispError;
|
||||||
use bytes::{Buf, BufMut, Bytes};
|
use bytes::{Buf, BufMut, Bytes};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StreamType {
|
||||||
|
Tcp = 0x01,
|
||||||
|
Udp = 0x02,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<u8> for StreamType {
|
||||||
|
type Error = WispError;
|
||||||
|
fn try_from(stream_type: u8) -> Result<Self, Self::Error> {
|
||||||
|
use StreamType::*;
|
||||||
|
match stream_type {
|
||||||
|
0x01 => Ok(Tcp),
|
||||||
|
0x02 => Ok(Udp),
|
||||||
|
_ => Err(Self::Error::InvalidStreamType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ConnectPacket {
|
pub struct ConnectPacket {
|
||||||
pub stream_type: u8,
|
pub stream_type: StreamType,
|
||||||
pub destination_port: u16,
|
pub destination_port: u16,
|
||||||
pub destination_hostname: String,
|
pub destination_hostname: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectPacket {
|
impl ConnectPacket {
|
||||||
pub fn new(stream_type: u8, destination_port: u16, destination_hostname: String) -> Self {
|
pub fn new(stream_type: StreamType, destination_port: u16, destination_hostname: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream_type,
|
stream_type,
|
||||||
destination_port,
|
destination_port,
|
||||||
|
@ -26,7 +44,7 @@ impl TryFrom<Bytes> for ConnectPacket {
|
||||||
return Err(Self::Error::PacketTooSmall);
|
return Err(Self::Error::PacketTooSmall);
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
stream_type: bytes.get_u8(),
|
stream_type: bytes.get_u8().try_into()?,
|
||||||
destination_port: bytes.get_u16_le(),
|
destination_port: bytes.get_u16_le(),
|
||||||
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
destination_hostname: std::str::from_utf8(&bytes)?.to_string(),
|
||||||
})
|
})
|
||||||
|
@ -36,7 +54,7 @@ impl TryFrom<Bytes> for ConnectPacket {
|
||||||
impl From<ConnectPacket> for Vec<u8> {
|
impl From<ConnectPacket> for Vec<u8> {
|
||||||
fn from(packet: ConnectPacket) -> Self {
|
fn from(packet: ConnectPacket) -> Self {
|
||||||
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
|
let mut encoded = Self::with_capacity(1 + 2 + packet.destination_hostname.len());
|
||||||
encoded.put_u8(packet.stream_type);
|
encoded.put_u8(packet.stream_type as u8);
|
||||||
encoded.put_u16_le(packet.destination_port);
|
encoded.put_u16_le(packet.destination_port);
|
||||||
encoded.extend(packet.destination_hostname.bytes());
|
encoded.extend(packet.destination_hostname.bytes());
|
||||||
encoded
|
encoded
|
||||||
|
@ -108,7 +126,7 @@ impl From<ClosePacket> for Vec<u8> {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum PacketType {
|
pub enum PacketType {
|
||||||
Connect(ConnectPacket),
|
Connect(ConnectPacket),
|
||||||
Data(Vec<u8>),
|
Data(Bytes),
|
||||||
Continue(ContinuePacket),
|
Continue(ContinuePacket),
|
||||||
Close(ClosePacket),
|
Close(ClosePacket),
|
||||||
}
|
}
|
||||||
|
@ -130,7 +148,7 @@ impl From<PacketType> for Vec<u8> {
|
||||||
use PacketType::*;
|
use PacketType::*;
|
||||||
match packet {
|
match packet {
|
||||||
Connect(x) => x.into(),
|
Connect(x) => x.into(),
|
||||||
Data(x) => x,
|
Data(x) => x.to_vec(),
|
||||||
Continue(x) => x.into(),
|
Continue(x) => x.into(),
|
||||||
Close(x) => x.into(),
|
Close(x) => x.into(),
|
||||||
}
|
}
|
||||||
|
@ -150,7 +168,7 @@ impl Packet {
|
||||||
|
|
||||||
pub fn new_connect(
|
pub fn new_connect(
|
||||||
stream_id: u32,
|
stream_id: u32,
|
||||||
stream_type: u8,
|
stream_type: StreamType,
|
||||||
destination_port: u16,
|
destination_port: u16,
|
||||||
destination_hostname: String,
|
destination_hostname: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -164,7 +182,7 @@ impl Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_data(stream_id: u32, data: Vec<u8>) -> Self {
|
pub fn new_data(stream_id: u32, data: Bytes) -> Self {
|
||||||
Self {
|
Self {
|
||||||
stream_id,
|
stream_id,
|
||||||
packet: PacketType::Data(data),
|
packet: PacketType::Data(data),
|
||||||
|
@ -198,7 +216,7 @@ impl TryFrom<Bytes> for Packet {
|
||||||
stream_id: bytes.get_u32_le(),
|
stream_id: bytes.get_u32_le(),
|
||||||
packet: match packet_type {
|
packet: match packet_type {
|
||||||
0x01 => Connect(ConnectPacket::try_from(bytes)?),
|
0x01 => Connect(ConnectPacket::try_from(bytes)?),
|
||||||
0x02 => Data(bytes.to_vec()),
|
0x02 => Data(bytes),
|
||||||
0x03 => Continue(ContinuePacket::try_from(bytes)?),
|
0x03 => Continue(ContinuePacket::try_from(bytes)?),
|
||||||
0x04 => Close(ClosePacket::try_from(bytes)?),
|
0x04 => Close(ClosePacket::try_from(bytes)?),
|
||||||
_ => return Err(Self::Error::InvalidPacketType),
|
_ => return Err(Self::Error::InvalidPacketType),
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use futures::lock::Mutex;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||||
pub enum OpCode {
|
pub enum OpCode {
|
||||||
Text,
|
Text,
|
||||||
Binary,
|
Binary,
|
||||||
Close,
|
Close,
|
||||||
|
Ping,
|
||||||
|
Pong,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Frame {
|
pub struct Frame {
|
||||||
|
@ -38,3 +42,37 @@ impl Frame {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait WebSocketRead {
|
||||||
|
fn wisp_read_frame(
|
||||||
|
&mut self,
|
||||||
|
tx: &mut crate::ws::LockedWebSocketWrite<impl crate::ws::WebSocketWrite>,
|
||||||
|
) -> impl std::future::Future<Output = Result<Frame, crate::WispError>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait WebSocketWrite {
|
||||||
|
fn wisp_write_frame(
|
||||||
|
&mut self,
|
||||||
|
frame: Frame,
|
||||||
|
) -> impl std::future::Future<Output = Result<(), crate::WispError>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LockedWebSocketWrite<S>(Arc<Mutex<S>>)
|
||||||
|
where
|
||||||
|
S: WebSocketWrite;
|
||||||
|
|
||||||
|
impl<S: WebSocketWrite> LockedWebSocketWrite<S> {
|
||||||
|
pub fn new(ws: S) -> Self {
|
||||||
|
Self(Arc::new(Mutex::new(ws)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> {
|
||||||
|
self.0.lock().await.wisp_write_frame(frame).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: WebSocketWrite> Clone for LockedWebSocketWrite<S> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self(self.0.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue