switch to draft 5 handshake

This commit is contained in:
Toshit Chawda 2024-10-24 00:10:15 -07:00
parent c8de5524b4
commit cda7ed2190
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
7 changed files with 194 additions and 136 deletions

33
Cargo.lock generated
View file

@ -707,7 +707,7 @@ dependencies = [
"rustls-pemfile", "rustls-pemfile",
"rustls-pki-types", "rustls-pki-types",
"rustls-webpki", "rustls-webpki",
"send_wrapper 0.6.0", "send_wrapper",
"thiserror", "thiserror",
"tokio", "tokio",
"wasm-bindgen", "wasm-bindgen",
@ -931,16 +931,6 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
dependencies = [
"gloo-timers",
"send_wrapper 0.4.0",
]
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.31" version = "0.3.31"
@ -1001,18 +991,6 @@ dependencies = [
"url", "url",
] ]
[[package]]
name = "gloo-timers"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c"
dependencies = [
"futures-channel",
"futures-core",
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.6" version = "0.4.6"
@ -1963,12 +1941,6 @@ version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]]
name = "send_wrapper"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0"
[[package]] [[package]]
name = "send_wrapper" name = "send_wrapper"
version = "0.6.0" version = "0.6.0"
@ -3000,7 +2972,7 @@ dependencies = [
[[package]] [[package]]
name = "wisp-mux" name = "wisp-mux"
version = "5.1.0" version = "6.0.0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"atomic_enum", "atomic_enum",
@ -3011,7 +2983,6 @@ dependencies = [
"fastwebsockets", "fastwebsockets",
"flume", "flume",
"futures", "futures",
"futures-timer",
"getrandom", "getrandom",
"nohash-hasher", "nohash-hasher",
"pin-project-lite", "pin-project-lite",

View file

@ -32,7 +32,7 @@ wasm-bindgen-futures = "0.4.43"
wasm-streams = "0.4.0" wasm-streams = "0.4.0"
web-sys = { version = "0.3.70", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "Url", "WebSocket"] } web-sys = { version = "0.3.70", features = ["BinaryType", "Headers", "MessageEvent", "Request", "RequestInit", "Response", "ResponseInit", "Url", "WebSocket"] }
webpki-roots = "0.26.3" webpki-roots = "0.26.3"
wisp-mux = { path = "../wisp", features = ["wasm"], version = "5.1.0", default-features = false } wisp-mux = { version = "*", path = "../wisp", features = ["wasm"], default-features = false }
[dependencies.getrandom] [dependencies.getrandom]
version = "*" version = "*"

View file

@ -38,7 +38,7 @@ tokio-rustls = { version = "0.26.0", features = ["ring", "tls12"], default-featu
tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] } tokio-util = { version = "0.7.11", features = ["codec", "compat", "io-util", "net"] }
toml = { version = "0.8.19", optional = true } toml = { version = "0.8.19", optional = true }
uuid = { version = "1.10.0", features = ["v4"] } uuid = { version = "1.10.0", features = ["v4"] }
wisp-mux = { version = "5.0.0", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] } wisp-mux = { version = "*", path = "../wisp", features = ["fastwebsockets", "generic_stream", "certificate"] }
[features] [features]
default = ["toml"] default = ["toml"]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "wisp-mux" name = "wisp-mux"
version = "5.1.0" version = "6.0.0"
license = "LGPL-3.0-only" license = "LGPL-3.0-only"
description = "A library for easily creating Wisp servers and clients." description = "A library for easily creating Wisp servers and clients."
homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp"
@ -20,7 +20,6 @@ event-listener = "5.3.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true }
flume = "0.11.0" flume = "0.11.0"
futures = "0.3.30" futures = "0.3.30"
futures-timer = "3.0.3"
getrandom = { version = "0.2.15", features = ["std"], optional = true } getrandom = { version = "0.2.15", features = ["std"], optional = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
pin-project-lite = "0.2.14" pin-project-lite = "0.2.14"
@ -31,7 +30,7 @@ tokio = { version = "1.39.3", optional = true, default-features = false }
default = ["generic_stream", "certificate"] default = ["generic_stream", "certificate"]
fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"]
generic_stream = [] generic_stream = []
wasm = ["futures-timer/wasm-bindgen", "getrandom/js"] wasm = ["getrandom/js"]
certificate = ["dep:ed25519", "dep:bitflags", "dep:getrandom"] certificate = ["dep:ed25519", "dep:bitflags", "dep:getrandom"]
[package.metadata.docs.rs] [package.metadata.docs.rs]

View file

@ -12,12 +12,71 @@ use futures::channel::oneshot;
use crate::{ use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension}, extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
inner::{MuxInner, WsEvent}, inner::{MuxInner, WsEvent},
mux::send_info_packet,
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType, CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
WispError, WispError,
}; };
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions}; use super::{
get_supported_extensions, validate_continue_packet, Multiplexor, MuxResult,
WispHandshakeResult, WispHandshakeResultKind, WispV2Extensions,
};
async fn handshake<R: WebSocketRead>(
rx: &mut R,
tx: &LockedWebSocketWrite,
v2_info: Option<WispV2Extensions>,
) -> Result<(WispHandshakeResult, u32), WispError> {
if let Some(WispV2Extensions {
mut builders,
closure,
}) = v2_info
{
let packet =
Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Client, &mut builders)?;
if let PacketType::Info(info) = packet.packet_type {
// v2 server
let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?;
(closure)(&mut builders).await?;
send_info_packet(tx, &mut builders).await?;
Ok((
WispHandshakeResult {
kind: WispHandshakeResultKind::V2 {
extensions: get_supported_extensions(info.extensions, &mut builders),
},
downgraded: false,
},
buffer_size,
))
} else {
// downgrade to v1
let buffer_size = validate_continue_packet(packet)?;
Ok((
WispHandshakeResult {
kind: WispHandshakeResultKind::V1 { frame: None },
downgraded: true,
},
buffer_size,
))
}
} else {
// user asked for a v1 client
let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?;
Ok((
WispHandshakeResult {
kind: WispHandshakeResultKind::V1 { frame: None },
downgraded: false,
},
buffer_size,
))
}
}
/// Client side multiplexor. /// Client side multiplexor.
pub struct ClientMux { pub struct ClientMux {
@ -48,49 +107,29 @@ impl ClientMux {
W: WebSocketWrite + Send + 'static, W: WebSocketWrite + Send + 'static,
{ {
let tx = LockedWebSocketWrite::new(Box::new(tx)); let tx = LockedWebSocketWrite::new(Box::new(tx));
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
if first_packet.stream_id != 0 { let (handshake_result, buffer_size) = handshake(&mut rx, &tx, wisp_v2).await?;
return Err(WispError::InvalidStreamId); let (extensions, frame) = handshake_result.kind.into_parts();
}
if let PacketType::Continue(packet) = first_packet.packet_type { let mux_inner = MuxInner::new_client(
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions { AppendingWebSocketRead(frame, rx),
mut builders, tx.clone(),
closure, extensions.clone(),
}) = wisp_v2 buffer_size,
{ );
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
// if not downgraded
if !res.2 {
(closure)(&mut builders).await?;
send_info_packet(&tx, &mut builders).await?;
}
res
} else {
(Vec::new(), None, true)
};
let mux_result = MuxInner::new_client( Ok(MuxResult(
AppendingWebSocketRead(extra_packet, rx), Self {
tx.clone(), actor_tx: mux_inner.actor_tx,
supported_extensions.clone(), actor_exited: mux_inner.actor_exited,
packet.buffer_remaining,
);
Ok(MuxResult( tx,
Self {
actor_tx: mux_result.actor_tx, downgraded: handshake_result.downgraded,
downgraded, supported_extensions: extensions,
supported_extensions, },
tx, mux_inner.mux.into_future(),
actor_exited: mux_result.actor_exited, ))
},
mux_result.mux.into_future(),
))
} else {
Err(WispError::InvalidPacketType)
}
} }
/// Create a new stream, multiplexed through Wisp. /// Create a new stream, multiplexed through Wisp.

View file

@ -1,53 +1,37 @@
mod client; mod client;
mod server; mod server;
use std::{future::Future, pin::Pin, time::Duration}; use std::{future::Future, pin::Pin};
pub use client::ClientMux; pub use client::ClientMux;
use futures::{select, FutureExt};
use futures_timer::Delay;
pub use server::ServerMux; pub use server::ServerMux;
use crate::{ use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder}, extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{Frame, LockedWebSocketWrite, WebSocketRead}, ws::{Frame, LockedWebSocketWrite},
CloseReason, Packet, PacketType, Role, WispError, CloseReason, Packet, PacketType, Role, WispError,
}; };
async fn maybe_wisp_v2<R>( struct WispHandshakeResult {
read: &mut R, kind: WispHandshakeResultKind,
write: &LockedWebSocketWrite, downgraded: bool,
role: Role, }
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<Frame<'static>>, bool), WispError>
where
R: WebSocketRead + Send,
{
let mut supported_extensions = Vec::new();
let mut extra_packet: Option<Frame<'static>> = None;
let mut downgraded = true;
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect(); enum WispHandshakeResultKind {
if let Some(frame) = select! { V2 {
x = read.wisp_read_frame(write).fuse() => Some(x?), extensions: Vec<AnyProtocolExtension>,
_ = Delay::new(Duration::from_secs(5)).fuse() => None },
} { V1 {
let packet = Packet::maybe_parse_info(frame, role, builders)?; frame: Option<Frame<'static>>,
if let PacketType::Info(info) = packet.packet_type { },
supported_extensions = info }
.extensions
.into_iter() impl WispHandshakeResultKind {
.filter(|x| extension_ids.contains(&x.get_id())) pub fn into_parts(self) -> (Vec<AnyProtocolExtension>, Option<Frame<'static>>) {
.collect(); match self {
downgraded = false; Self::V2 { extensions } => (extensions, None),
} else { Self::V1 { frame } => (vec![UdpProtocolExtension.into()], frame),
extra_packet.replace(Frame::from(packet).clone());
} }
} }
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(read, write).await?;
}
Ok((supported_extensions, extra_packet, downgraded))
} }
async fn send_info_packet( async fn send_info_packet(
@ -67,19 +51,42 @@ async fn send_info_packet(
.await .await
} }
fn validate_continue_packet(packet: Packet<'_>) -> Result<u32, WispError> {
if packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
let PacketType::Continue(continue_packet) = packet.packet_type else {
return Err(WispError::InvalidPacketType);
};
Ok(continue_packet.buffer_remaining)
}
fn get_supported_extensions(
extensions: Vec<AnyProtocolExtension>,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Vec<AnyProtocolExtension> {
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect()
}
trait Multiplexor { trait Multiplexor {
fn has_extension(&self, extension_id: u8) -> bool; fn has_extension(&self, extension_id: u8) -> bool;
async fn exit(&self, reason: CloseReason) -> Result<(), WispError>; async fn exit(&self, reason: CloseReason) -> Result<(), WispError>;
} }
/// Result of creating a multiplexor. Helps require protocol extensions. /// Result of creating a multiplexor. Helps require protocol extensions.
#[allow(private_bounds)] #[expect(private_bounds)]
pub struct MuxResult<M, F>(M, F) pub struct MuxResult<M, F>(M, F)
where where
M: Multiplexor, M: Multiplexor,
F: Future<Output = Result<(), WispError>> + Send; F: Future<Output = Result<(), WispError>> + Send;
#[allow(private_bounds)] #[expect(private_bounds)]
impl<M, F> MuxResult<M, F> impl<M, F> MuxResult<M, F>
where where
M: Multiplexor, M: Multiplexor,

View file

@ -1,6 +1,5 @@
use std::{ use std::{
future::Future, future::Future,
ops::DerefMut,
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
@ -14,10 +13,63 @@ use crate::{
extensions::AnyProtocolExtension, extensions::AnyProtocolExtension,
inner::{MuxInner, WsEvent}, inner::{MuxInner, WsEvent},
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite}, ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError, CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
WispError,
}; };
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions}; use super::{
get_supported_extensions, send_info_packet, Multiplexor, MuxResult, WispHandshakeResult,
WispHandshakeResultKind, WispV2Extensions,
};
async fn handshake<R: WebSocketRead>(
rx: &mut R,
tx: &LockedWebSocketWrite,
buffer_size: u32,
v2_info: Option<WispV2Extensions>,
) -> Result<WispHandshakeResult, WispError> {
if let Some(WispV2Extensions {
mut builders,
closure,
}) = v2_info
{
send_info_packet(tx, &mut builders).await?;
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
(closure)(&mut builders).await?;
let packet =
Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Server, &mut builders)?;
if let PacketType::Info(info) = packet.packet_type {
// v2 client
Ok(WispHandshakeResult {
kind: WispHandshakeResultKind::V2 {
extensions: get_supported_extensions(info.extensions, &mut builders),
},
downgraded: false,
})
} else {
// downgrade to v1
Ok(WispHandshakeResult {
kind: WispHandshakeResultKind::V1 {
frame: Some(packet.into()),
},
downgraded: true,
})
}
} else {
// user asked for v1 server
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
Ok(WispHandshakeResult {
kind: WispHandshakeResultKind::V1 { frame: None },
downgraded: false,
})
}
}
/// Server-side multiplexor. /// Server-side multiplexor.
pub struct ServerMux { pub struct ServerMux {
@ -52,36 +104,26 @@ impl ServerMux {
let tx = LockedWebSocketWrite::new(Box::new(tx)); let tx = LockedWebSocketWrite::new(Box::new(tx));
let ret_tx = tx.clone(); let ret_tx = tx.clone();
let ret = async { let ret = async {
tx.write_frame(Packet::new_continue(0, buffer_size).into()) let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?;
.await?; let (extensions, extra_packet) = handshake_result.kind.into_parts();
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
send_info_packet(&tx, builders.deref_mut()).await?;
(closure)(builders.deref_mut()).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server( let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx), AppendingWebSocketRead(extra_packet, rx),
tx.clone(), tx.clone(),
supported_extensions.clone(), extensions.clone(),
buffer_size, buffer_size,
); );
Ok(MuxResult( Ok(MuxResult(
Self { Self {
muxstream_recv,
actor_tx: mux_result.actor_tx, actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited, actor_exited: mux_result.actor_exited,
muxstream_recv,
tx,
downgraded: handshake_result.downgraded,
supported_extensions: extensions,
}, },
mux_result.mux.into_future(), mux_result.mux.into_future(),
)) ))