mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
serverside done except it deadlocks
This commit is contained in:
parent
1f23c26db6
commit
24d145cc66
8 changed files with 176 additions and 24 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1530,6 +1530,7 @@ name = "wisp-mux"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fastwebsockets",
|
||||
"futures",
|
||||
"futures-util",
|
||||
]
|
||||
|
|
|
@ -14,4 +14,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] }
|
|||
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
|
||||
tokio-native-tls = "0.3.1"
|
||||
tokio-util = { version = "0.7.10", features = ["codec"] }
|
||||
wisp-mux = { path = "../wisp" }
|
||||
wisp-mux = { path = "../wisp", features = ["fastwebsockets"] }
|
||||
|
|
|
@ -6,6 +6,7 @@ use tokio::sync::Mutex;
|
|||
|
||||
type Ws = FragmentCollector<TokioIo<Upgraded>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LockedWebSocket(Arc<Mutex<Ws>>);
|
||||
|
||||
impl LockedWebSocket {
|
||||
|
|
|
@ -12,8 +12,10 @@ use hyper::{
|
|||
Response, StatusCode,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::mpsc,
|
||||
};
|
||||
use tokio_native_tls::{native_tls, TlsAcceptor};
|
||||
use tokio_util::codec::{BytesCodec, Framed};
|
||||
|
||||
|
@ -120,35 +122,120 @@ async fn accept_ws(
|
|||
fut: upgrade::UpgradeFut,
|
||||
addr: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?));
|
||||
let ws_stream = FragmentCollector::new(fut.await?);
|
||||
let ws_stream = lockedws::LockedWebSocket::new(ws_stream);
|
||||
|
||||
let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new());
|
||||
|
||||
println!("{:?}: connected", addr);
|
||||
|
||||
while let Ok(mut frame) = ws_stream.read_frame().await {
|
||||
use fastwebsockets::OpCode::*;
|
||||
let frame = match frame.opcode {
|
||||
Continuation => unreachable!(),
|
||||
Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())),
|
||||
Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())),
|
||||
Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())),
|
||||
Ping => continue,
|
||||
Pong => continue,
|
||||
};
|
||||
if let Ok(packet) = Packet::try_from(frame) {
|
||||
ws_stream
|
||||
.write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into())
|
||||
.await?;
|
||||
|
||||
while let Ok(frame) = ws_stream.read_frame().await {
|
||||
if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) {
|
||||
use PacketType::*;
|
||||
match packet.packet {
|
||||
Connect(inner_packet) => {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded_channel::<WsEvent>();
|
||||
let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::<WsEvent>();
|
||||
stream_map.clone().insert(packet.stream_id, ch_tx);
|
||||
let ws_stream_cloned = ws_stream.clone();
|
||||
tokio::spawn(async move {
|
||||
|
||||
let tcp_stream = match TcpStream::connect(format!(
|
||||
"{}:{}",
|
||||
inner_packet.destination_hostname, inner_packet.destination_port
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
ws_stream_cloned
|
||||
.write_frame(
|
||||
ws::Frame::from(Packet::new_close(packet.stream_id, 0x03))
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
};
|
||||
println!("muxing");
|
||||
let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new());
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = tcp_stream.next() => {
|
||||
println!("recvd");
|
||||
if let Some(res) = event {
|
||||
match res {
|
||||
Ok(buf) => {
|
||||
ws_stream_cloned.write_frame(
|
||||
ws::Frame::from(
|
||||
Packet::new_data(
|
||||
packet.stream_id,
|
||||
buf.to_vec()
|
||||
)
|
||||
).into()
|
||||
).await.map_err(std::io::Error::other)?;
|
||||
}
|
||||
Err(err) => {
|
||||
ws_stream_cloned
|
||||
.write_frame(
|
||||
ws::Frame::from(Packet::new_close(
|
||||
packet.stream_id,
|
||||
0x03,
|
||||
))
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.map_err(std::io::Error::other)?;
|
||||
return Err(Box::new(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
event = ch_rx.recv() => {
|
||||
if let Some(event) = event {
|
||||
match event {
|
||||
WsEvent::Send(buf) => {
|
||||
tcp_stream.send(buf).await?;
|
||||
println!("sending");
|
||||
ws_stream_cloned
|
||||
.write_frame(
|
||||
ws::Frame::from(
|
||||
Packet::new_continue(
|
||||
packet.stream_id,
|
||||
u32::MAX
|
||||
)
|
||||
).into()
|
||||
).await.map_err(std::io::Error::other)?;
|
||||
println!("sent");
|
||||
}
|
||||
WsEvent::Close => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
Data(inner_packet) => {}
|
||||
Data(inner_packet) => {
|
||||
println!("recieved data for {:?}", packet.stream_id);
|
||||
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
|
||||
let _ = stream.send(WsEvent::Send(inner_packet.into()));
|
||||
}
|
||||
}
|
||||
Continue(_) => unreachable!(),
|
||||
Close(inner_packet) => {}
|
||||
Close(_) => {
|
||||
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
|
||||
let _ = stream.send(WsEvent::Close);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ name = "wisp-mux"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bytes = "1.5.0"
|
||||
fastwebsockets = { version = "0.6.0", optional = true }
|
||||
futures = "0.3.30"
|
||||
futures-util = "0.3.30"
|
||||
|
||||
[features]
|
||||
fastwebsockets = ["dep:fastwebsockets"]
|
||||
|
|
40
wisp/src/fastwebsockets.rs
Normal file
40
wisp/src/fastwebsockets.rs
Normal file
|
@ -0,0 +1,40 @@
|
|||
use bytes::Bytes;
|
||||
use fastwebsockets::{Payload, Frame, OpCode};
|
||||
|
||||
impl TryFrom<OpCode> for crate::ws::OpCode {
|
||||
type Error = crate::WispError;
|
||||
fn try_from(opcode: OpCode) -> Result<Self, Self::Error> {
|
||||
use OpCode::*;
|
||||
match opcode {
|
||||
Continuation => Err(Self::Error::WsImplNotSupported),
|
||||
Text => Ok(Self::Text),
|
||||
Binary => Ok(Self::Binary),
|
||||
Close => Ok(Self::Close),
|
||||
Ping => Err(Self::Error::WsImplNotSupported),
|
||||
Pong => Err(Self::Error::WsImplNotSupported),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Frame<'_>> for crate::ws::Frame {
|
||||
type Error = crate::WispError;
|
||||
fn try_from(mut frame: Frame) -> Result<Self, Self::Error> {
|
||||
let opcode = frame.opcode.try_into()?;
|
||||
Ok(Self {
|
||||
finished: frame.fin,
|
||||
opcode,
|
||||
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::ws::Frame> for Frame<'_> {
|
||||
fn from(frame: crate::ws::Frame) -> Self {
|
||||
use crate::ws::OpCode::*;
|
||||
match frame.opcode {
|
||||
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
|
||||
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
|
||||
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec()))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,5 @@
|
|||
#[cfg(feature = "fastwebsockets")]
|
||||
mod fastwebsockets;
|
||||
mod packet;
|
||||
pub mod ws;
|
||||
|
||||
|
@ -9,12 +11,14 @@ pub enum Role {
|
|||
Server,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WispError {
|
||||
PacketTooSmall,
|
||||
InvalidPacketType,
|
||||
WsFrameInvalidType,
|
||||
WsFrameNotFinished,
|
||||
WsImplError(Box<dyn std::error::Error>),
|
||||
WsImplNotSupported,
|
||||
Utf8Error(std::str::Utf8Error),
|
||||
}
|
||||
|
||||
|
@ -23,3 +27,20 @@ impl From<std::str::Utf8Error> for WispError {
|
|||
WispError::Utf8Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WispError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
use WispError::*;
|
||||
match self {
|
||||
PacketTooSmall => write!(f, "Packet too small"),
|
||||
InvalidPacketType => write!(f, "Invalid packet type"),
|
||||
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
|
||||
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
|
||||
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WispError {}
|
||||
|
|
|
@ -4,9 +4,9 @@ use bytes::{Buf, BufMut, Bytes};
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectPacket {
|
||||
stream_type: u8,
|
||||
destination_port: u16,
|
||||
destination_hostname: String,
|
||||
pub stream_type: u8,
|
||||
pub destination_port: u16,
|
||||
pub destination_hostname: String,
|
||||
}
|
||||
|
||||
impl ConnectPacket {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue