From 13f282160b832a19c17ee67a1c1d83ecf12aea69 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 24 Nov 2024 21:10:08 -0800 Subject: [PATCH] wisp net --- Cargo.lock | 1 + server/Cargo.toml | 5 +- server/src/config.rs | 7 + server/src/handle/wisp/mod.rs | 36 +++-- server/src/handle/wisp/wispnet.rs | 244 ++++++++++++++++++++++++++++++ server/src/handle/wsproxy.rs | 96 ++++++++++-- server/src/main.rs | 23 +-- server/src/route.rs | 99 ++++++++---- server/src/stats.rs | 4 +- server/src/stream.rs | 24 ++- wisp/src/extensions/cert.rs | 1 + wisp/src/extensions/mod.rs | 1 + wisp/src/extensions/motd.rs | 1 + wisp/src/extensions/password.rs | 1 + wisp/src/extensions/udp.rs | 1 + wisp/src/mux/inner.rs | 2 +- wisp/src/packet.rs | 1 + 17 files changed, 482 insertions(+), 65 deletions(-) create mode 100644 server/src/handle/wisp/wispnet.rs diff --git a/Cargo.lock b/Cargo.lock index 48287db..96561f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -723,6 +723,7 @@ dependencies = [ "anyhow", "async-speed-limit", "async-trait", + "base64 0.22.1", "bytes", "cfg-if", "clap", diff --git a/server/Cargo.toml b/server/Cargo.toml index 26ec28e..f281d36 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -6,7 +6,8 @@ edition = "2021" [dependencies] anyhow = "1.0.86" async-speed-limit = { version = "0.4.2", optional = true } -async-trait = { version = "0.1.81", optional = true } +async-trait = "0.1.81" +base64 = "0.22.1" bytes = "1.7.1" cfg-if = "1.0.0" clap = { version = "4.5.16", features = ["cargo", "derive"] } @@ -48,7 +49,7 @@ default = ["toml"] yaml = ["dep:serde_yaml"] toml = ["dep:toml"] -twisp = ["dep:pty-process", "dep:libc", "dep:async-trait", "dep:shell-words"] +twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"] speed-limit = ["dep:async-speed-limit"] tokio-console = ["dep:console-subscriber", "tokio/tracing"] diff --git a/server/src/config.rs b/server/src/config.rs index 103dcfe..7c1f4f1 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -123,6 +123,8 @@ pub enum ProtocolExtension { Udp, /// Wisp version 2 MOTD protocol extension. Motd, + /// Unofficial Wispnet-like protocol extension. + Wispnet, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -366,6 +368,11 @@ impl Default for WispConfig { } impl WispConfig { + #[doc(hidden)] + pub fn has_wispnet(&self) -> bool { + self.extensions.contains(&ProtocolExtension::Wispnet) + } + #[doc(hidden)] pub async fn to_opts(&self) -> anyhow::Result<(Option, Vec, u32)> { if self.wisp_v2 { diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index 70810b5..d904e3c 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -1,6 +1,7 @@ #[cfg(feature = "twisp")] pub mod twisp; pub mod utils; +pub mod wispnet; use std::{sync::Arc, time::Duration}; @@ -23,6 +24,7 @@ use wisp_mux::{ ws::Payload, CloseReason, ConnectPacket, MuxStream, MuxStreamAsyncRead, MuxStreamWrite, ServerMux, }; +use wispnet::route_wispnet; use crate::{ route::{WispResult, WispStreamWrite}, @@ -100,8 +102,23 @@ async fn handle_stream( let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; return; }; - let connect = match resolved { - ResolvedPacket::Valid(x) => x, + let (stream, resolved_stream) = match resolved { + ResolvedPacket::Valid(connect) => { + let resolved = connect.clone(); + let Ok(stream) = ClientStream::connect(connect).await else { + let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; + return; + }; + (stream, resolved) + } + ResolvedPacket::ValidWispnet(server, connect) => { + let resolved = connect.clone(); + let Ok(stream) = route_wispnet(server, connect).await else { + let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; + return; + }; + (stream, resolved) + } ResolvedPacket::NoResolvedAddrs => { let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; return; @@ -118,13 +135,6 @@ async fn handle_stream( } }; - let resolved_stream = connect.clone(); - - let Ok(stream) = ClientStream::connect(connect).await else { - let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; - return; - }; - let uuid = Uuid::new_v4(); debug!( @@ -137,7 +147,7 @@ async fn handle_stream( .0 .lock() .await - .insert(uuid, (requested_stream, resolved_stream)); + .insert(uuid, (requested_stream, resolved_stream.clone())); } let forward_fut = async { @@ -213,6 +223,12 @@ async fn handle_stream( } } } + ClientStream::Wispnet(stream, mux_id) => { + wispnet::handle_stream(muxstream, stream, mux_id, uuid, resolved_stream).await + } + ClientStream::NoResolvedAddrs => { + let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; + } ClientStream::Invalid => { let _ = muxstream.close(CloseReason::ServerStreamInvalidInfo).await; } diff --git a/server/src/handle/wisp/wispnet.rs b/server/src/handle/wisp/wispnet.rs new file mode 100644 index 0000000..c4439ef --- /dev/null +++ b/server/src/handle/wisp/wispnet.rs @@ -0,0 +1,244 @@ +use std::{ + collections::HashMap, + sync::atomic::{AtomicU32, Ordering}, +}; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use lazy_static::lazy_static; +use log::debug; +use tokio::{select, sync::Mutex}; +use uuid::Uuid; +use wisp_mux::{ + extensions::{ + AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder, ProtocolExtensionVecExt, + }, + ws::{DynWebSocketRead, Frame, LockingWebSocketWrite, Payload}, + ClientMux, CloseReason, ConnectPacket, MuxStream, MuxStreamRead, MuxStreamWrite, Role, + WispError, WispV2Handshake, +}; + +use crate::{ + route::{WispResult, WispStreamWrite}, + stream::ClientStream, + CLIENTS, +}; + +struct WispnetClient { + mux: ClientMux, + id: String, + private: bool, +} + +lazy_static! { + static ref WISPNET_SERVERS: Mutex> = Mutex::new(HashMap::new()); + static ref WISPNET_IDS: AtomicU32 = AtomicU32::new(0); +} + +// payload: +// client (acting like wisp server) sends a bool saying if it wants to be private or not +// server (acting like wisp client) sends a u32 client id +// +// packets: +// client sends a 0xF1 on stream id 0 with no body to probe +// server sends back a 0xF1 on stream id 0 with a body of a bunch of u32s for each public client + +struct WispnetServerProtocolExtensionBuilder(u32); +impl WispnetServerProtocolExtensionBuilder { + const ID: u8 = 0xF1; +} + +#[async_trait] +impl ProtocolExtensionBuilder for WispnetServerProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + Self::ID + } + + fn build_from_bytes( + &mut self, + mut bytes: Bytes, + _: Role, + ) -> Result { + if bytes.remaining() < 1 { + return Err(WispError::PacketTooSmall); + }; + Ok(WispnetServerProtocolExtension(self.0, bytes.get_u8() != 0).into()) + } + + fn build_to_extension(&mut self, _: Role) -> Result { + Ok(WispnetServerProtocolExtension(self.0, false).into()) + } +} + +#[derive(Debug, Copy, Clone)] +struct WispnetServerProtocolExtension(u32, pub bool); +impl WispnetServerProtocolExtension { + const ID: u8 = 0xF1; +} + +#[async_trait] +impl ProtocolExtension for WispnetServerProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + fn get_supported_packets(&self) -> &'static [u8] { + &[Self::ID] + } + fn get_congestion_stream_types(&self) -> &'static [u8] { + &[] + } + fn encode(&self) -> Bytes { + let mut out = BytesMut::new(); + out.put_u32_le(self.0); + out.freeze() + } + + async fn handle_handshake( + &mut self, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + packet_type: u8, + mut packet: Bytes, + _: &mut DynWebSocketRead, + write: &dyn LockingWebSocketWrite, + ) -> Result<(), WispError> { + if packet_type == Self::ID { + if packet.remaining() < 4 { + return Err(WispError::PacketTooSmall); + } + if packet.get_u32_le() != 0 { + return Err(WispError::InvalidStreamId); + } + + let mut out = BytesMut::new(); + out.put_u8(Self::ID); + out.put_u32_le(0); + + let locked = WISPNET_SERVERS.lock().await; + for client in locked.iter() { + if !client.1.private { + out.put_u32_le(*client.0); + } + } + drop(locked); + + write + .wisp_write_frame(Frame::binary(Payload::Bytes(out))) + .await?; + } + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(*self) + } +} + +pub async fn route_wispnet(server: u32, packet: ConnectPacket) -> Result { + if let Some(server) = WISPNET_SERVERS.lock().await.get(&server) { + let stream = server + .mux + .client_new_stream( + packet.stream_type, + packet.destination_hostname, + packet.destination_port, + ) + .await + .context("failed to connect to wispnet server")?; + Ok(ClientStream::Wispnet(stream, server.id.clone())) + } else { + Ok(ClientStream::NoResolvedAddrs) + } +} + +async fn copy_wisp( + rx: MuxStreamRead, + tx: MuxStreamWrite, + #[cfg(feature = "speed-limit")] limiter: async_speed_limit::Limiter< + async_speed_limit::clock::StandardClock, + >, +) -> Result<()> { + while let Some(data) = rx.read().await? { + #[cfg(feature = "speed-limit")] + limiter.consume(data.len()).await; + tx.write_payload(Payload::Borrowed(data.as_ref())).await?; + } + Ok(()) +} + +pub async fn handle_stream( + mux: MuxStream, + wisp: MuxStream, + mux_id: String, + uuid: Uuid, + resolved_stream: ConnectPacket, +) { + if let Some(client) = CLIENTS.lock().await.get(&mux_id) { + client + .0 + .lock() + .await + .insert(uuid, (resolved_stream.clone(), resolved_stream)); + } + + let closer = mux.get_close_handle(); + + let (muxread, muxwrite) = mux.into_split(); + let (wispread, wispwrite) = wisp.into_split(); + let _ = select! { + x = copy_wisp(muxread, wispwrite, #[cfg(feature = "speed-limit")] write_limit) => x, + x = copy_wisp(wispread, muxwrite, #[cfg(feature = "speed-limit")] read_limit) => x, + }; + + let _ = closer + .close(closer.get_close_reason().unwrap_or(CloseReason::Unknown)) + .await; + + if let Some(client) = CLIENTS.lock().await.get(&mux_id) { + client.0.lock().await.remove(&uuid); + } +} + +pub async fn handle_wispnet(stream: WispResult, id: String) -> Result<()> { + let (read, write) = stream; + let net_id = WISPNET_IDS.fetch_add(1, Ordering::SeqCst); + + let extensions = vec![WispnetServerProtocolExtensionBuilder(net_id).into()]; + + let (mux, fut) = ClientMux::create(read, write, Some(WispV2Handshake::new(extensions))) + .await + .context("failed to create client multiplexor")? + .with_required_extensions(&[WispnetServerProtocolExtension::ID]) + .await + .context("wispnet client did not have wispnet extension")?; + + let is_private = mux + .supported_extensions + .find_extension::() + .context("failed to find wispnet extension")? + .1; + + WISPNET_SERVERS.lock().await.insert( + net_id, + WispnetClient { + mux, + id: id.clone(), + private: is_private, + }, + ); + + // probably the only time someone would do this + let ret = fut.await; + debug!("wispnet client id {:?} multiplexor result {:?}", id, ret); + + WISPNET_SERVERS.lock().await.remove(&net_id); + + Ok(()) +} diff --git a/server/src/handle/wsproxy.rs b/server/src/handle/wsproxy.rs index 6d2c216..b415a5e 100644 --- a/server/src/handle/wsproxy.rs +++ b/server/src/handle/wsproxy.rs @@ -7,9 +7,10 @@ use tokio::{ select, }; use uuid::Uuid; -use wisp_mux::{ConnectPacket, StreamType}; +use wisp_mux::{ws::Payload, CloseReason, ConnectPacket, StreamType}; use crate::{ + handle::wisp::wispnet::route_wispnet, stream::{ClientStream, ResolvedPacket, WebSocketFrame, WebSocketStreamWrapper}, CLIENTS, CONFIG, }; @@ -48,8 +49,27 @@ pub async fn handle_wsproxy( .await; return Ok(()); }; - let connect = match resolved { - ResolvedPacket::Valid(x) => x, + let (stream, resolved_stream) = match resolved { + ResolvedPacket::Valid(connect) => { + let resolved = connect.clone(); + let Ok(stream) = ClientStream::connect(connect).await else { + let _ = ws + .close(CloseCode::Error.into(), b"failed to connect to host") + .await; + return Ok(()); + }; + (stream, resolved) + } + ResolvedPacket::ValidWispnet(server, connect) => { + let resolved = connect.clone(); + let Ok(stream) = route_wispnet(server, connect).await else { + let _ = ws + .close(CloseCode::Error.into(), b"failed to connect to host") + .await; + return Ok(()); + }; + (stream, resolved) + } ResolvedPacket::NoResolvedAddrs => { let _ = ws .close( @@ -74,15 +94,6 @@ pub async fn handle_wsproxy( } }; - let resolved_stream = connect.clone(); - - let Ok(stream) = ClientStream::connect(connect).await else { - let _ = ws - .close(CloseCode::Error.into(), b"failed to connect to host") - .await; - return Ok(()); - }; - let uuid = Uuid::new_v4(); debug!( @@ -95,7 +106,7 @@ pub async fn handle_wsproxy( .0 .lock() .await - .insert(uuid, (requested_stream, resolved_stream)); + .insert(uuid, (requested_stream, resolved_stream.clone())); } match stream { @@ -173,6 +184,65 @@ pub async fn handle_wsproxy( .close(CloseCode::Error.into(), b"twisp is not supported") .await; } + ClientStream::Wispnet(stream, mux_id) => { + if let Some(client) = CLIENTS.lock().await.get(&mux_id) { + client + .0 + .lock() + .await + .insert(uuid, (resolved_stream.clone(), resolved_stream)); + } + + let ret: anyhow::Result<()> = async { + loop { + select! { + x = ws.read() => { + match x? { + WebSocketFrame::Data(data) => { + stream.write_payload(Payload::Bytes(data)).await?; + } + WebSocketFrame::Close => { + stream.close(CloseReason::Voluntary).await?; + } + WebSocketFrame::Ignore => {} + } + } + x = stream.read() => { + let Some(x) = x? else { + break; + }; + ws.write(&x).await?; + } + } + } + Ok(()) + } + .await; + + if let Some(client) = CLIENTS.lock().await.get(&mux_id) { + client.0.lock().await.remove(&uuid); + } + + match ret { + Ok(_) => { + let _ = ws.close(CloseCode::Normal.into(), b"").await; + } + Err(x) => { + let _ = ws + .close(CloseCode::Normal.into(), x.to_string().as_bytes()) + .await; + } + } + } + ClientStream::NoResolvedAddrs => { + let _ = ws + .close( + CloseCode::Error.into(), + b"host did not resolve to any addrs", + ) + .await; + return Ok(()); + } ClientStream::Blocked => { let _ = ws.close(CloseCode::Error.into(), b"host is blocked").await; } diff --git a/server/src/main.rs b/server/src/main.rs index 93e0168..96a7277 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,13 +1,14 @@ #![doc(html_no_source)] #![deny(clippy::todo)] #![allow(unexpected_cfgs)] +#![warn(clippy::large_futures)] use std::{collections::HashMap, fs::read_to_string, net::IpAddr}; use anyhow::{Context, Result}; use clap::Parser; use config::{validate_config_cache, Cli, Config, RuntimeFlavor}; -use handle::{handle_wisp, handle_wsproxy}; +use handle::{handle_wisp, handle_wsproxy, wisp::wispnet::handle_wispnet}; use hickory_resolver::{ config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, system_conf::read_system_conf, @@ -41,7 +42,7 @@ mod stream; mod util_chain; #[doc(hidden)] -type Client = (Mutex>, bool); +type Client = (Mutex>, String); #[doc(hidden)] #[derive(Debug)] @@ -234,14 +235,18 @@ async fn async_main() -> Result<()> { #[doc(hidden)] fn handle_stream(stream: ServerRouteResult, id: String) { tokio::spawn(async move { - CLIENTS - .lock() - .await - .insert(id.clone(), (Mutex::new(HashMap::new()), false)); + CLIENTS.lock().await.insert( + id.clone(), + (Mutex::new(HashMap::new()), format!("{}", stream)), + ); let res = match stream { - ServerRouteResult::Wisp(stream, is_v2) => handle_wisp(stream, is_v2, id.clone()).await, - ServerRouteResult::WsProxy(ws, path, udp) => { - handle_wsproxy(ws, id.clone(), path, udp).await + ServerRouteResult::Wisp { + stream, + has_ws_protocol, + } => handle_wisp(stream, has_ws_protocol, id.clone()).await, + ServerRouteResult::Wispnet { stream } => handle_wispnet(stream, id.clone()).await, + ServerRouteResult::WsProxy { stream, path, udp } => { + handle_wsproxy(stream, id.clone(), path, udp).await } }; if let Err(e) = res { diff --git a/server/src/route.rs b/server/src/route.rs index bd83003..64e6eac 100644 --- a/server/src/route.rs +++ b/server/src/route.rs @@ -36,15 +36,26 @@ pub type WispStreamWrite = EitherWebSocketWrite< pub type WispResult = (WispStreamRead, WispStreamWrite); pub enum ServerRouteResult { - Wisp(WispResult, bool), - WsProxy(WebSocketStreamWrapper, String, bool), + Wisp { + stream: WispResult, + has_ws_protocol: bool, + }, + Wispnet { + stream: WispResult, + }, + WsProxy { + stream: WebSocketStreamWrapper, + path: String, + udp: bool, + }, } impl Display for ServerRouteResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Wisp(..) => write!(f, "Wisp"), - Self::WsProxy(_, path, udp) => write!(f, "WsProxy path {:?} udp {:?}", path, udp), + Self::Wisp { .. } => write!(f, "Wisp"), + Self::Wispnet { .. } => write!(f, "Wispnet"), + Self::WsProxy { path, udp, .. } => write!(f, "WsProxy path {:?} udp {:?}", path, udp), } } } @@ -81,8 +92,14 @@ fn get_header(headers: &HeaderMap, header: &str) -> Option { } enum HttpUpgradeResult { - Wisp(bool), - WsProxy(String, bool), + Wisp { + has_ws_protocol: bool, + is_wispnet: bool, + }, + WsProxy { + path: String, + udp: bool, + }, } async fn ws_upgrade( @@ -126,11 +143,19 @@ where let ws_protocol = headers.get(SEC_WEBSOCKET_PROTOCOL); let req_path = req.uri().path().to_string(); - if req_path.starts_with(&(CONFIG.wisp.prefix.clone() + "/")) { + if req_path.ends_with(&(CONFIG.wisp.prefix.clone() + "/")) { let has_ws_protocol = ws_protocol.is_some(); + let is_wispnet = CONFIG.wisp.has_wispnet() && req.uri().query().unwrap_or_default() == "net"; tokio::spawn(async move { - if let Err(err) = - (callback)(fut, HttpUpgradeResult::Wisp(has_ws_protocol), ip_header).await + if let Err(err) = (callback)( + fut, + HttpUpgradeResult::Wisp { + has_ws_protocol, + is_wispnet, + }, + ip_header, + ) + .await { error!("error while serving client: {:?}", err); } @@ -140,10 +165,17 @@ where .append(SEC_WEBSOCKET_PROTOCOL, protocol.clone()); } } else if CONFIG.wisp.allow_wsproxy { - let udp = req.uri().query().unwrap_or_default() == "?udp"; + let udp = req.uri().query().unwrap_or_default() == "udp"; tokio::spawn(async move { - if let Err(err) = - (callback)(fut, HttpUpgradeResult::WsProxy(req_path, udp), ip_header).await + if let Err(err) = (callback)( + fut, + HttpUpgradeResult::WsProxy { + path: req_path, + udp, + }, + ip_header, + ) + .await { error!("error while serving client: {:?}", err); } @@ -188,7 +220,10 @@ pub async fn route( ws.set_auto_pong(false); match res { - HttpUpgradeResult::Wisp(is_v2) => { + HttpUpgradeResult::Wisp { + has_ws_protocol, + is_wispnet, + } => { let (read, write) = ws.split(|x| { let parts = x .into_inner() @@ -198,21 +233,33 @@ pub async fn route( (chain(Cursor::new(parts.read_buf), r), w) }); - (callback)( - ServerRouteResult::Wisp( - ( + let result = if is_wispnet { + ServerRouteResult::Wispnet { + stream: ( EitherWebSocketRead::Left(read), EitherWebSocketWrite::Left(write), ), - is_v2, - ), - maybe_ip, - ) + } + } else { + ServerRouteResult::Wisp { + stream: ( + EitherWebSocketRead::Left(read), + EitherWebSocketWrite::Left(write), + ), + has_ws_protocol, + } + }; + + (callback)(result, maybe_ip) } - HttpUpgradeResult::WsProxy(path, udp) => { + HttpUpgradeResult::WsProxy { path, udp } => { let ws = WebSocketStreamWrapper(FragmentCollector::new(ws)); (callback)( - ServerRouteResult::WsProxy(ws, path, udp), + ServerRouteResult::WsProxy { + stream: ws, + path, + udp, + }, maybe_ip, ); } @@ -237,13 +284,13 @@ pub async fn route( let write = GenericWebSocketWrite::new(FramedWrite::new(write, codec)); (callback)( - ServerRouteResult::Wisp( - ( + ServerRouteResult::Wisp { + stream: ( EitherWebSocketRead::Right(read), EitherWebSocketWrite::Right(write), ), - true, - ), + has_ws_protocol: true, + }, None, ); } diff --git a/server/src/stats.rs b/server/src/stats.rs index c26a93b..51df4e1 100644 --- a/server/src/stats.rs +++ b/server/src/stats.rs @@ -50,7 +50,7 @@ impl From<(ConnectPacket, ConnectPacket)> for StreamStats { #[derive(Serialize)] struct ClientStats { - wsproxy: bool, + client_type: String, streams: HashMap, } @@ -81,7 +81,7 @@ pub async fn generate_stats() -> anyhow::Result { clients.insert( client.0.to_string(), ClientStats { - wsproxy: client.1 .1, + client_type: client.1 .1.clone(), streams: client .1 .0 diff --git a/server/src/stream.rs b/server/src/stream.rs index 4da30e0..06f4aa3 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -4,16 +4,18 @@ use std::{ }; use anyhow::Context; +use base64::{prelude::BASE64_STANDARD, Engine}; use bytes::BytesMut; use cfg_if::cfg_if; use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; +use log::debug; use regex::RegexSet; use tokio::net::{TcpStream, UdpSocket}; -use wisp_mux::{ConnectPacket, StreamType}; +use wisp_mux::{ConnectPacket, MuxStream, StreamType}; -use crate::{CONFIG, RESOLVER}; +use crate::{route::WispStreamWrite, CONFIG, RESOLVER}; fn match_addr(str: &str, allowed: &RegexSet, blocked: &RegexSet) -> bool { blocked.is_match(str) && !allowed.is_match(str) @@ -40,6 +42,9 @@ pub enum ClientStream { Udp(UdpSocket), #[cfg(feature = "twisp")] Pty(tokio::process::Child, pty_process::Pty), + Wispnet(MuxStream, String), + + NoResolvedAddrs, Blocked, Invalid, } @@ -105,6 +110,7 @@ fn is_global(addr: &IpAddr) -> bool { pub enum ResolvedPacket { Valid(ConnectPacket), + ValidWispnet(u32, ConnectPacket), NoResolvedAddrs, Blocked, Invalid, @@ -112,6 +118,20 @@ pub enum ResolvedPacket { impl ClientStream { pub async fn resolve(packet: ConnectPacket) -> anyhow::Result { + if CONFIG.wisp.has_wispnet() && packet.destination_hostname.ends_with(".wisp") { + if let Some(wispnet_server) = packet.destination_hostname.split(".wisp").next() { + debug!("routing {:?} through wispnet", packet); + let decoded = BASE64_STANDARD + .decode(wispnet_server) + .context("failed to decode wispnet server")?; + let server_id = u32::from_str( + &String::from_utf8(decoded).context("wispnet server was not a string")?, + ) + .context("failed to parse wispnet server from string")?; + return Ok(ResolvedPacket::ValidWispnet(server_id, packet)); + } + } + cfg_if! { if #[cfg(feature = "twisp")] { if let StreamType::Unknown(ty) = packet.stream_type { diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index 1112510..4276fab 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -191,6 +191,7 @@ impl ProtocolExtension for CertAuthProtocolExtension { async fn handle_packet( &mut self, + _: u8, _: Bytes, _: &mut DynWebSocketRead, _: &dyn LockingWebSocketWrite, diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index b4eaf5f..048f682 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -112,6 +112,7 @@ pub trait ProtocolExtension: std::fmt::Debug + Sync + Send + 'static { /// Handle receiving a packet. async fn handle_packet( &mut self, + packet_type: u8, packet: Bytes, read: &mut DynWebSocketRead, write: &dyn LockingWebSocketWrite, diff --git a/wisp/src/extensions/motd.rs b/wisp/src/extensions/motd.rs index 0747df9..f718cad 100644 --- a/wisp/src/extensions/motd.rs +++ b/wisp/src/extensions/motd.rs @@ -56,6 +56,7 @@ impl ProtocolExtension for MotdProtocolExtension { async fn handle_packet( &mut self, + _: u8, _: Bytes, _: &mut DynWebSocketRead, _: &dyn LockingWebSocketWrite, diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index 08b26f6..59efe0a 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -102,6 +102,7 @@ impl ProtocolExtension for PasswordProtocolExtension { async fn handle_packet( &mut self, + _: u8, _: Bytes, _: &mut DynWebSocketRead, _: &dyn LockingWebSocketWrite, diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index 1bb32c0..50cc445 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -48,6 +48,7 @@ impl ProtocolExtension for UdpProtocolExtension { async fn handle_packet( &mut self, + _: u8, _: Bytes, _: &mut DynWebSocketRead, _: &dyn LockingWebSocketWrite, diff --git a/wisp/src/mux/inner.rs b/wisp/src/mux/inner.rs index a7afa2f..e39a1e9 100644 --- a/wisp/src/mux/inner.rs +++ b/wisp/src/mux/inner.rs @@ -189,7 +189,7 @@ impl MuxInner { let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server { self.buffer_size as usize } else { - usize::MAX + usize::MAX - 8 }); let should_flow_control = self.tcp_extensions.contains(&stream_type.into()); diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index f3b6463..125bb7d 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -572,6 +572,7 @@ impl<'a> Packet<'a> { { extension .handle_packet( + packet_type, BytesMut::from(bytes).freeze(), DynWebSocketRead::from_mut(read), write,