refactor protocol extensions

This commit is contained in:
Toshit Chawda 2024-10-24 19:47:38 -07:00
parent 1ae3986a82
commit 36fddc8943
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
10 changed files with 289 additions and 190 deletions

View file

@ -162,9 +162,11 @@ pub struct WispConfig {
#[serde(skip_serializing_if = "HashMap::is_empty")] #[serde(skip_serializing_if = "HashMap::is_empty")]
/// Wisp version 2 password authentication extension username/passwords. /// Wisp version 2 password authentication extension username/passwords.
pub password_extension_users: HashMap<String, String>, pub password_extension_users: HashMap<String, String>,
pub password_extension_required: bool,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
/// Wisp version 2 certificate authentication extension public ed25519 pem keys. /// Wisp version 2 certificate authentication extension public ed25519 pem keys.
pub certificate_extension_keys: Vec<PathBuf>, pub certificate_extension_keys: Vec<PathBuf>,
pub certificate_extension_required: bool,
#[serde(skip_serializing_if = "is_default_motd")] #[serde(skip_serializing_if = "is_default_motd")]
/// Wisp version 2 MOTD extension message. /// Wisp version 2 MOTD extension message.
@ -334,7 +336,9 @@ impl Default for WispConfig {
auth_extension: None, auth_extension: None,
password_extension_users: HashMap::new(), password_extension_users: HashMap::new(),
password_extension_required: true,
certificate_extension_keys: Vec::new(), certificate_extension_keys: Vec::new(),
certificate_extension_required: true,
motd_extension: default_motd(), motd_extension: default_motd(),
} }
@ -364,18 +368,24 @@ impl WispConfig {
extensions.push(AnyProtocolExtensionBuilder::new( extensions.push(AnyProtocolExtensionBuilder::new(
PasswordProtocolExtensionBuilder::new_server( PasswordProtocolExtensionBuilder::new_server(
self.password_extension_users.clone(), self.password_extension_users.clone(),
self.password_extension_required,
), ),
)); ));
required_extensions.push(PasswordProtocolExtension::ID); if self.password_extension_required {
required_extensions.push(PasswordProtocolExtension::ID);
}
} }
Some(ProtocolExtensionAuth::Certificate) => { Some(ProtocolExtensionAuth::Certificate) => {
extensions.push(AnyProtocolExtensionBuilder::new( extensions.push(AnyProtocolExtensionBuilder::new(
CertAuthProtocolExtensionBuilder::new_server( CertAuthProtocolExtensionBuilder::new_server(
get_certificates_from_paths(self.certificate_extension_keys.clone()) get_certificates_from_paths(self.certificate_extension_keys.clone())
.await?, .await?,
self.certificate_extension_required,
), ),
)); ));
required_extensions.push(CertAuthProtocolExtension::ID); if self.certificate_extension_required {
required_extensions.push(CertAuthProtocolExtension::ID);
}
} }
None => {} None => {}
} }

View file

@ -80,12 +80,6 @@ impl ProtocolExtension for TWispServerProtocolExtension {
} }
} }
impl From<TWispServerProtocolExtension> for AnyProtocolExtension {
fn from(value: TWispServerProtocolExtension) -> Self {
AnyProtocolExtension::new(value)
}
}
pub struct TWispServerProtocolExtensionBuilder(TwispMap); pub struct TWispServerProtocolExtensionBuilder(TwispMap);
impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder { impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder {
@ -124,7 +118,7 @@ pub fn new_map() -> TwispMap {
} }
pub fn new_ext(map: TwispMap) -> AnyProtocolExtensionBuilder { pub fn new_ext(map: TwispMap) -> AnyProtocolExtensionBuilder {
AnyProtocolExtensionBuilder::new(TWispServerProtocolExtensionBuilder(map)) TWispServerProtocolExtensionBuilder(map).into()
} }
pub async fn handle_twisp( pub async fn handle_twisp(

View file

@ -139,7 +139,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let split: Vec<_> = auth.split(':').collect(); let split: Vec<_> = auth.split(':').collect();
let username = split[0].to_string(); let username = split[0].to_string();
let password = split[1..].join(":"); let password = split[1..].join(":");
PasswordProtocolExtensionBuilder::new_client(username, password) PasswordProtocolExtensionBuilder::new_client(Some((username, password)))
}); });
println!( println!(
@ -188,7 +188,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
} }
if let Some(certauth) = opts.certauth { if let Some(certauth) = opts.certauth {
let key = get_cert(certauth).await?; let key = get_cert(certauth).await?;
let extension = CertAuthProtocolExtensionBuilder::new_client(key); let extension = CertAuthProtocolExtensionBuilder::new_client(Some(key));
extensions.push(AnyProtocolExtensionBuilder::new(extension)); extensions.push(AnyProtocolExtensionBuilder::new(extension));
extension_ids.push(CertAuthProtocolExtension::ID); extension_ids.push(CertAuthProtocolExtension::ID);
} }

