fix twisp and fix wispnet throttling

This commit is contained in:
Toshit Chawda 2024-11-25 14:24:31 -08:00
parent 19fb49a4cc
commit 31df4aefc6
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
5 changed files with 34 additions and 17 deletions

1
Cargo.lock generated
View file

@ -727,7 +727,6 @@ dependencies = [
"bytes", "bytes",
"cfg-if", "cfg-if",
"clap", "clap",
"console-subscriber",
"ed25519-dalek", "ed25519-dalek",
"env_logger", "env_logger",
"event-listener", "event-listener",

View file

@ -14,7 +14,6 @@ base64 = "0.22.1"
bytes = "1.7.1" bytes = "1.7.1"
cfg-if = "1.0.0" cfg-if = "1.0.0"
clap = { version = "4.5.16", features = ["cargo", "derive"] } clap = { version = "4.5.16", features = ["cargo", "derive"] }
console-subscriber = { version = "0.4.1", optional = true }
ed25519-dalek = { version = "2.1.1", features = ["pem"] } ed25519-dalek = { version = "2.1.1", features = ["pem"] }
env_logger = "0.11.5" env_logger = "0.11.5"
event-listener = "5.3.1" event-listener = "5.3.1"
@ -54,7 +53,6 @@ toml = ["dep:toml"]
twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"] twisp = ["dep:pty-process", "dep:libc", "dep:shell-words"]
speed-limit = ["dep:async-speed-limit"] speed-limit = ["dep:async-speed-limit"]
tokio-console = ["dep:console-subscriber", "tokio/tracing"]
[build-dependencies] [build-dependencies]
vergen-git2 = { version = "1.0.0", features = ["rustc"] } vergen-git2 = { version = "1.0.0", features = ["rustc"] }

View file

@ -224,7 +224,18 @@ async fn handle_stream(
} }
} }
ClientStream::Wispnet(stream, mux_id) => { ClientStream::Wispnet(stream, mux_id) => {
wispnet::handle_stream(muxstream, stream, mux_id, uuid, resolved_stream).await wispnet::handle_stream(
muxstream,
stream,
mux_id,
uuid,
resolved_stream,
#[cfg(feature = "speed-limit")]
read_limit,
#[cfg(feature = "speed-limit")]
write_limit,
)
.await;
} }
ClientStream::NoResolvedAddrs => { ClientStream::NoResolvedAddrs => {
let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await; let _ = muxstream.close(CloseReason::ServerStreamUnreachable).await;

View file

@ -14,7 +14,7 @@ use wisp_mux::{
AnyProtocolExtension, AnyProtocolExtensionBuilder, ProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder, ProtocolExtension,
ProtocolExtensionBuilder, ProtocolExtensionBuilder,
}, },
ws::{LockedWebSocketWrite, WebSocketRead}, ws::{DynWebSocketRead, LockingWebSocketWrite},
MuxStreamAsyncRead, MuxStreamAsyncWrite, WispError, MuxStreamAsyncRead, MuxStreamAsyncWrite, WispError,
}; };
@ -50,27 +50,30 @@ impl ProtocolExtension for TWispServerProtocolExtension {
async fn handle_handshake( async fn handle_handshake(
&mut self, &mut self,
_: &mut dyn WebSocketRead, _: &mut DynWebSocketRead,
_: &LockedWebSocketWrite, _: &dyn LockingWebSocketWrite,
) -> std::result::Result<(), WispError> { ) -> std::result::Result<(), WispError> {
Ok(()) Ok(())
} }
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
packet_type: u8,
mut packet: Bytes, mut packet: Bytes,
_: &mut dyn WebSocketRead, _: &mut DynWebSocketRead,
_: &LockedWebSocketWrite, _: &dyn LockingWebSocketWrite,
) -> std::result::Result<(), WispError> { ) -> std::result::Result<(), WispError> {
if packet.remaining() < 4 + 2 + 2 { if packet_type == 0xF0 {
return Err(WispError::PacketTooSmall); if packet.remaining() < 4 + 2 + 2 {
} return Err(WispError::PacketTooSmall);
let stream_id = packet.get_u32_le(); }
let row = packet.get_u16_le(); let stream_id = packet.get_u32_le();
let col = packet.get_u16_le(); let row = packet.get_u16_le();
let col = packet.get_u16_le();
if let Some(pty) = self.0.lock().await.get(&stream_id) { if let Some(pty) = self.0.lock().await.get(&stream_id) {
let _ = set_term_size(*pty, Size::new(row, col)); let _ = set_term_size(*pty, Size::new(row, col));
}
} }
Ok(()) Ok(())
} }

View file

@ -179,6 +179,12 @@ pub async fn handle_stream(
mux_id: String, mux_id: String,
uuid: Uuid, uuid: Uuid,
resolved_stream: ConnectPacket, resolved_stream: ConnectPacket,
#[cfg(feature = "speed-limit")] read_limit: async_speed_limit::Limiter<
async_speed_limit::clock::StandardClock,
>,
#[cfg(feature = "speed-limit")] write_limit: async_speed_limit::Limiter<
async_speed_limit::clock::StandardClock,
>,
) { ) {
if let Some(client) = CLIENTS.lock().await.get(&mux_id) { if let Some(client) = CLIENTS.lock().await.get(&mux_id) {
client client