mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
refactor protocol extensions
This commit is contained in:
parent
1ae3986a82
commit
36fddc8943
10 changed files with 289 additions and 190 deletions
|
@ -162,9 +162,11 @@ pub struct WispConfig {
|
|||
#[serde(skip_serializing_if = "HashMap::is_empty")]
|
||||
/// Wisp version 2 password authentication extension username/passwords.
|
||||
pub password_extension_users: HashMap<String, String>,
|
||||
pub password_extension_required: bool,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
/// Wisp version 2 certificate authentication extension public ed25519 pem keys.
|
||||
pub certificate_extension_keys: Vec<PathBuf>,
|
||||
pub certificate_extension_required: bool,
|
||||
|
||||
#[serde(skip_serializing_if = "is_default_motd")]
|
||||
/// Wisp version 2 MOTD extension message.
|
||||
|
@ -334,7 +336,9 @@ impl Default for WispConfig {
|
|||
auth_extension: None,
|
||||
|
||||
password_extension_users: HashMap::new(),
|
||||
password_extension_required: true,
|
||||
certificate_extension_keys: Vec::new(),
|
||||
certificate_extension_required: true,
|
||||
|
||||
motd_extension: default_motd(),
|
||||
}
|
||||
|
@ -364,18 +368,24 @@ impl WispConfig {
|
|||
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||
PasswordProtocolExtensionBuilder::new_server(
|
||||
self.password_extension_users.clone(),
|
||||
self.password_extension_required,
|
||||
),
|
||||
));
|
||||
required_extensions.push(PasswordProtocolExtension::ID);
|
||||
if self.password_extension_required {
|
||||
required_extensions.push(PasswordProtocolExtension::ID);
|
||||
}
|
||||
}
|
||||
Some(ProtocolExtensionAuth::Certificate) => {
|
||||
extensions.push(AnyProtocolExtensionBuilder::new(
|
||||
CertAuthProtocolExtensionBuilder::new_server(
|
||||
get_certificates_from_paths(self.certificate_extension_keys.clone())
|
||||
.await?,
|
||||
self.certificate_extension_required,
|
||||
),
|
||||
));
|
||||
required_extensions.push(CertAuthProtocolExtension::ID);
|
||||
if self.certificate_extension_required {
|
||||
required_extensions.push(CertAuthProtocolExtension::ID);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
|
|
@ -80,12 +80,6 @@ impl ProtocolExtension for TWispServerProtocolExtension {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<TWispServerProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: TWispServerProtocolExtension) -> Self {
|
||||
AnyProtocolExtension::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TWispServerProtocolExtensionBuilder(TwispMap);
|
||||
|
||||
impl ProtocolExtensionBuilder for TWispServerProtocolExtensionBuilder {
|
||||
|
@ -124,7 +118,7 @@ pub fn new_map() -> TwispMap {
|
|||
}
|
||||
|
||||
pub fn new_ext(map: TwispMap) -> AnyProtocolExtensionBuilder {
|
||||
AnyProtocolExtensionBuilder::new(TWispServerProtocolExtensionBuilder(map))
|
||||
TWispServerProtocolExtensionBuilder(map).into()
|
||||
}
|
||||
|
||||
pub async fn handle_twisp(
|
||||
|
|
|
@ -139,7 +139,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
let split: Vec<_> = auth.split(':').collect();
|
||||
let username = split[0].to_string();
|
||||
let password = split[1..].join(":");
|
||||
PasswordProtocolExtensionBuilder::new_client(username, password)
|
||||
PasswordProtocolExtensionBuilder::new_client(Some((username, password)))
|
||||
});
|
||||
|
||||
println!(
|
||||
|
@ -188,7 +188,7 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
|||
}
|
||||
if let Some(certauth) = opts.certauth {
|
||||
let key = get_cert(certauth).await?;
|
||||
let extension = CertAuthProtocolExtensionBuilder::new_client(key);
|
||||
let extension = CertAuthProtocolExtensionBuilder::new_client(Some(key));
|
||||
extensions.push(AnyProtocolExtensionBuilder::new(extension));
|
||||
extension_ids.push(CertAuthProtocolExtension::ID);
|
||||
}
|
||||
|
|
|
@ -116,6 +116,8 @@ pub enum CertAuthProtocolExtension {
|
|||
cert_types: SupportedCertificateTypes,
|
||||
/// Random challenge for the client.
|
||||
challenge: Bytes,
|
||||
/// Whether the server requires this authentication method.
|
||||
required: bool,
|
||||
},
|
||||
/// Client variant of certificate authentication protocol extension.
|
||||
Client {
|
||||
|
@ -126,8 +128,8 @@ pub enum CertAuthProtocolExtension {
|
|||
/// Signature of challenge.
|
||||
signature: Bytes,
|
||||
},
|
||||
/// Marker that client has successfully signed the challenge.
|
||||
ClientSigned,
|
||||
/// Marker that client has successfully recieved the challenge.
|
||||
ClientRecieved,
|
||||
/// Marker that server has successfully verified the client.
|
||||
ServerVerified,
|
||||
}
|
||||
|
@ -155,8 +157,10 @@ impl ProtocolExtension for CertAuthProtocolExtension {
|
|||
Self::Server {
|
||||
cert_types,
|
||||
challenge,
|
||||
required,
|
||||
} => {
|
||||
let mut out = BytesMut::with_capacity(1 + challenge.len());
|
||||
let mut out = BytesMut::with_capacity(2 + challenge.len());
|
||||
out.put_u8(*required as u8);
|
||||
out.put_u8(cert_types.bits());
|
||||
out.extend_from_slice(challenge);
|
||||
out.freeze()
|
||||
|
@ -172,7 +176,7 @@ impl ProtocolExtension for CertAuthProtocolExtension {
|
|||
out.extend_from_slice(signature);
|
||||
out.freeze()
|
||||
}
|
||||
Self::ClientSigned => Bytes::new(),
|
||||
Self::ClientRecieved => Bytes::new(),
|
||||
Self::ServerVerified => Bytes::new(),
|
||||
}
|
||||
}
|
||||
|
@ -199,12 +203,6 @@ impl ProtocolExtension for CertAuthProtocolExtension {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<CertAuthProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: CertAuthProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// Certificate authentication protocol extension builder.
|
||||
pub enum CertAuthProtocolExtensionBuilder {
|
||||
/// Server variant of certificate authentication protocol extension before the challenge has
|
||||
|
@ -212,6 +210,8 @@ pub enum CertAuthProtocolExtensionBuilder {
|
|||
ServerBeforeChallenge {
|
||||
/// Keypair verifiers.
|
||||
verifiers: Vec<VerifyKey>,
|
||||
/// Whether the server requires this authentication method.
|
||||
required: bool,
|
||||
},
|
||||
/// Server variant of certificate authentication protocol extension after the challenge has
|
||||
/// been sent.
|
||||
|
@ -220,33 +220,63 @@ pub enum CertAuthProtocolExtensionBuilder {
|
|||
verifiers: Vec<VerifyKey>,
|
||||
/// Challenge to verify against.
|
||||
challenge: Bytes,
|
||||
/// Whether the server requires this authentication method.
|
||||
required: bool,
|
||||
},
|
||||
/// Client variant of certificate authentication protocol extension before the challenge has
|
||||
/// been recieved.
|
||||
ClientBeforeChallenge {
|
||||
/// Keypair signer.
|
||||
signer: SigningKey,
|
||||
signer: Option<SigningKey>,
|
||||
},
|
||||
/// Client variant of certificate authentication protocol extension after the challenge has
|
||||
/// been recieved.
|
||||
ClientAfterChallenge {
|
||||
/// Keypair signer.
|
||||
signer: SigningKey,
|
||||
/// Signature of challenge recieved from the server.
|
||||
signature: Bytes,
|
||||
signer: Option<SigningKey>,
|
||||
/// Supported certificate types recieved from the server.
|
||||
cert_types: SupportedCertificateTypes,
|
||||
/// Challenge recieved from the server.
|
||||
challenge: Bytes,
|
||||
/// Whether the server requires this authentication method.
|
||||
required: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl CertAuthProtocolExtensionBuilder {
|
||||
/// Create a new server variant of the certificate authentication protocol extension.
|
||||
pub fn new_server(verifiers: Vec<VerifyKey>) -> Self {
|
||||
Self::ServerBeforeChallenge { verifiers }
|
||||
pub fn new_server(verifiers: Vec<VerifyKey>, required: bool) -> Self {
|
||||
Self::ServerBeforeChallenge {
|
||||
verifiers,
|
||||
required,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client variant of the certificate authentication protocol extension.
|
||||
pub fn new_client(signer: SigningKey) -> Self {
|
||||
pub fn new_client(signer: Option<SigningKey>) -> Self {
|
||||
Self::ClientBeforeChallenge { signer }
|
||||
}
|
||||
|
||||
/// Get whether this authentication method is required. Could return None if the server has not
|
||||
/// sent the certificate authentication protocol extension.
|
||||
pub fn is_required(&self) -> Option<bool> {
|
||||
match self {
|
||||
Self::ServerBeforeChallenge { required, .. } => Some(*required),
|
||||
Self::ServerAfterChallenge { required, .. } => Some(*required),
|
||||
Self::ClientBeforeChallenge { .. } => None,
|
||||
Self::ClientAfterChallenge { required, .. } => Some(*required),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the credentials sent to the server, if this is a client variant.
|
||||
pub fn set_signing_key(&mut self, key: SigningKey) {
|
||||
match self {
|
||||
Self::ClientBeforeChallenge { signer } | Self::ClientAfterChallenge { signer, .. } => {
|
||||
*signer = Some(key);
|
||||
}
|
||||
Self::ServerBeforeChallenge { .. } | Self::ServerAfterChallenge { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -268,6 +298,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
|
|||
Self::ServerAfterChallenge {
|
||||
verifiers,
|
||||
challenge,
|
||||
..
|
||||
} => {
|
||||
// validate and parse response
|
||||
let cert_type = SupportedCertificateTypes::from_bits(bytes.get_u8())
|
||||
|
@ -286,26 +317,19 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
|
|||
}
|
||||
}
|
||||
Self::ClientBeforeChallenge { signer } => {
|
||||
let required = bytes.get_u8() != 0;
|
||||
// sign challenge
|
||||
let cert_types = SupportedCertificateTypes::from_bits(bytes.get_u8())
|
||||
.ok_or(WispError::CertAuthExtensionSigInvalid)?;
|
||||
if !cert_types.iter().any(|x| x == signer.cert_type) {
|
||||
return Err(WispError::CertAuthExtensionSigInvalid);
|
||||
}
|
||||
|
||||
let signed: Bytes = signer
|
||||
.signer
|
||||
.try_sign(&bytes)
|
||||
.map_err(CertAuthError::from)?
|
||||
.to_vec()
|
||||
.into();
|
||||
.ok_or(WispError::CertAuthExtensionCertTypeInvalid)?;
|
||||
|
||||
*self = Self::ClientAfterChallenge {
|
||||
signer: signer.clone(),
|
||||
signature: signed,
|
||||
cert_types,
|
||||
challenge: bytes,
|
||||
required,
|
||||
};
|
||||
|
||||
Ok(CertAuthProtocolExtension::ClientSigned.into())
|
||||
Ok(CertAuthProtocolExtension::ClientRecieved.into())
|
||||
}
|
||||
// client has already recieved a challenge
|
||||
Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
|
@ -316,19 +340,26 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
|
|||
// server: 1
|
||||
fn build_to_extension(&mut self, _: Role) -> Result<AnyProtocolExtension, WispError> {
|
||||
match self {
|
||||
Self::ServerBeforeChallenge { verifiers } => {
|
||||
Self::ServerBeforeChallenge {
|
||||
verifiers,
|
||||
required,
|
||||
} => {
|
||||
let mut challenge = [0u8; 64];
|
||||
getrandom::getrandom(&mut challenge).map_err(CertAuthError::from)?;
|
||||
let challenge = Bytes::from(challenge.to_vec());
|
||||
|
||||
let required = *required;
|
||||
|
||||
*self = Self::ServerAfterChallenge {
|
||||
verifiers: verifiers.to_vec(),
|
||||
challenge: challenge.clone(),
|
||||
required,
|
||||
};
|
||||
|
||||
Ok(CertAuthProtocolExtension::Server {
|
||||
cert_types: SupportedCertificateTypes::Ed25519,
|
||||
challenge,
|
||||
required,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
|
@ -336,7 +367,24 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
|
|||
Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
// client needs to recieve a challenge
|
||||
Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
Self::ClientAfterChallenge { signer, signature } => {
|
||||
Self::ClientAfterChallenge {
|
||||
signer,
|
||||
challenge,
|
||||
cert_types,
|
||||
..
|
||||
} => {
|
||||
let signer = signer.as_ref().ok_or(WispError::CertAuthExtensionNoKey)?;
|
||||
if !cert_types.iter().any(|x| x == signer.cert_type) {
|
||||
return Err(WispError::CertAuthExtensionCertTypeInvalid);
|
||||
}
|
||||
|
||||
let signature: Bytes = signer
|
||||
.signer
|
||||
.try_sign(challenge)
|
||||
.map_err(CertAuthError::from)?
|
||||
.to_vec()
|
||||
.into();
|
||||
|
||||
Ok(CertAuthProtocolExtension::Client {
|
||||
cert_type: signer.cert_type,
|
||||
hash: signer.hash,
|
||||
|
|
|
@ -74,6 +74,12 @@ impl From<AnyProtocolExtension> for Bytes {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: ProtocolExtension> From<T> for AnyProtocolExtension {
|
||||
fn from(value: T) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Wisp protocol extension.
|
||||
///
|
||||
/// See [the
|
||||
|
@ -162,6 +168,8 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
|
|||
fn get_id(&self) -> u8;
|
||||
|
||||
/// Build a protocol extension from the extension's metadata.
|
||||
///
|
||||
/// This is called second on the server and first on the client.
|
||||
fn build_from_bytes(
|
||||
&mut self,
|
||||
bytes: Bytes,
|
||||
|
@ -169,6 +177,8 @@ pub trait ProtocolExtensionBuilder: Sync + Send + 'static {
|
|||
) -> Result<AnyProtocolExtension, WispError>;
|
||||
|
||||
/// Build a protocol extension to send to the other side.
|
||||
///
|
||||
/// This is called first on the server and second on the client.
|
||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError>;
|
||||
|
||||
/// Do not override.
|
||||
|
@ -248,3 +258,9 @@ impl DerefMut for AnyProtocolExtensionBuilder {
|
|||
self.0.deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ProtocolExtensionBuilder> From<T> for AnyProtocolExtensionBuilder {
|
||||
fn from(value: T) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,12 +68,6 @@ impl ProtocolExtension for MotdProtocolExtension {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<MotdProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: MotdProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// MOTD protocol extension builder.
|
||||
pub enum MotdProtocolExtensionBuilder {
|
||||
/// Server variant of MOTD protocol extension builder. Has the MOTD.
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
//! Password protocol extension.
|
||||
//!
|
||||
//! Passwords are sent in plain text!!
|
||||
//! **Passwords are sent in plain text!!**
|
||||
//!
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
@ -16,56 +15,53 @@ use crate::{
|
|||
|
||||
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// ID of password protocol extension.
|
||||
pub const PASSWORD_PROTOCOL_EXTENSION_ID: u8 = 0x02;
|
||||
|
||||
/// Password protocol extension.
|
||||
///
|
||||
/// **Passwords are sent in plain text!!**
|
||||
/// **This extension will panic when encoding if the username's length does not fit within a u8
|
||||
/// or the password's length does not fit within a u16.**
|
||||
pub struct PasswordProtocolExtension {
|
||||
/// The username to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u8.
|
||||
pub username: String,
|
||||
/// The password to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u16.
|
||||
pub password: String,
|
||||
role: Role,
|
||||
///
|
||||
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PasswordProtocolExtension {
|
||||
/// Password protocol extension before the client INFO packet has been received.
|
||||
ServerBeforeClientInfo {
|
||||
/// Whether this authentication method is required.
|
||||
required: bool,
|
||||
},
|
||||
/// Password protocol extension after the client INFO packet has been received.
|
||||
ServerAfterClientInfo {
|
||||
/// The client's chosen user.
|
||||
chosen_user: String,
|
||||
/// The client's chosen password
|
||||
chosen_password: String,
|
||||
},
|
||||
|
||||
/// Password protocol extension before the server INFO has been received.
|
||||
ClientBeforeServerInfo,
|
||||
/// Password protocol extension after the server INFO has been received.
|
||||
ClientAfterServerInfo {
|
||||
/// The user to send to the server.
|
||||
user: String,
|
||||
/// The password to send to the user.
|
||||
password: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl PasswordProtocolExtension {
|
||||
/// Password protocol extension ID.
|
||||
pub const ID: u8 = 0x02;
|
||||
|
||||
/// Create a new password protocol extension for the server.
|
||||
///
|
||||
/// This signifies that the server requires a password.
|
||||
pub fn new_server() -> Self {
|
||||
Self {
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
role: Role::Server,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new password protocol extension for the client, with a username and password.
|
||||
///
|
||||
/// The username's length must fit within a u8. The password's length must fit within a
|
||||
/// u16.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
username,
|
||||
password,
|
||||
role: Role::Client,
|
||||
}
|
||||
}
|
||||
/// ID of password protocol extension.
|
||||
pub const ID: u8 = PASSWORD_PROTOCOL_EXTENSION_ID;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolExtension for PasswordProtocolExtension {
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
PASSWORD_PROTOCOL_EXTENSION_ID
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
|
@ -77,142 +73,179 @@ impl ProtocolExtension for PasswordProtocolExtension {
|
|||
}
|
||||
|
||||
fn encode(&self) -> Bytes {
|
||||
match self.role {
|
||||
Role::Server => Bytes::new(),
|
||||
Role::Client => {
|
||||
let username = Bytes::from(self.username.clone().into_bytes());
|
||||
let password = Bytes::from(self.password.clone().into_bytes());
|
||||
let username_len = u8::try_from(username.len()).expect("username was too long");
|
||||
let password_len = u16::try_from(password.len()).expect("password was too long");
|
||||
|
||||
let mut bytes =
|
||||
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
||||
bytes.put_u8(username_len);
|
||||
bytes.put_u16_le(password_len);
|
||||
bytes.extend(username);
|
||||
bytes.extend(password);
|
||||
bytes.freeze()
|
||||
match self {
|
||||
Self::ServerBeforeClientInfo { required } => {
|
||||
let mut out = BytesMut::with_capacity(1);
|
||||
out.put_u8(*required as u8);
|
||||
out.freeze()
|
||||
}
|
||||
Self::ServerAfterClientInfo { .. } => Bytes::new(),
|
||||
Self::ClientBeforeServerInfo => Bytes::new(),
|
||||
Self::ClientAfterServerInfo { user, password } => {
|
||||
let mut out = BytesMut::with_capacity(1 + 2 + user.len() + password.len());
|
||||
out.put_u8(user.len().try_into().unwrap());
|
||||
out.put_u16_le(password.len().try_into().unwrap());
|
||||
out.extend_from_slice(user.as_bytes());
|
||||
out.extend_from_slice(password.as_bytes());
|
||||
out.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_read: &mut dyn WebSocketRead,
|
||||
_write: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
_packet: Bytes,
|
||||
_read: &mut dyn WebSocketRead,
|
||||
_write: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PasswordProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: PasswordProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
Err(WispError::ExtensionImplNotSupported)
|
||||
}
|
||||
}
|
||||
|
||||
/// Password protocol extension builder.
|
||||
///
|
||||
/// **Passwords are sent in plain text!!**
|
||||
pub struct PasswordProtocolExtensionBuilder {
|
||||
/// Map of users and their passwords to allow. Only used on server.
|
||||
pub users: HashMap<String, String>,
|
||||
/// Username to authenticate with. Only used on client.
|
||||
pub username: String,
|
||||
/// Password to authenticate with. Only used on client.
|
||||
pub password: String,
|
||||
///
|
||||
/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
|
||||
pub enum PasswordProtocolExtensionBuilder {
|
||||
/// Password protocol extension builder before the client INFO has been received.
|
||||
ServerBeforeClientInfo {
|
||||
/// The user+password combinations to verify the client with.
|
||||
users: HashMap<String, String>,
|
||||
/// Whether this authentication method is required.
|
||||
required: bool,
|
||||
},
|
||||
/// Password protocol extension builder after the client INFO has been received.
|
||||
ServerAfterClientInfo {
|
||||
/// The user+password combinations to verify the client with.
|
||||
users: HashMap<String, String>,
|
||||
/// Whether this authentication method is required.
|
||||
required: bool,
|
||||
},
|
||||
|
||||
/// Password protocol extension builder before the server INFO has been received.
|
||||
ClientBeforeServerInfo {
|
||||
/// The credentials to send to the server.
|
||||
creds: Option<(String, String)>,
|
||||
},
|
||||
/// Password protocol extension builder after the server INFO has been received.
|
||||
ClientAfterServerInfo {
|
||||
/// The credentials to send to the server.
|
||||
creds: Option<(String, String)>,
|
||||
/// Whether this authentication method is required.
|
||||
required: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl PasswordProtocolExtensionBuilder {
|
||||
/// Create a new password protocol extension builder for the server, with a map of users
|
||||
/// and passwords to allow.
|
||||
pub fn new_server(users: HashMap<String, String>) -> Self {
|
||||
Self {
|
||||
users,
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
/// ID of password protocol extension.
|
||||
pub const ID: u8 = PASSWORD_PROTOCOL_EXTENSION_ID;
|
||||
|
||||
/// Create a new server variant of the password protocol extension.
|
||||
pub fn new_server(users: HashMap<String, String>, required: bool) -> Self {
|
||||
Self::ServerBeforeClientInfo { users, required }
|
||||
}
|
||||
|
||||
/// Create a new client variant of the password protocol extension with a username and password.
|
||||
pub fn new_client(creds: Option<(String, String)>) -> Self {
|
||||
Self::ClientBeforeServerInfo { creds }
|
||||
}
|
||||
|
||||
/// Get whether this authentication method is required. Could return None if the server has not
|
||||
/// sent the password protocol extension.
|
||||
pub fn is_required(&self) -> Option<bool> {
|
||||
match self {
|
||||
Self::ServerBeforeClientInfo { required, .. } => Some(*required),
|
||||
Self::ServerAfterClientInfo { required, .. } => Some(*required),
|
||||
Self::ClientBeforeServerInfo { .. } => None,
|
||||
Self::ClientAfterServerInfo { required, .. } => Some(*required),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new password protocol extension builder for the client, with a username and
|
||||
/// password to authenticate with.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
users: HashMap::new(),
|
||||
username,
|
||||
password,
|
||||
/// Set the credentials sent to the server, if this is a client variant.
|
||||
pub fn set_creds(&mut self, credentials: (String, String)) {
|
||||
match self {
|
||||
Self::ClientBeforeServerInfo { creds } | Self::ClientAfterServerInfo { creds, .. } => {
|
||||
*creds = Some(credentials);
|
||||
}
|
||||
Self::ServerBeforeClientInfo { .. } | Self::ServerAfterClientInfo { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
|
||||
fn get_id(&self) -> u8 {
|
||||
PasswordProtocolExtension::ID
|
||||
PASSWORD_PROTOCOL_EXTENSION_ID
|
||||
}
|
||||
|
||||
fn build_to_extension(&mut self, _role: Role) -> Result<AnyProtocolExtension, WispError> {
|
||||
match self {
|
||||
Self::ServerBeforeClientInfo { users: _, required } => {
|
||||
Ok(PasswordProtocolExtension::ServerBeforeClientInfo {
|
||||
required: *required,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
Self::ClientBeforeServerInfo { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
Self::ClientAfterServerInfo { creds, .. } => {
|
||||
let (user, password) = creds.clone().ok_or(WispError::PasswordExtensionNoCreds)?;
|
||||
Ok(PasswordProtocolExtension::ClientAfterServerInfo { user, password }.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_from_bytes(
|
||||
&mut self,
|
||||
mut payload: Bytes,
|
||||
role: crate::Role,
|
||||
mut bytes: Bytes,
|
||||
_role: Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
match role {
|
||||
Role::Server => {
|
||||
if payload.remaining() < 3 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
match self {
|
||||
Self::ServerBeforeClientInfo { users, required } => {
|
||||
let user_len = bytes.get_u8();
|
||||
let password_len = bytes.get_u16_le();
|
||||
|
||||
let username_len = payload.get_u8();
|
||||
let password_len = payload.get_u16_le();
|
||||
if payload.remaining() < (password_len + username_len as u16) as usize {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
|
||||
let username =
|
||||
std::str::from_utf8(&payload.split_to(username_len as usize))?.to_string();
|
||||
let user = std::str::from_utf8(&bytes.split_to(user_len as usize))?.to_string();
|
||||
let password =
|
||||
std::str::from_utf8(&payload.split_to(password_len as usize))?.to_string();
|
||||
std::str::from_utf8(&bytes.split_to(password_len as usize))?.to_string();
|
||||
|
||||
let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
|
||||
return Err(WispError::PasswordExtensionCredsInvalid);
|
||||
let valid = users.get(&user).map(|x| *x == password).unwrap_or(false);
|
||||
|
||||
*self = Self::ServerAfterClientInfo {
|
||||
users: users.clone(),
|
||||
required: *required,
|
||||
};
|
||||
|
||||
if *user.1 != password {
|
||||
return Err(WispError::PasswordExtensionCredsInvalid);
|
||||
if !valid {
|
||||
Err(WispError::PasswordExtensionCredsInvalid)
|
||||
} else {
|
||||
Ok(PasswordProtocolExtension::ServerAfterClientInfo {
|
||||
chosen_user: user,
|
||||
chosen_password: password,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
Self::ClientBeforeServerInfo { creds } => {
|
||||
let required = bytes.get_u8() != 0;
|
||||
|
||||
Ok(PasswordProtocolExtension {
|
||||
username,
|
||||
password,
|
||||
role,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
Role::Client => {
|
||||
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
*self = Self::ClientAfterServerInfo {
|
||||
creds: creds.clone(),
|
||||
required,
|
||||
};
|
||||
|
||||
fn build_to_extension(&mut self, role: Role) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(match role {
|
||||
Role::Server => PasswordProtocolExtension::new_server(),
|
||||
Role::Client => {
|
||||
PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
|
||||
Ok(PasswordProtocolExtension::ClientBeforeServerInfo.into())
|
||||
}
|
||||
Self::ClientAfterServerInfo { .. } => Err(WispError::ExtensionImplNotSupported),
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,12 +60,6 @@ impl ProtocolExtension for UdpProtocolExtension {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<UdpProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: UdpProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// UDP protocol extension builder.
|
||||
pub struct UdpProtocolExtensionBuilder;
|
||||
|
||||
|
|
|
@ -106,10 +106,21 @@ pub enum WispError {
|
|||
/// The specified protocol extensions are not supported by the other side.
|
||||
#[error("Protocol extensions {0:?} not supported")]
|
||||
ExtensionsNotSupported(Vec<u8>),
|
||||
|
||||
/// The password authentication username/password was invalid.
|
||||
#[error("Password protocol extension: Invalid username/password")]
|
||||
PasswordExtensionCredsInvalid,
|
||||
/// No password authentication username/password was provided.
|
||||
#[error("Password protocol extension: No username/password provided")]
|
||||
PasswordExtensionNoCreds,
|
||||
|
||||
/// The certificate authentication certificate type was unsupported.
|
||||
#[error("Certificate authentication protocol extension: Invalid certificate type")]
|
||||
CertAuthExtensionCertTypeInvalid,
|
||||
/// The certificate authentication signature was invalid.
|
||||
#[error("Certificate authentication protocol extension: Invalid signature")]
|
||||
CertAuthExtensionSigInvalid,
|
||||
/// No certificate authentication signing key was provided.
|
||||
#[error("Password protocol extension: No signing key provided")]
|
||||
CertAuthExtensionNoKey,
|
||||
}
|
||||
|
|
|
@ -124,11 +124,10 @@ where
|
|||
}
|
||||
|
||||
type WispV2ClosureResult = Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>;
|
||||
type WispV2ClosureBuilders<'a> = &'a mut [AnyProtocolExtensionBuilder];
|
||||
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
||||
pub struct WispV2Extensions {
|
||||
builders: Vec<AnyProtocolExtensionBuilder>,
|
||||
closure: Box<dyn Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send>,
|
||||
closure: Box<dyn Fn(&mut Vec<AnyProtocolExtensionBuilder>) -> WispV2ClosureResult + Send>,
|
||||
}
|
||||
|
||||
impl WispV2Extensions {
|
||||
|
@ -143,7 +142,7 @@ impl WispV2Extensions {
|
|||
/// Create a Wisp V2 settings struct with some middleware.
|
||||
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
||||
where
|
||||
C: Fn(WispV2ClosureBuilders) -> WispV2ClosureResult + Send + 'static,
|
||||
C: Fn(&mut Vec<AnyProtocolExtensionBuilder>) -> WispV2ClosureResult + Send + 'static,
|
||||
{
|
||||
Self {
|
||||
builders,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue