wisp part 1

This commit is contained in:
r58Playz 2024-01-22 18:19:39 -08:00
parent ad7a34e86d
commit 1f23c26db6
6 changed files with 117 additions and 46 deletions

20
Cargo.lock generated
View file

@ -214,6 +214,19 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@ -266,6 +279,7 @@ name = "epoxy-server"
version = "1.0.0" version = "1.0.0"
dependencies = [ dependencies = [
"bytes", "bytes",
"dashmap",
"fastwebsockets", "fastwebsockets",
"futures-util", "futures-util",
"http-body-util", "http-body-util",
@ -461,6 +475,12 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "hashbrown"
version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.3.3" version = "0.3.3"

View file

@ -5,6 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
bytes = "1.5.0" bytes = "1.5.0"
dashmap = "5.5.3"
fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] } fastwebsockets = { version = "0.6.0", features = ["upgrade", "simdutf8"] }
futures-util = { version = "0.3.30", features = ["sink"] } futures-util = { version = "0.3.30", features = ["sink"] }
http-body-util = "0.1.0" http-body-util = "0.1.0"

23
server/src/lockedws.rs Normal file
View file

@ -0,0 +1,23 @@
use fastwebsockets::{FragmentCollector, Frame, WebSocketError};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::sync::Mutex;
type Ws = FragmentCollector<TokioIo<Upgraded>>;
pub struct LockedWebSocket(Arc<Mutex<Ws>>);
impl LockedWebSocket {
pub fn new(ws: Ws) -> Self {
Self(Arc::new(Mutex::new(ws)))
}
pub async fn read_frame(&self) -> Result<Frame, WebSocketError> {
self.0.lock().await.read_frame().await
}
pub async fn write_frame(&self, frame: Frame<'_>) -> Result<(), WebSocketError> {
self.0.lock().await.write_frame(frame).await
}
}

View file

