serverside done except it deadlocks

This commit is contained in:
Toshit Chawda 2024-01-22 20:11:58 -08:00
parent 1f23c26db6
commit 24d145cc66
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
8 changed files with 176 additions and 24 deletions

1
Cargo.lock generated
View file

@ -1530,6 +1530,7 @@ name = "wisp-mux"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bytes", "bytes",
"fastwebsockets",
"futures", "futures",
"futures-util", "futures-util",
] ]

View file

@ -14,4 +14,4 @@ hyper-util = { version = "0.1.2", features = ["tokio"] }
tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] } tokio = { version = "1.5.1", features = ["rt-multi-thread", "macros"] }
tokio-native-tls = "0.3.1" tokio-native-tls = "0.3.1"
tokio-util = { version = "0.7.10", features = ["codec"] } tokio-util = { version = "0.7.10", features = ["codec"] }
wisp-mux = { path = "../wisp" } wisp-mux = { path = "../wisp", features = ["fastwebsockets"] }

View file

@ -6,6 +6,7 @@ use tokio::sync::Mutex;
type Ws = FragmentCollector<TokioIo<Upgraded>>; type Ws = FragmentCollector<TokioIo<Upgraded>>;
#[derive(Clone)]
pub struct LockedWebSocket(Arc<Mutex<Ws>>); pub struct LockedWebSocket(Arc<Mutex<Ws>>);
impl LockedWebSocket { impl LockedWebSocket {

View file

@ -12,8 +12,10 @@ use hyper::{
Response, StatusCode, Response, StatusCode,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream}; use tokio::{
use tokio::sync::mpsc; net::{TcpListener, TcpStream},
sync::mpsc,
};
use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
@ -120,35 +122,120 @@ async fn accept_ws(
fut: upgrade::UpgradeFut, fut: upgrade::UpgradeFut,
addr: String, addr: String,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?)); let ws_stream = FragmentCollector::new(fut.await?);
let ws_stream = lockedws::LockedWebSocket::new(ws_stream);
let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new()); let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new());
println!("{:?}: connected", addr); println!("{:?}: connected", addr);
while let Ok(mut frame) = ws_stream.read_frame().await { ws_stream
use fastwebsockets::OpCode::*; .write_frame(ws::Frame::from(Packet::new_continue(0, u32::MAX)).into())
let frame = match frame.opcode { .await?;
Continuation => unreachable!(),
Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())), while let Ok(frame) = ws_stream.read_frame().await {
Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())), if let Ok(packet) = Packet::try_from(ws::Frame::try_from(frame)?) {
Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())),
Ping => continue,
Pong => continue,
};
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*; use PacketType::*;
match packet.packet { match packet.packet {
Connect(inner_packet) => { Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded_channel::<WsEvent>(); let (ch_tx, mut ch_rx) = mpsc::unbounded_channel::<WsEvent>();
stream_map.clone().insert(packet.stream_id, ch_tx); stream_map.clone().insert(packet.stream_id, ch_tx);
let ws_stream_cloned = ws_stream.clone();
tokio::spawn(async move { tokio::spawn(async move {
let tcp_stream = match TcpStream::connect(format!(
"{}:{}",
inner_packet.destination_hostname, inner_packet.destination_port
))
.await
{
Ok(stream) => stream,
Err(err) => {
ws_stream_cloned
.write_frame(
ws::Frame::from(Packet::new_close(packet.stream_id, 0x03))
.into(),
)
.await
.map_err(std::io::Error::other)?;
return Err(Box::new(err));
}
};
println!("muxing");
let mut tcp_stream = Framed::new(tcp_stream, BytesCodec::new());
loop {
tokio::select! {
event = tcp_stream.next() => {
println!("recvd");
if let Some(res) = event {
match res {
Ok(buf) => {
ws_stream_cloned.write_frame(
ws::Frame::from(
Packet::new_data(
packet.stream_id,
buf.to_vec()
)
).into()
).await.map_err(std::io::Error::other)?;
}
Err(err) => {
ws_stream_cloned
.write_frame(
ws::Frame::from(Packet::new_close(
packet.stream_id,
0x03,
))
.into(),
)
.await
.map_err(std::io::Error::other)?;
return Err(Box::new(err));
}
}
}
}
event = ch_rx.recv() => {
if let Some(event) = event {
match event {
WsEvent::Send(buf) => {
tcp_stream.send(buf).await?;
println!("sending");
ws_stream_cloned
.write_frame(
ws::Frame::from(
Packet::new_continue(
packet.stream_id,
u32::MAX
)
).into()
).await.map_err(std::io::Error::other)?;
println!("sent");
}
WsEvent::Close => {
break;
}
}
} else {
break;
}
}
};
}
Ok(())
}); });
} }
Data(inner_packet) => {} Data(inner_packet) => {
println!("recieved data for {:?}", packet.stream_id);
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
let _ = stream.send(WsEvent::Send(inner_packet.into()));
}
}
Continue(_) => unreachable!(), Continue(_) => unreachable!(),
Close(inner_packet) => {} Close(_) => {
if let Some(stream) = stream_map.clone().get(&packet.stream_id) {
let _ = stream.send(WsEvent::Close);
}
}
} }
} }
} }