View file

@ -116,6 +116,8 @@ pub enum CertAuthProtocolExtension {
cert_types: SupportedCertificateTypes, cert_types: SupportedCertificateTypes,
/// Random challenge for the client. /// Random challenge for the client.
challenge: Bytes, challenge: Bytes,
/// Whether the server requires this authentication method.
required: bool,
}, },
/// Client variant of certificate authentication protocol extension. /// Client variant of certificate authentication protocol extension.
Client { Client {
@ -126,8 +128,8 @@ pub enum CertAuthProtocolExtension {
/// Signature of challenge. /// Signature of challenge.
signature: Bytes, signature: Bytes,
}, },
/// Marker that client has successfully signed the challenge. /// Marker that client has successfully recieved the challenge.
ClientSigned, ClientRecieved,
/// Marker that server has successfully verified the client. /// Marker that server has successfully verified the client.
ServerVerified, ServerVerified,
} }
@ -155,8 +157,10 @@ impl ProtocolExtension for CertAuthProtocolExtension {
Self::Server { Self::Server {
cert_types, cert_types,
challenge, challenge,
required,
} => { } => {
let mut out = BytesMut::with_capacity(1 + challenge.len()); let mut out = BytesMut::with_capacity(2 + challenge.len());
out.put_u8(*required as u8);
out.put_u8(cert_types.bits()); out.put_u8(cert_types.bits());
out.extend_from_slice(challenge); out.extend_from_slice(challenge);
out.freeze() out.freeze()
@ -172,7 +176,7 @@ impl ProtocolExtension for CertAuthProtocolExtension {
out.extend_from_slice(signature); out.extend_from_slice(signature);
out.freeze() out.freeze()
} }
Self::ClientSigned => Bytes::new(), Self::ClientRecieved => Bytes::new(),
Self::ServerVerified => Bytes::new(), Self::ServerVerified => Bytes::new(),
} }
} }
@ -199,12 +203,6 @@ impl ProtocolExtension for CertAuthProtocolExtension {
} }
} }
impl From<CertAuthProtocolExtension> for AnyProtocolExtension {
fn from(value: CertAuthProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value))
}
}
/// Certificate authentication protocol extension builder. /// Certificate authentication protocol extension builder.
pub enum CertAuthProtocolExtensionBuilder { pub enum CertAuthProtocolExtensionBuilder {
/// Server variant of certificate authentication protocol extension before the challenge has /// Server variant of certificate authentication protocol extension before the challenge has
@ -212,6 +210,8 @@ pub enum CertAuthProtocolExtensionBuilder {
ServerBeforeChallenge { ServerBeforeChallenge {
/// Keypair verifiers. /// Keypair verifiers.
verifiers: Vec<VerifyKey>, verifiers: Vec<VerifyKey>,
/// Whether the server requires this authentication method.
required: bool,
}, },
/// Server variant of certificate authentication protocol extension after the challenge has /// Server variant of certificate authentication protocol extension after the challenge has
/// been sent. /// been sent.
@ -220,33 +220,63 @@ pub enum CertAuthProtocolExtensionBuilder {
verifiers: Vec<VerifyKey>, verifiers: Vec<VerifyKey>,
/// Challenge to verify against. /// Challenge to verify against.
challenge: Bytes, challenge: Bytes,
/// Whether the server requires this authentication method.
required: bool,
}, },
/// Client variant of certificate authentication protocol extension before the challenge has /// Client variant of certificate authentication protocol extension before the challenge has
/// been recieved. /// been recieved.
ClientBeforeChallenge { ClientBeforeChallenge {
/// Keypair signer. /// Keypair signer.
signer: SigningKey, signer: Option<SigningKey>,
}, },
/// Client variant of certificate authentication protocol extension after the challenge has /// Client variant of certificate authentication protocol extension after the challenge has
/// been recieved. /// been recieved.
ClientAfterChallenge { ClientAfterChallenge {
/// Keypair signer. /// Keypair signer.
signer: SigningKey, signer: Option<SigningKey>,
/// Signature of challenge recieved from the server. /// Supported certificate types recieved from the server.
signature: Bytes, cert_types: SupportedCertificateTypes,
/// Challenge recieved from the server.
challenge: Bytes,
/// Whether the server requires this authentication method.
required: bool,
}, },
} }
impl CertAuthProtocolExtensionBuilder { impl CertAuthProtocolExtensionBuilder {
/// Create a new server variant of the certificate authentication protocol extension. /// Create a new server variant of the certificate authentication protocol extension.
pub fn new_server(verifiers: Vec<VerifyKey>) -> Self { pub fn new_server(verifiers: Vec<VerifyKey>, required: bool) -> Self {
Self::ServerBeforeChallenge { verifiers } Self::ServerBeforeChallenge {
verifiers,
required,
}
} }
/// Create a new client variant of the certificate authentication protocol extension. /// Create a new client variant of the certificate authentication protocol extension.
pub fn new_client(signer: SigningKey) -> Self { pub fn new_client(signer: Option<SigningKey>) -> Self {
Self::ClientBeforeChallenge { signer } Self::ClientBeforeChallenge { signer }
} }
/// Get whether this authentication method is required. Could return None if the server has not
/// sent the certificate authentication protocol extension.
pub fn is_required(&self) -> Option<bool> {
match self {
Self::ServerBeforeChallenge { required, .. } => Some(*required),
Self::ServerAfterChallenge { required, .. } => Some(*required),
Self::ClientBeforeChallenge { .. } => None,
Self::ClientAfterChallenge { required, .. } => Some(*required),
}
}
/// Set the credentials sent to the server, if this is a client variant.
pub fn set_signing_key(&mut self, key: SigningKey) {
match self {
Self::ClientBeforeChallenge { signer } | Self::ClientAfterChallenge { signer, .. } => {
*signer = Some(key);
}
Self::ServerBeforeChallenge { .. } | Self::ServerAfterChallenge { .. } => {}
}
}
} }
#[async_trait] #[async_trait]
@ -268,6 +298,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
Self::ServerAfterChallenge { Self::ServerAfterChallenge {
verifiers, verifiers,
challenge, challenge,
..
} => { } => {
// validate and parse response // validate and parse response
let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8()) let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8())
@ -286,26 +317,19 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
} }
} }
Self::ClientBeforeChallenge { signer } => { Self::ClientBeforeChallenge { signer } => {
let required = bytes.get_u8() != 0;
// sign challenge // sign challenge
let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8()) let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8())
.ok_or(WispError::CertAuthExtensionSigInvalid)?; .ok_or(WispError::CertAuthExtensionCertTypeInvalid)?;
if !cert_types.iter().any(|x| x == signer.cert_type) {
return Err(WispError::CertAuthExtensionSigInvalid);
}
let signed: Bytes = signer
.signer
.try_sign(&bytes)
.map_err(CertAuthError::from)?
.to_vec()
.into();
*self = Self::ClientAfterChallenge { *self = Self::ClientAfterChallenge {
signer: signer.clone(), signer: signer.clone(),
signature: signed, cert_types,
challenge: bytes,
required,
}; };
Ok(CertAuthProtocolExtension::ClientSigned.into()) Ok(CertAuthProtocolExtension::ClientRecieved.into())
} }
// client has already recieved a challenge // client has already recieved a challenge
Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
@ -316,19 +340,26 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
// server: 1 // server: 1
fn build_to_extension(&mut self, _: Role) -> Result<AnyProtocolExtension, WispError> { fn build_to_extension(&mut self, _: Role) -> Result<AnyProtocolExtension, WispError> {
match self { match self {
Self::ServerBeforeChallenge { verifiers } => { Self::ServerBeforeChallenge {
verifiers,
required,
} => {
let mut challenge = [0u8; 64]; let mut challenge = [0u8; 64];
getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?; getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?;
let challenge = Bytes::from(challenge.to_vec()); let challenge = Bytes::from(challenge.to_vec());
let required = *required;
*self = Self::ServerAfterChallenge { *self = Self::ServerAfterChallenge {
verifiers: verifiers.to_vec(), verifiers: verifiers.to_vec(),
challenge: challenge.clone(), challenge: challenge.clone(),
required,
}; };
Ok(CertAuthProtocolExtension::Server { Ok(CertAuthProtocolExtension::Server {
cert_types: SupportedCertificateTypes::Ed25519, cert_types: SupportedCertificateTypes::Ed25519,
challenge, challenge,
required,
} }
.into()) .into())
} }
@ -336,7 +367,24 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
// client needs to recieve a challenge // client needs to recieve a challenge
Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported), Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ClientAfterChallenge { signer, signature } => { Self::ClientAfterChallenge {
signer,
challenge,
cert_types,
..
} => {
let signer = signer.as_ref().ok_or(WispError::CertAuthExtensionNoKey)?;
if !cert_types.iter().any(|x| x == signer.cert_type) {
return Err(WispError::CertAuthExtensionCertTypeInvalid);
}
let signature: Bytes = signer
.signer
.try_sign(challenge)
.map_err(CertAuthError::from)?
.to_vec()
.into();
Ok(CertAuthProtocolExtension::Client { Ok(CertAuthProtocolExtension::Client {
cert_type: signer.cert_type, cert_type: signer.cert_type,
hash: signer.hash, hash: signer.hash,

View file

@ -74,6 +74,12 @@ impl From<AnyProtocolExtension> for Bytes {
} }
} }
impl<T: ProtocolExtension> From<T> for AnyProtocolExtension {
fn from(value: T) -> Self {
Self::new(value)
}
}
/// A Wisp protocol extension. /// A Wisp protocol extension.
/// ///
/// See [the /// See [the
@ -162,6 +168,8 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
fn get_id(&self) -> u8; fn get_id(&self) -> u8;
/// Build a protocol extension from the extension's metadata. /// Build a protocol extension from the extension's metadata.
///
/// This is called second on the server and first on the client.
fn build_from_bytes( fn build_from_bytes(
&mut self, &mut self,
bytes: Bytes, bytes: Bytes,
@ -169,6 +177,8 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
) -> Result<AnyProtocolExtension, WispError>; ) -> Result<AnyProtocolExtension, WispError>;
/// Build a protocol extension to send to the other side. /// Build a protocol extension to send to the other side.
///
/// This is called first on the server and second on the client.
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>; fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
/// Do not override. /// Do not override.
@ -248,3 +258,9 @@ impl DerefMut for AnyProtocolExtensionBuilder {
self.0.deref_mut() self.0.deref_mut()
} }
} }
impl<T: ProtocolExtensionBuilder> From<T> for AnyProtocolExtensionBuilder {
fn from(value: T) -> Self {
Self::new(value)
}
}

View file

@ -68,12 +68,6 @@ impl ProtocolExtension for MotdProtocolExtension {
} }
} }
impl From<MotdProtocolExtension> for AnyProtocolExtension {
fn from(value: MotdProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value))
}
}
/// MOTD protocol extension builder. /// MOTD protocol extension builder.
pub enum MotdProtocolExtensionBuilder { pub enum MotdProtocolExtensionBuilder {
/// Server variant of MOTD protocol extension builder. Has the MOTD. /// Server variant of MOTD protocol extension builder. Has the MOTD.

View file

@ -1,9 +1,8 @@
//! Password protocol extension. //! Password protocol extension.
//! //!
//! Passwords are sent in plain text!! //! **Passwords are sent in plain text!!**
//! //!
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication) //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
use std::collections::HashMap; use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
@ -16,56 +15,53 @@ use crate::{
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
#[derive(Debug, Clone)] /// ID of password protocol extension.
pub const PASSWORD_PROTOCOL_EXTENSION_ID: u8 = 0x02;
/// Password protocol extension. /// Password protocol extension.
/// ///
/// **Passwords are sent in plain text!!** /// **Passwords are sent in plain text!!**
/// **This extension will panic when encoding if the username's length does not fit within a u8 ///
/// or the password's length does not fit within a u16.** /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
pub struct PasswordProtocolExtension { #[derive(Debug, Clone)]
/// The username to log in with. pub enum PasswordProtocolExtension {
/// /// Password protocol extension before the client INFO packet has been received.
/// This string's length must fit within a u8. ServerBeforeClientInfo {
pub username: String, /// Whether this authentication method is required.
/// The password to log in with. required: bool,
/// },
/// This string's length must fit within a u16. /// Password protocol extension after the client INFO packet has been received.
pub password: String, ServerAfterClientInfo {
role: Role, /// The client's chosen user.
chosen_user: String,
/// The client's chosen password
chosen_password: String,
},
/// Password protocol extension before the server INFO has been received.
ClientBeforeServerInfo,
/// Password protocol extension after the server INFO has been received.
ClientAfterServerInfo {
/// The user to send to the server.
user: String,
/// The password to send to the user.
password: String,
},
} }
impl PasswordProtocolExtension { impl PasswordProtocolExtension {
/// Password protocol extension ID. /// ID of password protocol extension.
pub const ID: u8 = 0x02; pub const ID: u8 = PASSWORD_PROTOCOL_EXTENSION_ID;
/// Create a new password protocol extension for the server.
///
/// This signifies that the server requires a password.
pub fn new_server() -> Self {
Self {
username: String::new(),
password: String::new(),
role: Role::Server,
}
}
/// Create a new password protocol extension for the client, with a username and password.
///
/// The username's length must fit within a u8. The password's length must fit within a
/// u16.
pub fn new_client(username: String, password: String) -> Self {
Self {
username,
password,
role: Role::Client,
}
}
} }
#[async_trait] #[async_trait]
impl ProtocolExtension for PasswordProtocolExtension { impl ProtocolExtension for PasswordProtocolExtension {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
Self::ID PASSWORD_PROTOCOL_EXTENSION_ID
}
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(self.clone())
} }
fn get_supported_packets(&self) -> &'static [u8] { fn get_supported_packets(&self) -> &'static [u8] {
@ -77,142 +73,179 @@ impl ProtocolExtension for PasswordProtocolExtension {
} }
fn encode(&self) -> Bytes { fn encode(&self) -> Bytes {
match self.role { match self {
Role::Server => Bytes::new(), Self::ServerBeforeClientInfo { required } => {
Role::Client => { let mut out = BytesMut::with_capacity(1);
let username = Bytes::from(self.username.clone().into_bytes()); out.put_u8(*required as u8);
let password = Bytes::from(self.password.clone().into_bytes()); out.freeze()
let username_len = u8::try_from(username.len()).expect("username was too long"); }
let password_len = u16::try_from(password.len()).expect("password was too long"); Self::ServerAfterClientInfo { .. } => Bytes::new(),
Self::ClientBeforeServerInfo => Bytes::new(),
let mut bytes = Self::ClientAfterServerInfo { user, password } => {
BytesMut::with_capacity(3 + username_len as usize + password_len as usize); let mut out = BytesMut::with_capacity(1 + 2 + user.len() + password.len());
bytes.put_u8(username_len); out.put_u8(user.len().try_into().unwrap());
bytes.put_u16_le(password_len); out.put_u16_le(password.len().try_into().unwrap());
bytes.extend(username); out.extend_from_slice(user.as_bytes());
bytes.extend(password); out.extend_from_slice(password.as_bytes());
bytes.freeze() out.freeze()
} }
} }
} }
async fn handle_handshake( async fn handle_handshake(
&mut self, &mut self,
_: &mut dyn WebSocketRead, _read: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite, _write: &LockedWebSocketWrite,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
Ok(()) Ok(())
} }
async fn handle_packet( async fn handle_packet(
&mut self, &mut self,
_: Bytes, _packet: Bytes,
_: &mut dyn WebSocketRead, _read: &mut dyn WebSocketRead,
_: &LockedWebSocketWrite, _write: &LockedWebSocketWrite,
) -> Result<(), WispError> { ) -> Result<(), WispError> {
Ok(()) Err(WispError::ExtensionImplNotSupported)
}
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
Box::new(self.clone())
}
}
impl From<PasswordProtocolExtension> for AnyProtocolExtension {
fn from(value: PasswordProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value))
} }
} }
/// Password protocol extension builder. /// Password protocol extension builder.
/// ///
/// **Passwords are sent in plain text!!** /// **Passwords are sent in plain text!!**
pub struct PasswordProtocolExtensionBuilder { ///
/// Map of users and their passwords to allow. Only used on server. /// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
pub users: HashMap<String, String>, pub enum PasswordProtocolExtensionBuilder {
/// Username to authenticate with. Only used on client. /// Password protocol extension builder before the client INFO has been received.
pub username: String, ServerBeforeClientInfo {
/// Password to authenticate with. Only used on client. /// The user+password combinations to verify the client with.
pub password: String, users: HashMap<String, String>,
/// Whether this authentication method is required.
required: bool,
},
/// Password protocol extension builder after the client INFO has been received.
ServerAfterClientInfo {
/// The user+password combinations to verify the client with.
users: HashMap<String, String>,
/// Whether this authentication method is required.
required: bool,
},
/// Password protocol extension builder before the server INFO has been received.
ClientBeforeServerInfo {
/// The credentials to send to the server.
creds: Option<(String, String)>,
},
/// Password protocol extension builder after the server INFO has been received.
ClientAfterServerInfo {
/// The credentials to send to the server.
creds: Option<(String, String)>,
/// Whether this authentication method is required.
required: bool,
},
} }
impl PasswordProtocolExtensionBuilder { impl PasswordProtocolExtensionBuilder {
/// Create a new password protocol extension builder for the server, with a map of users /// ID of password protocol extension.
/// and passwords to allow. pub const ID: u8 = PASSWORD_PROTOCOL_EXTENSION_ID;
pub fn new_server(users: HashMap<String, String>) -> Self {
Self { /// Create a new server variant of the password protocol extension.
users, pub fn new_server(users: HashMap<String, String>, required: bool) -> Self {
username: String::new(), Self::ServerBeforeClientInfo { users, required }
password: String::new(), }
/// Create a new client variant of the password protocol extension with a username and password.
pub fn new_client(creds: Option<(String, String)>) -> Self {
Self::ClientBeforeServerInfo { creds }
}
/// Get whether this authentication method is required. Could return None if the server has not
/// sent the password protocol extension.
pub fn is_required(&self) -> Option<bool> {
match self {
Self::ServerBeforeClientInfo { required, .. } => Some(*required),
Self::ServerAfterClientInfo { required, .. } => Some(*required),
Self::ClientBeforeServerInfo { .. } => None,
Self::ClientAfterServerInfo { required, .. } => Some(*required),
} }
} }
/// Create a new password protocol extension builder for the client, with a username and /// Set the credentials sent to the server, if this is a client variant.
/// password to authenticate with. pub fn set_creds(&mut self, credentials: (String, String)) {
pub fn new_client(username: String, password: String) -> Self { match self {
Self { Self::ClientBeforeServerInfo { creds } | Self::ClientAfterServerInfo { creds, .. } => {
users: HashMap::new(), *creds = Some(credentials);
username, }
password, Self::ServerBeforeClientInfo { .. } | Self::ServerAfterClientInfo { .. } => {}
} }
} }
} }
impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
fn get_id(&self) -> u8 { fn get_id(&self) -> u8 {
PasswordProtocolExtension::ID PASSWORD_PROTOCOL_EXTENSION_ID
}
fn build_to_extension(&mut self, _role: Role) -> Result<AnyProtocolExtension, WispError> {
match self {
Self::ServerBeforeClientInfo { users: _, required } => {
Ok(PasswordProtocolExtension::ServerBeforeClientInfo {
required: *required,
}
.into())
}
Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ClientBeforeServerInfo { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ClientAfterServerInfo { creds, .. } => {
let (user, password) = creds.clone().ok_or(WispError::PasswordExtensionNoCreds)?;
Ok(PasswordProtocolExtension::ClientAfterServerInfo { user, password }.into())
}
}
} }
fn build_from_bytes( fn build_from_bytes(
&mut self, &mut self,
mut payload: Bytes, mut bytes: Bytes,
role: crate::Role, _role: Role,
) -> Result<AnyProtocolExtension, WispError> { ) -> Result<AnyProtocolExtension, WispError> {
match role { match self {
Role::Server => { Self::ServerBeforeClientInfo { users, required } => {
if payload.remaining() < 3 { let user_len = bytes.get_u8();
return Err(WispError::PacketTooSmall); let password_len = bytes.get_u16_le();
}
let username_len = payload.get_u8(); let user = std::str::from_utf8(&bytes.split_to(user_len as usize))?.to_string();
let password_len = payload.get_u16_le();
if payload.remaining() < (password_len + username_len as u16) as usize {
return Err(WispError::PacketTooSmall);
}
let username =
std::str::from_utf8(&payload.split_to(username_len as usize))?.to_string();
let password = let password =
std::str::from_utf8(&payload.split_to(password_len as usize))?.to_string(); std::str::from_utf8(&bytes.split_to(password_len as usize))?.to_string();
let Some(user) = self.users.iter().find(|x| *x.0 == username) else { let valid = users.get(&user).map(|x| *x == password).unwrap_or(false);
return Err(WispError::PasswordExtensionCredsInvalid);
*self = Self::ServerAfterClientInfo {
users: users.clone(),
required: *required,
}; };
if *user.1 != password { if !valid {
return Err(WispError::PasswordExtensionCredsInvalid); Err(WispError::PasswordExtensionCredsInvalid)
} else {
Ok(PasswordProtocolExtension::ServerAfterClientInfo {
chosen_user: user,
chosen_password: password,
}
.into())
} }
}
Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ClientBeforeServerInfo { creds } => {
let required = bytes.get_u8() != 0;
Ok(PasswordProtocolExtension { *self = Self::ClientAfterServerInfo {
username, creds: creds.clone(),
password, required,
role, };
}
.into())
}
Role::Client => {
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
}
}
}
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError> { Ok(PasswordProtocolExtension::ClientBeforeServerInfo.into())
Ok(match role {
Role::Server => PasswordProtocolExtension::new_server(),
Role::Client => {
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
} }
Self::ClientAfterServerInfo { .. } => Err(WispError::ExtensionImplNotSupported),
} }
.into())
} }
} }