@ -1,4 +1,6 @@
use std::io::Error; mod lockedws;
use std::{io::Error, sync::Arc};
use bytes::Bytes; use bytes::Bytes;
use fastwebsockets::{ use fastwebsockets::{
@ -11,9 +13,12 @@ use hyper::{
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_native_tls::{native_tls, TlsAcceptor}; use tokio_native_tls::{native_tls, TlsAcceptor};
use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::codec::{BytesCodec, Framed};
use wisp_mux::{ws, Packet, PacketType};
type HttpBody = http_body_util::Empty<hyper::body::Bytes>; type HttpBody = http_body_util::Empty<hyper::body::Bytes>;
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
@ -47,7 +52,8 @@ async fn main() -> Result<(), Error> {
tokio::spawn(async move { tokio::spawn(async move {
let stream = acceptor_cloned.accept(stream).await.expect("not tls"); let stream = acceptor_cloned.accept(stream).await.expect("not tls");
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let service = service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone())); let service =
service_fn(move |res| accept_http(res, addr.to_string(), prefix_cloned.clone()));
let conn = http1::Builder::new() let conn = http1::Builder::new()
.serve_connection(io, service) .serve_connection(io, service)
.with_upgrades(); .with_upgrades();
@ -72,10 +78,13 @@ async fn accept_http(
tokio::spawn(async move { tokio::spawn(async move {
if *uri.path() != prefix { if *uri.path() != prefix {
if let Err(e) = if let Err(e) =
accept_wsproxy(fut, uri.path().to_string(), addr.clone(), prefix).await accept_wsproxy(fut, uri.path().strip_prefix(&prefix).unwrap(), addr.clone())
.await
{ {
println!("{:?}: error in ws handling: {:?}", addr, e); println!("{:?}: error in ws handling: {:?}", addr, e);
} }
} else if let Err(e) = accept_ws(fut, addr.clone()).await {
println!("{:?}: error in ws handling: {:?}", addr, e);
} }
}); });
@ -102,18 +111,60 @@ async fn accept_http(
} }
} }
enum WsEvent {
Send(Bytes),
Close,
}
async fn accept_ws(
fut: upgrade::UpgradeFut,
addr: String,
) -> Result<(), Box<dyn std::error::Error>> {
let ws_stream = lockedws::LockedWebSocket::new(FragmentCollector::new(fut.await?));
let stream_map = Arc::new(dashmap::DashMap::<u32, mpsc::UnboundedSender<WsEvent>>::new());
println!("{:?}: connected", addr);
while let Ok(mut frame) = ws_stream.read_frame().await {
use fastwebsockets::OpCode::*;
let frame = match frame.opcode {
Continuation => unreachable!(),
Text => ws::Frame::text(Bytes::copy_from_slice(frame.payload.to_mut())),
Binary => ws::Frame::binary(Bytes::copy_from_slice(frame.payload.to_mut())),
Close => ws::Frame::close(Bytes::copy_from_slice(frame.payload.to_mut())),
Ping => continue,
Pong => continue,
};
if let Ok(packet) = Packet::try_from(frame) {
use PacketType::*;
match packet.packet {
Connect(inner_packet) => {
let (ch_tx, ch_rx) = mpsc::unbounded_channel::<WsEvent>();
stream_map.clone().insert(packet.stream_id, ch_tx);
tokio::spawn(async move {
});
}
Data(inner_packet) => {}
Continue(_) => unreachable!(),
Close(inner_packet) => {}
}
}
}
println!("{:?}: disconnected", addr);
Ok(())
}
async fn accept_wsproxy( async fn accept_wsproxy(
fut: upgrade::UpgradeFut, fut: upgrade::UpgradeFut,
incoming_uri: String, incoming_uri: &str,
addr: String, addr: String,
prefix: String,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let mut ws_stream = FragmentCollector::new(fut.await?); let mut ws_stream = FragmentCollector::new(fut.await?);
// should always have prefix println!("{:?}: connected (wsproxy)", addr);
let incoming_uri = incoming_uri.strip_prefix(&prefix).unwrap();
println!("{:?}: connected", addr);
let tcp_stream = match TcpStream::connect(incoming_uri).await { let tcp_stream = match TcpStream::connect(incoming_uri).await {
Ok(stream) => stream, Ok(stream) => stream,
@ -132,56 +183,33 @@ async fn accept_wsproxy(
event = ws_stream.read_frame() => { event = ws_stream.read_frame() => {
match event { match event {
Ok(frame) => { Ok(frame) => {
print!("{:?}: event ws - ", addr);
match frame.opcode { match frame.opcode {
OpCode::Text | OpCode::Binary => { OpCode::Text | OpCode::Binary => {
if tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await.is_ok() { let _ = tcp_stream_framed.send(Bytes::from(frame.payload.to_vec())).await;
println!("sent success");
} else {
println!("sent FAILED");
}
} }
OpCode::Close => { OpCode::Close => {
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() { // tokio closes the stream for us
println!("closed success"); drop(tcp_stream_framed);
} else {
println!("closed FAILED");
}
break; break;
} }
_ => { _ => {}
println!("ignored");
}
} }
}, },
Err(err) => { Err(_) => {
print!("{:?}: err in ws: {:?} - ", addr, err); // tokio closes the stream for us
if <Framed<tokio::net::TcpStream, BytesCodec> as SinkExt<Bytes>>::close(&mut tcp_stream_framed).await.is_ok() { drop(tcp_stream_framed);
println!("closed tcp success");
} else {
println!("closed tcp FAILED");
}
break; break;
} }
} }
}, },
event = tcp_stream_framed.next() => { event = tcp_stream_framed.next() => {
if let Some(res) = event { if let Some(res) = event {
print!("{:?}: event tcp - ", addr);
match res { match res {
Ok(buf) => { Ok(buf) => {
if ws_stream.write_frame(Frame::binary(Payload::Owned(buf.to_vec()))).await.is_ok() { let _ = ws_stream.write_frame(Frame::binary(Payload::Borrowed(&buf))).await;
println!("sent success");
} else {
println!("sent FAILED");
}
} }
Err(_) => { Err(_) => {
if ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await.is_ok() { let _ = ws_stream.write_frame(Frame::close(CloseCode::Away.into(), b"tcp side is going away")).await;
println!("closed success");
} else {
println!("closed FAILED");
}
} }
} }
} }
@ -189,8 +217,7 @@ async fn accept_wsproxy(
} }
} }
println!("\"{}\": connection closed", addr); println!("{:?}: disconnected (wsproxy)", addr);
Ok(()) Ok(())
} }

View file

@ -1,5 +1,5 @@
mod packet; mod packet;
mod ws; pub mod ws;
pub use crate::packet::*; pub use crate::packet::*;

View file

@ -139,8 +139,8 @@ impl From<PacketType> for Vec<u8> {
#[derive(Debug)] #[derive(Debug)]
pub struct Packet { pub struct Packet {
stream_id: u32, pub stream_id: u32,
packet: PacketType, pub packet: PacketType,
} }
impl Packet { impl Packet {