diff --git a/server/src/config.rs b/server/src/config.rs index ea73aa2..50c705f 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -162,9 +162,11 @@ pub struct WispConfig { #[serde(skip_serializing_if = "HashMap::is_empty")] /// Wisp version 2 password authentication extension username/passwords. pub password_extension_users: HashMap, + pub password_extension_required: bool, #[serde(skip_serializing_if = "Vec::is_empty")] /// Wisp version 2 certificate authentication extension public ed25519 pem keys. pub certificate_extension_keys: Vec, + pub certificate_extension_required: bool, #[serde(skip_serializing_if = "is_default_motd")] /// Wisp version 2 MOTD extension message. @@ -334,7 +336,9 @@ impl Default for WispConfig { auth_extension: None, password_extension_users: HashMap::new(), + password_extension_required: true, certificate_extension_keys: Vec::new(), + certificate_extension_required: true, motd_extension: default_motd(), } @@ -364,18 +368,24 @@ impl WispConfig { extensions.push(AnyProtocolExtensionBuilder::new( PasswordProtocolExtensionBuilder::new_server( self.password_extension_users.clone(), + self.password_extension_required, ), )); - required_extensions.push(PasswordProtocolExtension::ID); + if self.password_extension_required { + required_extensions.push(PasswordProtocolExtension::ID); + } } Some(ProtocolExtensionAuth::Certificate) => { extensions.push(AnyProtocolExtensionBuilder::new( CertAuthProtocolExtensionBuilder::new_server( get_certificates_from_paths(self.certificate_extension_keys.clone()) .await?, + self.certificate_extension_required, ), )); - required_extensions.push(CertAuthProtocolExtension::ID); + if self.certificate_extension_required { + required_extensions.push(CertAuthProtocolExtension::ID); + } } None => {} } diff --git a/server/src/handle/wisp/twisp.rs b/server/src/handle/wisp/twisp.rs index 00363bc..9a556ef 100644 --- a/server/src/handle/wisp/twisp.rs +++ b/server/src/handle/wisp/twisp.rs @@ -80,12 +80,6 @@ impl ProtocolExtension for TWispServerProtocolExtension { } } -impl From for AnyProtocolExtension { - fn from(value: TWispServerProtocolExtension) -> Self { - AnyProtocolExtension::new(value) - } -} - pub struct TWispServerProtocolExtensionBuilder(TwispMap); impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder { @@ -124,7 +118,7 @@ pub fn new_map() -> TwispMap { } pub fn new_ext(map: TwispMap) -> AnyProtocolExtensionBuilder { - AnyProtocolExtensionBuilder::new(TWispServerProtocolExtensionBuilder(map)) + TWispServerProtocolExtensionBuilder(map).into() } pub async fn handle_twisp( diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index a3c705a..8d52f74 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -139,7 +139,7 @@ async fn main() -> Result<(), Box> { let split: Vec<_> = auth.split(':').collect(); let username = split[0].to_string(); let password = split[1..].join(":"); - PasswordProtocolExtensionBuilder::new_client(username, password) + PasswordProtocolExtensionBuilder::new_client(Some((username, password))) }); println!( @@ -188,7 +188,7 @@ async fn main() -> Result<(), Box> { } if let Some(certauth) = opts.certauth { let key = get_cert(certauth).await?; - let extension = CertAuthProtocolExtensionBuilder::new_client(key); + let extension = CertAuthProtocolExtensionBuilder::new_client(Some(key)); extensions.push(AnyProtocolExtensionBuilder::new(extension)); extension_ids.push(CertAuthProtocolExtension::ID); } diff --git a/wisp/src/extensions/cert.rs b/wisp/src/extensions/cert.rs index 981ba16..b5b2075 100644 --- a/wisp/src/extensions/cert.rs +++ b/wisp/src/extensions/cert.rs @@ -116,6 +116,8 @@ pub enum CertAuthProtocolExtension { cert_types: SupportedCertificateTypes, /// Random challenge for the client. challenge: Bytes, + /// Whether the server requires this authentication method. + required: bool, }, /// Client variant of certificate authentication protocol extension. Client { @@ -126,8 +128,8 @@ pub enum CertAuthProtocolExtension { /// Signature of challenge. signature: Bytes, }, - /// Marker that client has successfully signed the challenge. - ClientSigned, + /// Marker that client has successfully recieved the challenge. + ClientRecieved, /// Marker that server has successfully verified the client. ServerVerified, } @@ -155,8 +157,10 @@ impl ProtocolExtension for CertAuthProtocolExtension { Self::Server { cert_types, challenge, + required, } => { - let mut out = BytesMut::with_capacity(1 + challenge.len()); + let mut out = BytesMut::with_capacity(2 + challenge.len()); + out.put_u8(*required as u8); out.put_u8(cert_types.bits()); out.extend_from_slice(challenge); out.freeze() @@ -172,7 +176,7 @@ impl ProtocolExtension for CertAuthProtocolExtension { out.extend_from_slice(signature); out.freeze() } - Self::ClientSigned => Bytes::new(), + Self::ClientRecieved => Bytes::new(), Self::ServerVerified => Bytes::new(), } } @@ -199,12 +203,6 @@ impl ProtocolExtension for CertAuthProtocolExtension { } } -impl From for AnyProtocolExtension { - fn from(value: CertAuthProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } -} - /// Certificate authentication protocol extension builder. pub enum CertAuthProtocolExtensionBuilder { /// Server variant of certificate authentication protocol extension before the challenge has @@ -212,6 +210,8 @@ pub enum CertAuthProtocolExtensionBuilder { ServerBeforeChallenge { /// Keypair verifiers. verifiers: Vec, + /// Whether the server requires this authentication method. + required: bool, }, /// Server variant of certificate authentication protocol extension after the challenge has /// been sent. @@ -220,33 +220,63 @@ pub enum CertAuthProtocolExtensionBuilder { verifiers: Vec, /// Challenge to verify against. challenge: Bytes, + /// Whether the server requires this authentication method. + required: bool, }, /// Client variant of certificate authentication protocol extension before the challenge has /// been recieved. ClientBeforeChallenge { /// Keypair signer. - signer: SigningKey, + signer: Option, }, /// Client variant of certificate authentication protocol extension after the challenge has /// been recieved. ClientAfterChallenge { /// Keypair signer. - signer: SigningKey, - /// Signature of challenge recieved from the server. - signature: Bytes, + signer: Option, + /// Supported certificate types recieved from the server. + cert_types: SupportedCertificateTypes, + /// Challenge recieved from the server. + challenge: Bytes, + /// Whether the server requires this authentication method. + required: bool, }, } impl CertAuthProtocolExtensionBuilder { /// Create a new server variant of the certificate authentication protocol extension. - pub fn new_server(verifiers: Vec) -> Self { - Self::ServerBeforeChallenge { verifiers } + pub fn new_server(verifiers: Vec, required: bool) -> Self { + Self::ServerBeforeChallenge { + verifiers, + required, + } } /// Create a new client variant of the certificate authentication protocol extension. - pub fn new_client(signer: SigningKey) -> Self { + pub fn new_client(signer: Option) -> Self { Self::ClientBeforeChallenge { signer } } + + /// Get whether this authentication method is required. Could return None if the server has not + /// sent the certificate authentication protocol extension. + pub fn is_required(&self) -> Option { + match self { + Self::ServerBeforeChallenge { required, .. } => Some(*required), + Self::ServerAfterChallenge { required, .. } => Some(*required), + Self::ClientBeforeChallenge { .. } => None, + Self::ClientAfterChallenge { required, .. } => Some(*required), + } + } + + /// Set the credentials sent to the server, if this is a client variant. + pub fn set_signing_key(&mut self, key: SigningKey) { + match self { + Self::ClientBeforeChallenge { signer } | Self::ClientAfterChallenge { signer, .. } => { + *signer = Some(key); + } + Self::ServerBeforeChallenge { .. } | Self::ServerAfterChallenge { .. } => {} + } + } } #[async_trait] @@ -268,6 +298,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { Self::ServerAfterChallenge { verifiers, challenge, + .. } => { // validate and parse response let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8()) @@ -286,26 +317,19 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { } } Self::ClientBeforeChallenge { signer } => { + let required = bytes.get_u8() != 0; // sign challenge let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8()) - .ok_or(WispError::CertAuthExtensionSigInvalid)?; - if !cert_types.iter().any(|x| x == signer.cert_type) { - return Err(WispError::CertAuthExtensionSigInvalid); - } - - let signed: Bytes = signer - .signer - .try_sign(&bytes) - .map_err(CertAuthError::from)? - .to_vec() - .into(); + .ok_or(WispError::CertAuthExtensionCertTypeInvalid)?; *self = Self::ClientAfterChallenge { signer: signer.clone(), - signature: signed, + cert_types, + challenge: bytes, + required, }; - Ok(CertAuthProtocolExtension::ClientSigned.into()) + Ok(CertAuthProtocolExtension::ClientRecieved.into()) } // client has already recieved a challenge Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), @@ -316,19 +340,26 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { // server: 1 fn build_to_extension(&mut self, _: Role) -> Result { match self { - Self::ServerBeforeChallenge { verifiers } => { + Self::ServerBeforeChallenge { + verifiers, + required, + } => { let mut challenge = [0u8; 64]; getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?; let challenge = Bytes::from(challenge.to_vec()); + let required = *required; + *self = Self::ServerAfterChallenge { verifiers: verifiers.to_vec(), challenge: challenge.clone(), + required, }; Ok(CertAuthProtocolExtension::Server { cert_types: SupportedCertificateTypes::Ed25519, challenge, + required, } .into()) } @@ -336,7 +367,24 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder { Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), // client needs to recieve a challenge Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported), - Self::ClientAfterChallenge { signer, signature } => { + Self::ClientAfterChallenge { + signer, + challenge, + cert_types, + .. + } => { + let signer = signer.as_ref().ok_or(WispError::CertAuthExtensionNoKey)?; + if !cert_types.iter().any(|x| x == signer.cert_type) { + return Err(WispError::CertAuthExtensionCertTypeInvalid); + } + + let signature: Bytes = signer + .signer + .try_sign(challenge) + .map_err(CertAuthError::from)? + .to_vec() + .into(); + Ok(CertAuthProtocolExtension::Client { cert_type: signer.cert_type, hash: signer.hash, diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs index 3c81875..8d2e936 100644 --- a/wisp/src/extensions/mod.rs +++ b/wisp/src/extensions/mod.rs @@ -74,6 +74,12 @@ impl From for Bytes { } } +impl From for AnyProtocolExtension { + fn from(value: T) -> Self { + Self::new(value) + } +} + /// A Wisp protocol extension. /// /// See [the @@ -162,6 +168,8 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static { fn get_id(&self) -> u8; /// Build a protocol extension from the extension's metadata. + /// + /// This is called second on the server and first on the client. fn build_from_bytes( &mut self, bytes: Bytes, @@ -169,6 +177,8 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static { ) -> Result; /// Build a protocol extension to send to the other side. + /// + /// This is called first on the server and second on the client. fn build_to_extension(&mut self, role: Role) -> Result; /// Do not override. @@ -248,3 +258,9 @@ impl DerefMut for AnyProtocolExtensionBuilder { self.0.deref_mut() } } + +impl From for AnyProtocolExtensionBuilder { + fn from(value: T) -> Self { + Self::new(value) + } +} diff --git a/wisp/src/extensions/motd.rs b/wisp/src/extensions/motd.rs index f835486..2c1c17a 100644 --- a/wisp/src/extensions/motd.rs +++ b/wisp/src/extensions/motd.rs @@ -68,12 +68,6 @@ impl ProtocolExtension for MotdProtocolExtension { } } -impl From for AnyProtocolExtension { - fn from(value: MotdProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } -} - /// MOTD protocol extension builder. pub enum MotdProtocolExtensionBuilder { /// Server variant of MOTD protocol extension builder. Has the MOTD. diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs index c20e01b..ad5f71f 100644 --- a/wisp/src/extensions/password.rs +++ b/wisp/src/extensions/password.rs @@ -1,9 +1,8 @@ //! Password protocol extension. //! -//! Passwords are sent in plain text!! +//! **Passwords are sent in plain text!!** //! //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) - use std::collections::HashMap; use async_trait::async_trait; @@ -16,56 +15,53 @@ use crate::{ use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; -#[derive(Debug, Clone)] +/// ID of password protocol extension. +pub const PASSWORD_PROTOCOL_EXTENSION_ID: u8 = 0x02; + /// Password protocol extension. /// /// **Passwords are sent in plain text!!** -/// **This extension will panic when encoding if the username's length does not fit within a u8 -/// or the password's length does not fit within a u16.** -pub struct PasswordProtocolExtension { - /// The username to log in with. - /// - /// This string's length must fit within a u8. - pub username: String, - /// The password to log in with. - /// - /// This string's length must fit within a u16. - pub password: String, - role: Role, +/// +/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) +#[derive(Debug, Clone)] +pub enum PasswordProtocolExtension { + /// Password protocol extension before the client INFO packet has been received. + ServerBeforeClientInfo { + /// Whether this authentication method is required. + required: bool, + }, + /// Password protocol extension after the client INFO packet has been received. + ServerAfterClientInfo { + /// The client's chosen user. + chosen_user: String, + /// The client's chosen password + chosen_password: String, + }, + + /// Password protocol extension before the server INFO has been received. + ClientBeforeServerInfo, + /// Password protocol extension after the server INFO has been received. + ClientAfterServerInfo { + /// The user to send to the server. + user: String, + /// The password to send to the user. + password: String, + }, } impl PasswordProtocolExtension { - /// Password protocol extension ID. - pub const ID: u8 = 0x02; - - /// Create a new password protocol extension for the server. - /// - /// This signifies that the server requires a password. - pub fn new_server() -> Self { - Self { - username: String::new(), - password: String::new(), - role: Role::Server, - } - } - - /// Create a new password protocol extension for the client, with a username and password. - /// - /// The username's length must fit within a u8. The password's length must fit within a - /// u16. - pub fn new_client(username: String, password: String) -> Self { - Self { - username, - password, - role: Role::Client, - } - } + /// ID of password protocol extension. + pub const ID: u8 = PASSWORD_PROTOCOL_EXTENSION_ID; } #[async_trait] impl ProtocolExtension for PasswordProtocolExtension { fn get_id(&self) -> u8 { - Self::ID + PASSWORD_PROTOCOL_EXTENSION_ID + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) } fn get_supported_packets(&self) -> &'static [u8] { @@ -77,142 +73,179 @@ impl ProtocolExtension for PasswordProtocolExtension { } fn encode(&self) -> Bytes { - match self.role { - Role::Server => Bytes::new(), - Role::Client => { - let username = Bytes::from(self.username.clone().into_bytes()); - let password = Bytes::from(self.password.clone().into_bytes()); - let username_len = u8::try_from(username.len()).expect("username was too long"); - let password_len = u16::try_from(password.len()).expect("password was too long"); - - let mut bytes = - BytesMut::with_capacity(3 + username_len as usize + password_len as usize); - bytes.put_u8(username_len); - bytes.put_u16_le(password_len); - bytes.extend(username); - bytes.extend(password); - bytes.freeze() + match self { + Self::ServerBeforeClientInfo { required } => { + let mut out = BytesMut::with_capacity(1); + out.put_u8(*required as u8); + out.freeze() + } + Self::ServerAfterClientInfo { .. } => Bytes::new(), + Self::ClientBeforeServerInfo => Bytes::new(), + Self::ClientAfterServerInfo { user, password } => { + let mut out = BytesMut::with_capacity(1 + 2 + user.len() + password.len()); + out.put_u8(user.len().try_into().unwrap()); + out.put_u16_le(password.len().try_into().unwrap()); + out.extend_from_slice(user.as_bytes()); + out.extend_from_slice(password.as_bytes()); + out.freeze() } } } async fn handle_handshake( &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _read: &mut dyn WebSocketRead, + _write: &LockedWebSocketWrite, ) -> Result<(), WispError> { Ok(()) } async fn handle_packet( &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, + _packet: Bytes, + _read: &mut dyn WebSocketRead, + _write: &LockedWebSocketWrite, ) -> Result<(), WispError> { - Ok(()) - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } -} - -impl From for AnyProtocolExtension { - fn from(value: PasswordProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) + Err(WispError::ExtensionImplNotSupported) } } /// Password protocol extension builder. /// /// **Passwords are sent in plain text!!** -pub struct PasswordProtocolExtensionBuilder { - /// Map of users and their passwords to allow. Only used on server. - pub users: HashMap, - /// Username to authenticate with. Only used on client. - pub username: String, - /// Password to authenticate with. Only used on client. - pub password: String, +/// +/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) +pub enum PasswordProtocolExtensionBuilder { + /// Password protocol extension builder before the client INFO has been received. + ServerBeforeClientInfo { + /// The user+password combinations to verify the client with. + users: HashMap, + /// Whether this authentication method is required. + required: bool, + }, + /// Password protocol extension builder after the client INFO has been received. + ServerAfterClientInfo { + /// The user+password combinations to verify the client with. + users: HashMap, + /// Whether this authentication method is required. + required: bool, + }, + + /// Password protocol extension builder before the server INFO has been received. + ClientBeforeServerInfo { + /// The credentials to send to the server. + creds: Option<(String, String)>, + }, + /// Password protocol extension builder after the server INFO has been received. + ClientAfterServerInfo { + /// The credentials to send to the server. + creds: Option<(String, String)>, + /// Whether this authentication method is required. + required: bool, + }, } impl PasswordProtocolExtensionBuilder { - /// Create a new password protocol extension builder for the server, with a map of users - /// and passwords to allow. - pub fn new_server(users: HashMap) -> Self { - Self { - users, - username: String::new(), - password: String::new(), + /// ID of password protocol extension. + pub const ID: u8 = PASSWORD_PROTOCOL_EXTENSION_ID; + + /// Create a new server variant of the password protocol extension. + pub fn new_server(users: HashMap, required: bool) -> Self { + Self::ServerBeforeClientInfo { users, required } + } + + /// Create a new client variant of the password protocol extension with a username and password. + pub fn new_client(creds: Option<(String, String)>) -> Self { + Self::ClientBeforeServerInfo { creds } + } + + /// Get whether this authentication method is required. Could return None if the server has not + /// sent the password protocol extension. + pub fn is_required(&self) -> Option { + match self { + Self::ServerBeforeClientInfo { required, .. } => Some(*required), + Self::ServerAfterClientInfo { required, .. } => Some(*required), + Self::ClientBeforeServerInfo { .. } => None, + Self::ClientAfterServerInfo { required, .. } => Some(*required), } } - /// Create a new password protocol extension builder for the client, with a username and - /// password to authenticate with. - pub fn new_client(username: String, password: String) -> Self { - Self { - users: HashMap::new(), - username, - password, + /// Set the credentials sent to the server, if this is a client variant. + pub fn set_creds(&mut self, credentials: (String, String)) { + match self { + Self::ClientBeforeServerInfo { creds } | Self::ClientAfterServerInfo { creds, .. } => { + *creds = Some(credentials); + } + Self::ServerBeforeClientInfo { .. } | Self::ServerAfterClientInfo { .. } => {} } } } impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { fn get_id(&self) -> u8 { - PasswordProtocolExtension::ID + PASSWORD_PROTOCOL_EXTENSION_ID + } + + fn build_to_extension(&mut self, _role: Role) -> Result { + match self { + Self::ServerBeforeClientInfo { users: _, required } => { + Ok(PasswordProtocolExtension::ServerBeforeClientInfo { + required: *required, + } + .into()) + } + Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported), + Self::ClientBeforeServerInfo { .. } => Err(WispError::ExtensionImplNotSupported), + Self::ClientAfterServerInfo { creds, .. } => { + let (user, password) = creds.clone().ok_or(WispError::PasswordExtensionNoCreds)?; + Ok(PasswordProtocolExtension::ClientAfterServerInfo { user, password }.into()) + } + } } fn build_from_bytes( &mut self, - mut payload: Bytes, - role: crate::Role, + mut bytes: Bytes, + _role: Role, ) -> Result { - match role { - Role::Server => { - if payload.remaining() < 3 { - return Err(WispError::PacketTooSmall); - } + match self { + Self::ServerBeforeClientInfo { users, required } => { + let user_len = bytes.get_u8(); + let password_len = bytes.get_u16_le(); - let username_len = payload.get_u8(); - let password_len = payload.get_u16_le(); - if payload.remaining() < (password_len + username_len as u16) as usize { - return Err(WispError::PacketTooSmall); - } - - let username = - std::str::from_utf8(&payload.split_to(username_len as usize))?.to_string(); + let user = std::str::from_utf8(&bytes.split_to(user_len as usize))?.to_string(); let password = - std::str::from_utf8(&payload.split_to(password_len as usize))?.to_string(); + std::str::from_utf8(&bytes.split_to(password_len as usize))?.to_string(); - let Some(user) = self.users.iter().find(|x| *x.0 == username) else { - return Err(WispError::PasswordExtensionCredsInvalid); + let valid = users.get(&user).map(|x| *x == password).unwrap_or(false); + + *self = Self::ServerAfterClientInfo { + users: users.clone(), + required: *required, }; - if *user.1 != password { - return Err(WispError::PasswordExtensionCredsInvalid); + if !valid { + Err(WispError::PasswordExtensionCredsInvalid) + } else { + Ok(PasswordProtocolExtension::ServerAfterClientInfo { + chosen_user: user, + chosen_password: password, + } + .into()) } + } + Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported), + Self::ClientBeforeServerInfo { creds } => { + let required = bytes.get_u8() != 0; - Ok(PasswordProtocolExtension { - username, - password, - role, - } - .into()) - } - Role::Client => { - Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) - } - } - } + *self = Self::ClientAfterServerInfo { + creds: creds.clone(), + required, + }; - fn build_to_extension(&mut self, role: Role) -> Result { - Ok(match role { - Role::Server => PasswordProtocolExtension::new_server(), - Role::Client => { - PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) + Ok(PasswordProtocolExtension::ClientBeforeServerInfo.into()) } + Self::ClientAfterServerInfo { .. } => Err(WispError::ExtensionImplNotSupported), } - .into()) } } diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs index c1e025b..33bea07 100644 --- a/wisp/src/extensions/udp.rs +++ b/wisp/src/extensions/udp.rs @@ -60,12 +60,6 @@ impl ProtocolExtension for UdpProtocolExtension { } } -impl From for AnyProtocolExtension { - fn from(value: UdpProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } -} - /// UDP protocol extension builder. pub struct UdpProtocolExtensionBuilder; diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 213d214..ff4e14b 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -106,10 +106,21 @@ pub enum WispError { /// The specified protocol extensions are not supported by the other side. #[error("Protocol extensions {0:?} not supported")] ExtensionsNotSupported(Vec), + /// The password authentication username/password was invalid. #[error("Password protocol extension: Invalid username/password")] PasswordExtensionCredsInvalid, + /// No password authentication username/password was provided. + #[error("Password protocol extension: No username/password provided")] + PasswordExtensionNoCreds, + + /// The certificate authentication certificate type was unsupported. + #[error("Certificate authentication protocol extension: Invalid certificate type")] + CertAuthExtensionCertTypeInvalid, /// The certificate authentication signature was invalid. #[error("Certificate authentication protocol extension: Invalid signature")] CertAuthExtensionSigInvalid, + /// No certificate authentication signing key was provided. + #[error("Password protocol extension: No signing key provided")] + CertAuthExtensionNoKey, } diff --git a/wisp/src/mux/mod.rs b/wisp/src/mux/mod.rs index ba94b73..9166a89 100644 --- a/wisp/src/mux/mod.rs +++ b/wisp/src/mux/mod.rs @@ -124,11 +124,10 @@ where } type WispV2ClosureResult = Pin> + Sync + Send>>; -type WispV2ClosureBuilders<'a> = &'a mut [AnyProtocolExtensionBuilder]; /// Wisp V2 handshake and protocol extension settings wrapper struct. pub struct WispV2Extensions { builders: Vec, - closure: Box WispV2ClosureResult + Send>, + closure: Box) -> WispV2ClosureResult + Send>, } impl WispV2Extensions { @@ -143,7 +142,7 @@ impl WispV2Extensions { /// Create a Wisp V2 settings struct with some middleware. pub fn new_with_middleware(builders: Vec, closure: C) -> Self where - C: Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send + 'static, + C: Fn(&mut Vec) -> WispV2ClosureResult + Send + 'static, { Self { builders,