diff --git a/Cargo.lock b/Cargo.lock index 7c9ae4b..a6bfb1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,9 +98,9 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] name = "async-compression" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd066d0b4ef8ecb03a55319dc13aa6910616d0f44008a045bb1835af830abff5" +checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" dependencies = [ "brotli", "flate2", @@ -477,6 +477,20 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "dashmap" +version = "6.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "digest" version = "0.10.7" @@ -493,6 +507,29 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "env_filter" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "epoxy-client" version = "2.0.8" @@ -533,17 +570,21 @@ version = "2.0.0" dependencies = [ "anyhow", "bytes", + "dashmap 6.0.1", + "env_logger", "fastwebsockets", "futures-util", "http-body-util", "hyper 1.4.1", "hyper-util", "lazy_static", + "log", "regex", "serde", "tokio", "tokio-util", "toml", + "uuid", "wisp-mux", ] @@ -583,6 +624,7 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" [[package]] name = "fastwebsockets" version = "0.8.0" +source = "git+https://github.com/r58Playz/fastwebsockets#9152ec2e28512feeb93d3aba3b516f07355025b6" dependencies = [ "base64 0.21.7", "bytes", @@ -1112,6 +1154,9 @@ name = "log" version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +dependencies = [ + "serde", +] [[package]] name = "matchers" @@ -1238,9 +1283,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" -version = "0.10.65" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2823eb4c6453ed64055057ea8bd416eda38c71018723869dd043a3b1186115e" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -1768,9 +1813,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.71" +version = "2.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462" +checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" dependencies = [ "proc-macro2", "quote", @@ -2077,6 +2122,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" @@ -2386,7 +2440,7 @@ version = "5.0.0" dependencies = [ "async-trait", "bytes", - "dashmap", + "dashmap 5.5.3", "event-listener", "fastwebsockets", "flume", diff --git a/Cargo.toml b/Cargo.toml index c1638f6..d7676dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,3 +8,6 @@ debug = true panic = "abort" codegen-units = 1 opt-level = 3 + +[patch.crates-io] +fastwebsockets = { git = "https://github.com/r58Playz/fastwebsockets" } diff --git a/server/Cargo.toml b/server/Cargo.toml index 0e8871d..03be1a4 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -6,15 +6,19 @@ edition = "2021" [dependencies] anyhow = "1.0.86" bytes = "1.6.1" +dashmap = "6.0.1" +env_logger = "0.11.3" fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] } futures-util = "0.3.30" http-body-util = "0.1.2" hyper = { version = "1.4.1", features = ["server", "http1"] } hyper-util = { version = "0.1.6", features = ["tokio"] } lazy_static = "1.5.0" +log = { version = "0.4.22", features = ["serde", "std"] } regex = "1.10.5" serde = { version = "1.0.204", features = ["derive"] } tokio = { version = "1.38.1", features = ["full"] } tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] } toml = "0.8.15" +uuid = { version = "1.10.0", features = ["v4"] } wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] } diff --git a/server/src/config.rs b/server/src/config.rs index 2d4dd05..22a2dff 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, ops::RangeInclusive}; use lazy_static::lazy_static; +use log::LevelFilter; use regex::RegexSet; use serde::{Deserialize, Serialize}; use wisp_mux::extensions::{ @@ -48,7 +49,7 @@ pub fn validate_config_cache() { let _ = CONFIG_CACHE.wisp_config; } -#[derive(Serialize, Deserialize, Default)] +#[derive(Serialize, Deserialize, Default, Debug)] #[serde(rename_all = "lowercase")] pub enum SocketType { #[default] @@ -63,19 +64,38 @@ pub struct ServerConfig { pub socket: SocketType, pub resolve_ipv6: bool, + pub verbose_stats: bool, + pub enable_stats_endpoint: bool, + pub stats_endpoint: String, + + pub non_ws_response: String, + + // DO NOT add a trailing slash to this config option + pub prefix: String, + pub max_message_size: usize, - // TODO - // prefix: String, + + pub log_level: LevelFilter, } impl Default for ServerConfig { fn default() -> Self { Self { - bind: "127.0.0.1:4000".to_owned(), + bind: "127.0.0.1:4000".to_string(), socket: SocketType::default(), resolve_ipv6: false, + verbose_stats: true, + stats_endpoint: "/stats".to_string(), + enable_stats_endpoint: true, + + non_ws_response: ":3".to_string(), + + prefix: String::new(), + max_message_size: 64 * 1024, + + log_level: LevelFilter::Info, } } } @@ -90,21 +110,21 @@ pub enum ProtocolExtension { #[derive(Serialize, Deserialize)] #[serde(default)] pub struct WispConfig { - pub wisp_v2: bool, + pub allow_wsproxy: bool, pub buffer_size: u32, + pub wisp_v2: bool, pub extensions: Vec, pub password_extension_users: HashMap, - // TODO - // enable_wsproxy: bool, } impl Default for WispConfig { fn default() -> Self { Self { - buffer_size: 512, - wisp_v2: false, + buffer_size: 128, + allow_wsproxy: true, + wisp_v2: false, extensions: Vec::new(), password_extension_users: HashMap::new(), } @@ -112,7 +132,9 @@ impl Default for WispConfig { } impl WispConfig { - pub fn to_opts_inner(&self) -> anyhow::Result<(Option>, u32)> { + pub(super) fn to_opts_inner( + &self, + ) -> anyhow::Result<(Option>, u32)> { if self.wisp_v2 { let mut extensions: Vec> = Vec::new(); @@ -144,6 +166,7 @@ impl WispConfig { #[serde(default)] pub struct StreamConfig { pub allow_udp: bool, + pub allow_wsproxy_udp: bool, pub allow_direct_ip: bool, pub allow_loopback: bool, @@ -163,6 +186,7 @@ impl Default for StreamConfig { fn default() -> Self { Self { allow_udp: true, + allow_wsproxy_udp: false, allow_direct_ip: true, allow_loopback: true, diff --git a/server/src/handle/mod.rs b/server/src/handle/mod.rs new file mode 100644 index 0000000..90663fc --- /dev/null +++ b/server/src/handle/mod.rs @@ -0,0 +1,5 @@ +mod wisp; +mod wsproxy; + +pub use wisp::handle_wisp; +pub use wsproxy::handle_wsproxy; diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs new file mode 100644 index 0000000..7bfac71 --- /dev/null +++ b/server/src/handle/wisp.rs @@ -0,0 +1,204 @@ +use anyhow::Context; +use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead}; +use futures_util::FutureExt; +use hyper_util::rt::TokioIo; +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + select, + task::JoinSet, +}; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use uuid::Uuid; +use wisp_mux::{ + CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux, +}; + +use crate::{ + stream::{ClientStream, ResolvedPacket, ServerStream, ServerStreamExt}, + CLIENTS, CONFIG, +}; + +async fn copy_read_fast( + muxrx: MuxStreamAsyncRead, + mut tcptx: OwnedWriteHalf, +) -> std::io::Result<()> { + let mut muxrx = muxrx.compat(); + loop { + let buf = muxrx.fill_buf().await?; + if buf.is_empty() { + tcptx.flush().await?; + return Ok(()); + } + + let i = tcptx.write(buf).await?; + if i == 0 { + return Err(std::io::ErrorKind::WriteZero.into()); + } + + muxrx.consume(i); + } +} + +async fn copy_write_fast(muxtx: MuxStreamWrite, tcprx: OwnedReadHalf) -> anyhow::Result<()> { + let mut tcprx = BufReader::new(tcprx); + loop { + let buf = tcprx.fill_buf().await?; + muxtx.write(&buf).await?; + let len = buf.len(); + tcprx.consume(len); + } +} + +async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String) { + let requested_stream = connect.clone(); + + let Ok(resolved) = ClientStream::resolve(connect).await else { + let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; + return; + }; + let connect = match resolved { + ResolvedPacket::Valid(x) => x, + ResolvedPacket::NoResolvedAddrs => { + let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; + return; + } + ResolvedPacket::Blocked => { + let _ = muxstream + .close(CloseReason::ServerStreamBlockedAddress) + .await; + return; + } + }; + + let resolved_stream = connect.clone(); + + let Ok(stream) = ClientStream::connect(connect).await else { + let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; + return; + }; + + let uuid = Uuid::new_v4(); + + CLIENTS + .get(&id) + .unwrap() + .0 + .insert(uuid, (requested_stream, resolved_stream)); + + match stream { + ClientStream::Tcp(stream) => { + let closer = muxstream.get_close_handle(); + + let ret: anyhow::Result<()> = async { + /* + let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split(); + let (mut tcpread, tcpwrite) = stream.into_split(); + let mut muxwrite = muxwrite.compat_write(); + select! { + x = copy_read_fast(muxread, tcpwrite) => x?, + x = copy(&mut tcpread, &mut muxwrite) => {x?;}, + } + */ + // TODO why is copy_write_fast not working? + let (muxread, muxwrite) = muxstream.into_split(); + let muxread = muxread.into_stream().into_asyncread(); + let (tcpread, tcpwrite) = stream.into_split(); + select! { + x = copy_read_fast(muxread, tcpwrite) => x?, + x = copy_write_fast(muxwrite, tcpread) => x?, + } + Ok(()) + } + .await; + + match ret { + Ok(()) => { + let _ = closer.close(CloseReason::Voluntary).await; + } + Err(_) => { + let _ = closer.close(CloseReason::Unexpected).await; + } + } + } + ClientStream::Udp(stream) => { + let closer = muxstream.get_close_handle(); + + let ret: anyhow::Result<()> = async move { + let mut data = vec![0u8; 65507]; + loop { + select! { + size = stream.recv(&mut data) => { + let size = size?; + muxstream.write(&data[..size]).await?; + } + data = muxstream.read() => { + if let Some(data) = data { + stream.send(&data).await?; + } else { + break Ok(()); + } + } + } + } + } + .await; + + match ret { + Ok(()) => { + let _ = closer.close(CloseReason::Voluntary).await; + } + Err(_) => { + let _ = closer.close(CloseReason::Unexpected).await; + } + } + } + ClientStream::Invalid => { + let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; + } + ClientStream::Blocked => { + let _ = muxstream + .close(CloseReason::ServerStreamBlockedAddress) + .await; + } + }; + + CLIENTS.get(&id).unwrap().0.remove(&uuid); +} + +pub async fn handle_wisp(fut: UpgradeFut, id: String) -> anyhow::Result<()> { + let mut ws = fut.await.context("failed to await upgrade future")?; + ws.set_max_message_size(CONFIG.server.max_message_size); + + let (read, write) = ws.split(|x| { + let parts = x.into_inner().downcast::>().unwrap(); + assert_eq!(parts.read_buf.len(), 0); + parts.io.into_inner().split() + }); + let read = FragmentCollectorRead::new(read); + + let (extensions, buffer_size) = CONFIG.wisp.to_opts(); + + let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions) + .await + .context("failed to create server multiplexor")? + .with_no_required_extensions(); + + let mut set: JoinSet<()> = JoinSet::new(); + + set.spawn(tokio::task::unconstrained(fut.map(|_| {}))); + + while let Some((connect, stream)) = mux.server_new_stream().await { + set.spawn(tokio::task::unconstrained(handle_stream( + connect, + stream, + id.clone(), + ))); + } + + set.abort_all(); + + while set.join_next().await.is_some() {} + + Ok(()) +} diff --git a/server/src/handle/wsproxy.rs b/server/src/handle/wsproxy.rs new file mode 100644 index 0000000..a98955e --- /dev/null +++ b/server/src/handle/wsproxy.rs @@ -0,0 +1,145 @@ +use std::str::FromStr; + +use anyhow::Context; +use fastwebsockets::{upgrade::UpgradeFut, CloseCode, FragmentCollector}; +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, + select, +}; +use uuid::Uuid; +use wisp_mux::{ConnectPacket, StreamType}; + +use crate::{ + stream::{ClientStream, ResolvedPacket, WebSocketFrame, WebSocketStreamWrapper}, + CLIENTS, CONFIG, +}; + +pub async fn handle_wsproxy( + fut: UpgradeFut, + id: String, + path: String, + udp: bool, +) -> anyhow::Result<()> { + let mut ws = fut.await.context("failed to await upgrade future")?; + ws.set_max_message_size(CONFIG.server.max_message_size); + let ws = FragmentCollector::new(ws); + let mut ws = WebSocketStreamWrapper(ws); + + if udp && !CONFIG.stream.allow_wsproxy_udp { + let _ = ws.close(CloseCode::Error.into(), b"udp is blocked").await; + return Ok(()); + } + + let vec: Vec<&str> = path.split("/").last().unwrap().split(":").collect(); + let Ok(port) = FromStr::from_str(vec[1]) else { + let _ = ws.close(CloseCode::Error.into(), b"invalid port").await; + return Ok(()); + }; + let connect = ConnectPacket { + stream_type: if udp { + StreamType::Udp + } else { + StreamType::Tcp + }, + destination_hostname: vec[0].to_string(), + destination_port: port, + }; + + let requested_stream = connect.clone(); + + let Ok(resolved) = ClientStream::resolve(connect).await else { + let _ = ws + .close(CloseCode::Error.into(), b"failed to resolve host") + .await; + return Ok(()); + }; + let connect = match resolved { + ResolvedPacket::Valid(x) => x, + ResolvedPacket::NoResolvedAddrs => { + let _ = ws + .close( + CloseCode::Error.into(), + b"host did not resolve to any addrs", + ) + .await; + return Ok(()); + } + ResolvedPacket::Blocked => { + let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; + return Ok(()); + } + }; + + let resolved_stream = connect.clone(); + + let Ok(stream) = ClientStream::connect(connect).await else { + let _ = ws + .close(CloseCode::Error.into(), b"failed to connect to host") + .await; + return Ok(()); + }; + + let uuid = Uuid::new_v4(); + + CLIENTS + .get(&id) + .unwrap() + .0 + .insert(uuid, (requested_stream, resolved_stream)); + + match stream { + ClientStream::Tcp(stream) => { + let mut stream = BufReader::new(stream); + let ret: anyhow::Result<()> = async { + let mut to_consume = 0usize; + loop { + if to_consume != 0 { + stream.consume(to_consume); + to_consume = 0; + } + select! { + x = ws.read() => { + match x? { + WebSocketFrame::Data(data) => { + stream.write_all(&data).await?; + } + WebSocketFrame::Close => { + stream.shutdown().await?; + } + WebSocketFrame::Ignore => {} + } + } + x = stream.fill_buf() => { + let x = x?; + ws.write(x).await?; + to_consume += x.len(); + } + } + } + } + .await; + match ret { + Ok(_) => { + let _ = ws.close(CloseCode::Normal.into(), b"").await; + } + Err(x) => { + let _ = ws + .close(CloseCode::Normal.into(), x.to_string().as_bytes()) + .await; + } + } + } + ClientStream::Udp(_stream) => { + // TODO + let _ = ws.close(CloseCode::Error.into(), b"coming soon").await; + } + ClientStream::Blocked => { + let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; + } + ClientStream::Invalid => { + let _ = ws.close(CloseCode::Error.into(), b"host is invalid").await; + } + } + + Ok(()) +} diff --git a/server/src/main.rs b/server/src/main.rs index 061c41a..fdb0d5c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,25 +1,30 @@ #![feature(ip)] -use std::{env::args, fs::read_to_string, ops::Deref}; +use std::{env::args, fmt::Write, fs::read_to_string}; -use anyhow::Context; use bytes::Bytes; use config::{validate_config_cache, Config}; -use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead}; -use http_body_util::Empty; -use hyper::{body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response}; +use dashmap::DashMap; +use handle::{handle_wisp, handle_wsproxy}; +use http_body_util::Full; +use hyper::{ + body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response, + StatusCode, +}; use hyper_util::rt::TokioIo; use lazy_static::lazy_static; -use stream::{ - copy_read_fast, ClientStream, ResolvedPacket, ServerListener, ServerStream, ServerStreamExt, -}; -use tokio::{io::copy, select}; -use tokio_util::compat::FuturesAsyncWriteCompatExt; -use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux}; +use log::{error, info}; +use stream::ServerListener; +use tokio::signal::unix::{signal, SignalKind}; +use uuid::Uuid; +use wisp_mux::{ConnectPacket, StreamType}; mod config; +mod handle; mod stream; +type Client = (DashMap, bool); + lazy_static! { pub static ref CONFIG: Config = { if let Some(path) = args().nth(1) { @@ -28,169 +33,159 @@ lazy_static! { Config::default() } }; + pub static ref CLIENTS: DashMap = DashMap::new(); } -async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream) { - let Ok(resolved) = ClientStream::resolve(connect).await else { - let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; - return; - }; - let connect = match resolved { - ResolvedPacket::Valid(x) => x, - ResolvedPacket::NoResolvedAddrs => { - let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; - return; - } - ResolvedPacket::Blocked => { - let _ = muxstream - .close(CloseReason::ServerStreamBlockedAddress) - .await; - return; - } - }; - - let Ok(stream) = ClientStream::connect(connect).await else { - let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; - return; - }; - - match stream { - ClientStream::Tcp(stream) => { - let closer = muxstream.get_close_handle(); - - let ret: anyhow::Result<()> = async move { - let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split(); - let (mut tcpread, tcpwrite) = stream.into_split(); - let mut muxwrite = muxwrite.compat_write(); - select! { - x = copy_read_fast(muxread, tcpwrite) => x?, - x = copy(&mut tcpread, &mut muxwrite) => {x?;}, - } - // TODO why is copy_write_fast not working? - /* - let (muxread, muxwrite) = muxstream.into_split(); - let muxread = muxread.into_stream().into_asyncread(); - let (mut tcpread, tcpwrite) = stream.into_split(); - select! { - x = copy_read_fast(muxread, tcpwrite) => x?, - x = copy_write_fast(muxwrite, tcpread) => {x?;}, - } - */ - Ok(()) - } - .await; - - match ret { - Ok(()) => { - let _ = closer.close(CloseReason::Voluntary).await; - } - Err(_) => { - let _ = closer.close(CloseReason::Unexpected).await; - } - } - } - ClientStream::Udp(stream) => { - let closer = muxstream.get_close_handle(); - - let ret: anyhow::Result<()> = async move { - let mut data = vec![0u8; 65507]; - loop { - select! { - size = stream.recv(&mut data) => { - let size = size?; - muxstream.write(&data[..size]).await?; - } - data = muxstream.read() => { - if let Some(data) = data { - stream.send(&data).await?; - } else { - break Ok(()); - } - } - } - } - } - .await; - - match ret { - Ok(()) => { - let _ = closer.close(CloseReason::Voluntary).await; - } - Err(_) => { - let _ = closer.close(CloseReason::Unexpected).await; - } - } - } - ClientStream::Invalid => { - let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; - } - ClientStream::Blocked => { - let _ = muxstream - .close(CloseReason::ServerStreamBlockedAddress) - .await; - } - }; +type Body = Full; +fn non_ws_resp() -> Response { + Response::builder() + .status(StatusCode::OK) + .body(Body::new(CONFIG.server.non_ws_response.as_bytes().into())) + .unwrap() } -async fn handle(fut: UpgradeFut) -> anyhow::Result<()> { - let mut ws = fut.await.context("failed to await upgrade future")?; - - ws.set_max_message_size(CONFIG.server.max_message_size); - - let (read, write) = ws.split(|x| { - let parts = x.into_inner().downcast::>().unwrap(); - assert_eq!(parts.read_buf.len(), 0); - parts.io.into_inner().split() - }); - let read = FragmentCollectorRead::new(read); - - let (extensions, buffer_size) = CONFIG.wisp.to_opts_inner()?; - - let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref()) - .await - .context("failed to create server multiplexor")? - .with_no_required_extensions(); - - tokio::spawn(tokio::task::unconstrained(fut)); - - while let Some((connect, stream)) = mux.server_new_stream().await { - tokio::spawn(tokio::task::unconstrained(handle_stream(connect, stream))); +async fn upgrade(mut req: Request, id: String) -> anyhow::Result> { + if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint { + match generate_stats() { + Ok(x) => { + return Ok(Response::builder() + .status(StatusCode::OK) + .body(Body::new(x.into())) + .unwrap()) + } + Err(x) => { + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::new(x.to_string().into())) + .unwrap()) + } + } + } else if !fastwebsockets::upgrade::is_upgrade_request(&req) { + return Ok(non_ws_resp()); } - Ok(()) -} - -type Body = Empty; -async fn upgrade(mut req: Request) -> anyhow::Result> { let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?; + // replace body of Empty with Full + let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new())); - tokio::spawn(async move { - if let Err(e) = handle(fut).await { - println!("{:?}", e); - }; - }); + if req + .uri() + .path() + .starts_with(&(CONFIG.server.prefix.clone() + "/")) + { + tokio::spawn(async move { + CLIENTS.insert(id.clone(), (DashMap::new(), false)); + if let Err(e) = handle_wisp(fut, id.clone()).await { + error!("error while handling upgraded client: {:?}", e); + }; + CLIENTS.remove(&id) + }); + } else if CONFIG.wisp.allow_wsproxy { + let udp = req.uri().query().unwrap_or_default() == "?udp"; + tokio::spawn(async move { + CLIENTS.insert(id.clone(), (DashMap::new(), true)); + if let Err(e) = handle_wsproxy(fut, id.clone(), req.uri().path().to_string(), udp).await + { + error!("error while handling upgraded client: {:?}", e); + }; + CLIENTS.remove(&id) + }); + } else { + return Ok(non_ws_resp()); + } Ok(resp) } +fn format_stream_type(stream_type: StreamType) -> &'static str { + match stream_type { + StreamType::Tcp => "tcp", + StreamType::Udp => "udp", + StreamType::Unknown(_) => unreachable!(), + } +} + +fn generate_stats() -> Result { + let mut out = String::new(); + let len = CLIENTS.len(); + writeln!( + &mut out, + "{} clients connected{}", + len, + if len != 0 { ":" } else { "" } + )?; + + for client in CLIENTS.iter() { + let len = client.value().0.len(); + + writeln!( + &mut out, + "\tClient \"{}\"{}: {} streams connected{}", + client.key(), + if client.value().1 { " (wsproxy)" } else { "" }, + len, + if len != 0 && CONFIG.server.verbose_stats { + ":" + } else { + "" + } + )?; + + if CONFIG.server.verbose_stats { + for stream in client.value().0.iter() { + writeln!( + &mut out, + "\t\tStream \"{}\": {}", + stream.key(), + format_stream_type(stream.value().0.stream_type) + )?; + writeln!( + &mut out, + "\t\t\tRequested: {}:{}", + stream.value().0.destination_hostname, + stream.value().0.destination_port + )?; + writeln!( + &mut out, + "\t\t\tResolved: {}:{}", + stream.value().1.destination_hostname, + stream.value().1.destination_port + )?; + } + } + } + Ok(out) +} + #[tokio::main(flavor = "multi_thread")] async fn main() -> anyhow::Result<()> { + env_logger::builder() + .filter_level(CONFIG.server.log_level) + .parse_default_env() + .init(); validate_config_cache(); - println!("{}", toml::to_string_pretty(CONFIG.deref()).unwrap()); + info!("listening on {:?} with socket type {:?}", CONFIG.server.bind, CONFIG.server.socket); + + tokio::spawn(async { + let mut sig = signal(SignalKind::user_defined1()).unwrap(); + while sig.recv().await.is_some() { + info!("{}", generate_stats().unwrap()); + } + }); let listener = ServerListener::new().await?; loop { - let (stream, _) = listener.accept().await?; + let (stream, id) = listener.accept().await?; tokio::spawn(async move { let stream = TokioIo::new(stream); let fut = Builder::new() - .serve_connection(stream, service_fn(upgrade)) + .serve_connection(stream, service_fn(|req| upgrade(req, id.clone()))) .with_upgrades(); if let Err(e) = fut.await { - println!("{:?}", e); + error!("error while serving client: {:?}", e); } }); } diff --git a/server/src/stream.rs b/server/src/stream.rs index 8837fe9..ee8c169 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -5,17 +5,16 @@ use std::{ use anyhow::Context; use bytes::BytesMut; -use futures_util::AsyncBufReadExt; +use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; +use hyper::upgrade::Upgraded; +use hyper_util::rt::TokioIo; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{ - lookup_host, - tcp::{self, OwnedReadHalf, OwnedWriteHalf}, - unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream, - }, + fs::{remove_file, try_exists}, + net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream}, }; use tokio_util::either::Either; -use wisp_mux::{ConnectPacket, MuxStreamAsyncRead, MuxStreamWrite, StreamType}; +use uuid::Uuid; +use wisp_mux::{ConnectPacket, StreamType}; use crate::{config::SocketType, CONFIG}; @@ -58,6 +57,9 @@ impl ServerListener { })?, ), SocketType::Unix => { + if try_exists(&CONFIG.server.bind).await? { + remove_file(&CONFIG.server.bind).await?; + } Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| { format!("failed to bind to unix socket at `{}`", CONFIG.server.bind) })?) @@ -65,12 +67,12 @@ impl ServerListener { }) } - pub async fn accept(&self) -> anyhow::Result<(ServerStream, Option)> { + pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> { match self { Self::Tcp(x) => x .accept() .await - .map(|(x, y)| (Either::Left(x), Some(y.to_string()))) + .map(|(x, y)| (Either::Left(x), y.to_string())) .context("failed to accept tcp connection"), Self::Unix(x) => x .accept() @@ -80,7 +82,8 @@ impl ServerListener { Either::Right(x), y.as_pathname() .and_then(|x| x.to_str()) - .map(ToString::to_string), + .map(ToString::to_string) + .unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"), ) }) .context("failed to accept unix socket connection"), @@ -207,34 +210,31 @@ impl ClientStream { } } -pub async fn copy_read_fast( - mut muxrx: MuxStreamAsyncRead, - mut tcptx: OwnedWriteHalf, -) -> std::io::Result<()> { - loop { - let buf = muxrx.fill_buf().await?; - if buf.is_empty() { - tcptx.flush().await?; - return Ok(()); - } - - let i = tcptx.write(buf).await?; - if i == 0 { - return Err(std::io::ErrorKind::WriteZero.into()); - } - - muxrx.consume_unpin(i); - } +pub enum WebSocketFrame { + Data(BytesMut), + Close, + Ignore, } -#[allow(dead_code)] -pub async fn copy_write_fast( - muxtx: MuxStreamWrite, - mut tcprx: OwnedReadHalf, -) -> anyhow::Result<()> { - loop { - let mut buf = BytesMut::with_capacity(8 * 1024); - let amt = tcprx.read(&mut buf).await?; - muxtx.write(&buf[..amt]).await?; +pub struct WebSocketStreamWrapper(pub FragmentCollector>); + +impl WebSocketStreamWrapper { + pub async fn read(&mut self) -> Result { + let frame = self.0.read_frame().await?; + Ok(match frame.opcode { + OpCode::Text | OpCode::Binary => WebSocketFrame::Data(frame.payload.into()), + OpCode::Close => WebSocketFrame::Close, + _ => WebSocketFrame::Ignore, + }) + } + + pub async fn write(&mut self, data: &[u8]) -> Result<(), WebSocketError> { + self.0 + .write_frame(Frame::binary(Payload::Borrowed(data))) + .await + } + + pub async fn close(&mut self, code: u16, reason: &[u8]) -> Result<(), WebSocketError> { + self.0.write_frame(Frame::close(code, reason)).await } }