rewrite server

This commit is contained in:
Toshit Chawda 2024-07-20 22:21:51 -07:00
parent 3bf19be9f0
commit 24bfcae975
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 1301 additions and 178 deletions

20
server/Cargo.toml Normal file
View file

@ -0,0 +1,20 @@
[package]
name = "epoxy-server"
version = "2.0.0"
edition = "2021"
[dependencies]
anyhow = "1.0.86"
bytes = "1.6.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
futures-util = "0.3.30"
http-body-util = "0.1.2"
hyper = { version = "1.4.1", features = ["server", "http1"] }
hyper-util = { version = "0.1.6", features = ["tokio"] }
lazy_static = "1.5.0"
regex = "1.10.5"
serde = { version = "1.0.204", features = ["derive"] }
tokio = { version = "1.38.1", features = ["full"] }
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
toml = "0.8.15"
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] }

491
server/flamegraph.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 460 KiB

207
server/src/config.rs Normal file
View file

@ -0,0 +1,207 @@
use std::{collections::HashMap, ops::RangeInclusive};
use lazy_static::lazy_static;
use regex::RegexSet;
use serde::{Deserialize, Serialize};
use wisp_mux::extensions::{
password::PasswordProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
ProtocolExtensionBuilder,
};
use crate::CONFIG;
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
struct ConfigCache {
pub blocked_ports: Vec<RangeInclusive<u16>>,
pub allowed_ports: Vec<RangeInclusive<u16>>,
pub allowed_hosts: RegexSet,
pub blocked_hosts: RegexSet,
pub wisp_config: (Option<Vec<AnyProtocolExtensionBuilder>>, u32),
}
lazy_static! {
static ref CONFIG_CACHE: ConfigCache = {
ConfigCache {
allowed_ports: CONFIG
.stream
.allow_ports
.iter()
.map(|x| x[0]..=x[1])
.collect(),
blocked_ports: CONFIG
.stream
.block_ports
.iter()
.map(|x| x[0]..=x[1])
.collect(),
allowed_hosts: RegexSet::new(&CONFIG.stream.allow_hosts).unwrap(),
blocked_hosts: RegexSet::new(&CONFIG.stream.block_hosts).unwrap(),
wisp_config: CONFIG.wisp.to_opts_inner().unwrap(),
}
};
}
pub fn validate_config_cache() {
let _ = CONFIG_CACHE.wisp_config;
}
#[derive(Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum SocketType {
#[default]
Tcp,
Unix,
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
pub bind: String,
pub socket: SocketType,
pub resolve_ipv6: bool,
pub max_message_size: usize,
// TODO
// prefix: String,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
bind: "127.0.0.1:4000".to_owned(),
socket: SocketType::default(),
resolve_ipv6: false,
max_message_size: 64 * 1024,
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "lowercase")]
pub enum ProtocolExtension {
Udp,
Password,
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct WispConfig {
pub wisp_v2: bool,
pub buffer_size: u32,
pub extensions: Vec<ProtocolExtension>,
pub password_extension_users: HashMap<String, String>,
// TODO
// enable_wsproxy: bool,
}
impl Default for WispConfig {
fn default() -> Self {
Self {
buffer_size: 512,
wisp_v2: false,
extensions: Vec::new(),
password_extension_users: HashMap::new(),
}
}
}
impl WispConfig {
pub fn to_opts_inner(&self) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, u32)> {
if self.wisp_v2 {
let mut extensions: Vec<Box<dyn ProtocolExtensionBuilder + Sync + Send>> = Vec::new();
if self.extensions.contains(&ProtocolExtension::Udp) {
extensions.push(Box::new(UdpProtocolExtensionBuilder));
}
if self.extensions.contains(&ProtocolExtension::Password) {
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
self.password_extension_users.clone(),
)));
}
Ok((Some(extensions), self.buffer_size))
} else {
Ok((None, self.buffer_size))
}
}
pub fn to_opts(&self) -> (Option<&'static [AnyProtocolExtensionBuilder]>, u32) {
(
CONFIG_CACHE.wisp_config.0.as_deref(),
CONFIG_CACHE.wisp_config.1,
)
}
}
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
pub allow_udp: bool,
pub allow_direct_ip: bool,
pub allow_loopback: bool,
pub allow_multicast: bool,
pub allow_global: bool,
pub allow_non_global: bool,
pub allow_hosts: Vec<String>,
pub block_hosts: Vec<String>,
pub allow_ports: Vec<Vec<u16>>,
pub block_ports: Vec<Vec<u16>>,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
allow_udp: true,
allow_direct_ip: true,
allow_loopback: true,
allow_multicast: true,
allow_global: true,
allow_non_global: true,
allow_hosts: Vec::new(),
block_hosts: Vec::new(),
allow_ports: Vec::new(),
block_ports: Vec::new(),
}
}
}
impl StreamConfig {
pub fn allowed_ports(&self) -> &'static [RangeInclusive<u16>] {
&CONFIG_CACHE.allowed_ports
}
pub fn blocked_ports(&self) -> &'static [RangeInclusive<u16>] {
&CONFIG_CACHE.blocked_ports
}
pub fn allowed_hosts(&self) -> &RegexSet {
&CONFIG_CACHE.allowed_hosts
}
pub fn blocked_hosts(&self) -> &RegexSet {
&CONFIG_CACHE.blocked_hosts
}
}
#[derive(Serialize, Deserialize, Default)]
#[serde(default)]
pub struct Config {
pub server: ServerConfig,
pub wisp: WispConfig,
pub stream: StreamConfig,
}

197
server/src/main.rs Normal file
View file

@ -0,0 +1,197 @@
#![feature(ip)]
use std::{env::args, fs::read_to_string, ops::Deref};
use anyhow::Context;
use bytes::Bytes;
use config::{validate_config_cache, Config};
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
use http_body_util::Empty;
use hyper::{body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use lazy_static::lazy_static;
use stream::{
copy_read_fast, ClientStream, ResolvedPacket, ServerListener, ServerStream, ServerStreamExt,
};
use tokio::{io::copy, select};
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux};
mod config;
mod stream;
lazy_static! {
pub static ref CONFIG: Config = {
if let Some(path) = args().nth(1) {
toml::from_str(&read_to_string(path).unwrap()).unwrap()
} else {
Config::default()
}
};
}
async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream) {
let Ok(resolved) = ClientStream::resolve(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
ResolvedPacket::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
}
ResolvedPacket::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
return;
}
};
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
match stream {
ClientStream::Tcp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split();
let (mut tcpread, tcpwrite) = stream.into_split();
let mut muxwrite = muxwrite.compat_write();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy(&mut tcpread, &mut muxwrite) => {x?;},
}
// TODO why is copy_write_fast not working?
/*
let (muxread, muxwrite) = muxstream.into_split();
let muxread = muxread.into_stream().into_asyncread();
let (mut tcpread, tcpwrite) = stream.into_split();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy_write_fast(muxwrite, tcpread) => {x?;},
}
*/
Ok(())
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Udp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let mut data = vec![0u8; 65507];
loop {
select! {
size = stream.recv(&mut data) => {
let size = size?;
muxstream.write(&data[..size]).await?;
}
data = muxstream.read() => {
if let Some(data) = data {
stream.send(&data).await?;
} else {
break Ok(());
}
}
}
}
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Invalid => {
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
}
ClientStream::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
}
};
}
async fn handle(fut: UpgradeFut) -> anyhow::Result<()> {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
let (read, write) = ws.split(|x| {
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
assert_eq!(parts.read_buf.len(), 0);
parts.io.into_inner().split()
});
let read = FragmentCollectorRead::new(read);
let (extensions, buffer_size) = CONFIG.wisp.to_opts_inner()?;
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref())
.await
.context("failed to create server multiplexor")?
.with_no_required_extensions();
tokio::spawn(tokio::task::unconstrained(fut));
while let Some((connect, stream)) = mux.server_new_stream().await {
tokio::spawn(tokio::task::unconstrained(handle_stream(connect, stream)));
}
Ok(())
}
type Body = Empty<Bytes>;
async fn upgrade(mut req: Request<Incoming>) -> anyhow::Result<Response<Body>> {
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
tokio::spawn(async move {
if let Err(e) = handle(fut).await {
println!("{:?}", e);
};
});
Ok(resp)
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
validate_config_cache();
println!("{}", toml::to_string_pretty(CONFIG.deref()).unwrap());
let listener = ServerListener::new().await?;
loop {
let (stream, _) = listener.accept().await?;
tokio::spawn(async move {
let stream = TokioIo::new(stream);
let fut = Builder::new()
.serve_connection(stream, service_fn(upgrade))
.with_upgrades();
if let Err(e) = fut.await {
println!("{:?}", e);
}
});
}
}

240
server/src/stream.rs Normal file
View file

@ -0,0 +1,240 @@
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
str::FromStr,
};
use anyhow::Context;
use bytes::BytesMut;
use futures_util::AsyncBufReadExt;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
lookup_host,
tcp::{self, OwnedReadHalf, OwnedWriteHalf},
unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream,
},
};
use tokio_util::either::Either;
use wisp_mux::{ConnectPacket, MuxStreamAsyncRead, MuxStreamWrite, StreamType};
use crate::{config::SocketType, CONFIG};
pub enum ServerListener {
Tcp(TcpListener),
Unix(UnixListener),
}
pub type ServerStream = Either<TcpStream, UnixStream>;
pub type ServerStreamRead = Either<tcp::OwnedReadHalf, unix::OwnedReadHalf>;
pub type ServerStreamWrite = Either<tcp::OwnedWriteHalf, unix::OwnedWriteHalf>;
pub trait ServerStreamExt {
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
}
impl ServerStreamExt for ServerStream {
fn split(self) -> (ServerStreamRead, ServerStreamWrite) {
match self {
Self::Left(x) => {
let (r, w) = x.into_split();
(Either::Left(r), Either::Left(w))
}
Self::Right(x) => {
let (r, w) = x.into_split();
(Either::Right(r), Either::Right(w))
}
}
}
}
impl ServerListener {
pub async fn new() -> anyhow::Result<Self> {
Ok(match CONFIG.server.socket {
SocketType::Tcp => Self::Tcp(
TcpListener::bind(&CONFIG.server.bind)
.await
.with_context(|| {
format!("failed to bind to tcp address `{}`", CONFIG.server.bind)
})?,
),
SocketType::Unix => {
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
})?)
}
})
}
pub async fn accept(&self) -> anyhow::Result<(ServerStream, Option<String>)> {
match self {
Self::Tcp(x) => x
.accept()
.await
.map(|(x, y)| (Either::Left(x), Some(y.to_string())))
.context("failed to accept tcp connection"),
Self::Unix(x) => x
.accept()
.await
.map(|(x, y)| {
(
Either::Right(x),
y.as_pathname()
.and_then(|x| x.to_str())
.map(ToString::to_string),
)
})
.context("failed to accept unix socket connection"),
}
}
}
pub enum ClientStream {
Tcp(TcpStream),
Udp(UdpSocket),
Blocked,
Invalid,
}
pub enum ResolvedPacket {
Valid(ConnectPacket),
NoResolvedAddrs,
Blocked,
}
impl ClientStream {
pub async fn resolve(packet: ConnectPacket) -> anyhow::Result<ResolvedPacket> {
if !CONFIG.stream.allow_udp && packet.stream_type == StreamType::Udp {
return Ok(ResolvedPacket::Blocked);
}
if CONFIG
.stream
.blocked_ports()
.iter()
.any(|x| x.contains(&packet.destination_port))
&& !CONFIG
.stream
.allowed_ports()
.iter()
.any(|x| x.contains(&packet.destination_port))
{
return Ok(ResolvedPacket::Blocked);
}
if let Ok(addr) = IpAddr::from_str(&packet.destination_hostname) {
if !CONFIG.stream.allow_direct_ip {
return Ok(ResolvedPacket::Blocked);
}
if addr.is_loopback() && !CONFIG.stream.allow_loopback {
return Ok(ResolvedPacket::Blocked);
}
if addr.is_multicast() && !CONFIG.stream.allow_multicast {
return Ok(ResolvedPacket::Blocked);
}
if (addr.is_global() && !CONFIG.stream.allow_global)
|| (!addr.is_global() && !CONFIG.stream.allow_non_global)
{
return Ok(ResolvedPacket::Blocked);
}
}
if CONFIG
.stream
.blocked_hosts()
.is_match(&packet.destination_hostname)
&& !CONFIG
.stream
.allowed_hosts()
.is_match(&packet.destination_hostname)
{
return Ok(ResolvedPacket::Blocked);
}
let packet = lookup_host(packet.destination_hostname + ":0")
.await
.context("failed to resolve hostname")?
.filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4())
.map(|x| ConnectPacket {
stream_type: packet.stream_type,
destination_hostname: x.ip().to_string(),
destination_port: packet.destination_port,
})
.next();
Ok(packet
.map(ResolvedPacket::Valid)
.unwrap_or(ResolvedPacket::NoResolvedAddrs))
}
pub async fn connect(packet: ConnectPacket) -> anyhow::Result<Self> {
let ipaddr = IpAddr::from_str(&packet.destination_hostname)
.context("failed to parse hostname as ipaddr")?;
match packet.stream_type {
StreamType::Tcp => {
let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.destination_port))
.await
.with_context(|| {
format!("failed to connect to host {}", packet.destination_hostname)
})?;
Ok(ClientStream::Tcp(stream))
}
StreamType::Udp => {
if !CONFIG.stream.allow_udp {
return Ok(ClientStream::Blocked);
}
let bind_addr = if ipaddr.is_ipv4() {
SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)
} else {
SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
};
let stream = UdpSocket::bind(bind_addr).await?;
stream
.connect(SocketAddr::new(ipaddr, packet.destination_port))
.await?;
Ok(ClientStream::Udp(stream))
}
StreamType::Unknown(_) => Ok(ClientStream::Invalid),
}
}
}
pub async fn copy_read_fast(
mut muxrx: MuxStreamAsyncRead,
mut tcptx: OwnedWriteHalf,
) -> std::io::Result<()> {
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume_unpin(i);
}
}
#[allow(dead_code)]
pub async fn copy_write_fast(
muxtx: MuxStreamWrite,
mut tcprx: OwnedReadHalf,
) -> anyhow::Result<()> {
loop {
let mut buf = BytesMut::with_capacity(8 * 1024);
let amt = tcprx.read(&mut buf).await?;
muxtx.write(&buf[..amt]).await?;
}
}