View file

@ -60,12 +60,6 @@ impl ProtocolExtension for UdpProtocolExtension {
} }
} }
impl From<UdpProtocolExtension> for AnyProtocolExtension {
fn from(value: UdpProtocolExtension) -> Self {
AnyProtocolExtension(Box::new(value))
}
}
/// UDP protocol extension builder. /// UDP protocol extension builder.
pub struct UdpProtocolExtensionBuilder; pub struct UdpProtocolExtensionBuilder;

View file

@ -106,10 +106,21 @@ pub enum WispError {
/// The specified protocol extensions are not supported by the other side. /// The specified protocol extensions are not supported by the other side.
#[error("Protocol extensions {0:?} not supported")] #[error("Protocol extensions {0:?} not supported")]
ExtensionsNotSupported(Vec<u8>), ExtensionsNotSupported(Vec<u8>),
/// The password authentication username/password was invalid. /// The password authentication username/password was invalid.
#[error("Password protocol extension: Invalid username/password")] #[error("Password protocol extension: Invalid username/password")]
PasswordExtensionCredsInvalid, PasswordExtensionCredsInvalid,
/// No password authentication username/password was provided.
#[error("Password protocol extension: No username/password provided")]
PasswordExtensionNoCreds,
/// The certificate authentication certificate type was unsupported.
#[error("Certificate authentication protocol extension: Invalid certificate type")]
CertAuthExtensionCertTypeInvalid,
/// The certificate authentication signature was invalid. /// The certificate authentication signature was invalid.
#[error("Certificate authentication protocol extension: Invalid signature")] #[error("Certificate authentication protocol extension: Invalid signature")]
CertAuthExtensionSigInvalid, CertAuthExtensionSigInvalid,
/// No certificate authentication signing key was provided.
#[error("Password protocol extension: No signing key provided")]
CertAuthExtensionNoKey,
} }

View file

@ -124,11 +124,10 @@ where
} }
type WispV2ClosureResult = Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>; type WispV2ClosureResult = Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>;
type WispV2ClosureBuilders<'a> = &'a mut [AnyProtocolExtensionBuilder];
/// Wisp V2 handshake and protocol extension settings wrapper struct. /// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions { pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>, builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<dyn Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send>, closure: Box<dyn Fn(&mut Vec<AnyProtocolExtensionBuilder>) -> WispV2ClosureResult + Send>,
} }
impl WispV2Extensions { impl WispV2Extensions {
@ -143,7 +142,7 @@ impl WispV2Extensions {
/// Create a Wisp V2 settings struct with some middleware. /// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where where
C: Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send + 'static, C: Fn(&mut Vec<AnyProtocolExtensionBuilder>) -> WispV2ClosureResult + Send + 'static,
{ {
Self { Self {
builders, builders,