From b2435b554a1dc95380391dc90669e028f7c139d4 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 14 Sep 2024 10:28:45 -0700 Subject: [PATCH] ed25519 auth --- Cargo.lock | 75 ++++++++ client/src/stream_provider.rs | 2 +- server/src/handle/wisp.rs | 2 +- simple-wisp-client/src/main.rs | 2 +- wisp/Cargo.toml | 8 +- wisp/src/extensions/cert.rs | 319 ++++++++++++++++++++++++++++++++ wisp/src/extensions/mod.rs | 6 +- wisp/src/extensions/password.rs | 35 +--- wisp/src/extensions/udp.rs | 15 +- wisp/src/lib.rs | 94 ++++------ wisp/src/packet.rs | 10 +- 11 files changed, 450 insertions(+), 118 deletions(-) create mode 100644 wisp/src/extensions/cert.rs diff --git a/Cargo.lock b/Cargo.lock index 3f7727e..595234a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -252,6 +252,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bitflags" version = "2.6.0" @@ -421,6 +427,12 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "cpufeatures" version = "0.2.13" @@ -484,6 +496,17 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +[[package]] +name = "der" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -494,6 +517,17 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", + "zeroize", +] + [[package]] name = "either" version = "1.13.0" @@ -1387,6 +1421,15 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1425,6 +1468,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1801,6 +1854,15 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "rand_core", +] + [[package]] name = "simdutf8" version = "0.1.4" @@ -1869,6 +1931,16 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "strsim" version = "0.11.1" @@ -2597,12 +2669,15 @@ version = "5.1.0" dependencies = [ "async-trait", "atomic_enum", + "bitflags", "bytes", + "ed25519", "event-listener", "fastwebsockets", "flume", "futures", "futures-timer", + "getrandom", "nohash-hasher", "pin-project-lite", "tokio", diff --git a/client/src/stream_provider.rs b/client/src/stream_provider.rs index 8ad3f00..11c1097 100644 --- a/client/src/stream_provider.rs +++ b/client/src/stream_provider.rs @@ -109,7 +109,7 @@ impl StreamProvider { let extensions_vec: Vec> = vec![Box::new(UdpProtocolExtensionBuilder)]; let extensions = if self.wisp_v2 { - Some(extensions_vec.as_slice()) + Some(extensions_vec) } else { None }; diff --git a/server/src/handle/wisp.rs b/server/src/handle/wisp.rs index 6b8af0b..1e59863 100644 --- a/server/src/handle/wisp.rs +++ b/server/src/handle/wisp.rs @@ -230,7 +230,7 @@ pub async fn handle_wisp(stream: WispResult, id: String) -> anyhow::Result<()> { } } - let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions.as_deref()) + let (mux, fut) = ServerMux::create(read, write, buffer_size, extensions) .await .context("failed to create server multiplexor")? .with_no_required_extensions(); diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 14bfba1..870e828 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -163,7 +163,7 @@ async fn main() -> Result<(), Box> { .await? .with_no_required_extensions() } else { - ClientMux::create(rx, tx, Some(extensions.as_slice())) + ClientMux::create(rx, tx, Some(extensions)) .await? .with_required_extensions(extension_ids.as_slice()) .await? diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 4e70af4..8799d75 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -13,21 +13,25 @@ categories = ["network-programming", "asynchronous", "web-programming::websocket [dependencies] async-trait = "0.1.81" atomic_enum = "0.3.0" +bitflags = { version = "2.6.0", optional = true, features = ["std"] } bytes = "1.7.1" +ed25519 = { version = "2.2.3", optional = true, features = ["pem", "zeroize"] } event-listener = "5.3.1" fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" +getrandom = { version = "0.2.15", features = ["std"], optional = true } nohash-hasher = "0.2.0" pin-project-lite = "0.2.14" tokio = { version = "1.39.3", optional = true, default-features = false } [features] -default = ["generic_stream"] +default = ["generic_stream", "certificate"] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] generic_stream = [] -wasm = ["futures-timer/wasm-bindgen"] +wasm = ["futures-timer/wasm-bindgen", "getrandom/js"] +certificate = ["dep:ed25519", "dep:bitflags", "dep:getrandom"] [package.metadata.docs.rs] all-features = true diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs new file mode 100644 index 0000000..fc4bcf0 --- /dev/null +++ b/wisp/src/extensions/cert.rs @@ -0,0 +1,319 @@ +//! Certificate authentication protocol extension. +//! + +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use ed25519::{ + signature::{Signer, Verifier}, + Signature, +}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + +/// Certificate authentication protocol extension error. +#[derive(Debug)] +pub enum CertAuthError { + /// Invalid or unsupported certificate type + InvalidCertType, + /// Invalid signature + InvalidSignature, + /// ED25519 error + Ed25519(ed25519::Error), + /// Getrandom error + Getrandom(getrandom::Error), +} + +impl std::fmt::Display for CertAuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidCertType => write!(f, "Invalid or unsupported certificate type"), + Self::InvalidSignature => write!(f, "Invalid signature"), + Self::Ed25519(x) => write!(f, "ED25519: {:?}", x), + Self::Getrandom(x) => write!(f, "getrandom: {:?}", x), + } + } +} +impl std::error::Error for CertAuthError {} + +impl From for CertAuthError { + fn from(value: ed25519::Error) -> Self { + CertAuthError::Ed25519(value) + } +} +impl From for CertAuthError { + fn from(value: getrandom::Error) -> Self { + CertAuthError::Getrandom(value) + } +} +impl From for WispError { + fn from(value: CertAuthError) -> Self { + WispError::ExtensionImplError(Box::new(value)) + } +} + +bitflags::bitflags! { + /// Supported certificate types for certificate authentication. + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct SupportedCertificateTypes: u8 { + /// ED25519 certificate. + const Ed25519 = 0b00000001; + } +} + +/// Verification key. +#[derive(Clone)] +pub struct VerifyKey { + /// Certificate type of the keypair. + pub cert_type: SupportedCertificateTypes, + /// SHA-512 hash of the public key. + pub hash: [u8; 64], + /// Verifier. + pub verifier: Arc>, +} + +/// Signing key. +#[derive(Clone)] +pub struct SigningKey { + /// Certificate type of the keypair. + pub cert_type: SupportedCertificateTypes, + /// SHA-512 hash of the public key. + pub hash: [u8; 64], + /// Signer. + pub signer: Arc>, +} + +/// Certificate authentication protocol extension. +#[derive(Debug, Clone)] +pub enum CertAuthProtocolExtension { + /// Server variant of certificate authentication protocol extension. + Server { + /// Supported certificate types on the server. + cert_types: SupportedCertificateTypes, + /// Random challenge for the client. + challenge: Bytes, + }, + /// Client variant of certificate authentication protocol extension. + Client { + /// Chosen certificate type. + cert_type: SupportedCertificateTypes, + /// Hash of public key. + hash: [u8; 64], + /// Signature of challenge. + signature: Bytes, + }, + /// Marker that client has successfully signed the challenge. + ClientSigned, + /// Marker that server has successfully verified the client. + ServerVerified, +} + +impl CertAuthProtocolExtension { + /// ID of certificate authentication protocol extension. + pub const ID: u8 = 0x03; +} + +#[async_trait] +impl ProtocolExtension for CertAuthProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + fn get_congestion_stream_types(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + match self { + Self::Server { + cert_types, + challenge, + } => { + let mut out = BytesMut::with_capacity(1 + challenge.len()); + out.put_u8(cert_types.bits()); + out.extend_from_slice(challenge); + out.freeze() + } + Self::Client { + cert_type, + hash, + signature, + } => { + let mut out = BytesMut::with_capacity(1 + signature.len()); + out.put_u8(cert_type.bits()); + out.extend_from_slice(hash); + out.extend_from_slice(signature); + out.freeze() + } + Self::ClientSigned => Bytes::new(), + Self::ServerVerified => Bytes::new(), + } + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +impl From for AnyProtocolExtension { + fn from(value: CertAuthProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } +} + +/// Certificate authentication protocol extension builder. +pub enum CertAuthProtocolExtensionBuilder { + /// Server variant of certificate authentication protocol extension before the challenge has + /// been sent. + ServerBeforeChallenge { + /// Keypair verifiers. + verifiers: Vec, + }, + /// Server variant of certificate authentication protocol extension after the challenge has + /// been sent. + ServerAfterChallenge { + /// Keypair verifiers. + verifiers: Vec, + /// Challenge to verify against. + challenge: Bytes, + }, + /// Client variant of certificate authentication protocol extension before the challenge has + /// been recieved. + ClientBeforeChallenge { + /// Keypair signer. + signer: SigningKey, + }, + /// Client variant of certificate authentication protocol extension after the challenge has + /// been recieved. + ClientAfterChallenge { + /// Keypair signer. + signer: SigningKey, + /// Signature of challenge recieved from the server. + signature: Bytes, + }, +} + +#[async_trait] +impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + CertAuthProtocolExtension::ID + } + + // client: 1 + // server: 2 + fn build_from_bytes( + &mut self, + mut bytes: Bytes, + _: Role, + ) -> Result { + match self { + // server should have already sent the challenge before recieving a response to parse + Self::ServerBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported), + Self::ServerAfterChallenge { + verifiers, + challenge, + } => { + // validate and parse response + let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8()) + .ok_or(CertAuthError::InvalidCertType)?; + let hash = bytes.split_to(64); + let sig = Signature::from_slice(&bytes).map_err(CertAuthError::from)?; + let is_valid = verifiers + .iter() + .filter(|x| x.cert_type == cert_type && x.hash == *hash) + .any(|x| x.verifier.verify(challenge, &sig).is_ok()); + + if is_valid { + Ok(CertAuthProtocolExtension::ServerVerified.into()) + } else { + Err(CertAuthError::InvalidSignature.into()) + } + } + Self::ClientBeforeChallenge { signer } => { + // sign challenge + let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8()) + .ok_or(CertAuthError::InvalidCertType)?; + if !cert_types.iter().any(|x| x == signer.cert_type) { + return Err(CertAuthError::InvalidCertType.into()); + } + + let signed: Bytes = signer + .signer + .try_sign(&bytes) + .map_err(CertAuthError::from)? + .to_vec() + .into(); + + *self = Self::ClientAfterChallenge { + signer: signer.clone(), + signature: signed, + }; + + Ok(CertAuthProtocolExtension::ClientSigned.into()) + } + // client has already recieved a challenge + Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), + } + } + + // client: 2 + // server: 1 + fn build_to_extension(&mut self, _: Role) -> Result { + match self { + Self::ServerBeforeChallenge { verifiers } => { + let mut challenge = BytesMut::with_capacity(64); + getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?; + let challenge = challenge.freeze(); + + *self = Self::ServerAfterChallenge { + verifiers: verifiers.to_vec(), + challenge: challenge.clone(), + }; + + Ok(CertAuthProtocolExtension::Server { + cert_types: SupportedCertificateTypes::Ed25519, + challenge, + } + .into()) + } + // server has already sent a challenge + Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), + // client needs to recieve a challenge + Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported), + Self::ClientAfterChallenge { signer, signature } => { + Ok(CertAuthProtocolExtension::Client { + cert_type: signer.cert_type, + hash: signer.hash, + signature: signature.clone(), + } + .into()) + } + } + } +} diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 141d45d..c1ab7be 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -1,6 +1,8 @@ //! Wisp protocol extensions. pub mod password; pub mod udp; +#[cfg(feature = "certificate")] +pub mod cert; use std::ops::{Deref, DerefMut}; @@ -102,9 +104,9 @@ pub trait ProtocolExtensionBuilder { fn get_id(&self) -> u8; /// Build a protocol extension from the extension's metadata. - fn build_from_bytes(&self, bytes: Bytes, role: Role) + fn build_from_bytes(&mut self, bytes: Bytes, role: Role) -> Result; /// Build a protocol extension to send to the other side. - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; + fn build_to_extension(&mut self, role: Role) -> Result; } diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index 6246c6c..6d6b6ca 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -2,33 +2,6 @@ //! //! Passwords are sent in plain text!! //! -//! # Example -//! Server: -//! ``` -//! let mut passwords = HashMap::new(); -//! passwords.insert("user1".to_string(), "pw".to_string()); -//! let (mux, fut) = ServerMux::new( -//! rx, -//! tx, -//! 128, -//! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))]) -//! ); -//! ``` -//! -//! Client: -//! ``` -//! let (mux, fut) = ClientMux::new( -//! rx, -//! tx, -//! 128, -//! Some(&[ -//! Box::new(PasswordProtocolExtensionBuilder::new_client( -//! "user1".to_string(), -//! "pw".to_string() -//! )) -//! ]) -//! ); -//! ``` //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; @@ -223,7 +196,7 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { } fn build_from_bytes( - &self, + &mut self, mut payload: Bytes, role: crate::Role, ) -> Result { @@ -268,13 +241,13 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { } } - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { - match role { + fn build_to_extension(&mut self, role: Role) -> Result { + Ok(match role { Role::Server => PasswordProtocolExtension::new_server(), Role::Client => { PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) } } - .into() + .into()) } } diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index b2b5150..c1e025b 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -1,14 +1,5 @@ //! UDP protocol extension. //! -//! # Example -//! ``` -//! let (mux, fut) = ServerMux::new( -//! rx, -//! tx, -//! 128, -//! Some(&[Box::new(UdpProtocolExtensionBuilder)]) -//! ); -//! ``` //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x01---udp) use async_trait::async_trait; use bytes::Bytes; @@ -84,14 +75,14 @@ impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { } fn build_from_bytes( - &self, + &mut self, _: Bytes, _: crate::Role, ) -> Result { Ok(UdpProtocolExtension.into()) } - fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { - UdpProtocolExtension.into() + fn build_to_extension(&mut self, _: crate::Role) -> Result { + Ok(UdpProtocolExtension.into()) } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 6ccd8ec..b39f7c4 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,4 +1,4 @@ -#![deny(missing_docs)] +#![deny(missing_docs, clippy::todo)] #![cfg_attr(docsrs, feature(doc_cfg))] //! A library for easily creating [Wisp] clients and servers. //! @@ -162,7 +162,7 @@ impl std::error::Error for WispError {} async fn maybe_wisp_v2( read: &mut R, write: &LockedWebSocketWrite, - builders: &[Box], + builders: &mut [Box], ) -> Result<(Vec, Option>, bool), WispError> where R: ws::WebSocketRead + Send, @@ -195,25 +195,24 @@ where Ok((supported_extensions, extra_packet, downgraded)) } +async fn send_info_packet( + write: &LockedWebSocketWrite, + builders: &mut [Box], +) -> Result<(), WispError> { + write + .write_frame( + Packet::new_info( + builders + .iter_mut() + .map(|x| x.build_to_extension(Role::Server)) + .collect::, _>>()?, + ) + .into(), + ) + .await +} + /// Server-side multiplexor. -/// -/// # Example -/// ``` -/// use wisp_mux::ServerMux; -/// -/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some([])); -/// tokio::spawn(async move { -/// if let Err(e) = fut.await { -/// println!("error in multiplexor: {:?}", e); -/// } -/// }); -/// while let Some((packet, stream)) = mux.server_new_stream().await { -/// tokio::spawn(async move { -/// let url = format!("{}:{}", packet.destination_hostname, packet.destination_port); -/// // do something with `url` and `packet.stream_type` -/// }); -/// } -/// ``` pub struct ServerMux { /// Whether the connection was downgraded to Wisp v1. /// @@ -237,7 +236,7 @@ impl ServerMux { mut rx: R, tx: W, buffer_size: u32, - extension_builders: Option<&[Box]>, + extension_builders: Option>>, ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, @@ -249,18 +248,9 @@ impl ServerMux { .await?; let (supported_extensions, extra_packet, downgraded) = - if let Some(builders) = extension_builders { - tx.write_frame( - Packet::new_info( - builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(), - ) - .into(), - ) - .await?; - maybe_wisp_v2(&mut rx, &tx, builders).await? + if let Some(mut builders) = extension_builders { + send_info_packet(&tx, &mut builders).await?; + maybe_wisp_v2(&mut rx, &tx, &mut builders).await? } else { (Vec::new(), None, true) }; @@ -367,7 +357,9 @@ where } else { self.0.close_extension_incompat().await?; self.1.await?; - Err(WispError::ExtensionsNotSupported(unsupported_extensions)) + Err(WispError::ExtensionsNotSupported( + unsupported_extensions, + )) } } @@ -379,19 +371,6 @@ where } /// Client side multiplexor. -/// -/// # Example -/// ``` -/// use wisp_mux::{ClientMux, StreamType}; -/// -/// let (mux, fut) = ClientMux::new(rx, tx, Some([])).await?; -/// tokio::spawn(async move { -/// if let Err(e) = fut.await { -/// println!("error in multiplexor: {:?}", e); -/// } -/// }); -/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80); -/// ``` pub struct ClientMux { /// Whether the connection was downgraded to Wisp v1. /// @@ -413,7 +392,7 @@ impl ClientMux { pub async fn create( mut rx: R, tx: W, - extension_builders: Option<&[Box]>, + extension_builders: Option>>, ) -> Result> + Send>, WispError> where R: ws::WebSocketRead + Send, @@ -428,22 +407,13 @@ impl ClientMux { if let PacketType::Continue(packet) = first_packet.packet_type { let (supported_extensions, extra_packet, downgraded) = - if let Some(builders) = extension_builders { - let x = maybe_wisp_v2(&mut rx, &tx, builders).await?; + if let Some(mut builders) = extension_builders { + let res = maybe_wisp_v2(&mut rx, &tx, &mut builders).await?; // if not downgraded - if !x.2 { - tx.write_frame( - Packet::new_info( - builders - .iter() - .map(|x| x.build_to_extension(Role::Client)) - .collect(), - ) - .into(), - ) - .await?; + if !res.2 { + send_info_packet(&tx, &mut builders).await?; } - x + res } else { (Vec::new(), None, true) }; diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index ce857d2..0463f8f 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -428,7 +428,7 @@ impl<'a> Packet<'a> { pub(crate) fn maybe_parse_info( frame: Frame<'a>, role: Role, - extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], + extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], ) -> Result { if !frame.finished { return Err(WispError::WsFrameNotFinished); @@ -502,7 +502,7 @@ impl<'a> Packet<'a> { fn parse_info( mut bytes: Payload<'a>, role: Role, - extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], + extension_builders: &mut [Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], ) -> Result { // packet type is already read by code that calls this if bytes.remaining() < 4 + 2 { @@ -530,10 +530,8 @@ impl<'a> Packet<'a> { if bytes.remaining() < length { return Err(WispError::PacketTooSmall); } - if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { - if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { - extensions.push(extension) - } + if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) { + extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?) } else { bytes.advance(length) }