View file

@ -3,9 +3,11 @@ name = "wisp-mux"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
bytes = "1.5.0" bytes = "1.5.0"
fastwebsockets = { version = "0.6.0", optional = true }
futures = "0.3.30" futures = "0.3.30"
futures-util = "0.3.30" futures-util = "0.3.30"
[features]
fastwebsockets = ["dep:fastwebsockets"]

View file

@ -0,0 +1,40 @@
use bytes::Bytes;
use fastwebsockets::{Payload, Frame, OpCode};
impl TryFrom<OpCode> for crate::ws::OpCode {
type Error = crate::WispError;
fn try_from(opcode: OpCode) -> Result<Self, Self::Error> {
use OpCode::*;
match opcode {
Continuation => Err(Self::Error::WsImplNotSupported),
Text => Ok(Self::Text),
Binary => Ok(Self::Binary),
Close => Ok(Self::Close),
Ping => Err(Self::Error::WsImplNotSupported),
Pong => Err(Self::Error::WsImplNotSupported),
}
}
}
impl TryFrom<Frame<'_>> for crate::ws::Frame {
type Error = crate::WispError;
fn try_from(mut frame: Frame) -> Result<Self, Self::Error> {
let opcode = frame.opcode.try_into()?;
Ok(Self {
finished: frame.fin,
opcode,
payload: Bytes::copy_from_slice(frame.payload.to_mut()),
})
}
}
impl From<crate::ws::Frame> for Frame<'_> {
fn from(frame: crate::ws::Frame) -> Self {
use crate::ws::OpCode::*;
match frame.opcode {
Text => Self::text(Payload::Owned(frame.payload.to_vec())),
Binary => Self::binary(Payload::Owned(frame.payload.to_vec())),
Close => Self::close_raw(Payload::Owned(frame.payload.to_vec()))
}
}
}

View file

@ -1,3 +1,5 @@
#[cfg(feature = "fastwebsockets")]
mod fastwebsockets;
mod packet; mod packet;
pub mod ws; pub mod ws;
@ -9,12 +11,14 @@ pub enum Role {
Server, Server,
} }
#[derive(Debug)]
pub enum WispError { pub enum WispError {
PacketTooSmall, PacketTooSmall,
InvalidPacketType, InvalidPacketType,
WsFrameInvalidType, WsFrameInvalidType,
WsFrameNotFinished, WsFrameNotFinished,
WsImplError(Box<dyn std::error::Error>), WsImplError(Box<dyn std::error::Error>),
WsImplNotSupported,
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
} }
@ -23,3 +27,20 @@ impl From<std::str::Utf8Error> for WispError {
WispError::Utf8Error(err) WispError::Utf8Error(err)
} }
} }
impl std::fmt::Display for WispError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use WispError::*;
match self {
PacketTooSmall => write!(f, "Packet too small"),
InvalidPacketType => write!(f, "Invalid packet type"),
WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
WsImplError(err) => write!(f, "Websocket implementation error: {:?}", err),
WsImplNotSupported => write!(f, "Websocket implementation error: unsupported feature"),
Utf8Error(err) => write!(f, "UTF-8 error: {:?}", err),
}
}
}
impl std::error::Error for WispError {}

View file

@ -4,9 +4,9 @@ use bytes::{Buf, BufMut, Bytes};
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectPacket { pub struct ConnectPacket {
stream_type: u8, pub stream_type: u8,
destination_port: u16, pub destination_port: u16,
destination_hostname: String, pub destination_hostname: String,
} }
impl ConnectPacket { impl ConnectPacket {