rewrite server

This commit is contained in:
Toshit Chawda 2024-07-20 22:21:51 -07:00
parent 3bf19be9f0
commit 24bfcae975
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 1301 additions and 178 deletions

224
Cargo.lock generated
View file

@ -295,9 +295,9 @@ checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952"
[[package]]
name = "cc"
version = "1.1.3"
version = "1.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18e2d530f35b40a84124146478cd16f34225306a8441998836466a2e2961c950"
checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f"
[[package]]
name = "certs-grabber"
@ -333,7 +333,6 @@ dependencies = [
"anstyle",
"clap_lex",
"strsim",
"terminal_size",
]
[[package]]
@ -354,21 +353,6 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70"
[[package]]
name = "clio"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7fc6734af48458f72f5a3fa7b840903606427d98a710256e808f76a965047d9"
dependencies = [
"cfg-if",
"clap",
"is-terminal",
"libc",
"tempfile",
"walkdir",
"windows-sys 0.42.0",
]
[[package]]
name = "colorchoice"
version = "1.0.1"
@ -545,21 +529,21 @@ dependencies = [
[[package]]
name = "epoxy-server"
version = "1.0.0"
version = "2.0.0"
dependencies = [
"anyhow",
"bytes",
"cfg-if",
"clap",
"clio",
"console-subscriber",
"dashmap",
"fastwebsockets",
"futures-util",
"http-body-util",
"hyper 1.4.1",
"hyper-util",
"lazy_static",
"regex",
"serde",
"tokio",
"tokio-util",
"toml",
"wisp-mux",
]
@ -599,8 +583,6 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]]
name = "fastwebsockets"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26da0c7b5cef45c521a6f9cdfffdfeb6c9f5804fbac332deb5ae254634c7a6be"
dependencies = [
"base64 0.21.7",
"bytes",
@ -1067,17 +1049,6 @@ dependencies = [
"hashbrown 0.14.5",
]
[[package]]
name = "is-terminal"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.0"
@ -1267,9 +1238,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "openssl"
version = "0.10.64"
version = "0.10.65"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
checksum = "c2823eb4c6453ed64055057ea8bd416eda38c71018723869dd043a3b1186115e"
dependencies = [
"bitflags 2.6.0",
"cfg-if",
@ -1299,9 +1270,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-sys"
version = "0.9.102"
version = "0.9.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2"
checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6"
dependencies = [
"cc",
"libc",
@ -1470,9 +1441,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd"
checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4"
dependencies = [
"bitflags 2.6.0",
]
@ -1601,15 +1572,6 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.23"
@ -1627,9 +1589,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.0"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
@ -1640,9 +1602,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
version = "2.11.0"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7"
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
dependencies = [
"core-foundation-sys",
"libc",
@ -1685,6 +1647,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_spanned"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0"
dependencies = [
"serde",
]
[[package]]
name = "sha1"
version = "0.10.6"
@ -1824,30 +1795,20 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "terminal_size"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21bebf2b7c9e0a515f6e0f8c51dc0f8e4696391e6f1ff30379559f8365fb0df7"
dependencies = [
"rustix",
"windows-sys 0.48.0",
]
[[package]]
name = "thiserror"
version = "1.0.62"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.62"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [
"proc-macro2",
"quote",
@ -1866,9 +1827,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.38.0"
version = "1.38.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a"
checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df"
dependencies = [
"backtrace",
"bytes",
@ -1940,6 +1901,40 @@ dependencies = [
"tokio",
]
[[package]]
name = "toml"
version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac2caab0bf757388c6c0ae23b3293fdb463fee59434529014f85e3263b995c28"
dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"toml_edit",
]
[[package]]
name = "toml_datetime"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "278f3d518e152219c994ce877758516bca5e118eaed6996192a774fb9fbf0788"
dependencies = [
"indexmap 2.2.6",
"serde",
"serde_spanned",
"toml_datetime",
"winnow",
]
[[package]]
name = "tonic"
version = "0.10.2"
@ -2100,16 +2095,6 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@ -2247,30 +2232,6 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
@ -2320,12 +2281,6 @@ dependencies = [
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
@ -2338,12 +2293,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
@ -2356,12 +2305,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
@ -2380,12 +2323,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
@ -2398,12 +2335,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
@ -2416,12 +2347,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
@ -2434,12 +2359,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
@ -2452,6 +2371,15 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374ec40a2d767a3c1b4972d9475ecd557356637be906f2cb3f7fe17a6eb5e22f"
dependencies = [
"memchr",
]
[[package]]
name = "wisp-mux"
version = "5.0.0"

View file

@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = ["client", "wisp", "simple-wisp-client", "certs-grabber"]
members = ["server", "client", "wisp", "simple-wisp-client", "certs-grabber"]
[profile.release]
lto = true

20
server/Cargo.toml Normal file
View file

@ -0,0 +1,20 @@
[package]
name = "epoxy-server"
version = "2.0.0"
edition = "2021"
[dependencies]
anyhow = "1.0.86"
bytes = "1.6.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
futures-util = "0.3.30"
http-body-util = "0.1.2"
hyper = { version = "1.4.1", features = ["server", "http1"] }
hyper-util = { version = "0.1.6", features = ["tokio"] }
lazy_static = "1.5.0"
regex = "1.10.5"
serde = { version = "1.0.204", features = ["derive"] }
tokio = { version = "1.38.1", features = ["full"] }
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
toml = "0.8.15"
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] }

491
server/flamegraph.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 460 KiB

207
server/src/config.rs Normal file
View file

@ -0,0 +1,207 @@
use std::{collections::HashMap, ops::RangeInclusive};
use lazy_static::lazy_static;
use regex::RegexSet;
use serde::{Deserialize, Serialize};
use wisp_mux::extensions::{
password::PasswordProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder,
};
use crate::CONFIG;
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
struct ConfigCache {
pub blocked_ports: Vec<RangeInclusive<u16>>,
pub allowed_ports: Vec<RangeInclusive<u16>>,
pub allowed_hosts: RegexSet,
pub blocked_hosts: RegexSet,
pub wisp_config: (Option<Vec<AnyProtocolExtensionBuilder>>, u32),
}
lazy_static! {
static ref CONFIG_CACHE: ConfigCache = {
ConfigCache {
allowed_ports: CONFIG
.stream
.allow_ports
.iter()
.map(|x| x[0]..=x[1])
.collect(),
blocked_ports: CONFIG
.stream
.block_ports
.iter()
.map(|x| x[0]..=x[1])
.collect(),
allowed_hosts: RegexSet::new(&CONFIG.stream.allow_hosts).unwrap(),
blocked_hosts: RegexSet::new(&CONFIG.stream.block_hosts).unwrap(),
wisp_config: CONFIG.wisp.to_opts_inner().unwrap(),
}
};
}
pub fn validate_config_cache() {
let _ = CONFIG_CACHE.wisp_config;
}
#[derive(Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum SocketType {
#[default]
Tcp,
Unix,
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
pub bind: String,
pub socket: SocketType,
pub resolve_ipv6: bool,
pub max_message_size: usize,
// TODO
// prefix: String,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
bind: "127.0.0.1:4000".to_owned(),
socket: SocketType::default(),
resolve_ipv6: false,
max_message_size: 64 * 1024,
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "lowercase")]
pub enum ProtocolExtension {
Udp,
Password,
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct WispConfig {
pub wisp_v2: bool,
pub buffer_size: u32,
pub extensions: Vec<ProtocolExtension>,
pub password_extension_users: HashMap<String, String>,
// TODO
// enable_wsproxy: bool,
}
impl Default for WispConfig {
fn default() -> Self {
Self {
buffer_size: 512,
wisp_v2: false,
extensions: Vec::new(),
password_extension_users: HashMap::new(),
}
}
}
impl WispConfig {
pub fn to_opts_inner(&self) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, u32)> {
if self.wisp_v2 {
let mut extensions: Vec<Box<dyn ProtocolExtensionBuilder + Sync + Send>> = Vec::new();
if self.extensions.contains(&ProtocolExtension::Udp) {
extensions.push(Box::new(UdpProtocolExtensionBuilder));
}
if self.extensions.contains(&ProtocolExtension::Password) {
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
self.password_extension_users.clone(),
)));
}
Ok((Some(extensions), self.buffer_size))
} else {
Ok((None, self.buffer_size))
}
}
pub fn to_opts(&self) -> (Option<&'static [AnyProtocolExtensionBuilder]>, u32) {
(
CONFIG_CACHE.wisp_config.0.as_deref(),
CONFIG_CACHE.wisp_config.1,
)
}
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
pub allow_udp: bool,
pub allow_direct_ip: bool,
pub allow_loopback: bool,
pub allow_multicast: bool,
pub allow_global: bool,
pub allow_non_global: bool,
pub allow_hosts: Vec<String>,
pub block_hosts: Vec<String>,
pub allow_ports: Vec<Vec<u16>>,
pub block_ports: Vec<Vec<u16>>,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
allow_udp: true,
allow_direct_ip: true,
allow_loopback: true,
allow_multicast: true,
allow_global: true,
allow_non_global: true,
allow_hosts: Vec::new(),
block_hosts: Vec::new(),
allow_ports: Vec::new(),
block_ports: Vec::new(),
}
}
}
impl StreamConfig {
pub fn allowed_ports(&self) -> &'static [RangeInclusive<u16>] {
&CONFIG_CACHE.allowed_ports
}
pub fn blocked_ports(&self) -> &'static [RangeInclusive<u16>] {
&CONFIG_CACHE.blocked_ports
}
pub fn allowed_hosts(&self) -> &RegexSet {
&CONFIG_CACHE.allowed_hosts
}
pub fn blocked_hosts(&self) -> &RegexSet {
&CONFIG_CACHE.blocked_hosts
}
}
#[derive(Serialize, Deserialize, Default)]
#[serde(default)]
pub struct Config {
pub server: ServerConfig,
pub wisp: WispConfig,
pub stream: StreamConfig,
}

