From cda7ed2190b806617c051dadd1f35868bdc91087 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Thu, 24 Oct 2024 00:10:15 -0700 Subject: [PATCH] switch to draft 5 handshake --- Cargo.lock | 33 +----------- client/Cargo.toml | 2 +- server/Cargo.toml | 2 +- wisp/Cargo.toml | 5 +- wisp/src/mux/client.rs | 119 +++++++++++++++++++++++++++-------------- wisp/src/mux/mod.rs | 83 +++++++++++++++------------- wisp/src/mux/server.rs | 86 +++++++++++++++++++++-------- 7 files changed, 194 insertions(+), 136 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 48e519c..a97ed60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -707,7 +707,7 @@ dependencies = [ "rustls-pemfile", "rustls-pki-types", "rustls-webpki", - "send_wrapper 0.6.0", + "send_wrapper", "thiserror", "tokio", "wasm-bindgen", @@ -931,16 +931,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" -dependencies = [ - "gloo-timers", - "send_wrapper 0.4.0", -] - [[package]] name = "futures-util" version = "0.3.31" @@ -1001,18 +991,6 @@ dependencies = [ "url", ] -[[package]] -name = "gloo-timers" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" -dependencies = [ - "futures-channel", - "futures-core", - "js-sys", - "wasm-bindgen", -] - [[package]] name = "h2" version = "0.4.6" @@ -1963,12 +1941,6 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" -[[package]] -name = "send_wrapper" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" - [[package]] name = "send_wrapper" version = "0.6.0" @@ -3000,7 +2972,7 @@ dependencies = [ [[package]] name = "wisp-mux" -version = "5.1.0" +version = "6.0.0" dependencies = [ "async-trait", "atomic_enum", @@ -3011,7 +2983,6 @@ dependencies = [ "fastwebsockets", "flume", "futures", - "futures-timer", "getrandom", "nohash-hasher", "pin-project-lite", diff --git a/client/Cargo.toml b/client/Cargo.toml index 7e5425a..0aad55b 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -32,7 +32,7 @@ wasm-bindgen-futures = "0.4.43" wasm-streams = "0.4.0" web-sys = { version = "0.3.70", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "Url", "WebSocket"] } webpki-roots = "0.26.3" -wisp-mux = { path = "../wisp", features = ["wasm"], version = "5.1.0", default-features = false } +wisp-mux = { version = "*", path = "../wisp", features = ["wasm"], default-features = false } [dependencies.getrandom] version = "*" diff --git a/server/Cargo.toml b/server/Cargo.toml index dc3b7e2..58a5e5b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -38,7 +38,7 @@ tokio-rustls = { version = "0.26.0", features = ["ring", "tls12"], default-featu tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } toml = { version = "0.8.19", optional = true } uuid = { version = "1.10.0", features = ["v4"] } -wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] } +wisp-mux = { version = "*", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] } [features] default = ["toml"] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 3fd59b4..539f253 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "5.1.0" +version = "6.0.0" license = "LGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -20,7 +20,6 @@ event-listener = "5.3.1" fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } flume = "0.11.0" futures = "0.3.30" -futures-timer = "3.0.3" getrandom = { version = "0.2.15", features = ["std"], optional = true } nohash-hasher = "0.2.0" pin-project-lite = "0.2.14" @@ -31,7 +30,7 @@ tokio = { version = "1.39.3", optional = true, default-features = false } default = ["generic_stream", "certificate"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] generic_stream = [] -wasm = ["futures-timer/wasm-bindgen", "getrandom/js"] +wasm = ["getrandom/js"] certificate = ["dep:ed25519", "dep:bitflags", "dep:getrandom"] [package.metadata.docs.rs] diff --git a/wisp/src/mux/client.rs b/wisp/src/mux/client.rs index 2544c1e..e2f44b0 100644 --- a/wisp/src/mux/client.rs +++ b/wisp/src/mux/client.rs @@ -12,12 +12,71 @@ use futures::channel::oneshot; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, inner::{MuxInner, WsEvent}, + mux::send_info_packet, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, WispError, }; -use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions}; +use super::{ + get_supported_extensions, validate_continue_packet, Multiplexor, MuxResult, + WispHandshakeResult, WispHandshakeResultKind, WispV2Extensions, +}; + +async fn handshake( + rx: &mut R, + tx: &LockedWebSocketWrite, + v2_info: Option, +) -> Result<(WispHandshakeResult, u32), WispError> { + if let Some(WispV2Extensions { + mut builders, + closure, + }) = v2_info + { + let packet = + Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Client, &mut builders)?; + + if let PacketType::Info(info) = packet.packet_type { + // v2 server + let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?; + + (closure)(&mut builders).await?; + send_info_packet(tx, &mut builders).await?; + + Ok(( + WispHandshakeResult { + kind: WispHandshakeResultKind::V2 { + extensions: get_supported_extensions(info.extensions, &mut builders), + }, + downgraded: false, + }, + buffer_size, + )) + } else { + // downgrade to v1 + let buffer_size = validate_continue_packet(packet)?; + + Ok(( + WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { frame: None }, + downgraded: true, + }, + buffer_size, + )) + } + } else { + // user asked for a v1 client + let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?; + + Ok(( + WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { frame: None }, + downgraded: false, + }, + buffer_size, + )) + } +} /// Client side multiplexor. pub struct ClientMux { @@ -48,49 +107,29 @@ impl ClientMux { W: WebSocketWrite + Send + 'static, { let tx = LockedWebSocketWrite::new(Box::new(tx)); - let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?; - if first_packet.stream_id != 0 { - return Err(WispError::InvalidStreamId); - } + let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?; + let (extensions, frame) = handshake_result.kind.into_parts(); - if let PacketType::Continue(packet) = first_packet.packet_type { - let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { - mut builders, - closure, - }) = wisp_v2 - { - let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?; - // if not downgraded - if !res.2 { - (closure)(&mut builders).await?; - send_info_packet(&tx, &mut builders).await?; - } - res - } else { - (Vec::new(), None, true) - }; + let mux_inner = MuxInner::new_client( + AppendingWebSocketRead(frame, rx), + tx.clone(), + extensions.clone(), + buffer_size, + ); - let mux_result = MuxInner::new_client( - AppendingWebSocketRead(extra_packet, rx), - tx.clone(), - supported_extensions.clone(), - packet.buffer_remaining, - ); + Ok(MuxResult( + Self { + actor_tx: mux_inner.actor_tx, + actor_exited: mux_inner.actor_exited, - Ok(MuxResult( - Self { - actor_tx: mux_result.actor_tx, - downgraded, - supported_extensions, - tx, - actor_exited: mux_result.actor_exited, - }, - mux_result.mux.into_future(), - )) - } else { - Err(WispError::InvalidPacketType) - } + tx, + + downgraded: handshake_result.downgraded, + supported_extensions: extensions, + }, + mux_inner.mux.into_future(), + )) } /// Create a new stream, multiplexed through Wisp. diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs index 1212359..ba94b73 100644 --- a/wisp/src/mux/mod.rs +++ b/wisp/src/mux/mod.rs @@ -1,53 +1,37 @@ mod client; mod server; -use std::{future::Future, pin::Pin, time::Duration}; +use std::{future::Future, pin::Pin}; pub use client::ClientMux; -use futures::{select, FutureExt}; -use futures_timer::Delay; pub use server::ServerMux; use crate::{ extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}, - ws::{Frame, LockedWebSocketWrite, WebSocketRead}, + ws::{Frame, LockedWebSocketWrite}, CloseReason, Packet, PacketType, Role, WispError, }; -async fn maybe_wisp_v2( - read: &mut R, - write: &LockedWebSocketWrite, - role: Role, - builders: &mut [AnyProtocolExtensionBuilder], -) -> Result<(Vec, Option>, bool), WispError> -where - R: WebSocketRead + Send, -{ - let mut supported_extensions = Vec::new(); - let mut extra_packet: Option> = None; - let mut downgraded = true; +struct WispHandshakeResult { + kind: WispHandshakeResultKind, + downgraded: bool, +} - let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); - if let Some(frame) = select! { - x = read.wisp_read_frame(write).fuse() => Some(x?), - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, role, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - downgraded = false; - } else { - extra_packet.replace(Frame::from(packet).clone()); +enum WispHandshakeResultKind { + V2 { + extensions: Vec, + }, + V1 { + frame: Option>, + }, +} + +impl WispHandshakeResultKind { + pub fn into_parts(self) -> (Vec, Option>) { + match self { + Self::V2 { extensions } => (extensions, None), + Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame), } } - - for extension in supported_extensions.iter_mut() { - extension.handle_handshake(read, write).await?; - } - Ok((supported_extensions, extra_packet, downgraded)) } async fn send_info_packet( @@ -67,19 +51,42 @@ async fn send_info_packet( .await } +fn validate_continue_packet(packet: Packet<'_>) -> Result { + if packet.stream_id != 0 { + return Err(WispError::InvalidStreamId); + } + + let PacketType::Continue(continue_packet) = packet.packet_type else { + return Err(WispError::InvalidPacketType); + }; + + Ok(continue_packet.buffer_remaining) +} + +fn get_supported_extensions( + extensions: Vec, + builders: &mut [AnyProtocolExtensionBuilder], +) -> Vec { + let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); + extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect() +} + trait Multiplexor { fn has_extension(&self, extension_id: u8) -> bool; async fn exit(&self, reason: CloseReason) -> Result<(), WispError>; } /// Result of creating a multiplexor. Helps require protocol extensions. -#[allow(private_bounds)] +#[expect(private_bounds)] pub struct MuxResult(M, F) where M: Multiplexor, F: Future> + Send; -#[allow(private_bounds)] +#[expect(private_bounds)] impl MuxResult where M: Multiplexor, diff --git a/wisp/src/mux/server.rs b/wisp/src/mux/server.rs index aefbf38..a60f409 100644 --- a/wisp/src/mux/server.rs +++ b/wisp/src/mux/server.rs @@ -1,6 +1,5 @@ use std::{ future::Future, - ops::DerefMut, sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -14,10 +13,63 @@ use crate::{ extensions::AnyProtocolExtension, inner::{MuxInner, WsEvent}, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, - CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError, + CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, + WispError, }; -use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions}; +use super::{ + get_supported_extensions, send_info_packet, Multiplexor, MuxResult, WispHandshakeResult, + WispHandshakeResultKind, WispV2Extensions, +}; + +async fn handshake( + rx: &mut R, + tx: &LockedWebSocketWrite, + buffer_size: u32, + v2_info: Option, +) -> Result { + if let Some(WispV2Extensions { + mut builders, + closure, + }) = v2_info + { + send_info_packet(tx, &mut builders).await?; + tx.write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; + + (closure)(&mut builders).await?; + + let packet = + Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Server, &mut builders)?; + + if let PacketType::Info(info) = packet.packet_type { + // v2 client + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V2 { + extensions: get_supported_extensions(info.extensions, &mut builders), + }, + downgraded: false, + }) + } else { + // downgrade to v1 + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { + frame: Some(packet.into()), + }, + downgraded: true, + }) + } + } else { + // user asked for v1 server + tx.write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; + + Ok(WispHandshakeResult { + kind: WispHandshakeResultKind::V1 { frame: None }, + downgraded: false, + }) + } +} /// Server-side multiplexor. pub struct ServerMux { @@ -52,36 +104,26 @@ impl ServerMux { let tx = LockedWebSocketWrite::new(Box::new(tx)); let ret_tx = tx.clone(); let ret = async { - tx.write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; - - let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { - mut builders, - closure, - }) = wisp_v2 - { - send_info_packet(&tx, builders.deref_mut()).await?; - (closure)(builders.deref_mut()).await?; - maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await? - } else { - (Vec::new(), None, true) - }; + let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?; + let (extensions, extra_packet) = handshake_result.kind.into_parts(); let (mux_result, muxstream_recv) = MuxInner::new_server( AppendingWebSocketRead(extra_packet, rx), tx.clone(), - supported_extensions.clone(), + extensions.clone(), buffer_size, ); Ok(MuxResult( Self { - muxstream_recv, actor_tx: mux_result.actor_tx, - downgraded, - supported_extensions, - tx, actor_exited: mux_result.actor_exited, + muxstream_recv, + + tx, + + downgraded: handshake_result.downgraded, + supported_extensions: extensions, }, mux_result.mux.into_future(), ))