This commit is contained in:
Toshit Chawda 2024-11-24 21:10:08 -08:00
parent 3cef68d164
commit 13f282160b
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
17 changed files with 482 additions and 65 deletions

1
Cargo.lock generated
View file

@ -723,6 +723,7 @@ dependencies = [
"anyhow", "anyhow",
"async-speed-limit", "async-speed-limit",
"async-trait", "async-trait",
"base64 0.22.1",
"bytes", "bytes",
"cfg-if", "cfg-if",
"clap", "clap",

View file

@ -6,7 +6,8 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.86" anyhow = "1.0.86"
async-speed-limit = { version = "0.4.2", optional = true } async-speed-limit = { version = "0.4.2", optional = true }
async-trait = { version = "0.1.81", optional = true } async-trait = "0.1.81"
base64 = "0.22.1"
bytes = "1.7.1" bytes = "1.7.1"
cfg-if = "1.0.0" cfg-if = "1.0.0"
clap = { version = "4.5.16", features = ["cargo", "derive"] } clap = { version = "4.5.16", features = ["cargo", "derive"] }
@ -48,7 +49,7 @@ default = ["toml"]
yaml = ["dep:serde_yaml"] yaml = ["dep:serde_yaml"]
toml = ["dep:toml"] toml = ["dep:toml"]
twisp = ["dep:pty-process", "dep:libc", "dep:async-trait", "dep:shell-words"] twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"]
speed-limit = ["dep:async-speed-limit"] speed-limit = ["dep:async-speed-limit"]
tokio-console = ["dep:console-subscriber", "tokio/tracing"] tokio-console = ["dep:console-subscriber", "tokio/tracing"]

View file

@ -123,6 +123,8 @@ pub enum ProtocolExtension {
Udp, Udp,
/// Wisp version 2 MOTD protocol extension. /// Wisp version 2 MOTD protocol extension.
Motd, Motd,
/// Unofficial Wispnet-like protocol extension.
Wispnet,
} }
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
@ -366,6 +368,11 @@ impl Default for WispConfig {
} }
impl WispConfig { impl WispConfig {
#[doc(hidden)]
pub fn has_wispnet(&self) -> bool {
self.extensions.contains(&ProtocolExtension::Wispnet)
}
#[doc(hidden)] #[doc(hidden)]
pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Handshake>, Vec<u8>, u32)> { pub async fn to_opts(&self) -> anyhow::Result<(Option<WispV2Handshake>, Vec<u8>, u32)> {
if self.wisp_v2 { if self.wisp_v2 {

View file

@ -1,6 +1,7 @@
#[cfg(feature = "twisp")] #[cfg(feature = "twisp")]
pub mod twisp; pub mod twisp;
pub mod utils; pub mod utils;
pub mod wispnet;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
@ -23,6 +24,7 @@ use wisp_mux::{
ws::Payload, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ws::Payload, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite,
ServerMux, ServerMux,
}; };
use wispnet::route_wispnet;
use crate::{ use crate::{
route::{WispResult, WispStreamWrite}, route::{WispResult, WispStreamWrite},
@ -100,8 +102,23 @@ async fn handle_stream(
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return; return;
}; };
let connect = match resolved { let (stream, resolved_stream) = match resolved {
ResolvedPacket::Valid(x) => x, ResolvedPacket::Valid(connect) => {
let resolved = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
(stream, resolved)
}
ResolvedPacket::ValidWispnet(server, connect) => {
let resolved = connect.clone();
let Ok(stream) = route_wispnet(server, connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
(stream, resolved)
}
ResolvedPacket::NoResolvedAddrs => { ResolvedPacket::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return; return;
@ -118,13 +135,6 @@ async fn handle_stream(
} }
}; };
let resolved_stream = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
return;
};
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
debug!( debug!(
@ -137,7 +147,7 @@ async fn handle_stream(
.0 .0
.lock() .lock()
.await .await
.insert(uuid, (requested_stream, resolved_stream)); .insert(uuid, (requested_stream, resolved_stream.clone()));
} }
let forward_fut = async { let forward_fut = async {
@ -213,6 +223,12 @@ async fn handle_stream(
} }
} }
} }
ClientStream::Wispnet(stream, mux_id) => {
wispnet::handle_stream(muxstream, stream, mux_id, uuid, resolved_stream).await
}
ClientStream::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;
}
ClientStream::Invalid => { ClientStream::Invalid => {
let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await;
} }

View file

@ -0,0 +1,244 @@
use std::{
collections::HashMap,
sync::atomic::{AtomicU32, Ordering},
};
use anyhow::{Context, Result};
use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use lazy_static::lazy_static;
use log::debug;
use tokio::{select, sync::Mutex};
use uuid::Uuid;
use wisp_mux::{
extensions::{
AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder, ProtocolExtensionVecExt,
},
ws::{DynWebSocketRead, Frame, LockingWebSocketWrite, Payload},
ClientMux, CloseReason, ConnectPacket, MuxStream, MuxStreamRead, MuxStreamWrite, Role,
WispError, WispV2Handshake,
};
use crate::{
route::{WispResult, WispStreamWrite},
stream::ClientStream,
CLIENTS,
};
struct WispnetClient {
mux: ClientMux<WispStreamWrite>,
id: String,
private: bool,
}
lazy_static! {
static ref WISPNET_SERVERS: Mutex<HashMap<u32, WispnetClient>> = Mutex::new(HashMap::new());
static ref WISPNET_IDS: AtomicU32 = AtomicU32::new(0);
}
// payload:
// client (acting like wisp server) sends a bool saying if it wants to be private or not
// server (acting like wisp client) sends a u32 client id
//
// packets:
// client sends a 0xF1 on stream id 0 with no body to probe
// server sends back a 0xF1 on stream id 0 with a body of a bunch of u32s for each public client
struct WispnetServerProtocolExtensionBuilder(u32);
impl WispnetServerProtocolExtensionBuilder {
const ID: u8 = 0xF1;
}
#[async_trait]
impl ProtocolExtensionBuilder for WispnetServerProtocolExtensionBuilder {
fn get_id(&self) -> u8 {
Self::ID
}
fn build_from_bytes(
&mut self,
mut bytes: Bytes,
_: Role,
) -> Result<AnyProtocolExtension, WispError> {
if bytes.remaining() < 1 {
return Err(WispError::PacketTooSmall);
};
Ok(WispnetServerProtocolExtension(self.0, bytes.get_u8() != 0).into())
}
fn build_to_extension(&mut self, _: Role) -> Result<AnyProtocolExtension, WispError> {
Ok(WispnetServerProtocolExtension(self.0, false).into())
}
}
#[derive(Debug, Copy, Clone)]
struct WispnetServerProtocolExtension(u32, pub bool);
impl WispnetServerProtocolExtension {
const ID: u8 = 0xF1;
}
#[async_trait]
impl ProtocolExtension for WispnetServerProtocolExtension {
fn get_id(&self) -> u8 {
Self::ID
}
fn get_supported_packets(&self) -> &'static [u8] {
&[Self::ID]
}
fn get_congestion_stream_types(&self) -> &'static [u8] {
&[]
}
fn encode(&self) -> Bytes {
let mut out = BytesMut::new();
out.put_u32_le(self.0);
out.freeze()
}
async fn handle_handshake(
&mut self,
_: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
Ok(())
}
async fn handle_packet(
&mut self,
packet_type: u8,
mut packet: Bytes,
_: &mut DynWebSocketRead,
write: &dyn LockingWebSocketWrite,
) -> Result<(), WispError> {
if packet_type == Self::ID {
if packet.remaining() < 4 {
return Err(WispError::PacketTooSmall);
}
if packet.get_u32_le() != 0 {
return Err(WispError::InvalidStreamId);
}
let mut out = BytesMut::new();
out.put_u8(Self::ID);
out.put_u32_le(0);
let locked = WISPNET_SERVERS.lock().await;
for client in locked.iter() {
if !client.1.private {
out.put_u32_le(*client.0);
}
}
drop(locked);
write
.wisp_write_frame(Frame::binary(Payload::Bytes(out)))
.await?;
}
Ok(())
}
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(*self)
}
}
pub async fn route_wispnet(server: u32, packet: ConnectPacket) -> Result<ClientStream> {
if let Some(server) = WISPNET_SERVERS.lock().await.get(&server) {
let stream = server
.mux
.client_new_stream(
packet.stream_type,
packet.destination_hostname,
packet.destination_port,
)
.await
.context("failed to connect to wispnet server")?;
Ok(ClientStream::Wispnet(stream, server.id.clone()))
} else {
Ok(ClientStream::NoResolvedAddrs)
}
}
async fn copy_wisp(
rx: MuxStreamRead<WispStreamWrite>,
tx: MuxStreamWrite<WispStreamWrite>,
#[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter<
async_speed_limit::clock::StandardClock,
>,
) -> Result<()> {
while let Some(data) = rx.read().await? {
#[cfg(feature = "speed-limit")]
limiter.consume(data.len()).await;
tx.write_payload(Payload::Borrowed(data.as_ref())).await?;
}
Ok(())
}
pub async fn handle_stream(
mux: MuxStream<WispStreamWrite>,
wisp: MuxStream<WispStreamWrite>,
mux_id: String,
uuid: Uuid,
resolved_stream: ConnectPacket,
) {
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client
.0
.lock()
.await
.insert(uuid, (resolved_stream.clone(), resolved_stream));
}
let closer = mux.get_close_handle();
let (muxread, muxwrite) = mux.into_split();
let (wispread, wispwrite) = wisp.into_split();
let _ = select! {
x = copy_wisp(muxread, wispwrite, #[cfg(feature = "speed-limit")] write_limit) => x,
x = copy_wisp(wispread, muxwrite, #[cfg(feature = "speed-limit")] read_limit) => x,
};
let _ = closer
.close(closer.get_close_reason().unwrap_or(CloseReason::Unknown))
.await;
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client.0.lock().await.remove(&uuid);
}
}
pub async fn handle_wispnet(stream: WispResult, id: String) -> Result<()> {
let (read, write) = stream;
let net_id = WISPNET_IDS.fetch_add(1, Ordering::SeqCst);
let extensions = vec![WispnetServerProtocolExtensionBuilder(net_id).into()];
let (mux, fut) = ClientMux::create(read, write, Some(WispV2Handshake::new(extensions)))
.await
.context("failed to create client multiplexor")?
.with_required_extensions(&[WispnetServerProtocolExtension::ID])
.await
.context("wispnet client did not have wispnet extension")?;
let is_private = mux
.supported_extensions
.find_extension::<WispnetServerProtocolExtension>()
.context("failed to find wispnet extension")?
.1;
WISPNET_SERVERS.lock().await.insert(
net_id,
WispnetClient {
mux,
id: id.clone(),
private: is_private,
},
);
// probably the only time someone would do this
let ret = fut.await;
debug!("wispnet client id {:?} multiplexor result {:?}", id, ret);
WISPNET_SERVERS.lock().await.remove(&net_id);
Ok(())
}

View file

@ -7,9 +7,10 @@ use tokio::{
select, select,
}; };
use uuid::Uuid; use uuid::Uuid;
use wisp_mux::{ConnectPacket, StreamType}; use wisp_mux::{ws::Payload, CloseReason, ConnectPacket, StreamType};
use crate::{ use crate::{
handle::wisp::wispnet::route_wispnet,
stream::{ClientStream, ResolvedPacket, WebSocketFrame, WebSocketStreamWrapper}, stream::{ClientStream, ResolvedPacket, WebSocketFrame, WebSocketStreamWrapper},
CLIENTS, CONFIG, CLIENTS, CONFIG,
}; };
@ -48,8 +49,27 @@ pub async fn handle_wsproxy(
.await; .await;
return Ok(()); return Ok(());
}; };
let connect = match resolved { let (stream, resolved_stream) = match resolved {
ResolvedPacket::Valid(x) => x, ResolvedPacket::Valid(connect) => {
let resolved = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
(stream, resolved)
}
ResolvedPacket::ValidWispnet(server, connect) => {
let resolved = connect.clone();
let Ok(stream) = route_wispnet(server, connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
(stream, resolved)
}
ResolvedPacket::NoResolvedAddrs => { ResolvedPacket::NoResolvedAddrs => {
let _ = ws let _ = ws
.close( .close(
@ -74,15 +94,6 @@ pub async fn handle_wsproxy(
} }
}; };
let resolved_stream = connect.clone();
let Ok(stream) = ClientStream::connect(connect).await else {
let _ = ws
.close(CloseCode::Error.into(), b"failed to connect to host")
.await;
return Ok(());
};
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
debug!( debug!(
@ -95,7 +106,7 @@ pub async fn handle_wsproxy(
.0 .0
.lock() .lock()
.await .await
.insert(uuid, (requested_stream, resolved_stream)); .insert(uuid, (requested_stream, resolved_stream.clone()));
} }
match stream { match stream {
@ -173,6 +184,65 @@ pub async fn handle_wsproxy(
.close(CloseCode::Error.into(), b"twisp is not supported") .close(CloseCode::Error.into(), b"twisp is not supported")
.await; .await;
} }
ClientStream::Wispnet(stream, mux_id) => {
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client
.0
.lock()
.await
.insert(uuid, (resolved_stream.clone(), resolved_stream));
}
let ret: anyhow::Result<()> = async {
loop {
select! {
x = ws.read() => {
match x? {
WebSocketFrame::Data(data) => {
stream.write_payload(Payload::Bytes(data)).await?;
}
WebSocketFrame::Close => {
stream.close(CloseReason::Voluntary).await?;
}
WebSocketFrame::Ignore => {}
}
}
x = stream.read() => {
let Some(x) = x? else {
break;
};
ws.write(&x).await?;
}
}
}
Ok(())
}
.await;
if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client.0.lock().await.remove(&uuid);
}
match ret {
Ok(_) => {
let _ = ws.close(CloseCode::Normal.into(), b"").await;
}
Err(x) => {
let _ = ws
.close(CloseCode::Normal.into(), x.to_string().as_bytes())
.await;
}
}
}
ClientStream::NoResolvedAddrs => {
let _ = ws
.close(
CloseCode::Error.into(),
b"host did not resolve to any addrs",
)
.await;
return Ok(());
}
ClientStream::Blocked => { ClientStream::Blocked => {
let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await;
} }

View file

@ -1,13 +1,14 @@
#![doc(html_no_source)] #![doc(html_no_source)]
#![deny(clippy::todo)] #![deny(clippy::todo)]
#![allow(unexpected_cfgs)] #![allow(unexpected_cfgs)]
#![warn(clippy::large_futures)]
use std::{collections::HashMap, fs::read_to_string, net::IpAddr}; use std::{collections::HashMap, fs::read_to_string, net::IpAddr};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use clap::Parser; use clap::Parser;
use config::{validate_config_cache, Cli, Config, RuntimeFlavor}; use config::{validate_config_cache, Cli, Config, RuntimeFlavor};
use handle::{handle_wisp, handle_wsproxy}; use handle::{handle_wisp, handle_wsproxy, wisp::wispnet::handle_wispnet};
use hickory_resolver::{ use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
system_conf::read_system_conf, system_conf::read_system_conf,
@ -41,7 +42,7 @@ mod stream;
mod util_chain; mod util_chain;
#[doc(hidden)] #[doc(hidden)]
type Client = (Mutex<HashMap<Uuid, (ConnectPacket, ConnectPacket)>>, bool); type Client = (Mutex<HashMap<Uuid, (ConnectPacket, ConnectPacket)>>, String);
#[doc(hidden)] #[doc(hidden)]
#[derive(Debug)] #[derive(Debug)]
@ -234,14 +235,18 @@ async fn async_main() -> Result<()> {
#[doc(hidden)] #[doc(hidden)]
fn handle_stream(stream: ServerRouteResult, id: String) { fn handle_stream(stream: ServerRouteResult, id: String) {
tokio::spawn(async move { tokio::spawn(async move {
CLIENTS CLIENTS.lock().await.insert(
.lock() id.clone(),
.await (Mutex::new(HashMap::new()), format!("{}", stream)),
.insert(id.clone(), (Mutex::new(HashMap::new()), false)); );
let res = match stream { let res = match stream {
ServerRouteResult::Wisp(stream, is_v2) => handle_wisp(stream, is_v2, id.clone()).await, ServerRouteResult::Wisp {
ServerRouteResult::WsProxy(ws, path, udp) => { stream,
handle_wsproxy(ws, id.clone(), path, udp).await has_ws_protocol,
} => handle_wisp(stream, has_ws_protocol, id.clone()).await,
ServerRouteResult::Wispnet { stream } => handle_wispnet(stream, id.clone()).await,
ServerRouteResult::WsProxy { stream, path, udp } => {
handle_wsproxy(stream, id.clone(), path, udp).await
} }
}; };
if let Err(e) = res { if let Err(e) = res {

View file

@ -36,15 +36,26 @@ pub type WispStreamWrite = EitherWebSocketWrite<
pub type WispResult = (WispStreamRead, WispStreamWrite); pub type WispResult = (WispStreamRead, WispStreamWrite);
pub enum ServerRouteResult { pub enum ServerRouteResult {
Wisp(WispResult, bool), Wisp {
WsProxy(WebSocketStreamWrapper, String, bool), stream: WispResult,
has_ws_protocol: bool,
},
Wispnet {
stream: WispResult,
},
WsProxy {
stream: WebSocketStreamWrapper,
path: String,
udp: bool,
},
} }
impl Display for ServerRouteResult { impl Display for ServerRouteResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::Wisp(..) => write!(f, "Wisp"), Self::Wisp { .. } => write!(f, "Wisp"),
Self::WsProxy(_, path, udp) => write!(f, "WsProxy path {:?} udp {:?}", path, udp), Self::Wispnet { .. } => write!(f, "Wispnet"),
Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {:?} udp {:?}", path, udp),
} }
} }
} }
@ -81,8 +92,14 @@ fn get_header(headers: &HeaderMap, header: &str) -> Option<String> {
} }
enum HttpUpgradeResult { enum HttpUpgradeResult {
Wisp(bool), Wisp {
WsProxy(String, bool), has_ws_protocol: bool,
is_wispnet: bool,
},
WsProxy {
path: String,
udp: bool,
},
} }
async fn ws_upgrade<F, R>( async fn ws_upgrade<F, R>(
@ -126,11 +143,19 @@ where
let ws_protocol = headers.get(SEC_WEBSOCKET_PROTOCOL); let ws_protocol = headers.get(SEC_WEBSOCKET_PROTOCOL);
let req_path = req.uri().path().to_string(); let req_path = req.uri().path().to_string();
if req_path.starts_with(&(CONFIG.wisp.prefix.clone() + "/")) { if req_path.ends_with(&(CONFIG.wisp.prefix.clone() + "/")) {
let has_ws_protocol = ws_protocol.is_some(); let has_ws_protocol = ws_protocol.is_some();
let is_wispnet = CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net";
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = if let Err(err) = (callback)(
(callback)(fut, HttpUpgradeResult::Wisp(has_ws_protocol), ip_header).await fut,
HttpUpgradeResult::Wisp {
has_ws_protocol,
is_wispnet,
},
ip_header,
)
.await
{ {
error!("error while serving client: {:?}", err); error!("error while serving client: {:?}", err);
} }
@ -140,10 +165,17 @@ where
.append(SEC_WEBSOCKET_PROTOCOL, protocol.clone()); .append(SEC_WEBSOCKET_PROTOCOL, protocol.clone());
} }
} else if CONFIG.wisp.allow_wsproxy { } else if CONFIG.wisp.allow_wsproxy {
let udp = req.uri().query().unwrap_or_default() == "?udp"; let udp = req.uri().query().unwrap_or_default() == "udp";
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = if let Err(err) = (callback)(
(callback)(fut, HttpUpgradeResult::WsProxy(req_path, udp), ip_header).await fut,
HttpUpgradeResult::WsProxy {
path: req_path,
udp,
},
ip_header,
)
.await
{ {
error!("error while serving client: {:?}", err); error!("error while serving client: {:?}", err);
} }
@ -188,7 +220,10 @@ pub async fn route(
ws.set_auto_pong(false); ws.set_auto_pong(false);
match res { match res {
HttpUpgradeResult::Wisp(is_v2) => { HttpUpgradeResult::Wisp {
has_ws_protocol,
is_wispnet,
} => {
let (read, write) = ws.split(|x| { let (read, write) = ws.split(|x| {
let parts = x let parts = x
.into_inner() .into_inner()
@ -198,21 +233,33 @@ pub async fn route(
(chain(Cursor::new(parts.read_buf), r), w) (chain(Cursor::new(parts.read_buf), r), w)
}); });
(callback)( let result = if is_wispnet {
ServerRouteResult::Wisp( ServerRouteResult::Wispnet {
( stream: (
EitherWebSocketRead::Left(read), EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write), EitherWebSocketWrite::Left(write),
), ),
is_v2, }
), } else {
maybe_ip, ServerRouteResult::Wisp {
) stream: (
EitherWebSocketRead::Left(read),
EitherWebSocketWrite::Left(write),
),
has_ws_protocol,
}
};
(callback)(result, maybe_ip)
} }
HttpUpgradeResult::WsProxy(path, udp) => { HttpUpgradeResult::WsProxy { path, udp } => {
let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); let ws = WebSocketStreamWrapper(FragmentCollector::new(ws));
(callback)( (callback)(
ServerRouteResult::WsProxy(ws, path, udp), ServerRouteResult::WsProxy {
stream: ws,
path,
udp,
},
maybe_ip, maybe_ip,
); );
} }
@ -237,13 +284,13 @@ pub async fn route(
let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec));
(callback)( (callback)(
ServerRouteResult::Wisp( ServerRouteResult::Wisp {
( stream: (
EitherWebSocketRead::Right(read), EitherWebSocketRead::Right(read),
EitherWebSocketWrite::Right(write), EitherWebSocketWrite::Right(write),
), ),
true, has_ws_protocol: true,
), },
None, None,
); );
} }

View file

@ -50,7 +50,7 @@ impl From<(ConnectPacket, ConnectPacket)> for StreamStats {
#[derive(Serialize)] #[derive(Serialize)]
struct ClientStats { struct ClientStats {
wsproxy: bool, client_type: String,
streams: HashMap<String, StreamStats>, streams: HashMap<String, StreamStats>,
} }
@ -81,7 +81,7 @@ pub async fn generate_stats() -> anyhow::Result<String> {
clients.insert( clients.insert(
client.0.to_string(), client.0.to_string(),
ClientStats { ClientStats {
wsproxy: client.1 .1, client_type: client.1 .1.clone(),
streams: client streams: client
.1 .1
.0 .0

View file

@ -4,16 +4,18 @@ use std::{
}; };
use anyhow::Context; use anyhow::Context;
use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::BytesMut; use bytes::BytesMut;
use cfg_if::cfg_if; use cfg_if::cfg_if;
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError};
use hyper::upgrade::Upgraded; use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use log::debug;
use regex::RegexSet; use regex::RegexSet;
use tokio::net::{TcpStream, UdpSocket}; use tokio::net::{TcpStream, UdpSocket};
use wisp_mux::{ConnectPacket, StreamType}; use wisp_mux::{ConnectPacket, MuxStream, StreamType};
use crate::{CONFIG, RESOLVER}; use crate::{route::WispStreamWrite, CONFIG, RESOLVER};
fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool { fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool {
blocked.is_match(str) && !allowed.is_match(str) blocked.is_match(str) && !allowed.is_match(str)
@ -40,6 +42,9 @@ pub enum ClientStream {
Udp(UdpSocket), Udp(UdpSocket),
#[cfg(feature = "twisp")] #[cfg(feature = "twisp")]
Pty(tokio::process::Child, pty_process::Pty), Pty(tokio::process::Child, pty_process::Pty),
Wispnet(MuxStream<WispStreamWrite>, String),
NoResolvedAddrs,
Blocked, Blocked,
Invalid, Invalid,
} }
@ -105,6 +110,7 @@ fn is_global(addr: &IpAddr) -> bool {
pub enum ResolvedPacket { pub enum ResolvedPacket {
Valid(ConnectPacket), Valid(ConnectPacket),
ValidWispnet(u32, ConnectPacket),
NoResolvedAddrs, NoResolvedAddrs,
Blocked, Blocked,
Invalid, Invalid,
@ -112,6 +118,20 @@ pub enum ResolvedPacket {
impl ClientStream { impl ClientStream {
pub async fn resolve(packet: ConnectPacket) -> anyhow::Result<ResolvedPacket> { pub async fn resolve(packet: ConnectPacket) -> anyhow::Result<ResolvedPacket> {
if CONFIG.wisp.has_wispnet() && packet.destination_hostname.ends_with(".wisp") {
if let Some(wispnet_server) = packet.destination_hostname.split(".wisp").next() {
debug!("routing {:?} through wispnet", packet);
let decoded = BASE64_STANDARD
.decode(wispnet_server)
.context("failed to decode wispnet server")?;
let server_id = u32::from_str(
&String::from_utf8(decoded).context("wispnet server was not a string")?,
)
.context("failed to parse wispnet server from string")?;
return Ok(ResolvedPacket::ValidWispnet(server_id, packet));
}
}
cfg_if! { cfg_if! {
if #[cfg(feature = "twisp")] { if #[cfg(feature = "twisp")] {
if let StreamType::Unknown(ty) = packet.stream_type { if let StreamType::Unknown(ty) = packet.stream_type {

View file

@ -191,6 +191,7 @@ impl ProtocolExtension for CertAuthProtocolExtension {
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: u8,
_: Bytes, _: Bytes,
_: &mut DynWebSocketRead, _: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite, _: &dyn LockingWebSocketWrite,

View file

@ -112,6 +112,7 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static {
/// Handle receiving a packet. /// Handle receiving a packet.
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
packet_type: u8,
packet: Bytes, packet: Bytes,
read: &mut DynWebSocketRead, read: &mut DynWebSocketRead,
write: &dyn LockingWebSocketWrite, write: &dyn LockingWebSocketWrite,

View file

@ -56,6 +56,7 @@ impl ProtocolExtension for MotdProtocolExtension {
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: u8,
_: Bytes, _: Bytes,
_: &mut DynWebSocketRead, _: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite, _: &dyn LockingWebSocketWrite,

View file

@ -102,6 +102,7 @@ impl ProtocolExtension for PasswordProtocolExtension {
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: u8,
_: Bytes, _: Bytes,
_: &mut DynWebSocketRead, _: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite, _: &dyn LockingWebSocketWrite,

View file

@ -48,6 +48,7 @@ impl ProtocolExtension for UdpProtocolExtension {
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: u8,
_: Bytes, _: Bytes,
_: &mut DynWebSocketRead, _: &mut DynWebSocketRead,
_: &dyn LockingWebSocketWrite, _: &dyn LockingWebSocketWrite,

View file

@ -189,7 +189,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server { let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server {
self.buffer_size as usize self.buffer_size as usize
} else { } else {
usize::MAX usize::MAX - 8
}); });
let should_flow_control = self.tcp_extensions.contains(&stream_type.into()); let should_flow_control = self.tcp_extensions.contains(&stream_type.into());

View file

@ -572,6 +572,7 @@ impl<'a> Packet<'a> {
{ {
extension extension
.handle_packet( .handle_packet(
packet_type,
BytesMut::from(bytes).freeze(), BytesMut::from(bytes).freeze(),
DynWebSocketRead::from_mut(read), DynWebSocketRead::from_mut(read),
write, write,