197
server/src/main.rs Normal file
View file

@ -0,0 +1,197 @@
#![feature(ip)]
use std::{env::args, fs::read_to_string, ops::Deref};
use anyhow::Context;
use bytes::Bytes;
use config::{validate_config_cache, Config};
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
use http_body_util::Empty;
use hyper::{body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use lazy_static::lazy_static;
use stream::{
copy_read_fast, ClientStream, ResolvedPacket, ServerListener, ServerStream, ServerStreamExt,
};
use tokio::{io::copy, select};
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux};
mod config;
mod stream;
lazy_static! {
pub static ref CONFIG: Config = {
if let Some(path) = args().nth(1) {
toml::from_str(&read_to_string(path).unwrap()).unwrap()
} else {
Config::default()
}
};
}
async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream) {
let Ok(resolved) = ClientStream::resolve(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
ResolvedPacket::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
}
ResolvedPacket::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
return;
}
};
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
match stream {
ClientStream::Tcp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split();
let (mut tcpread, tcpwrite) = stream.into_split();
let mut muxwrite = muxwrite.compat_write();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy(&mut tcpread, &mut muxwrite) => {x?;},
}
// TODO why is copy_write_fast not working?
/*
let (muxread, muxwrite) = muxstream.into_split();
let muxread = muxread.into_stream().into_asyncread();
let (mut tcpread, tcpwrite) = stream.into_split();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy_write_fast(muxwrite, tcpread) => {x?;},
}
*/
Ok(())
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Udp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let mut data = vec![0u8; 65507];
loop {
select! {
size = stream.recv(&mut data) => {
let size = size?;
muxstream.write(&data[..size]).await?;
}
data = muxstream.read() => {
if let Some(data) = data {
stream.send(&data).await?;
} else {
break Ok(());
}
}
}
}
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Invalid => {
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
}
ClientStream::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
}
};
}
async fn handle(fut: UpgradeFut) -> anyhow::Result<()> {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
let (read, write) = ws.split(|x| {
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
assert_eq!(parts.read_buf.len(), 0);
parts.io.into_inner().split()
});
let read = FragmentCollectorRead::new(read);
let (extensions, buffer_size) = CONFIG.wisp.to_opts_inner()?;
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref())
.await
.context("failed to create server multiplexor")?
.with_no_required_extensions();
tokio::spawn(tokio::task::unconstrained(fut));
while let Some((connect, stream)) = mux.server_new_stream().await {
tokio::spawn(tokio::task::unconstrained(handle_stream(connect, stream)));
}
Ok(())
}
type Body = Empty<Bytes>;
async fn upgrade(mut req: Request<Incoming>) -> anyhow::Result<Response<Body>> {
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
tokio::spawn(async move {
if let Err(e) = handle(fut).await {
println!("{:?}", e);
};
});
Ok(resp)
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
validate_config_cache();
println!("{}", toml::to_string_pretty(CONFIG.deref()).unwrap());
let listener = ServerListener::new().await?;
loop {
let (stream, _) = listener.accept().await?;
tokio::spawn(async move {
let stream = TokioIo::new(stream);
let fut = Builder::new()
.serve_connection(stream, service_fn(upgrade))
.with_upgrades();
if let Err(e) = fut.await {
println!("{:?}", e);
}
});
}
}

240
server/src/stream.rs Normal file
View file

@ -0,0 +1,240 @@
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
str::FromStr,
};
use anyhow::Context;
use bytes::BytesMut;
use futures_util::AsyncBufReadExt;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
lookup_host,
tcp::{self, OwnedReadHalf, OwnedWriteHalf},
unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream,
},
};
use tokio_util::either::Either;
use wisp_mux::{ConnectPacket, MuxStreamAsyncRead, MuxStreamWrite, StreamType};
use crate::{config::SocketType, CONFIG};
pub enum ServerListener {
Tcp(TcpListener),
Unix(UnixListener),
}
pub type ServerStream = Either<TcpStream, UnixStream>;
pub type ServerStreamRead = Either<tcp::OwnedReadHalf, unix::OwnedReadHalf>;
pub type ServerStreamWrite = Either<tcp::OwnedWriteHalf, unix::OwnedWriteHalf>;
pub trait ServerStreamExt {
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
}
impl ServerStreamExt for ServerStream {
fn split(self) -> (ServerStreamRead, ServerStreamWrite) {
match self {
Self::Left(x) => {
let (r, w) = x.into_split();
(Either::Left(r), Either::Left(w))
}
Self::Right(x) => {
let (r, w) = x.into_split();
(Either::Right(r), Either::Right(w))
}
}
}
}
impl ServerListener {
pub async fn new() -> anyhow::Result<Self> {
Ok(match CONFIG.server.socket {
SocketType::Tcp => Self::Tcp(
TcpListener::bind(&CONFIG.server.bind)
.await
.with_context(|| {
format!("failed to bind to tcp address `{}`", CONFIG.server.bind)
})?,
),
SocketType::Unix => {
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
})?)
}
})
}
pub async fn accept(&self) -> anyhow::Result<(ServerStream, Option<String>)> {
match self {
Self::Tcp(x) => x
.accept()
.await
.map(|(x, y)| (Either::Left(x), Some(y.to_string())))
.context("failed to accept tcp connection"),
Self::Unix(x) => x
.accept()
.await
.map(|(x, y)| {
(
Either::Right(x),
y.as_pathname()
.and_then(|x| x.to_str())
.map(ToString::to_string),
)
})
.context("failed to accept unix socket connection"),
}
}
}
pub enum ClientStream {
Tcp(TcpStream),
Udp(UdpSocket),
Blocked,
Invalid,
}
pub enum ResolvedPacket {
Valid(ConnectPacket),
NoResolvedAddrs,
Blocked,
}
impl ClientStream {
pub async fn resolve(packet: ConnectPacket) -> anyhow::Result<ResolvedPacket> {
if !CONFIG.stream.allow_udp && packet.stream_type == StreamType::Udp {
return Ok(ResolvedPacket::Blocked);
}
if CONFIG
.stream
.blocked_ports()
.iter()
.any(|x| x.contains(&packet.destination_port))
&& !CONFIG
.stream
.allowed_ports()
.iter()
.any(|x| x.contains(&packet.destination_port))
{
return Ok(ResolvedPacket::Blocked);
}
if let Ok(addr) = IpAddr::from_str(&packet.destination_hostname) {
if !CONFIG.stream.allow_direct_ip {
return Ok(ResolvedPacket::Blocked);
}
if addr.is_loopback() && !CONFIG.stream.allow_loopback {
return Ok(ResolvedPacket::Blocked);
}
if addr.is_multicast() && !CONFIG.stream.allow_multicast {
return Ok(ResolvedPacket::Blocked);
}
if (addr.is_global() && !CONFIG.stream.allow_global)
|| (!addr.is_global() && !CONFIG.stream.allow_non_global)
{
return Ok(ResolvedPacket::Blocked);
}
}
if CONFIG
.stream
.blocked_hosts()
.is_match(&packet.destination_hostname)
&& !CONFIG
.stream
.allowed_hosts()
.is_match(&packet.destination_hostname)
{
return Ok(ResolvedPacket::Blocked);
}
let packet = lookup_host(packet.destination_hostname + ":0")
.await
.context("failed to resolve hostname")?
.filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4())
.map(|x| ConnectPacket {
stream_type: packet.stream_type,
destination_hostname: x.ip().to_string(),
destination_port: packet.destination_port,
})
.next();
Ok(packet
.map(ResolvedPacket::Valid)
.unwrap_or(ResolvedPacket::NoResolvedAddrs))
}
pub async fn connect(packet: ConnectPacket) -> anyhow::Result<Self> {
let ipaddr = IpAddr::from_str(&packet.destination_hostname)
.context("failed to parse hostname as ipaddr")?;
match packet.stream_type {
StreamType::Tcp => {
let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.destination_port))
.await
.with_context(|| {
format!("failed to connect to host {}", packet.destination_hostname)
})?;
Ok(ClientStream::Tcp(stream))
}
StreamType::Udp => {
if !CONFIG.stream.allow_udp {
return Ok(ClientStream::Blocked);
}
let bind_addr = if ipaddr.is_ipv4() {
SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)
} else {
SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
};
let stream = UdpSocket::bind(bind_addr).await?;
stream
.connect(SocketAddr::new(ipaddr, packet.destination_port))
.await?;
Ok(ClientStream::Udp(stream))
}
StreamType::Unknown(_) => Ok(ClientStream::Invalid),
}
}
}
pub async fn copy_read_fast(
mut muxrx: MuxStreamAsyncRead,
mut tcptx: OwnedWriteHalf,
) -> std::io::Result<()> {
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume_unpin(i);
}
}
#[allow(dead_code)]
pub async fn copy_write_fast(
muxtx: MuxStreamWrite,
mut tcprx: OwnedReadHalf,
) -> anyhow::Result<()> {
loop {
let mut buf = BytesMut::with_capacity(8 * 1024);
let amt = tcprx.read(&mut buf).await?;
muxtx.write(&buf[..amt]).await?;
}
}

View file

@ -9,7 +9,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use crate::{ws::LockedWebSocketWrite, WispError};
fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> {
fn match_payload(payload: Payload<'_>) -> crate::ws::Payload<'_> {
match payload {
Payload::Bytes(x) => crate::ws::Payload::Bytes(x),
Payload::Owned(x) => crate::ws::Payload::Bytes(BytesMut::from(x.deref())),
@ -18,7 +18,7 @@ fn match_payload<'a>(payload: Payload<'a>) -> crate::ws::Payload<'a> {
}
}
fn match_payload_reverse<'a>(payload: crate::ws::Payload<'a>) -> Payload<'a> {
fn match_payload_reverse(payload: crate::ws::Payload<'_>) -> Payload<'_> {
match payload {
crate::ws::Payload::Bytes(x) => Payload::Bytes(x),
crate::ws::Payload::Borrowed(x) => Payload::Borrowed(x),
@ -94,6 +94,18 @@ impl<S: AsyncWrite + Unpin + Send> crate::ws::WebSocketWrite for WebSocketWrite<
self.write_frame(frame.into()).await.map_err(|e| e.into())
}
async fn wisp_write_split(&mut self, header: crate::ws::Frame<'_>, body: crate::ws::Frame<'_>) -> Result<(), WispError> {
let mut header = Frame::from(header);
header.fin = false;
self.write_frame(header).await?;
let mut body = Frame::from(body);
body.opcode = OpCode::Continuation;
self.write_frame(body).await?;
Ok(())
}
async fn wisp_close(&mut self) -> Result<(), WispError> {
self.write_frame(Frame::close(CloseCode::Normal.into(), b""))
.await

View file

@ -12,7 +12,7 @@ use futures::{
ready, select,
stream::{self, IntoAsyncRead},
task::{noop_waker_ref, Context, Poll},
AsyncBufRead, AsyncRead, AsyncWrite, Future, FutureExt, Sink, Stream, TryStreamExt,
AsyncBufRead, AsyncRead, AsyncWrite, FutureExt, Sink, Stream, TryStreamExt,
};
use pin_project_lite::pin_project;
use std::{
@ -79,11 +79,18 @@ impl MuxStreamRead {
Some(bytes)
}
pub(crate) fn into_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
pub(crate) fn into_inner_stream(self) -> Pin<Box<dyn Stream<Item = Bytes> + Send>> {
Box::pin(stream::unfold(self, |rx| async move {
Some((rx.read().await?, rx))
}))
}
/// Turn the read half into one that implements futures `Stream`, consuming it.
pub fn into_stream(self) -> MuxStreamIoStream {
MuxStreamIoStream {
rx: self.into_inner_stream(),
}
}
}
/// Write side of a multiplexor stream.
@ -101,9 +108,10 @@ pub struct MuxStreamWrite {
}
impl MuxStreamWrite {
pub(crate) async fn write_payload_internal(
pub(crate) async fn write_payload_internal<'a>(
&self,
frame: Frame<'static>,
header: Frame<'static>,
body: Frame<'a>,
) -> Result<(), WispError> {
if self.role == Role::Client
&& self.stream_type == StreamType::Tcp
@ -115,7 +123,7 @@ impl MuxStreamWrite {
return Err(WispError::StreamAlreadyClosed);
}
self.tx.write_frame(frame).await?;
self.tx.write_split(header, body).await?;
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
self.flow_control.store(
@ -127,12 +135,13 @@ impl MuxStreamWrite {
}
/// Write a payload to the stream.
pub fn write_payload<'a>(
&'a self,
data: Payload<'_>,
) -> impl Future<Output = Result<(), WispError>> + 'a {
let frame: Frame<'static> = Frame::from(Packet::new_data(self.stream_id, data));
self.write_payload_internal(frame)
pub async fn write_payload(&self, data: Payload<'_>) -> Result<(), WispError> {
let frame: Frame<'static> = Frame::from(Packet::new_data(
self.stream_id,
Payload::Bytes(BytesMut::new()),
));
self.write_payload_internal(frame, Frame::binary(data))
.await
}
/// Write data to the stream.
@ -188,12 +197,14 @@ impl MuxStreamWrite {
Ok(())
}
pub(crate) fn into_sink(self) -> Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>> {
pub(crate) fn into_inner_sink(
self,
) -> Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>> {
let handle = self.get_close_handle();
Box::pin(sink_unfold::unfold(
self,
|tx, data| async move {
tx.write_payload_internal(data).await?;
tx.write_payload(data).await?;
Ok(tx)
},
handle,
@ -203,6 +214,13 @@ impl MuxStreamWrite {
},
))
}
/// Turn the write half into one that implements futures `Sink`, consuming it.
pub fn into_sink(self) -> MuxStreamIoSink {
MuxStreamIoSink {
tx: self.into_inner_sink(),
}
}
}
impl Drop for MuxStreamWrite {
@ -316,13 +334,8 @@ impl MuxStream {
/// Turn the stream into one that implements futures `Stream + Sink`, consuming it.
pub fn into_io(self) -> MuxStreamIo {
MuxStreamIo {
rx: MuxStreamIoStream {
rx: self.rx.into_stream(),
},
tx: MuxStreamIoSink {
tx: self.tx.into_sink(),
stream_id: self.stream_id,
},
}
}
}
@ -456,8 +469,7 @@ pin_project! {
/// Write side of a multiplexor stream that implements futures `Sink`.
pub struct MuxStreamIoSink {
#[pin]
tx: Pin<Box<dyn Sink<Frame<'static>, Error = WispError> + Send>>,
stream_id: u32,
tx: Pin<Box<dyn Sink<Payload<'static>, Error = WispError> + Send>>,
}
}
@ -477,13 +489,9 @@ impl Sink<&[u8]> for MuxStreamIoSink {
.map_err(std::io::Error::other)
}
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
let stream_id = self.stream_id;
self.project()
.tx
.start_send(Frame::from(Packet::new_data(
stream_id,
Payload::Borrowed(item),
)))
.start_send(Payload::Bytes(BytesMut::from(item)))
.map_err(std::io::Error::other)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {

View file

@ -166,6 +166,18 @@ pub trait WebSocketWrite {
/// Close the socket.
async fn wisp_close(&mut self) -> Result<(), WispError>;
/// Write a split frame to the socket.
async fn wisp_write_split(
&mut self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
let mut payload = BytesMut::from(header.payload);
payload.extend_from_slice(&body.payload);
self.wisp_write_frame(Frame::binary(Payload::Bytes(payload)))
.await
}
}
/// Locked WebSocket.
@ -183,6 +195,14 @@ impl LockedWebSocketWrite {
self.0.lock().await.wisp_write_frame(frame).await
}
pub(crate) async fn write_split(
&self,
header: Frame<'_>,
body: Frame<'_>,
) -> Result<(), WispError> {
self.0.lock().await.wisp_write_split(header, body).await
}
/// Close the websocket.
pub async fn close(&self) -> Result<(), WispError> {
self.0.lock().await.wisp_close().await