wsproxy support with udp, logger, other random stuff

This commit is contained in:
Toshit Chawda 2024-07-21 21:35:33 -07:00
parent 4b44567a0e
commit 04b8feaaf3
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
9 changed files with 637 additions and 203 deletions

View file

@ -6,15 +6,19 @@ edition = "2021"
[dependencies]
anyhow = "1.0.86"
bytes = "1.6.1"
dashmap = "6.0.1"
env_logger = "0.11.3"
fastwebsockets = { version = "0.8.0", features = ["unstable-split", "upgrade"] }
futures-util = "0.3.30"
http-body-util = "0.1.2"
hyper = { version = "1.4.1", features = ["server", "http1"] }
hyper-util = { version = "0.1.6", features = ["tokio"] }
lazy_static = "1.5.0"
log = { version = "0.4.22", features = ["serde", "std"] }
regex = "1.10.5"
serde = { version = "1.0.204", features = ["derive"] }
tokio = { version = "1.38.1", features = ["full"] }
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
toml = "0.8.15"
uuid = { version = "1.10.0", features = ["v4"] }
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] }

View file

@ -1,6 +1,7 @@
use std::{collections::HashMap, ops::RangeInclusive};
use lazy_static::lazy_static;
use log::LevelFilter;
use regex::RegexSet;
use serde::{Deserialize, Serialize};
use wisp_mux::extensions::{
@ -48,7 +49,7 @@ pub fn validate_config_cache() {
let _ = CONFIG_CACHE.wisp_config;
}
#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize, Default, Debug)]
#[serde(rename_all = "lowercase")]
pub enum SocketType {
#[default]
@ -63,19 +64,38 @@ pub struct ServerConfig {
pub socket: SocketType,
pub resolve_ipv6: bool,
pub verbose_stats: bool,
pub enable_stats_endpoint: bool,
pub stats_endpoint: String,
pub non_ws_response: String,
// DO NOT add a trailing slash to this config option
pub prefix: String,
pub max_message_size: usize,
// TODO
// prefix: String,
pub log_level: LevelFilter,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
bind: "127.0.0.1:4000".to_owned(),
bind: "127.0.0.1:4000".to_string(),
socket: SocketType::default(),
resolve_ipv6: false,
verbose_stats: true,
stats_endpoint: "/stats".to_string(),
enable_stats_endpoint: true,
non_ws_response: ":3".to_string(),
prefix: String::new(),
max_message_size: 64 * 1024,
log_level: LevelFilter::Info,
}
}
}
@ -90,21 +110,21 @@ pub enum ProtocolExtension {
#[derive(Serialize, Deserialize)]
#[serde(default)]
pub struct WispConfig {
pub wisp_v2: bool,
pub allow_wsproxy: bool,
pub buffer_size: u32,
pub wisp_v2: bool,
pub extensions: Vec<ProtocolExtension>,
pub password_extension_users: HashMap<String, String>,
// TODO
// enable_wsproxy: bool,
}
impl Default for WispConfig {
fn default() -> Self {
Self {
buffer_size: 512,
wisp_v2: false,
buffer_size: 128,
allow_wsproxy: true,
wisp_v2: false,
extensions: Vec::new(),
password_extension_users: HashMap::new(),
}
@ -112,7 +132,9 @@ impl Default for WispConfig {
}
impl WispConfig {
pub fn to_opts_inner(&self) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, u32)> {
pub(super) fn to_opts_inner(
&self,
) -> anyhow::Result<(Option<Vec<AnyProtocolExtensionBuilder>>, u32)> {
if self.wisp_v2 {
let mut extensions: Vec<Box<dyn ProtocolExtensionBuilder + Sync + Send>> = Vec::new();
@ -144,6 +166,7 @@ impl WispConfig {
#[serde(default)]
pub struct StreamConfig {
pub allow_udp: bool,
pub allow_wsproxy_udp: bool,
pub allow_direct_ip: bool,
pub allow_loopback: bool,
@ -163,6 +186,7 @@ impl Default for StreamConfig {
fn default() -> Self {
Self {
allow_udp: true,
allow_wsproxy_udp: false,
allow_direct_ip: true,
allow_loopback: true,

5
server/src/handle/mod.rs Normal file
View file

@ -0,0 +1,5 @@
mod wisp;
mod wsproxy;
pub use wisp::handle_wisp;
pub use wsproxy::handle_wsproxy;

204
server/src/handle/wisp.rs Normal file
View file

@ -0,0 +1,204 @@
use anyhow::Context;
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
use futures_util::FutureExt;
use hyper_util::rt::TokioIo;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
select,
task::JoinSet,
};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use uuid::Uuid;
use wisp_mux::{
CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux,
};
use crate::{
stream::{ClientStream, ResolvedPacket, ServerStream, ServerStreamExt},
CLIENTS, CONFIG,
};
async fn copy_read_fast(
muxrx: MuxStreamAsyncRead,
mut tcptx: OwnedWriteHalf,
) -> std::io::Result<()> {
let mut muxrx = muxrx.compat();
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume(i);
}
}
async fn copy_write_fast(muxtx: MuxStreamWrite, tcprx: OwnedReadHalf) -> anyhow::Result<()> {
let mut tcprx = BufReader::new(tcprx);
loop {
let buf = tcprx.fill_buf().await?;
muxtx.write(&buf).await?;
let len = buf.len();
tcprx.consume(len);
}
}
async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String) {
let requested_stream = connect.clone();
let Ok(resolved) = ClientStream::resolve(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
ResolvedPacket::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
}
ResolvedPacket::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
return;
}
};
let resolved_stream = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let uuid = Uuid::new_v4();
CLIENTS
.get(&id)
.unwrap()
.0
.insert(uuid, (requested_stream, resolved_stream));
match stream {
ClientStream::Tcp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async {
/*
let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split();
let (mut tcpread, tcpwrite) = stream.into_split();
let mut muxwrite = muxwrite.compat_write();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy(&mut tcpread, &mut muxwrite) => {x?;},
}
*/
// TODO why is copy_write_fast not working?
let (muxread, muxwrite) = muxstream.into_split();
let muxread = muxread.into_stream().into_asyncread();
let (tcpread, tcpwrite) = stream.into_split();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy_write_fast(muxwrite, tcpread) => x?,
}
Ok(())
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Udp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let mut data = vec![0u8; 65507];
loop {
select! {
size = stream.recv(&mut data) => {
let size = size?;
muxstream.write(&data[..size]).await?;
}
data = muxstream.read() => {
if let Some(data) = data {
stream.send(&data).await?;
} else {
break Ok(());
}
}
}
}
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Invalid => {
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
}
ClientStream::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
}
};
CLIENTS.get(&id).unwrap().0.remove(&uuid);
}
pub async fn handle_wisp(fut: UpgradeFut, id: String) -> anyhow::Result<()> {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
let (read, write) = ws.split(|x| {
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
assert_eq!(parts.read_buf.len(), 0);
parts.io.into_inner().split()
});
let read = FragmentCollectorRead::new(read);
let (extensions, buffer_size) = CONFIG.wisp.to_opts();
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions)
.await
.context("failed to create server multiplexor")?
.with_no_required_extensions();
let mut set: JoinSet<()> = JoinSet::new();
set.spawn(tokio::task::unconstrained(fut.map(|_| {})));
while let Some((connect, stream)) = mux.server_new_stream().await {
set.spawn(tokio::task::unconstrained(handle_stream(
connect,
stream,
id.clone(),
)));
}
set.abort_all();
while set.join_next().await.is_some() {}
Ok(())
}

View file

@ -0,0 +1,145 @@
use std::str::FromStr;
use anyhow::Context;
use fastwebsockets::{upgrade::UpgradeFut, CloseCode, FragmentCollector};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
select,
};
use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType};
use crate::{
stream::{ClientStream, ResolvedPacket, WebSocketFrame, WebSocketStreamWrapper},
CLIENTS, CONFIG,
};
pub async fn handle_wsproxy(
fut: UpgradeFut,
id: String,
path: String,
udp: bool,
) -> anyhow::Result<()> {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
let ws = FragmentCollector::new(ws);
let mut ws = WebSocketStreamWrapper(ws);
if udp && !CONFIG.stream.allow_wsproxy_udp {
let _ = ws.close(CloseCode::Error.into(), b"udp is blocked").await;
return Ok(());
}
let vec: Vec<&str> = path.split("/").last().unwrap().split(":").collect();
let Ok(port) = FromStr::from_str(vec[1]) else {
let _ = ws.close(CloseCode::Error.into(), b"invalid port").await;
return Ok(());
};
let connect = ConnectPacket {
stream_type: if udp {
StreamType::Udp
} else {
StreamType::Tcp
},
destination_hostname: vec[0].to_string(),
destination_port: port,
};
let requested_stream = connect.clone();
let Ok(resolved) = ClientStream::resolve(connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to resolve host")
.await;
return Ok(());
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
ResolvedPacket::NoResolvedAddrs => {
let _ = ws
.close(
CloseCode::Error.into(),
b"host did not resolve to any addrs",
)
.await;
return Ok(());
}
ResolvedPacket::Blocked => {
let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await;
return Ok(());
}
};
let resolved_stream = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
let uuid = Uuid::new_v4();
CLIENTS
.get(&id)
.unwrap()
.0
.insert(uuid, (requested_stream, resolved_stream));
match stream {
ClientStream::Tcp(stream) => {
let mut stream = BufReader::new(stream);
let ret: anyhow::Result<()> = async {
let mut to_consume = 0usize;
loop {
if to_consume != 0 {
stream.consume(to_consume);
to_consume = 0;
}
select! {
x = ws.read() => {
match x? {
WebSocketFrame::Data(data) => {
stream.write_all(&data).await?;
}
WebSocketFrame::Close => {
stream.shutdown().await?;
}
WebSocketFrame::Ignore => {}
}
}
x = stream.fill_buf() => {
let x = x?;
ws.write(x).await?;
to_consume += x.len();
}
}
}
}
.await;
match ret {
Ok(_) => {
let _ = ws.close(CloseCode::Normal.into(), b"").await;
}
Err(x) => {
let _ = ws
.close(CloseCode::Normal.into(), x.to_string().as_bytes())
.await;
}
}
}
ClientStream::Udp(_stream) => {
// TODO
let _ = ws.close(CloseCode::Error.into(), b"coming soon").await;
}
ClientStream::Blocked => {
let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await;
}
ClientStream::Invalid => {
let _ = ws.close(CloseCode::Error.into(), b"host is invalid").await;
}
}
Ok(())
}

View file

@ -1,25 +1,30 @@
#![feature(ip)]
use std::{env::args, fs::read_to_string, ops::Deref};
use std::{env::args, fmt::Write, fs::read_to_string};
use anyhow::Context;
use bytes::Bytes;
use config::{validate_config_cache, Config};
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollectorRead};
use http_body_util::Empty;
use hyper::{body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response};
use dashmap::DashMap;
use handle::{handle_wisp, handle_wsproxy};
use http_body_util::Full;
use hyper::{
body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response,
StatusCode,
};
use hyper_util::rt::TokioIo;
use lazy_static::lazy_static;
use stream::{
copy_read_fast, ClientStream, ResolvedPacket, ServerListener, ServerStream, ServerStreamExt,
};
use tokio::{io::copy, select};
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux};
use log::{error, info};
use stream::ServerListener;
use tokio::signal::unix::{signal, SignalKind};
use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType};
mod config;
mod handle;
mod stream;
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
lazy_static! {
pub static ref CONFIG: Config = {
if let Some(path) = args().nth(1) {
@ -28,169 +33,159 @@ lazy_static! {
Config::default()
}
};
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
}
async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream) {
let Ok(resolved) = ClientStream::resolve(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
ResolvedPacket::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
}
ResolvedPacket::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
return;
}
};
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
match stream {
ClientStream::Tcp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let (muxread, muxwrite) = muxstream.into_io().into_asyncrw().into_split();
let (mut tcpread, tcpwrite) = stream.into_split();
let mut muxwrite = muxwrite.compat_write();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy(&mut tcpread, &mut muxwrite) => {x?;},
}
// TODO why is copy_write_fast not working?
/*
let (muxread, muxwrite) = muxstream.into_split();
let muxread = muxread.into_stream().into_asyncread();
let (mut tcpread, tcpwrite) = stream.into_split();
select! {
x = copy_read_fast(muxread, tcpwrite) => x?,
x = copy_write_fast(muxwrite, tcpread) => {x?;},
}
*/
Ok(())
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Udp(stream) => {
let closer = muxstream.get_close_handle();
let ret: anyhow::Result<()> = async move {
let mut data = vec![0u8; 65507];
loop {
select! {
size = stream.recv(&mut data) => {
let size = size?;
muxstream.write(&data[..size]).await?;
}
data = muxstream.read() => {
if let Some(data) = data {
stream.send(&data).await?;
} else {
break Ok(());
}
}
}
}
}
.await;
match ret {
Ok(()) => {
let _ = closer.close(CloseReason::Voluntary).await;
}
Err(_) => {
let _ = closer.close(CloseReason::Unexpected).await;
}
}
}
ClientStream::Invalid => {
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
}
ClientStream::Blocked => {
let _ = muxstream
.close(CloseReason::ServerStreamBlockedAddress)
.await;
}
};
type Body = Full<Bytes>;
fn non_ws_resp() -> Response<Body> {
Response::builder()
.status(StatusCode::OK)
.body(Body::new(CONFIG.server.non_ws_response.as_bytes().into()))
.unwrap()
}
async fn handle(fut: UpgradeFut) -> anyhow::Result<()> {
let mut ws = fut.await.context("failed to await upgrade future")?;
ws.set_max_message_size(CONFIG.server.max_message_size);
let (read, write) = ws.split(|x| {
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
assert_eq!(parts.read_buf.len(), 0);
parts.io.into_inner().split()
});
let read = FragmentCollectorRead::new(read);
let (extensions, buffer_size) = CONFIG.wisp.to_opts_inner()?;
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref())
.await
.context("failed to create server multiplexor")?
.with_no_required_extensions();
tokio::spawn(tokio::task::unconstrained(fut));
while let Some((connect, stream)) = mux.server_new_stream().await {
tokio::spawn(tokio::task::unconstrained(handle_stream(connect, stream)));
async fn upgrade(mut req: Request<Incoming>, id: String) -> anyhow::Result<Response<Body>> {
if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint {
match generate_stats() {
Ok(x) => {
return Ok(Response::builder()
.status(StatusCode::OK)
.body(Body::new(x.into()))
.unwrap())
}
Err(x) => {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::new(x.to_string().into()))
.unwrap())
}
}
} else if !fastwebsockets::upgrade::is_upgrade_request(&req) {
return Ok(non_ws_resp());
}
Ok(())
}
type Body = Empty<Bytes>;
async fn upgrade(mut req: Request<Incoming>) -> anyhow::Result<Response<Body>> {
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
// replace body of Empty<Bytes> with Full<Bytes>
let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
tokio::spawn(async move {
if let Err(e) = handle(fut).await {
println!("{:?}", e);
};
});
if req
.uri()
.path()
.starts_with(&(CONFIG.server.prefix.clone() + "/"))
{
tokio::spawn(async move {
CLIENTS.insert(id.clone(), (DashMap::new(), false));
if let Err(e) = handle_wisp(fut, id.clone()).await {
error!("error while handling upgraded client: {:?}", e);
};
CLIENTS.remove(&id)
});
} else if CONFIG.wisp.allow_wsproxy {
let udp = req.uri().query().unwrap_or_default() == "?udp";
tokio::spawn(async move {
CLIENTS.insert(id.clone(), (DashMap::new(), true));
if let Err(e) = handle_wsproxy(fut, id.clone(), req.uri().path().to_string(), udp).await
{
error!("error while handling upgraded client: {:?}", e);
};
CLIENTS.remove(&id)
});
} else {
return Ok(non_ws_resp());
}
Ok(resp)
}
fn format_stream_type(stream_type: StreamType) -> &'static str {
match stream_type {
StreamType::Tcp => "tcp",
StreamType::Udp => "udp",
StreamType::Unknown(_) => unreachable!(),
}
}
fn generate_stats() -> Result<String, std::fmt::Error> {
let mut out = String::new();
let len = CLIENTS.len();
writeln!(
&mut out,
"{} clients connected{}",
len,
if len != 0 { ":" } else { "" }
)?;
for client in CLIENTS.iter() {
let len = client.value().0.len();
writeln!(
&mut out,
"\tClient \"{}\"{}: {} streams connected{}",
client.key(),
if client.value().1 { " (wsproxy)" } else { "" },
len,
if len != 0 && CONFIG.server.verbose_stats {
":"
} else {
""
}
)?;
if CONFIG.server.verbose_stats {
for stream in client.value().0.iter() {
writeln!(
&mut out,
"\t\tStream \"{}\": {}",
stream.key(),
format_stream_type(stream.value().0.stream_type)
)?;
writeln!(
&mut out,
"\t\t\tRequested: {}:{}",
stream.value().0.destination_hostname,
stream.value().0.destination_port
)?;
writeln!(
&mut out,
"\t\t\tResolved: {}:{}",
stream.value().1.destination_hostname,
stream.value().1.destination_port
)?;
}
}
}
Ok(out)
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
env_logger::builder()
.filter_level(CONFIG.server.log_level)
.parse_default_env()
.init();
validate_config_cache();
println!("{}", toml::to_string_pretty(CONFIG.deref()).unwrap());
info!("listening on {:?} with socket type {:?}", CONFIG.server.bind, CONFIG.server.socket);
tokio::spawn(async {
let mut sig = signal(SignalKind::user_defined1()).unwrap();
while sig.recv().await.is_some() {
info!("{}", generate_stats().unwrap());
}
});
let listener = ServerListener::new().await?;
loop {
let (stream, _) = listener.accept().await?;
let (stream, id) = listener.accept().await?;
tokio::spawn(async move {
let stream = TokioIo::new(stream);
let fut = Builder::new()
.serve_connection(stream, service_fn(upgrade))
.serve_connection(stream, service_fn(|req| upgrade(req, id.clone())))
.with_upgrades();
if let Err(e) = fut.await {
println!("{:?}", e);
error!("error while serving client: {:?}", e);
}
});
}

View file

@ -5,17 +5,16 @@ use std::{
use anyhow::Context;
use bytes::BytesMut;
use futures_util::AsyncBufReadExt;
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
lookup_host,
tcp::{self, OwnedReadHalf, OwnedWriteHalf},
unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream,
},
fs::{remove_file, try_exists},
net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream},
};
use tokio_util::either::Either;
use wisp_mux::{ConnectPacket, MuxStreamAsyncRead, MuxStreamWrite, StreamType};
use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType};
use crate::{config::SocketType, CONFIG};
@ -58,6 +57,9 @@ impl ServerListener {
})?,
),
SocketType::Unix => {
if try_exists(&CONFIG.server.bind).await? {
remove_file(&CONFIG.server.bind).await?;
}
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
})?)
@ -65,12 +67,12 @@ impl ServerListener {
})
}
pub async fn accept(&self) -> anyhow::Result<(ServerStream, Option<String>)> {
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
match self {
Self::Tcp(x) => x
.accept()
.await
.map(|(x, y)| (Either::Left(x), Some(y.to_string())))
.map(|(x, y)| (Either::Left(x), y.to_string()))
.context("failed to accept tcp connection"),
Self::Unix(x) => x
.accept()
@ -80,7 +82,8 @@ impl ServerListener {
Either::Right(x),
y.as_pathname()
.and_then(|x| x.to_str())
.map(ToString::to_string),
.map(ToString::to_string)
.unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"),
)
})
.context("failed to accept unix socket connection"),
@ -207,34 +210,31 @@ impl ClientStream {
}
}
pub async fn copy_read_fast(
mut muxrx: MuxStreamAsyncRead,
mut tcptx: OwnedWriteHalf,
) -> std::io::Result<()> {
loop {
let buf = muxrx.fill_buf().await?;
if buf.is_empty() {
tcptx.flush().await?;
return Ok(());
}
let i = tcptx.write(buf).await?;
if i == 0 {
return Err(std::io::ErrorKind::WriteZero.into());
}
muxrx.consume_unpin(i);
}
pub enum WebSocketFrame {
Data(BytesMut),
Close,
Ignore,
}
#[allow(dead_code)]
pub async fn copy_write_fast(
muxtx: MuxStreamWrite,
mut tcprx: OwnedReadHalf,
) -> anyhow::Result<()> {
loop {
let mut buf = BytesMut::with_capacity(8 * 1024);
let amt = tcprx.read(&mut buf).await?;
muxtx.write(&buf[..amt]).await?;
pub struct WebSocketStreamWrapper(pub FragmentCollector<TokioIo<Upgraded>>);
impl WebSocketStreamWrapper {
pub async fn read(&mut self) -> Result<WebSocketFrame, WebSocketError> {
let frame = self.0.read_frame().await?;
Ok(match frame.opcode {
OpCode::Text | OpCode::Binary => WebSocketFrame::Data(frame.payload.into()),
OpCode::Close => WebSocketFrame::Close,
_ => WebSocketFrame::Ignore,
})
}
pub async fn write(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
self.0
.write_frame(Frame::binary(Payload::Borrowed(data)))
.await
}
pub async fn close(&mut self, code: u16, reason: &[u8]) -> Result<(), WebSocketError> {
self.0.write_frame(Frame::close(code, reason)).await
}
}