From 31df4aefc64f111682dff20a114c2eedf389615a Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 25 Nov 2024 14:24:31 -0800 Subject: [PATCH] fix twisp and fix wispnet throttling --- Cargo.lock | 1 - server/Cargo.toml | 2 -- server/src/handle/wisp/mod.rs | 13 ++++++++++++- server/src/handle/wisp/twisp.rs | 29 ++++++++++++++++------------- server/src/handle/wisp/wispnet.rs | 6 ++++++ 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96561f1..d34d035 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -727,7 +727,6 @@ dependencies = [ "bytes", "cfg-if", "clap", - "console-subscriber", "ed25519-dalek", "env_logger", "event-listener", diff --git a/server/Cargo.toml b/server/Cargo.toml index 8ddd6bc..0248352 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -14,7 +14,6 @@ base64 = "0.22.1" bytes = "1.7.1" cfg-if = "1.0.0" clap = { version = "4.5.16", features = ["cargo", "derive"] } -console-subscriber = { version = "0.4.1", optional = true } ed25519-dalek = { version = "2.1.1", features = ["pem"] } env_logger = "0.11.5" event-listener = "5.3.1" @@ -54,7 +53,6 @@ toml = ["dep:toml"] twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"] speed-limit = ["dep:async-speed-limit"] -tokio-console = ["dep:console-subscriber", "tokio/tracing"] [build-dependencies] vergen-git2 = { version = "1.0.0", features = ["rustc"] } diff --git a/server/src/handle/wisp/mod.rs b/server/src/handle/wisp/mod.rs index d904e3c..5050b40 100644 --- a/server/src/handle/wisp/mod.rs +++ b/server/src/handle/wisp/mod.rs @@ -224,7 +224,18 @@ async fn handle_stream( } } ClientStream::Wispnet(stream, mux_id) => { - wispnet::handle_stream(muxstream, stream, mux_id, uuid, resolved_stream).await + wispnet::handle_stream( + muxstream, + stream, + mux_id, + uuid, + resolved_stream, + #[cfg(feature = "speed-limit")] + read_limit, + #[cfg(feature = "speed-limit")] + write_limit, + ) + .await; } ClientStream::NoResolvedAddrs => { let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; diff --git a/server/src/handle/wisp/twisp.rs b/server/src/handle/wisp/twisp.rs index 9a556ef..cf2905d 100644 --- a/server/src/handle/wisp/twisp.rs +++ b/server/src/handle/wisp/twisp.rs @@ -14,7 +14,7 @@ use wisp_mux::{ AnyProtocolExtension, AnyProtocolExtensionBuilder, ProtocolExtension, ProtocolExtensionBuilder, }, - ws::{LockedWebSocketWrite, WebSocketRead}, + ws::{DynWebSocketRead, LockingWebSocketWrite}, MuxStreamAsyncRead, MuxStreamAsyncWrite, WispError, }; @@ -50,27 +50,30 @@ impl ProtocolExtension for TWispServerProtocolExtension { async fn handle_handshake( &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> std::result::Result<(), WispError> { Ok(()) } async fn handle_packet( &mut self, + packet_type: u8, mut packet: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _: &mut DynWebSocketRead, + _: &dyn LockingWebSocketWrite, ) -> std::result::Result<(), WispError> { - if packet.remaining() < 4 + 2 + 2 { - return Err(WispError::PacketTooSmall); - } - let stream_id = packet.get_u32_le(); - let row = packet.get_u16_le(); - let col = packet.get_u16_le(); + if packet_type == 0xF0 { + if packet.remaining() < 4 + 2 + 2 { + return Err(WispError::PacketTooSmall); + } + let stream_id = packet.get_u32_le(); + let row = packet.get_u16_le(); + let col = packet.get_u16_le(); - if let Some(pty) = self.0.lock().await.get(&stream_id) { - let _ = set_term_size(*pty, Size::new(row, col)); + if let Some(pty) = self.0.lock().await.get(&stream_id) { + let _ = set_term_size(*pty, Size::new(row, col)); + } } Ok(()) } diff --git a/server/src/handle/wisp/wispnet.rs b/server/src/handle/wisp/wispnet.rs index c4439ef..7d9f85e 100644 --- a/server/src/handle/wisp/wispnet.rs +++ b/server/src/handle/wisp/wispnet.rs @@ -179,6 +179,12 @@ pub async fn handle_stream( mux_id: String, uuid: Uuid, resolved_stream: ConnectPacket, + #[cfg(feature = "speed-limit")] read_limit: async_speed_limit::Limiter< + async_speed_limit::clock::StandardClock, + >, + #[cfg(feature = "speed-limit")] write_limit: async_speed_limit::Limiter< + async_speed_limit::clock::StandardClock, + >, ) { if let Some(client) = CLIENTS.lock().await.get(&mux_id) { client