mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
features for config formats, generic wsr/wsw, length delimited codec transport
This commit is contained in:
parent
1b03620be0
commit
b6727b5019
12 changed files with 399 additions and 196 deletions
|
@ -18,10 +18,16 @@ lazy_static = "1.5.0"
|
||||||
log = { version = "0.4.22", features = ["serde", "std"] }
|
log = { version = "0.4.22", features = ["serde", "std"] }
|
||||||
regex = "1.10.6"
|
regex = "1.10.6"
|
||||||
serde = { version = "1.0.208", features = ["derive"] }
|
serde = { version = "1.0.208", features = ["derive"] }
|
||||||
serde_json = "1.0.125"
|
serde_json = { version = "1.0.125", optional = true }
|
||||||
serde_yaml = "0.9.34"
|
serde_yaml = { version = "0.9.34", optional = true }
|
||||||
tokio = { version = "1.39.3", features = ["full"] }
|
tokio = { version = "1.39.3", features = ["full"] }
|
||||||
tokio-util = { version = "0.7.11", features = ["compat", "io-util", "net"] }
|
tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] }
|
||||||
toml = "0.8.19"
|
toml = { version = "0.8.19", optional = true }
|
||||||
uuid = { version = "1.10.0", features = ["v4"] }
|
uuid = { version = "1.10.0", features = ["v4"] }
|
||||||
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets"] }
|
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["toml"]
|
||||||
|
json = ["dep:serde_json"]
|
||||||
|
yaml = ["dep:serde_yaml"]
|
||||||
|
toml = ["dep:toml"]
|
||||||
|
|
|
@ -22,6 +22,18 @@ pub enum SocketType {
|
||||||
Unix,
|
Unix,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Default, Debug)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum SocketTransport {
|
||||||
|
/// WebSocket transport.
|
||||||
|
#[default]
|
||||||
|
WebSocket,
|
||||||
|
/// Little-endian u32 length-delimited codec. See
|
||||||
|
/// [tokio-util](https://docs.rs/tokio-util/latest/tokio_util/codec/length_delimited/index.html)
|
||||||
|
/// for more information.
|
||||||
|
LengthDelimitedLe,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
|
@ -29,6 +41,8 @@ pub struct ServerConfig {
|
||||||
pub bind: String,
|
pub bind: String,
|
||||||
/// Socket type to listen on.
|
/// Socket type to listen on.
|
||||||
pub socket: SocketType,
|
pub socket: SocketType,
|
||||||
|
/// Transport to listen on.
|
||||||
|
pub transport: SocketTransport,
|
||||||
/// Whether or not to resolve and connect to IPV6 upstream addresses.
|
/// Whether or not to resolve and connect to IPV6 upstream addresses.
|
||||||
pub resolve_ipv6: bool,
|
pub resolve_ipv6: bool,
|
||||||
/// Whether or not to enable TCP nodelay on client TCP streams.
|
/// Whether or not to enable TCP nodelay on client TCP streams.
|
||||||
|
@ -189,6 +203,7 @@ impl Default for ServerConfig {
|
||||||
Self {
|
Self {
|
||||||
bind: "127.0.0.1:4000".to_string(),
|
bind: "127.0.0.1:4000".to_string(),
|
||||||
socket: SocketType::default(),
|
socket: SocketType::default(),
|
||||||
|
transport: SocketTransport::default(),
|
||||||
resolve_ipv6: false,
|
resolve_ipv6: false,
|
||||||
tcp_nodelay: false,
|
tcp_nodelay: false,
|
||||||
|
|
||||||
|
@ -318,16 +333,22 @@ impl StreamConfig {
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn ser(&self) -> anyhow::Result<String> {
|
pub fn ser(&self) -> anyhow::Result<String> {
|
||||||
Ok(match CLI.format {
|
Ok(match CLI.format {
|
||||||
|
#[cfg(feature = "toml")]
|
||||||
ConfigFormat::Toml => toml::to_string_pretty(self)?,
|
ConfigFormat::Toml => toml::to_string_pretty(self)?,
|
||||||
|
#[cfg(feature = "json")]
|
||||||
ConfigFormat::Json => serde_json::to_string_pretty(self)?,
|
ConfigFormat::Json => serde_json::to_string_pretty(self)?,
|
||||||
|
#[cfg(feature = "yaml")]
|
||||||
ConfigFormat::Yaml => serde_yaml::to_string(self)?,
|
ConfigFormat::Yaml => serde_yaml::to_string(self)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn de(string: String) -> anyhow::Result<Self> {
|
pub fn de(string: String) -> anyhow::Result<Self> {
|
||||||
Ok(match CLI.format {
|
Ok(match CLI.format {
|
||||||
|
#[cfg(feature = "toml")]
|
||||||
ConfigFormat::Toml => toml::from_str(&string)?,
|
ConfigFormat::Toml => toml::from_str(&string)?,
|
||||||
|
#[cfg(feature = "json")]
|
||||||
ConfigFormat::Json => serde_json::from_str(&string)?,
|
ConfigFormat::Json => serde_json::from_str(&string)?,
|
||||||
|
#[cfg(feature = "yaml")]
|
||||||
ConfigFormat::Yaml => serde_yaml::from_str(&string)?,
|
ConfigFormat::Yaml => serde_yaml::from_str(&string)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -335,9 +356,12 @@ impl Config {
|
||||||
|
|
||||||
#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Default, ValueEnum)]
|
#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Default, ValueEnum)]
|
||||||
pub enum ConfigFormat {
|
pub enum ConfigFormat {
|
||||||
|
#[cfg(feature = "toml")]
|
||||||
#[default]
|
#[default]
|
||||||
Toml,
|
Toml,
|
||||||
|
#[cfg(feature = "json")]
|
||||||
Json,
|
Json,
|
||||||
|
#[cfg(feature = "yaml")]
|
||||||
Yaml,
|
Yaml,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
use std::io::Cursor;
|
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use fastwebsockets::upgrade::UpgradeFut;
|
|
||||||
use futures_util::FutureExt;
|
use futures_util::FutureExt;
|
||||||
use hyper_util::rt::TokioIo;
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
|
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||||
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
|
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||||
select,
|
select,
|
||||||
task::JoinSet,
|
task::JoinSet,
|
||||||
|
@ -17,7 +13,8 @@ use wisp_mux::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
stream::{ClientStream, ResolvedPacket, ServerStream, ServerStreamExt},
|
listener::WispResult,
|
||||||
|
stream::{ClientStream, ResolvedPacket},
|
||||||
CLIENTS, CONFIG,
|
CLIENTS, CONFIG,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -162,16 +159,8 @@ async fn handle_stream(connect: ConnectPacket, muxstream: MuxStream, id: String)
|
||||||
CLIENTS.get(&id).unwrap().0.remove(&uuid);
|
CLIENTS.get(&id).unwrap().0.remove(&uuid);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_wisp(fut: UpgradeFut, id: String) -> anyhow::Result<()> {
|
pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> {
|
||||||
let mut ws = fut.await.context("failed to await upgrade future")?;
|
let (read, write) = stream;
|
||||||
ws.set_max_message_size(CONFIG.server.max_message_size);
|
|
||||||
|
|
||||||
let (read, write) = ws.split(|x| {
|
|
||||||
let parts = x.into_inner().downcast::<TokioIo<ServerStream>>().unwrap();
|
|
||||||
let (r, w) = parts.io.into_inner().split();
|
|
||||||
(Cursor::new(parts.read_buf).chain(r), w)
|
|
||||||
});
|
|
||||||
|
|
||||||
let (extensions, buffer_size) = CONFIG.wisp.to_opts();
|
let (extensions, buffer_size) = CONFIG.wisp.to_opts();
|
||||||
|
|
||||||
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions)
|
let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use anyhow::Context;
|
use fastwebsockets::CloseCode;
|
||||||
use fastwebsockets::{upgrade::UpgradeFut, CloseCode, FragmentCollector};
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
|
||||||
select,
|
select,
|
||||||
|
@ -15,16 +14,11 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub async fn handle_wsproxy(
|
pub async fn handle_wsproxy(
|
||||||
fut: UpgradeFut,
|
mut ws: WebSocketStreamWrapper,
|
||||||
id: String,
|
id: String,
|
||||||
path: String,
|
path: String,
|
||||||
udp: bool,
|
udp: bool,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let mut ws = fut.await.context("failed to await upgrade future")?;
|
|
||||||
ws.set_max_message_size(CONFIG.server.max_message_size);
|
|
||||||
let ws = FragmentCollector::new(ws);
|
|
||||||
let mut ws = WebSocketStreamWrapper(ws);
|
|
||||||
|
|
||||||
if udp && !CONFIG.stream.allow_wsproxy_udp {
|
if udp && !CONFIG.stream.allow_wsproxy_udp {
|
||||||
let _ = ws.close(CloseCode::Error.into(), b"udp is blocked").await;
|
let _ = ws.close(CloseCode::Error.into(), b"udp is blocked").await;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|
236
server/src/listener.rs
Normal file
236
server/src/listener.rs
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
use std::{future::Future, io::Cursor};
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use fastwebsockets::{upgrade::UpgradeFut, FragmentCollector};
|
||||||
|
use http_body_util::Full;
|
||||||
|
use hyper::{
|
||||||
|
body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response,
|
||||||
|
StatusCode,
|
||||||
|
};
|
||||||
|
use hyper_util::rt::TokioIo;
|
||||||
|
use log::error;
|
||||||
|
use tokio::{
|
||||||
|
fs::{remove_file, try_exists},
|
||||||
|
io::AsyncReadExt,
|
||||||
|
net::{tcp, unix, TcpListener, TcpStream, UnixListener, UnixStream},
|
||||||
|
};
|
||||||
|
use tokio_util::{
|
||||||
|
codec::{FramedRead, FramedWrite, LengthDelimitedCodec},
|
||||||
|
either::Either,
|
||||||
|
};
|
||||||
|
use uuid::Uuid;
|
||||||
|
use wisp_mux::{
|
||||||
|
generic::{GenericWebSocketRead, GenericWebSocketWrite},
|
||||||
|
ws::{WebSocketRead, WebSocketWrite},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
config::{SocketTransport, SocketType},
|
||||||
|
generate_stats,
|
||||||
|
stream::WebSocketStreamWrapper,
|
||||||
|
CONFIG,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub type ServerStream = Either<TcpStream, UnixStream>;
|
||||||
|
pub type ServerStreamRead = Either<tcp::OwnedReadHalf, unix::OwnedReadHalf>;
|
||||||
|
pub type ServerStreamWrite = Either<tcp::OwnedWriteHalf, unix::OwnedWriteHalf>;
|
||||||
|
|
||||||
|
type Body = Full<Bytes>;
|
||||||
|
fn non_ws_resp() -> Response<Body> {
|
||||||
|
Response::builder()
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.body(Body::new(CONFIG.server.non_ws_response.as_bytes().into()))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ws_upgrade<T, R>(mut req: Request<Incoming>, callback: T) -> anyhow::Result<Response<Body>>
|
||||||
|
where
|
||||||
|
T: FnOnce(UpgradeFut, bool, bool, String) -> R,
|
||||||
|
R: Future<Output = anyhow::Result<()>>,
|
||||||
|
{
|
||||||
|
if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint {
|
||||||
|
match generate_stats() {
|
||||||
|
Ok(x) => {
|
||||||
|
return Ok(Response::builder()
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.body(Body::new(x.into()))
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
|
Err(x) => {
|
||||||
|
return Ok(Response::builder()
|
||||||
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
.body(Body::new(x.to_string().into()))
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
||||||
|
return Ok(non_ws_resp());
|
||||||
|
}
|
||||||
|
|
||||||
|
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
|
||||||
|
// replace body of Empty<Bytes> with Full<Bytes>
|
||||||
|
let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
|
||||||
|
|
||||||
|
if req
|
||||||
|
.uri()
|
||||||
|
.path()
|
||||||
|
.starts_with(&(CONFIG.server.prefix.clone() + "/"))
|
||||||
|
{
|
||||||
|
(callback)(fut, false, false, req.uri().path().to_string());
|
||||||
|
} else if CONFIG.wisp.allow_wsproxy {
|
||||||
|
let udp = req.uri().query().unwrap_or_default() == "?udp";
|
||||||
|
(callback)(fut, true, udp, req.uri().path().to_string());
|
||||||
|
} else {
|
||||||
|
return Ok(non_ws_resp());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ServerStreamExt {
|
||||||
|
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
|
||||||
|
async fn route(self, callback: impl FnOnce(ServerRouteResult) + Clone) -> anyhow::Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerStreamExt for ServerStream {
|
||||||
|
fn split(self) -> (ServerStreamRead, ServerStreamWrite) {
|
||||||
|
match self {
|
||||||
|
Self::Left(x) => {
|
||||||
|
let (r, w) = x.into_split();
|
||||||
|
(Either::Left(r), Either::Left(w))
|
||||||
|
}
|
||||||
|
Self::Right(x) => {
|
||||||
|
let (r, w) = x.into_split();
|
||||||
|
(Either::Right(r), Either::Right(w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route(self, callback: impl FnOnce(ServerRouteResult) + Clone) -> anyhow::Result<()> {
|
||||||
|
match CONFIG.server.transport {
|
||||||
|
SocketTransport::WebSocket => {
|
||||||
|
let stream = TokioIo::new(self);
|
||||||
|
|
||||||
|
let fut = Builder::new()
|
||||||
|
.serve_connection(
|
||||||
|
stream,
|
||||||
|
service_fn(move |req| {
|
||||||
|
let callback = callback.clone();
|
||||||
|
|
||||||
|
ws_upgrade(req, |fut, wsproxy, udp, path| async move {
|
||||||
|
let mut ws = fut.await.context("failed to await upgrade future")?;
|
||||||
|
ws.set_max_message_size(CONFIG.server.max_message_size);
|
||||||
|
|
||||||
|
if wsproxy {
|
||||||
|
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
|
||||||
|
(callback)(ServerRouteResult::WsProxy(ws, path, udp));
|
||||||
|
} else {
|
||||||
|
let (read, write) = ws.split(|x| {
|
||||||
|
let parts = x
|
||||||
|
.into_inner()
|
||||||
|
.downcast::<TokioIo<ServerStream>>()
|
||||||
|
.unwrap();
|
||||||
|
let (r, w) = parts.io.into_inner().split();
|
||||||
|
(Cursor::new(parts.read_buf).chain(r), w)
|
||||||
|
});
|
||||||
|
|
||||||
|
(callback)(ServerRouteResult::Wisp((
|
||||||
|
Box::new(read),
|
||||||
|
Box::new(write),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.with_upgrades();
|
||||||
|
|
||||||
|
if let Err(e) = fut.await {
|
||||||
|
error!("error while serving client: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SocketTransport::LengthDelimitedLe => {
|
||||||
|
let codec = LengthDelimitedCodec::builder()
|
||||||
|
.little_endian()
|
||||||
|
.max_frame_length(usize::MAX)
|
||||||
|
.new_codec();
|
||||||
|
|
||||||
|
let (read, write) = self.split();
|
||||||
|
let read = GenericWebSocketRead::new(FramedRead::new(read, codec.clone()));
|
||||||
|
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
|
||||||
|
|
||||||
|
(callback)(ServerRouteResult::Wisp((Box::new(read), Box::new(write))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type WispResult = (
|
||||||
|
Box<dyn WebSocketRead + Send>,
|
||||||
|
Box<dyn WebSocketWrite + Send>,
|
||||||
|
);
|
||||||
|
|
||||||
|
pub enum ServerRouteResult {
|
||||||
|
Wisp(WispResult),
|
||||||
|
WsProxy(WebSocketStreamWrapper, String, bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ServerListener {
|
||||||
|
Tcp(TcpListener),
|
||||||
|
Unix(UnixListener),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerListener {
|
||||||
|
pub async fn new() -> anyhow::Result<Self> {
|
||||||
|
Ok(match CONFIG.server.socket {
|
||||||
|
SocketType::Tcp => Self::Tcp(
|
||||||
|
TcpListener::bind(&CONFIG.server.bind)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
format!("failed to bind to tcp address `{}`", CONFIG.server.bind)
|
||||||
|
})?,
|
||||||
|
),
|
||||||
|
SocketType::Unix => {
|
||||||
|
if try_exists(&CONFIG.server.bind).await? {
|
||||||
|
remove_file(&CONFIG.server.bind).await?;
|
||||||
|
}
|
||||||
|
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
|
||||||
|
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
|
||||||
|
})?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
|
||||||
|
match self {
|
||||||
|
Self::Tcp(x) => {
|
||||||
|
let (stream, addr) = x
|
||||||
|
.accept()
|
||||||
|
.await
|
||||||
|
.context("failed to accept tcp connection")?;
|
||||||
|
if CONFIG.server.tcp_nodelay {
|
||||||
|
stream
|
||||||
|
.set_nodelay(true)
|
||||||
|
.context("failed to set tcp nodelay")?;
|
||||||
|
}
|
||||||
|
Ok((Either::Left(stream), addr.to_string()))
|
||||||
|
}
|
||||||
|
Self::Unix(x) => x
|
||||||
|
.accept()
|
||||||
|
.await
|
||||||
|
.map(|(x, y)| {
|
||||||
|
(
|
||||||
|
Either::Right(x),
|
||||||
|
y.as_pathname()
|
||||||
|
.and_then(|x| x.to_str())
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.context("failed to accept unix socket connection"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,26 +2,20 @@
|
||||||
|
|
||||||
use std::{fmt::Write, fs::read_to_string};
|
use std::{fmt::Write, fs::read_to_string};
|
||||||
|
|
||||||
use bytes::Bytes;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use config::{validate_config_cache, Cli, Config};
|
use config::{validate_config_cache, Cli, Config};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use handle::{handle_wisp, handle_wsproxy};
|
use handle::{handle_wisp, handle_wsproxy};
|
||||||
use http_body_util::Full;
|
|
||||||
use hyper::{
|
|
||||||
body::Incoming, server::conn::http1::Builder, service::service_fn, Request, Response,
|
|
||||||
StatusCode,
|
|
||||||
};
|
|
||||||
use hyper_util::rt::TokioIo;
|
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
use listener::{ServerListener, ServerRouteResult, ServerStreamExt};
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
use stream::ServerListener;
|
|
||||||
use tokio::signal::unix::{signal, SignalKind};
|
use tokio::signal::unix::{signal, SignalKind};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use wisp_mux::{ConnectPacket, StreamType};
|
use wisp_mux::{ConnectPacket, StreamType};
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
mod handle;
|
mod handle;
|
||||||
|
mod listener;
|
||||||
mod stream;
|
mod stream;
|
||||||
|
|
||||||
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
|
type Client = (DashMap<Uuid, (ConnectPacket, ConnectPacket)>, bool);
|
||||||
|
@ -38,67 +32,6 @@ lazy_static! {
|
||||||
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
|
pub static ref CLIENTS: DashMap<String, Client> = DashMap::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
type Body = Full<Bytes>;
|
|
||||||
fn non_ws_resp() -> Response<Body> {
|
|
||||||
Response::builder()
|
|
||||||
.status(StatusCode::OK)
|
|
||||||
.body(Body::new(CONFIG.server.non_ws_response.as_bytes().into()))
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn upgrade(mut req: Request<Incoming>, id: String) -> anyhow::Result<Response<Body>> {
|
|
||||||
if CONFIG.server.enable_stats_endpoint && req.uri().path() == CONFIG.server.stats_endpoint {
|
|
||||||
match generate_stats() {
|
|
||||||
Ok(x) => {
|
|
||||||
return Ok(Response::builder()
|
|
||||||
.status(StatusCode::OK)
|
|
||||||
.body(Body::new(x.into()))
|
|
||||||
.unwrap())
|
|
||||||
}
|
|
||||||
Err(x) => {
|
|
||||||
return Ok(Response::builder()
|
|
||||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
.body(Body::new(x.to_string().into()))
|
|
||||||
.unwrap())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
|
||||||
return Ok(non_ws_resp());
|
|
||||||
}
|
|
||||||
|
|
||||||
let (resp, fut) = fastwebsockets::upgrade::upgrade(&mut req)?;
|
|
||||||
// replace body of Empty<Bytes> with Full<Bytes>
|
|
||||||
let resp = Response::from_parts(resp.into_parts().0, Body::new(Bytes::new()));
|
|
||||||
|
|
||||||
if req
|
|
||||||
.uri()
|
|
||||||
.path()
|
|
||||||
.starts_with(&(CONFIG.server.prefix.clone() + "/"))
|
|
||||||
{
|
|
||||||
tokio::spawn(async move {
|
|
||||||
CLIENTS.insert(id.clone(), (DashMap::new(), false));
|
|
||||||
if let Err(e) = handle_wisp(fut, id.clone()).await {
|
|
||||||
error!("error while handling upgraded client: {:?}", e);
|
|
||||||
};
|
|
||||||
CLIENTS.remove(&id)
|
|
||||||
});
|
|
||||||
} else if CONFIG.wisp.allow_wsproxy {
|
|
||||||
let udp = req.uri().query().unwrap_or_default() == "?udp";
|
|
||||||
tokio::spawn(async move {
|
|
||||||
CLIENTS.insert(id.clone(), (DashMap::new(), true));
|
|
||||||
if let Err(e) = handle_wsproxy(fut, id.clone(), req.uri().path().to_string(), udp).await
|
|
||||||
{
|
|
||||||
error!("error while handling upgraded client: {:?}", e);
|
|
||||||
};
|
|
||||||
CLIENTS.remove(&id)
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
return Ok(non_ws_resp());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_stream_type(stream_type: StreamType) -> &'static str {
|
fn format_stream_type(stream_type: StreamType) -> &'static str {
|
||||||
match stream_type {
|
match stream_type {
|
||||||
StreamType::Tcp => "tcp",
|
StreamType::Tcp => "tcp",
|
||||||
|
@ -159,6 +92,22 @@ fn generate_stats() -> Result<String, std::fmt::Error> {
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn handle_stream(stream: ServerRouteResult, id: String) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
CLIENTS.insert(id.clone(), (DashMap::new(), false));
|
||||||
|
let res = match stream {
|
||||||
|
ServerRouteResult::Wisp(stream) => handle_wisp(stream, id.clone()).await,
|
||||||
|
ServerRouteResult::WsProxy(ws, path, udp) => {
|
||||||
|
handle_wsproxy(ws, id.clone(), path, udp).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = res {
|
||||||
|
error!("error while handling client: {:?}", e);
|
||||||
|
}
|
||||||
|
CLIENTS.remove(&id)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
#[tokio::main(flavor = "multi_thread")]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
if CLI.default_config {
|
if CLI.default_config {
|
||||||
|
@ -174,8 +123,8 @@ async fn main() -> anyhow::Result<()> {
|
||||||
validate_config_cache();
|
validate_config_cache();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"listening on {:?} with socket type {:?}",
|
"listening on {:?} with socket type {:?} and socket transport {:?}",
|
||||||
CONFIG.server.bind, CONFIG.server.socket
|
CONFIG.server.bind, CONFIG.server.socket, CONFIG.server.transport
|
||||||
);
|
);
|
||||||
|
|
||||||
tokio::spawn(async {
|
tokio::spawn(async {
|
||||||
|
@ -189,14 +138,10 @@ async fn main() -> anyhow::Result<()> {
|
||||||
loop {
|
loop {
|
||||||
let (stream, id) = listener.accept().await?;
|
let (stream, id) = listener.accept().await?;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let stream = TokioIo::new(stream);
|
let res = stream.route(move |stream| handle_stream(stream, id)).await;
|
||||||
|
|
||||||
let fut = Builder::new()
|
if let Err(e) = res {
|
||||||
.serve_connection(stream, service_fn(|req| upgrade(req, id.clone())))
|
error!("error while routing client: {:?}", e);
|
||||||
.with_upgrades();
|
|
||||||
|
|
||||||
if let Err(e) = fut.await {
|
|
||||||
error!("error while serving client: {:?}", e);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,95 +9,10 @@ use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError};
|
||||||
use hyper::upgrade::Upgraded;
|
use hyper::upgrade::Upgraded;
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use regex::RegexSet;
|
use regex::RegexSet;
|
||||||
use tokio::{
|
use tokio::net::{lookup_host, TcpStream, UdpSocket};
|
||||||
fs::{remove_file, try_exists},
|
|
||||||
net::{lookup_host, tcp, unix, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream},
|
|
||||||
};
|
|
||||||
use tokio_util::either::Either;
|
|
||||||
use uuid::Uuid;
|
|
||||||
use wisp_mux::{ConnectPacket, StreamType};
|
use wisp_mux::{ConnectPacket, StreamType};
|
||||||
|
|
||||||
use crate::{config::SocketType, CONFIG};
|
use crate::CONFIG;
|
||||||
|
|
||||||
pub enum ServerListener {
|
|
||||||
Tcp(TcpListener),
|
|
||||||
Unix(UnixListener),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type ServerStream = Either<TcpStream, UnixStream>;
|
|
||||||
pub type ServerStreamRead = Either<tcp::OwnedReadHalf, unix::OwnedReadHalf>;
|
|
||||||
pub type ServerStreamWrite = Either<tcp::OwnedWriteHalf, unix::OwnedWriteHalf>;
|
|
||||||
|
|
||||||
pub trait ServerStreamExt {
|
|
||||||
fn split(self) -> (ServerStreamRead, ServerStreamWrite);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ServerStreamExt for ServerStream {
|
|
||||||
fn split(self) -> (ServerStreamRead, ServerStreamWrite) {
|
|
||||||
match self {
|
|
||||||
Self::Left(x) => {
|
|
||||||
let (r, w) = x.into_split();
|
|
||||||
(Either::Left(r), Either::Left(w))
|
|
||||||
}
|
|
||||||
Self::Right(x) => {
|
|
||||||
let (r, w) = x.into_split();
|
|
||||||
(Either::Right(r), Either::Right(w))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ServerListener {
|
|
||||||
pub async fn new() -> anyhow::Result<Self> {
|
|
||||||
Ok(match CONFIG.server.socket {
|
|
||||||
SocketType::Tcp => Self::Tcp(
|
|
||||||
TcpListener::bind(&CONFIG.server.bind)
|
|
||||||
.await
|
|
||||||
.with_context(|| {
|
|
||||||
format!("failed to bind to tcp address `{}`", CONFIG.server.bind)
|
|
||||||
})?,
|
|
||||||
),
|
|
||||||
SocketType::Unix => {
|
|
||||||
if try_exists(&CONFIG.server.bind).await? {
|
|
||||||
remove_file(&CONFIG.server.bind).await?;
|
|
||||||
}
|
|
||||||
Self::Unix(UnixListener::bind(&CONFIG.server.bind).with_context(|| {
|
|
||||||
format!("failed to bind to unix socket at `{}`", CONFIG.server.bind)
|
|
||||||
})?)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn accept(&self) -> anyhow::Result<(ServerStream, String)> {
|
|
||||||
match self {
|
|
||||||
Self::Tcp(x) => {
|
|
||||||
let (stream, addr) = x
|
|
||||||
.accept()
|
|
||||||
.await
|
|
||||||
.context("failed to accept tcp connection")?;
|
|
||||||
if CONFIG.server.tcp_nodelay {
|
|
||||||
stream
|
|
||||||
.set_nodelay(true)
|
|
||||||
.context("failed to set tcp nodelay")?;
|
|
||||||
}
|
|
||||||
Ok((Either::Left(stream), addr.to_string()))
|
|
||||||
}
|
|
||||||
Self::Unix(x) => x
|
|
||||||
.accept()
|
|
||||||
.await
|
|
||||||
.map(|(x, y)| {
|
|
||||||
(
|
|
||||||
Either::Right(x),
|
|
||||||
y.as_pathname()
|
|
||||||
.and_then(|x| x.to_str())
|
|
||||||
.map(ToString::to_string)
|
|
||||||
.unwrap_or_else(|| Uuid::new_v4().to_string() + "-unix_socket"),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.context("failed to accept unix socket connection"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool {
|
fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool {
|
||||||
blocked.is_match(str) && !allowed.is_match(str)
|
blocked.is_match(str) && !allowed.is_match(str)
|
||||||
|
|
|
@ -22,7 +22,9 @@ pin-project-lite = "0.2.14"
|
||||||
tokio = { version = "1.39.3", optional = true, default-features = false }
|
tokio = { version = "1.39.3", optional = true, default-features = false }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
default = ["generic_stream"]
|
||||||
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
|
||||||
|
generic_stream = []
|
||||||
wasm = ["futures-timer/wasm-bindgen"]
|
wasm = ["futures-timer/wasm-bindgen"]
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
//! WebSocketRead + WebSocketWrite implementation for the fastwebsockets library.
|
||||||
|
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
87
wisp/src/generic.rs
Normal file
87
wisp/src/generic.rs
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
//! WebSocketRead + WebSocketWrite implementation for generic `Stream + Sink`s.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
ws::{Frame, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
|
WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// WebSocketRead implementation for generic `Stream`s.
|
||||||
|
pub struct GenericWebSocketRead<
|
||||||
|
T: Stream<Item = Result<BytesMut, E>> + Send + Unpin,
|
||||||
|
E: Error + Sync + Send + 'static,
|
||||||
|
>(T);
|
||||||
|
|
||||||
|
impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Send + 'static>
|
||||||
|
GenericWebSocketRead<T, E>
|
||||||
|
{
|
||||||
|
/// Create a new wrapper WebSocketRead implementation.
|
||||||
|
pub fn new(stream: T) -> Self {
|
||||||
|
Self(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the inner Stream from the wrapper.
|
||||||
|
pub fn into_inner(self) -> T {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Send + 'static>
|
||||||
|
WebSocketRead for GenericWebSocketRead<T, E>
|
||||||
|
{
|
||||||
|
async fn wisp_read_frame(
|
||||||
|
&mut self,
|
||||||
|
_tx: &LockedWebSocketWrite,
|
||||||
|
) -> Result<Frame<'static>, WispError> {
|
||||||
|
match self.0.next().await {
|
||||||
|
Some(data) => Ok(Frame::binary(Payload::Bytes(
|
||||||
|
data.map_err(|x| WispError::WsImplError(Box::new(x)))?,
|
||||||
|
))),
|
||||||
|
None => Ok(Frame::close(Payload::Bytes(BytesMut::new()))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WebSocketWrite implementation for generic `Sink`s.
|
||||||
|
pub struct GenericWebSocketWrite<
|
||||||
|
T: Sink<Bytes, Error = E> + Send + Unpin,
|
||||||
|
E: Error + Sync + Send + 'static,
|
||||||
|
>(T);
|
||||||
|
|
||||||
|
impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static>
|
||||||
|
GenericWebSocketWrite<T, E>
|
||||||
|
{
|
||||||
|
/// Create a new wrapper WebSocketWrite implementation.
|
||||||
|
pub fn new(stream: T) -> Self {
|
||||||
|
Self(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the inner Sink from the wrapper.
|
||||||
|
pub fn into_inner(self) -> T {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static> WebSocketWrite
|
||||||
|
for GenericWebSocketWrite<T, E>
|
||||||
|
{
|
||||||
|
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||||
|
self.0
|
||||||
|
.send(BytesMut::from(frame.payload).freeze())
|
||||||
|
.await
|
||||||
|
.map_err(|x| WispError::WsImplError(Box::new(x)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wisp_close(&mut self) -> Result<(), WispError> {
|
||||||
|
self.0
|
||||||
|
.close()
|
||||||
|
.await
|
||||||
|
.map_err(|x| WispError::WsImplError(Box::new(x)))
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,9 @@ pub mod extensions;
|
||||||
#[cfg(feature = "fastwebsockets")]
|
#[cfg(feature = "fastwebsockets")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
|
||||||
mod fastwebsockets;
|
mod fastwebsockets;
|
||||||
|
#[cfg(feature = "generic_stream")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
|
||||||
|
pub mod generic;
|
||||||
mod packet;
|
mod packet;
|
||||||
mod sink_unfold;
|
mod sink_unfold;
|
||||||
mod stream;
|
mod stream;
|
||||||
|
|
|
@ -167,7 +167,7 @@ pub trait WebSocketRead {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WebSocketRead for Box<dyn WebSocketRead + Send + Sync> {
|
impl WebSocketRead for Box<dyn WebSocketRead + Send> {
|
||||||
async fn wisp_read_frame(
|
async fn wisp_read_frame(
|
||||||
&mut self,
|
&mut self,
|
||||||
tx: &LockedWebSocketWrite,
|
tx: &LockedWebSocketWrite,
|
||||||
|
@ -206,7 +206,7 @@ pub trait WebSocketWrite {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WebSocketWrite for Box<dyn WebSocketWrite + Send + Sync> {
|
impl WebSocketWrite for Box<dyn WebSocketWrite + Send> {
|
||||||
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
async fn wisp_write_frame(&mut self, frame: Frame<'_>) -> Result<(), WispError> {
|
||||||
self.as_mut().wisp_write_frame(frame).await
|
self.as_mut().wisp_write_frame(frame).await
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue