This commit is contained in:
Toshit Chawda 2024-11-24 21:10:08 -08:00
parent 3cef68d164
commit 13f282160b
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
17 changed files with 482 additions and 65 deletions

View file

@ -6,7 +6,8 @@ edition = "2021"
[dependencies]
anyhow = "1.0.86"
async-speed-limit = { version = "0.4.2", optional = true }
async-trait = { version = "0.1.81", optional = true }
async-trait = "0.1.81"
base64 = "0.22.1"
bytes = "1.7.1"
cfg-if = "1.0.0"
clap = { version = "4.5.16", features = ["cargo", "derive"] }
@ -48,7 +49,7 @@ default = ["toml"]
yaml = ["dep:serde_yaml"]
toml = ["dep:toml"]
twisp = ["dep:pty-process", "dep:libc", "dep:async-trait", "dep:shell-words"]
twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"]
speed-limit = ["dep:async-speed-limit"]
tokio-console = ["dep:console-subscriber", "tokio/tracing"]

View file

@ -123,6 +123,8 @@ pub enum ProtocolExtension {
Udp,
/// Wisp version 2 MOTD protocol extension.
Motd,
/// Unofficial Wispnet-like protocol extension.
Wispnet,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
@ -366,6 +368,11 @@ impl Default for WispConfig {
}
impl WispConfig {
#[doc(hidden)]
pub fn has_wispnet(&self) -> bool {
self.extensions.contains(&ProtocolExtension::Wispnet)
}
#[doc(hidden)]
pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Handshake>, Vec<u8>, u32)> {
if self.wisp_v2 {

View file

@ -1,6 +1,7 @@
#[cfg(feature = "twisp")]
pub mod twisp;
pub mod utils;
pub mod wispnet;
use std::{sync::Arc, time::Duration};
@ -23,6 +24,7 @@ use wisp_mux::{
ws::Payload, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite,
ServerMux,
};
use wispnet::route_wispnet;
use crate::{
route::{WispResult, WispStreamWrite},
@ -100,8 +102,23 @@ async fn handle_stream(
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
let (stream, resolved_stream) = match resolved {
ResolvedPacket::Valid(connect) => {
let resolved = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
(stream, resolved)
}
ResolvedPacket::ValidWispnet(server, connect) => {
let resolved = connect.clone();
let Ok(stream) = route_wispnet(server, connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
(stream, resolved)
}
ResolvedPacket::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
@ -118,13 +135,6 @@ async fn handle_stream(
}
};
let resolved_stream = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let uuid = Uuid::new_v4();
debug!(
@ -137,7 +147,7 @@ async fn handle_stream(
.0
.lock()
.await
.insert(uuid, (requested_stream, resolved_stream));
.insert(uuid, (requested_stream, resolved_stream.clone()));
}
let forward_fut = async {
@ -213,6 +223,12 @@ async fn handle_stream(
}
}
}
ClientStream::Wispnet(stream, mux_id) => {
wispnet::handle_stream(muxstream, stream, mux_id, uuid, resolved_stream).await
}
ClientStream::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
}
ClientStream::Invalid => {
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
}

View file

@ -0,0 +1,244 @@
use std::{
collections::HashMap,
sync::atomic::{AtomicU32, Ordering},
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use lazy_static::lazy_static;
use log::debug;
use tokio::{select, sync::Mutex};
use uuid::Uuid;
use wisp_mux::{
extensions::{
AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder, ProtocolExtensionVecExt,
},
ws::{DynWebSocketRead, Frame, LockingWebSocketWrite, Payload},
ClientMux, CloseReason, ConnectPacket, MuxStream, MuxStreamRead, MuxStreamWrite, Role,
WispError, WispV2Handshake,
};
use crate::{
route::{WispResult, WispStreamWrite},
stream::ClientStream,
CLIENTS,
};
struct WispnetClient {
mux: ClientMux<WispStreamWrite>,
id: String,
private: bool,
}
lazy_static! {
static ref WISPNET_SERVERS: Mutex<HashMap<u32, WispnetClient>> = Mutex::new(HashMap::new());
static ref WISPNET_IDS: AtomicU32 = AtomicU32::new(0);
}
// payload:
// client (acting like wisp server) sends a bool saying if it wants to be private or not
// server (acting like wisp client) sends a u32 client id
//
// packets:
// client sends a 0xF1 on stream id 0 with no body to probe
// server sends back a 0xF1 on stream id 0 with a body of a bunch of u32s for each public client
struct WispnetServerProtocolExtensionBuilder(u32);
impl WispnetServerProtocolExtensionBuilder {
const ID: u8 = 0xF1;
}
#[async_trait]
impl ProtocolExtensionBuilder for WispnetServerProtocolExtensionBuilder {
fn get_id(&self) -> u8 {
Self::ID
}
fn build_from_bytes(
&mut self,
mut bytes: Bytes,
_: Role,
) -> Result<AnyProtocolExtension, WispError> {
if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall);
};
Ok(WispnetServerProtocolExtension(self.0, bytes.get_u8() != 0).into())
}
fn build_to_extension(&mut self, _: Role) -> Result<AnyProtocolExtension, WispError> {
Ok(WispnetServerProtocolExtension(self.0, false).into())
}
}
#[derive(Debug, Copy, Clone)]
struct WispnetServerProtocolExtension(u32, pub bool);
impl WispnetServerProtocolExtension {
const ID: u8 = 0xF1;
}
#[async_trait]
impl ProtocolExtension for WispnetServerProtocolExtension {
fn get_id(&self) -> u8 {
Self::ID
}
fn get_supported_packets(&self) -> &'static [u8] {
&[Self::ID]
}
fn get_congestion_stream_types(&self) -> &'static [u8] {
&[]
}
fn encode(&self) -> Bytes {
let mut out = BytesMut::new();
out.put_u32_le(self.0);
out.freeze()
}
async fn handle_handshake(
&mut self,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
async fn handle_packet(
&mut self,
packet_type: u8,
mut packet: Bytes,
_: &mut DynWebSocketRead,
write: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
if packet_type == Self::ID {
if packet.remaining() < 4 {
return Err(WispError::PacketTooSmall);
}
if packet.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let mut out = BytesMut::new();
out.put_u8(Self::ID);
out.put_u32_le(0);
let locked = WISPNET_SERVERS.lock().await;
for client in locked.iter() {
if !client.1.private {
out.put_u32_le(*client.0);
}
}
drop(locked);
write
.wisp_write_frame(Frame::binary(Payload::Bytes(out)))
.await?;
}
Ok(())
}
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(*self)
}
}
pub async fn route_wispnet(server: u32, packet: ConnectPacket) -> Result<ClientStream> {
if let Some(server) = WISPNET_SERVERS.lock().await.get(&server) {
let stream = server
.mux
.client_new_stream(
packet.stream_type,
packet.destination_hostname,
packet.destination_port,
)
.await
.context("failed to connect to wispnet server")?;
Ok(ClientStream::Wispnet(stream, server.id.clone()))
} else {
Ok(ClientStream::NoResolvedAddrs)
}
}
async fn copy_wisp(
rx: MuxStreamRead<WispStreamWrite>,
tx: MuxStreamWrite<WispStreamWrite>,
#[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter<
async_speed_limit::clock::StandardClock,
>,
) -> Result<()> {
while let Some(data) = rx.read().await? {
#[cfg(feature = "speed-limit")]
limiter.consume(data.len()).await;
tx.write_payload(Payload::Borrowed(data.as_ref())).await?;
}
Ok(())
}
pub async fn handle_stream(
mux: MuxStream<WispStreamWrite>,
wisp: MuxStream<WispStreamWrite>,
mux_id: String,
uuid: Uuid,
resolved_stream: ConnectPacket,
) {
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client
.0
.lock()
.await
.insert(uuid, (resolved_stream.clone(), resolved_stream));
}
let closer = mux.get_close_handle();
let (muxread, muxwrite) = mux.into_split();
let (wispread, wispwrite) = wisp.into_split();
let _ = select! {
x = copy_wisp(muxread, wispwrite, #[cfg(feature = "speed-limit")] write_limit) => x,
x = copy_wisp(wispread, muxwrite, #[cfg(feature = "speed-limit")] read_limit) => x,
};
let _ = closer
.close(closer.get_close_reason().unwrap_or(CloseReason::Unknown))
.await;
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client.0.lock().await.remove(&uuid);
}
}
pub async fn handle_wispnet(stream: WispResult, id: String) -> Result<()> {
let (read, write) = stream;
let net_id = WISPNET_IDS.fetch_add(1, Ordering::SeqCst);
let extensions = vec![WispnetServerProtocolExtensionBuilder(net_id).into()];
let (mux, fut) = ClientMux::create(read, write, Some(WispV2Handshake::new(extensions)))
.await
.context("failed to create client multiplexor")?
.with_required_extensions(&[WispnetServerProtocolExtension::ID])
.await
.context("wispnet client did not have wispnet extension")?;
let is_private = mux
.supported_extensions
.find_extension::<WispnetServerProtocolExtension>()
.context("failed to find wispnet extension")?
.1;
WISPNET_SERVERS.lock().await.insert(
net_id,
WispnetClient {
mux,
id: id.clone(),
private: is_private,
},
);
// probably the only time someone would do this
let ret = fut.await;
debug!("wispnet client id {:?} multiplexor result {:?}", id, ret);
WISPNET_SERVERS.lock().await.remove(&net_id);
Ok(())
}

View file

@ -7,9 +7,10 @@ use tokio::{
select,
};
use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType};
use wisp_mux::{ws::Payload, CloseReason, ConnectPacket, StreamType};
use crate::{
handle::wisp::wispnet::route_wispnet,
stream::{ClientStream, ResolvedPacket, WebSocketFrame, WebSocketStreamWrapper},
CLIENTS, CONFIG,
};
@ -48,8 +49,27 @@ pub async fn handle_wsproxy(
.await;
return Ok(());
};
let connect = match resolved {
ResolvedPacket::Valid(x) => x,
let (stream, resolved_stream) = match resolved {
ResolvedPacket::Valid(connect) => {
let resolved = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
(stream, resolved)
}
ResolvedPacket::ValidWispnet(server, connect) => {
let resolved = connect.clone();
let Ok(stream) = route_wispnet(server, connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
(stream, resolved)
}
ResolvedPacket::NoResolvedAddrs => {
let _ = ws
.close(
@ -74,15 +94,6 @@ pub async fn handle_wsproxy(
}
};
let resolved_stream = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
let uuid = Uuid::new_v4();
debug!(
@ -95,7 +106,7 @@ pub async fn handle_wsproxy(
.0
.lock()
.await
.insert(uuid, (requested_stream, resolved_stream));
.insert(uuid, (requested_stream, resolved_stream.clone()));
}
match stream {
@ -173,6 +184,65 @@ pub async fn handle_wsproxy(
.close(CloseCode::Error.into(), b"twisp is not supported")
.await;
}
ClientStream::Wispnet(stream, mux_id) => {
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client
.0
.lock()
.await
.insert(uuid, (resolved_stream.clone(), resolved_stream));
}
let ret: anyhow::Result<()> = async {
loop {
select! {
x = ws.read() => {
match x? {
WebSocketFrame::Data(data) => {
stream.write_payload(Payload::Bytes(data)).await?;
}
WebSocketFrame::Close => {
stream.close(CloseReason::Voluntary).await?;
}
WebSocketFrame::Ignore => {}
}
}
x = stream.read() => {
let Some(x) = x? else {
break;
};
ws.write(&x).await?;
}
}
}
Ok(())
}
.await;
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client.0.lock().await.remove(&uuid);
}
match ret {
Ok(_) => {
let _ = ws.close(CloseCode::Normal.into(), b"").await;
}
Err(x) => {
let _ = ws
.close(CloseCode::Normal.into(), x.to_string().as_bytes())
.await;
}
}
}
ClientStream::NoResolvedAddrs => {
let _ = ws
.close(
CloseCode::Error.into(),
b"host did not resolve to any addrs",
)
.await;
return Ok(());
}
ClientStream::Blocked => {
let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await;
}

View file

@ -1,13 +1,14 @@
#![doc(html_no_source)]
#![deny(clippy::todo)]
#![allow(unexpected_cfgs)]
#![warn(clippy::large_futures)]
use std::{collections::HashMap, fs::read_to_string, net::IpAddr};
use anyhow::{Context, Result};
use clap::Parser;
use config::{validate_config_cache, Cli, Config, RuntimeFlavor};
use handle::{handle_wisp, handle_wsproxy};
use handle::{handle_wisp, handle_wsproxy, wisp::wispnet::handle_wispnet};
use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
system_conf::read_system_conf,
@ -41,7 +42,7 @@ mod stream;
mod util_chain;
#[doc(hidden)]
type Client = (Mutex<HashMap<Uuid, (ConnectPacket, ConnectPacket)>>, bool);
type Client = (Mutex<HashMap<Uuid, (ConnectPacket, ConnectPacket)>>, String);
#[doc(hidden)]
#[derive(Debug)]
@ -234,14 +235,18 @@ async fn async_main() -> Result<()> {
#[doc(hidden)]
fn handle_stream(stream: ServerRouteResult, id: String) {
tokio::spawn(async move {
CLIENTS
.lock()
.await
.insert(id.clone(), (Mutex::new(HashMap::new()), false));
CLIENTS.lock().await.insert(
id.clone(),
(Mutex::new(HashMap::new()), format!("{}", stream)),
);
let res = match stream {
ServerRouteResult::Wisp(stream, is_v2) => handle_wisp(stream, is_v2, id.clone()).await,
ServerRouteResult::WsProxy(ws, path, udp) => {
handle_wsproxy(ws, id.clone(), path, udp).await
ServerRouteResult::Wisp {
stream,
has_ws_protocol,
} => handle_wisp(stream, has_ws_protocol, id.clone()).await,
ServerRouteResult::Wispnet { stream } => handle_wispnet(stream, id.clone()).await,
ServerRouteResult::WsProxy { stream, path, udp } => {
handle_wsproxy(stream, id.clone(), path, udp).await
}
};
if let Err(e) = res {

View file

@ -36,15 +36,26 @@ pub type WispStreamWrite = EitherWebSocketWrite<
pub type WispResult = (WispStreamRead, WispStreamWrite);
pub enum ServerRouteResult {
Wisp(WispResult, bool),
WsProxy(WebSocketStreamWrapper, String, bool),
Wisp {
stream: WispResult,
has_ws_protocol: bool,
},
Wispnet {
stream: WispResult,
},
WsProxy {
stream: WebSocketStreamWrapper,
path: String,
udp: bool,
},
}
impl Display for ServerRouteResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Wisp(..) => write!(f, "Wisp"),
Self::WsProxy(_, path, udp) => write!(f, "WsProxy path {:?} udp {:?}", path, udp),
Self::Wisp { .. } => write!(f, "Wisp"),
Self::Wispnet { .. } => write!(f, "Wispnet"),
Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {:?} udp {:?}", path, udp),
}
}
}
@ -81,8 +92,14 @@ fn get_header(headers: &HeaderMap, header: &str) -> Option<String> {
}
enum HttpUpgradeResult {
Wisp(bool),
WsProxy(String, bool),
Wisp {
has_ws_protocol: bool,
is_wispnet: bool,
},
WsProxy {
path: String,
udp: bool,
},
}
async fn ws_upgrade<F, R>(
@ -126,11 +143,19 @@ where
let ws_protocol = headers.get(SEC_WEBSOCKET_PROTOCOL);
let req_path = req.uri().path().to_string();
if req_path.starts_with(&(CONFIG.wisp.prefix.clone() + "/")) {
if req_path.ends_with(&(CONFIG.wisp.prefix.clone() + "/")) {
let has_ws_protocol = ws_protocol.is_some();
let is_wispnet = CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net";
tokio::spawn(async move {
if let Err(err) =
(callback)(fut, HttpUpgradeResult::Wisp(has_ws_protocol), ip_header).await
if let Err(err) = (callback)(
fut,
HttpUpgradeResult::Wisp {
has_ws_protocol,
is_wispnet,
},
ip_header,
)
.await
{
error!("error while serving client: {:?}", err);
}
@ -140,10 +165,17 @@ where
.append(SEC_WEBSOCKET_PROTOCOL, protocol.clone());
}
} else if CONFIG.wisp.allow_wsproxy {
let udp = req.uri().query().unwrap_or_default() == "?udp";
let udp = req.uri().query().unwrap_or_default() == "udp";
tokio::spawn(async move {
if let Err(err) =
(callback)(fut, HttpUpgradeResult::WsProxy(req_path, udp), ip_header).await
if let Err(err) = (callback)(
fut,
HttpUpgradeResult::WsProxy {
path: req_path,
udp,
},
ip_header,
)
.await
{
error!("error while serving client: {:?}", err);
}
@ -188,7 +220,10 @@ pub async fn route(
ws.set_auto_pong(false);
match res {
HttpUpgradeResult::Wisp(is_v2) => {
HttpUpgradeResult::Wisp {
has_ws_protocol,
is_wispnet,
} => {
let (read, write) = ws.split(|x| {
let parts = x
.into_inner()
@ -198,21 +233,33 @@ pub async fn route(
(chain(Cursor::new(parts.read_buf), r), w)
});
(callback)(
ServerRouteResult::Wisp(
(
let result = if is_wispnet {
ServerRouteResult::Wispnet {
stream: (
EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write),
),
is_v2,
),
maybe_ip,
)
}
} else {
ServerRouteResult::Wisp {
stream: (
EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write),
),
has_ws_protocol,
}
};
(callback)(result, maybe_ip)
}
HttpUpgradeResult::WsProxy(path, udp) => {
HttpUpgradeResult::WsProxy { path, udp } => {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)(
ServerRouteResult::WsProxy(ws, path, udp),
ServerRouteResult::WsProxy {
stream: ws,
path,
udp,
},
maybe_ip,
);
}
@ -237,13 +284,13 @@ pub async fn route(
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
(callback)(
ServerRouteResult::Wisp(
(
ServerRouteResult::Wisp {
stream: (
EitherWebSocketRead::Right(read),
EitherWebSocketWrite::Right(write),
),
true,
),
has_ws_protocol: true,
},
None,
);
}

View file

@ -50,7 +50,7 @@ impl From<(ConnectPacket, ConnectPacket)> for StreamStats {
#[derive(Serialize)]
struct ClientStats {
wsproxy: bool,
client_type: String,
streams: HashMap<String, StreamStats>,
}
@ -81,7 +81,7 @@ pub async fn generate_stats() -> anyhow::Result<String> {
clients.insert(
client.0.to_string(),
ClientStats {
wsproxy: client.1 .1,
client_type: client.1 .1.clone(),
streams: client
.1
.0

View file

@ -4,16 +4,18 @@ use std::{
};
use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::BytesMut;
use cfg_if::cfg_if;
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use log::debug;
use regex::RegexSet;
use tokio::net::{TcpStream, UdpSocket};
use wisp_mux::{ConnectPacket, StreamType};
use wisp_mux::{ConnectPacket, MuxStream, StreamType};
use crate::{CONFIG, RESOLVER};
use crate::{route::WispStreamWrite, CONFIG, RESOLVER};
fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool {
blocked.is_match(str) && !allowed.is_match(str)
@ -40,6 +42,9 @@ pub enum ClientStream {
Udp(UdpSocket),
#[cfg(feature = "twisp")]
Pty(tokio::process::Child, pty_process::Pty),
Wispnet(MuxStream<WispStreamWrite>, String),
NoResolvedAddrs,
Blocked,
Invalid,
}
@ -105,6 +110,7 @@ fn is_global(addr: &IpAddr) -> bool {
pub enum ResolvedPacket {
Valid(ConnectPacket),
ValidWispnet(u32, ConnectPacket),
NoResolvedAddrs,
Blocked,
Invalid,
@ -112,6 +118,20 @@ pub enum ResolvedPacket {
impl ClientStream {
pub async fn resolve(packet: ConnectPacket) -> anyhow::Result<ResolvedPacket> {
if CONFIG.wisp.has_wispnet() && packet.destination_hostname.ends_with(".wisp") {
if let Some(wispnet_server) = packet.destination_hostname.split(".wisp").next() {
debug!("routing {:?} through wispnet", packet);
let decoded = BASE64_STANDARD
.decode(wispnet_server)
.context("failed to decode wispnet server")?;
let server_id = u32::from_str(
&String::from_utf8(decoded).context("wispnet server was not a string")?,
)
.context("failed to parse wispnet server from string")?;
return Ok(ResolvedPacket::ValidWispnet(server_id, packet));
}
}
cfg_if! {
if #[cfg(feature = "twisp")] {
if let StreamType::Unknown(ty) = packet.stream_type {