mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
rewrite server
This commit is contained in:
parent
3bf19be9f0
commit
24bfcae975
10 changed files with 1301 additions and 178 deletions
20
server/Cargo.toml
Normal file
20
server/Cargo.toml
Normal file
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "epoxy-server"
|
||||
version = "2.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.86"
|
||||
bytes = "1.6.1"
|
||||
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
|
||||
futures-util = "0.3.30"
|
||||
http-body-util = "0.1.2"
|
||||
hyper = { version = "1.4.1", features = ["server", "http1"] }
|
||||
hyper-util = { version = "0.1.6", features = ["tokio"] }
|
||||
lazy_static = "1.5.0"
|
||||
regex = "1.10.5"
|
||||
serde = { version = "1.0.204", features = ["derive"] }
|
||||
tokio = { version = "1.38.1", features = ["full"] }
|
||||
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
|
||||
toml = "0.8.15"
|
||||
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] }
|
491
server/flamegraph.svg
Normal file
491
server/flamegraph.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 460 KiB |
207
server/src/config.rs
Normal file
207
server/src/config.rs
Normal file
|
@ -0,0 +1,207 @@
|
|||
use std::{collections::HashMap, ops::RangeInclusive};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use regex::RegexSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wisp_mux::extensions::{
|
||||
password::PasswordProtocolExtensionBuilder, udp::UdpProtocolExtensionBuilder,
|
||||
ProtocolExtensionBuilder,
|
||||
};
|
||||
|
||||
use crate::CONFIG;
|
||||
|
||||
type AnyProtocolExtensionBuilder = Box<dyn ProtocolExtensionBuilder + Sync + Send>;
|
||||
|
||||
struct ConfigCache {
|
||||
pub blocked_ports: Vec<RangeInclusive<u16>>,
|
||||
pub allowed_ports: Vec<RangeInclusive<u16>>,
|
||||
|
||||
pub allowed_hosts: RegexSet,
|
||||
pub blocked_hosts: RegexSet,
|
||||
|
||||
pub wisp_config: (Option<Vec<AnyProtocolExtensionBuilder>>, u32),
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref CONFIG_CACHE: ConfigCache = {
|
||||
ConfigCache {
|
||||
allowed_ports: CONFIG
|
||||
.stream
|
||||
.allow_ports
|
||||
.iter()
|
||||
.map(|x| x[0]..=x[1])
|
||||
.collect(),
|
||||
blocked_ports: CONFIG
|
||||
.stream
|
||||
.block_ports
|
||||
.iter()
|
||||
.map(|x| x[0]..=x[1])
|
||||
.collect(),
|
||||
allowed_hosts: RegexSet::new(&CONFIG.stream.allow_hosts).unwrap(),
|
||||
blocked_hosts: RegexSet::new(&CONFIG.stream.block_hosts).unwrap(),
|
||||
wisp_config: CONFIG.wisp.to_opts_inner().unwrap(),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn validate_config_cache() {
|
||||
let _ = CONFIG_CACHE.wisp_config;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SocketType {
|
||||
#[default]
|
||||
Tcp,
|
||||
Unix,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct ServerConfig {
|
||||
pub bind: String,
|
||||
pub socket: SocketType,
|
||||
pub resolve_ipv6: bool,
|
||||
|
||||
pub max_message_size: usize,
|
||||
// TODO
|
||||
// prefix: String,
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bind: "127.0.0.1:4000".to_owned(),
|
||||
socket: SocketType::default(),
|
||||
resolve_ipv6: false,
|
||||
|
||||
max_message_size: 64 * 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ProtocolExtension {
|
||||
Udp,
|
||||
Password,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct WispConfig {
|
||||
pub wisp_v2: bool,
|
||||
pub buffer_size: u32,
|
||||
|
||||
pub extensions: Vec<ProtocolExtension>,
|
||||
pub password_extension_users: HashMap<String, String>,
|
||||
// TODO
|
||||
// enable_wsproxy: bool,
|
||||
}
|
||||
|
||||
impl Default for WispConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
buffer_size: 512,
|
||||
wisp_v2: false,
|
||||
|
||||
extensions: Vec::new(),
|
||||
password_extension_users: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WispConfig {
|
||||
pub fn to_opts_inner(&self) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, u32)> {
|
||||
if self.wisp_v2 {
|
||||
let mut extensions: Vec<Box<dyn ProtocolExtensionBuilder + Sync + Send>> = Vec::new();
|
||||
|
||||
if self.extensions.contains(&ProtocolExtension::Udp) {
|
||||
extensions.push(Box::new(UdpProtocolExtensionBuilder));
|
||||
}
|
||||
|
||||
if self.extensions.contains(&ProtocolExtension::Password) {
|
||||
extensions.push(Box::new(PasswordProtocolExtensionBuilder::new_server(
|
||||
self.password_extension_users.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
Ok((Some(extensions), self.buffer_size))
|
||||
} else {
|
||||
Ok((None, self.buffer_size))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_opts(&self) -> (Option<&'static [AnyProtocolExtensionBuilder]>, u32) {
|
||||
(
|
||||
CONFIG_CACHE.wisp_config.0.as_deref(),
|
||||
CONFIG_CACHE.wisp_config.1,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct StreamConfig {
|
||||
pub allow_udp: bool,
|
||||
|
||||
pub allow_direct_ip: bool,
|
||||
pub allow_loopback: bool,
|
||||
pub allow_multicast: bool,
|
||||
|
||||
pub allow_global: bool,
|
||||
pub allow_non_global: bool,
|
||||
|
||||
pub allow_hosts: Vec<String>,
|
||||
pub block_hosts: Vec<String>,
|
||||
|
||||
pub allow_ports: Vec<Vec<u16>>,
|
||||
pub block_ports: Vec<Vec<u16>>,
|
||||
}
|
||||
|
||||
impl Default for StreamConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allow_udp: true,
|
||||
|
||||
allow_direct_ip: true,
|
||||
allow_loopback: true,
|
||||
allow_multicast: true,
|
||||
|
||||
allow_global: true,
|
||||
allow_non_global: true,
|
||||
|
||||
allow_hosts: Vec::new(),
|
||||
block_hosts: Vec::new(),
|
||||
|
||||
allow_ports: Vec::new(),
|
||||
block_ports: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamConfig {
|
||||
pub fn allowed_ports(&self) -> &'static [RangeInclusive<u16>] {
|
||||
&CONFIG_CACHE.allowed_ports
|
||||
}
|
||||
|
||||
pub fn blocked_ports(&self) -> &'static [RangeInclusive<u16>] {
|
||||
&CONFIG_CACHE.blocked_ports
|
||||
}
|
||||
|
||||
pub fn allowed_hosts(&self) -> &RegexSet {
|
||||
&CONFIG_CACHE.allowed_hosts
|
||||
}
|
||||
|
||||
pub fn blocked_hosts(&self) -> &RegexSet {
|
||||
&CONFIG_CACHE.blocked_hosts
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
#[serde(default)]
|
||||
pub struct Config {
|
||||
pub server: ServerConfig,
|
||||
pub wisp: WispConfig,
|
||||
pub stream: StreamConfig,
|
||||
}
|
197
server/src/main.rs
Normal file
197
server/src/main.rs
Normal file
|
@ -0,0 +1,197 @@
|
|||
#![feature(ip)]
|
||||
|
||||
use std::{env::args, fs::read_to_string, ops::Deref};
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use config::{validate_config_cache, Config};
|
||||
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
|
||||
use http_body_util::Empty;
|
||||
use hyper::{body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use lazy_static::lazy_static;
|
||||
use stream::{
|
||||
copy_read_fast, ClientStream, ResolvedPacket, ServerListener, ServerStream, ServerStreamExt,
|
||||
};
|
||||
use tokio::{io::copy, select};
|
||||
use tokio_util::compat::FuturesAsyncWriteCompatExt;
|
||||
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux};
|
||||
|
||||
mod config;
|
||||
mod stream;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref CONFIG: Config = {
|
||||
if let Some(path) = args().nth(1) {
|
||||
toml::from_str(&read_to_string(path).unwrap()).unwrap()
|
||||
} else {
|
||||
Config::default()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream) {
|
||||
let Ok(resolved) = ClientStream::resolve(connect).await else {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
|
||||
return;
|
||||
};
|
||||
let connect = match resolved {
|
||||
ResolvedPacket::Valid(x) => x,
|
||||
ResolvedPacket::NoResolvedAddrs => {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
|
||||
return;
|
||||
}
|
||||
ResolvedPacket::Blocked => {
|
||||
let _ = muxstream
|
||||
.close(CloseReason::ServerStreamBlockedAddress)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let Ok(stream) = ClientStream::connect(connect).await else {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
|
||||
return;
|
||||
};
|
||||
|
||||
match stream {
|
||||
ClientStream::Tcp(stream) => {
|
||||
let closer = muxstream.get_close_handle();
|
||||
|
||||
let ret: anyhow::Result<()> = async move {
|
||||
let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split();
|
||||
let (mut tcpread, tcpwrite) = stream.into_split();
|
||||
let mut muxwrite = muxwrite.compat_write();
|
||||
select! {
|
||||
x = copy_read_fast(muxread, tcpwrite) => x?,
|
||||
x = copy(&mut tcpread, &mut muxwrite) => {x?;},
|
||||
}
|
||||
// TODO why is copy_write_fast not working?
|
||||
/*
|
||||
let (muxread, muxwrite) = muxstream.into_split();
|
||||
let muxread = muxread.into_stream().into_asyncread();
|
||||
let (mut tcpread, tcpwrite) = stream.into_split();
|
||||
select! {
|
||||
x = copy_read_fast(muxread, tcpwrite) => x?,
|
||||
x = copy_write_fast(muxwrite, tcpread) => {x?;},
|
||||
}
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(()) => {
|
||||
let _ = closer.close(CloseReason::Voluntary).await;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = closer.close(CloseReason::Unexpected).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
ClientStream::Udp(stream) => {
|
||||
let closer = muxstream.get_close_handle();
|
||||
|
||||
let ret: anyhow::Result<()> = async move {
|
||||
let mut data = vec![0u8; 65507];
|
||||
loop {
|
||||
select! {
|
||||
size = stream.recv(&mut data) => {
|
||||
let size = size?;
|
||||
muxstream.write(&data[..size]).await?;
|
||||
}
|
||||
data = muxstream.read() => {
|
||||
if let Some(data) = data {
|
||||
stream.send(&data).await?;
|
||||
} else {
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(()) => {
|
||||
let _ = closer.close(CloseReason::Voluntary).await;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = closer.close(CloseReason::Unexpected).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
ClientStream::Invalid => {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
|
||||
}
|
||||
ClientStream::Blocked => {
|
||||
let _ = muxstream
|
||||
.close(CloseReason::ServerStreamBlockedAddress)
|
||||
.await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async fn handle(fut: UpgradeFut) -> anyhow::Result<()> {
|
||||
let mut ws = fut.await.context("failed to await upgrade future")?;
|
||||
|
||||
ws.set_max_message_size(CONFIG.server.max_message_size);
|
||||
|
||||
let (read, write) = ws.split(|x| {
|
||||
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
|
||||
assert_eq!(parts.read_buf.len(), 0);
|
||||
parts.io.into_inner().split()
|
||||
});
|
||||
let read = FragmentCollectorRead::new(read);
|
||||
|
||||
let (extensions, buffer_size) = CONFIG.wisp.to_opts_inner()?;
|
||||
|
||||
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref())
|
||||
.await
|
||||
.context("failed to create server multiplexor")?
|
||||
.with_no_required_extensions();
|
||||
|
||||
tokio::spawn(tokio::task::unconstrained(fut));
|
||||
|
||||
while let Some((connect, stream)) = mux.server_new_stream().await {
|
||||
tokio::spawn(tokio::task::unconstrained(handle_stream(connect, stream)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
type Body = Empty<Bytes>;
|
||||
async fn upgrade(mut req: Request<Incoming>) -> anyhow::Result<Response<Body>> {
|
||||
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle(fut).await {
|
||||
println!("{:?}", e);
|
||||
};
|
||||
});
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
validate_config_cache();
|
||||
|
||||
println!("{}", toml::to_string_pretty(CONFIG.deref()).unwrap());
|
||||
|
||||
let listener = ServerListener::new().await?;
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
tokio::spawn(async move {
|
||||
let stream = TokioIo::new(stream);
|
||||
|
||||
let fut = Builder::new()
|
||||
.serve_connection(stream, service_fn(upgrade))
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = fut.await {
|
||||
println!("{:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
240
server/src/stream.rs
Normal file
240
server/src/stream.rs
Normal file
|
@ -0,0 +1,240 @@
|
|||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use futures_util::AsyncBufReadExt;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{
|
||||
lookup_host,
|
||||
tcp::{self, OwnedReadHalf, OwnedWriteHalf},
|
||||
unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream,
|
||||
},
|
||||
};
|
||||
use tokio_util::either::Either;
|
||||
use wisp_mux::{ConnectPacket, MuxStreamAsyncRead, MuxStreamWrite, StreamType};
|
||||
|
||||
use crate::{config::SocketType, CONFIG};
|
||||
|
||||
pub enum ServerListener {
|
||||
Tcp(TcpListener),
|
||||
Unix(UnixListener),
|
||||
}
|
||||
|
||||
pub type ServerStream = Either<TcpStream, UnixStream>;
|
||||
pub type ServerStreamRead = Either<tcp::OwnedReadHalf, unix::OwnedReadHalf>;
|
||||
pub type ServerStreamWrite = Either<tcp::OwnedWriteHalf, unix::OwnedWriteHalf>;
|
||||
|
||||
pub trait ServerStreamExt {
|
||||
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
|
||||
}
|
||||
|
||||
impl ServerStreamExt for ServerStream {
|
||||
fn split(self) -> (ServerStreamRead, ServerStreamWrite) {
|
||||
match self {
|
||||
Self::Left(x) => {
|
||||
let (r, w) = x.into_split();
|
||||
(Either::Left(r), Either::Left(w))
|
||||
}
|
||||
Self::Right(x) => {
|
||||
let (r, w) = x.into_split();
|
||||
(Either::Right(r), Either::Right(w))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerListener {
|
||||
pub async fn new() -> anyhow::Result<Self> {
|
||||
Ok(match CONFIG.server.socket {
|
||||
SocketType::Tcp => Self::Tcp(
|
||||
TcpListener::bind(&CONFIG.server.bind)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("failed to bind to tcp address `{}`", CONFIG.server.bind)
|
||||
})?,
|
||||
),
|
||||
SocketType::Unix => {
|
||||
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
|
||||
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
|
||||
})?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn accept(&self) -> anyhow::Result<(ServerStream, Option<String>)> {
|
||||
match self {
|
||||
Self::Tcp(x) => x
|
||||
.accept()
|
||||
.await
|
||||
.map(|(x, y)| (Either::Left(x), Some(y.to_string())))
|
||||
.context("failed to accept tcp connection"),
|
||||
Self::Unix(x) => x
|
||||
.accept()
|
||||
.await
|
||||
.map(|(x, y)| {
|
||||
(
|
||||
Either::Right(x),
|
||||
y.as_pathname()
|
||||
.and_then(|x| x.to_str())
|
||||
.map(ToString::to_string),
|
||||
)
|
||||
})
|
||||
.context("failed to accept unix socket connection"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ClientStream {
|
||||
Tcp(TcpStream),
|
||||
Udp(UdpSocket),
|
||||
Blocked,
|
||||
Invalid,
|
||||
}
|
||||
|
||||
pub enum ResolvedPacket {
|
||||
Valid(ConnectPacket),
|
||||
NoResolvedAddrs,
|
||||
Blocked,
|
||||
}
|
||||
|
||||
impl ClientStream {
|
||||
pub async fn resolve(packet: ConnectPacket) -> anyhow::Result<ResolvedPacket> {
|
||||
if !CONFIG.stream.allow_udp && packet.stream_type == StreamType::Udp {
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
|
||||
if CONFIG
|
||||
.stream
|
||||
.blocked_ports()
|
||||
.iter()
|
||||
.any(|x| x.contains(&packet.destination_port))
|
||||
&& !CONFIG
|
||||
.stream
|
||||
.allowed_ports()
|
||||
.iter()
|
||||
.any(|x| x.contains(&packet.destination_port))
|
||||
{
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
|
||||
if let Ok(addr) = IpAddr::from_str(&packet.destination_hostname) {
|
||||
if !CONFIG.stream.allow_direct_ip {
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
|
||||
if addr.is_loopback() && !CONFIG.stream.allow_loopback {
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
|
||||
if addr.is_multicast() && !CONFIG.stream.allow_multicast {
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
|
||||
if (addr.is_global() && !CONFIG.stream.allow_global)
|
||||
|| (!addr.is_global() && !CONFIG.stream.allow_non_global)
|
||||
{
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
}
|
||||
|
||||
if CONFIG
|
||||
.stream
|
||||
.blocked_hosts()
|
||||
.is_match(&packet.destination_hostname)
|
||||
&& !CONFIG
|
||||
.stream
|
||||
.allowed_hosts()
|
||||
.is_match(&packet.destination_hostname)
|
||||
{
|
||||
return Ok(ResolvedPacket::Blocked);
|
||||
}
|
||||
|
||||
let packet = lookup_host(packet.destination_hostname + ":0")
|
||||
.await
|
||||
.context("failed to resolve hostname")?
|
||||
.filter(|x| CONFIG.server.resolve_ipv6 || x.is_ipv4())
|
||||
.map(|x| ConnectPacket {
|
||||
stream_type: packet.stream_type,
|
||||
destination_hostname: x.ip().to_string(),
|
||||
destination_port: packet.destination_port,
|
||||
})
|
||||
.next();
|
||||
|
||||
Ok(packet
|
||||
.map(ResolvedPacket::Valid)
|
||||
.unwrap_or(ResolvedPacket::NoResolvedAddrs))
|
||||
}
|
||||
|
||||
pub async fn connect(packet: ConnectPacket) -> anyhow::Result<Self> {
|
||||
let ipaddr = IpAddr::from_str(&packet.destination_hostname)
|
||||
.context("failed to parse hostname as ipaddr")?;
|
||||
|
||||
match packet.stream_type {
|
||||
StreamType::Tcp => {
|
||||
let stream = TcpStream::connect(SocketAddr::new(ipaddr, packet.destination_port))
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("failed to connect to host {}", packet.destination_hostname)
|
||||
})?;
|
||||
|
||||
Ok(ClientStream::Tcp(stream))
|
||||
}
|
||||
StreamType::Udp => {
|
||||
if !CONFIG.stream.allow_udp {
|
||||
return Ok(ClientStream::Blocked);
|
||||
}
|
||||
|
||||
let bind_addr = if ipaddr.is_ipv4() {
|
||||
SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0)
|
||||
} else {
|
||||
SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0)
|
||||
};
|
||||
|
||||
let stream = UdpSocket::bind(bind_addr).await?;
|
||||
|
||||
stream
|
||||
.connect(SocketAddr::new(ipaddr, packet.destination_port))
|
||||
.await?;
|
||||
|
||||
Ok(ClientStream::Udp(stream))
|
||||
}
|
||||
StreamType::Unknown(_) => Ok(ClientStream::Invalid),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn copy_read_fast(
|
||||
mut muxrx: MuxStreamAsyncRead,
|
||||
mut tcptx: OwnedWriteHalf,
|
||||
) -> std::io::Result<()> {
|
||||
loop {
|
||||
let buf = muxrx.fill_buf().await?;
|
||||
if buf.is_empty() {
|
||||
tcptx.flush().await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let i = tcptx.write(buf).await?;
|
||||
if i == 0 {
|
||||
return Err(std::io::ErrorKind::WriteZero.into());
|
||||
}
|
||||
|
||||
muxrx.consume_unpin(i);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn copy_write_fast(
|
||||
muxtx: MuxStreamWrite,
|
||||
mut tcprx: OwnedReadHalf,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let mut buf = BytesMut::with_capacity(8 * 1024);
|
||||
let amt = tcprx.read(&mut buf).await?;
|
||||
muxtx.write(&buf[..amt]).await?;
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue