mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
wsproxy support with udp, logger, other random stuff
This commit is contained in:
parent
4b44567a0e
commit
04b8feaaf3
9 changed files with 637 additions and 203 deletions
|
@ -1,25 +1,30 @@
|
|||
#![feature(ip)]
|
||||
|
||||
use std::{env::args, fs::read_to_string, ops::Deref};
|
||||
use std::{env::args, fmt::Write, fs::read_to_string};
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use config::{validate_config_cache, Config};
|
||||
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
|
||||
use http_body_util::Empty;
|
||||
use hyper::{body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response};
|
||||
use dashmap::DashMap;
|
||||
use handle::{handle_wisp, handle_wsproxy};
|
||||
use http_body_util::Full;
|
||||
use hyper::{
|
||||
body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response,
|
||||
StatusCode,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use lazy_static::lazy_static;
|
||||
use stream::{
|
||||
copy_read_fast, ClientStream, ResolvedPacket, ServerListener, ServerStream, ServerStreamExt,
|
||||
};
|
||||
use tokio::{io::copy, select};
|
||||
use tokio_util::compat::FuturesAsyncWriteCompatExt;
|
||||
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux};
|
||||
use log::{error, info};
|
||||
use stream::ServerListener;
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
use uuid::Uuid;
|
||||
use wisp_mux::{ConnectPacket, StreamType};
|
||||
|
||||
mod config;
|
||||
mod handle;
|
||||
mod stream;
|
||||
|
||||
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
|
||||
|
||||
lazy_static! {
|
||||
pub static ref CONFIG: Config = {
|
||||
if let Some(path) = args().nth(1) {
|
||||
|
@ -28,169 +33,159 @@ lazy_static! {
|
|||
Config::default()
|
||||
}
|
||||
};
|
||||
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
|
||||
}
|
||||
|
||||
async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream) {
|
||||
let Ok(resolved) = ClientStream::resolve(connect).await else {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
|
||||
return;
|
||||
};
|
||||
let connect = match resolved {
|
||||
ResolvedPacket::Valid(x) => x,
|
||||
ResolvedPacket::NoResolvedAddrs => {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
|
||||
return;
|
||||
}
|
||||
ResolvedPacket::Blocked => {
|
||||
let _ = muxstream
|
||||
.close(CloseReason::ServerStreamBlockedAddress)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let Ok(stream) = ClientStream::connect(connect).await else {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
|
||||
return;
|
||||
};
|
||||
|
||||
match stream {
|
||||
ClientStream::Tcp(stream) => {
|
||||
let closer = muxstream.get_close_handle();
|
||||
|
||||
let ret: anyhow::Result<()> = async move {
|
||||
let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split();
|
||||
let (mut tcpread, tcpwrite) = stream.into_split();
|
||||
let mut muxwrite = muxwrite.compat_write();
|
||||
select! {
|
||||
x = copy_read_fast(muxread, tcpwrite) => x?,
|
||||
x = copy(&mut tcpread, &mut muxwrite) => {x?;},
|
||||
}
|
||||
// TODO why is copy_write_fast not working?
|
||||
/*
|
||||
let (muxread, muxwrite) = muxstream.into_split();
|
||||
let muxread = muxread.into_stream().into_asyncread();
|
||||
let (mut tcpread, tcpwrite) = stream.into_split();
|
||||
select! {
|
||||
x = copy_read_fast(muxread, tcpwrite) => x?,
|
||||
x = copy_write_fast(muxwrite, tcpread) => {x?;},
|
||||
}
|
||||
*/
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(()) => {
|
||||
let _ = closer.close(CloseReason::Voluntary).await;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = closer.close(CloseReason::Unexpected).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
ClientStream::Udp(stream) => {
|
||||
let closer = muxstream.get_close_handle();
|
||||
|
||||
let ret: anyhow::Result<()> = async move {
|
||||
let mut data = vec![0u8; 65507];
|
||||
loop {
|
||||
select! {
|
||||
size = stream.recv(&mut data) => {
|
||||
let size = size?;
|
||||
muxstream.write(&data[..size]).await?;
|
||||
}
|
||||
data = muxstream.read() => {
|
||||
if let Some(data) = data {
|
||||
stream.send(&data).await?;
|
||||
} else {
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(()) => {
|
||||
let _ = closer.close(CloseReason::Voluntary).await;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = closer.close(CloseReason::Unexpected).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
ClientStream::Invalid => {
|
||||
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
|
||||
}
|
||||
ClientStream::Blocked => {
|
||||
let _ = muxstream
|
||||
.close(CloseReason::ServerStreamBlockedAddress)
|
||||
.await;
|
||||
}
|
||||
};
|
||||
type Body = Full<Bytes>;
|
||||
fn non_ws_resp() -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::new(CONFIG.server.non_ws_response.as_bytes().into()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn handle(fut: UpgradeFut) -> anyhow::Result<()> {
|
||||
let mut ws = fut.await.context("failed to await upgrade future")?;
|
||||
|
||||
ws.set_max_message_size(CONFIG.server.max_message_size);
|
||||
|
||||
let (read, write) = ws.split(|x| {
|
||||
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
|
||||
assert_eq!(parts.read_buf.len(), 0);
|
||||
parts.io.into_inner().split()
|
||||
});
|
||||
let read = FragmentCollectorRead::new(read);
|
||||
|
||||
let (extensions, buffer_size) = CONFIG.wisp.to_opts_inner()?;
|
||||
|
||||
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref())
|
||||
.await
|
||||
.context("failed to create server multiplexor")?
|
||||
.with_no_required_extensions();
|
||||
|
||||
tokio::spawn(tokio::task::unconstrained(fut));
|
||||
|
||||
while let Some((connect, stream)) = mux.server_new_stream().await {
|
||||
tokio::spawn(tokio::task::unconstrained(handle_stream(connect, stream)));
|
||||
async fn upgrade(mut req: Request<Incoming>, id: String) -> anyhow::Result<Response<Body>> {
|
||||
if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint {
|
||||
match generate_stats() {
|
||||
Ok(x) => {
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Body::new(x.into()))
|
||||
.unwrap())
|
||||
}
|
||||
Err(x) => {
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::new(x.to_string().into()))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
} else if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
||||
return Ok(non_ws_resp());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
type Body = Empty<Bytes>;
|
||||
async fn upgrade(mut req: Request<Incoming>) -> anyhow::Result<Response<Body>> {
|
||||
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
|
||||
// replace body of Empty<Bytes> with Full<Bytes>
|
||||
let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle(fut).await {
|
||||
println!("{:?}", e);
|
||||
};
|
||||
});
|
||||
if req
|
||||
.uri()
|
||||
.path()
|
||||
.starts_with(&(CONFIG.server.prefix.clone() + "/"))
|
||||
{
|
||||
tokio::spawn(async move {
|
||||
CLIENTS.insert(id.clone(), (DashMap::new(), false));
|
||||
if let Err(e) = handle_wisp(fut, id.clone()).await {
|
||||
error!("error while handling upgraded client: {:?}", e);
|
||||
};
|
||||
CLIENTS.remove(&id)
|
||||
});
|
||||
} else if CONFIG.wisp.allow_wsproxy {
|
||||
let udp = req.uri().query().unwrap_or_default() == "?udp";
|
||||
tokio::spawn(async move {
|
||||
CLIENTS.insert(id.clone(), (DashMap::new(), true));
|
||||
if let Err(e) = handle_wsproxy(fut, id.clone(), req.uri().path().to_string(), udp).await
|
||||
{
|
||||
error!("error while handling upgraded client: {:?}", e);
|
||||
};
|
||||
CLIENTS.remove(&id)
|
||||
});
|
||||
} else {
|
||||
return Ok(non_ws_resp());
|
||||
}
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
fn format_stream_type(stream_type: StreamType) -> &'static str {
|
||||
match stream_type {
|
||||
StreamType::Tcp => "tcp",
|
||||
StreamType::Udp => "udp",
|
||||
StreamType::Unknown(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_stats() -> Result<String, std::fmt::Error> {
|
||||
let mut out = String::new();
|
||||
let len = CLIENTS.len();
|
||||
writeln!(
|
||||
&mut out,
|
||||
"{} clients connected{}",
|
||||
len,
|
||||
if len != 0 { ":" } else { "" }
|
||||
)?;
|
||||
|
||||
for client in CLIENTS.iter() {
|
||||
let len = client.value().0.len();
|
||||
|
||||
writeln!(
|
||||
&mut out,
|
||||
"\tClient \"{}\"{}: {} streams connected{}",
|
||||
client.key(),
|
||||
if client.value().1 { " (wsproxy)" } else { "" },
|
||||
len,
|
||||
if len != 0 && CONFIG.server.verbose_stats {
|
||||
":"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
)?;
|
||||
|
||||
if CONFIG.server.verbose_stats {
|
||||
for stream in client.value().0.iter() {
|
||||
writeln!(
|
||||
&mut out,
|
||||
"\t\tStream \"{}\": {}",
|
||||
stream.key(),
|
||||
format_stream_type(stream.value().0.stream_type)
|
||||
)?;
|
||||
writeln!(
|
||||
&mut out,
|
||||
"\t\t\tRequested: {}:{}",
|
||||
stream.value().0.destination_hostname,
|
||||
stream.value().0.destination_port
|
||||
)?;
|
||||
writeln!(
|
||||
&mut out,
|
||||
"\t\t\tResolved: {}:{}",
|
||||
stream.value().1.destination_hostname,
|
||||
stream.value().1.destination_port
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
env_logger::builder()
|
||||
.filter_level(CONFIG.server.log_level)
|
||||
.parse_default_env()
|
||||
.init();
|
||||
validate_config_cache();
|
||||
|
||||
println!("{}", toml::to_string_pretty(CONFIG.deref()).unwrap());
|
||||
info!("listening on {:?} with socket type {:?}", CONFIG.server.bind, CONFIG.server.socket);
|
||||
|
||||
tokio::spawn(async {
|
||||
let mut sig = signal(SignalKind::user_defined1()).unwrap();
|
||||
while sig.recv().await.is_some() {
|
||||
info!("{}", generate_stats().unwrap());
|
||||
}
|
||||
});
|
||||
|
||||
let listener = ServerListener::new().await?;
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let (stream, id) = listener.accept().await?;
|
||||
tokio::spawn(async move {
|
||||
let stream = TokioIo::new(stream);
|
||||
|
||||
let fut = Builder::new()
|
||||
.serve_connection(stream, service_fn(upgrade))
|
||||
.serve_connection(stream, service_fn(|req| upgrade(req, id.clone())))
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = fut.await {
|
||||
println!("{:?}", e);
|
||||
error!("error while serving client: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue