mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
add password protocol extension, simplify protocol extension api
This commit is contained in:
parent
b0d1038a3c
commit
481128e4f5
3 changed files with 343 additions and 58 deletions
|
@ -88,7 +88,7 @@ pub trait ProtocolExtension: std::fmt::Debug {
|
|||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
|
||||
}
|
||||
|
||||
/// Trait to build a Wisp protocol extension for the client.
|
||||
/// Trait to build a Wisp protocol extension from a payload.
|
||||
pub trait ProtocolExtensionBuilder {
|
||||
/// Get the protocol extension ID.
|
||||
///
|
||||
|
@ -96,7 +96,11 @@ pub trait ProtocolExtensionBuilder {
|
|||
fn get_id(&self) -> u8;
|
||||
|
||||
/// Build a protocol extension from the extension's metadata.
|
||||
fn build(&self, bytes: Bytes, role: Role) -> AnyProtocolExtension;
|
||||
fn build_from_bytes(&self, bytes: Bytes, role: Role)
|
||||
-> Result<AnyProtocolExtension, WispError>;
|
||||
|
||||
/// Build a protocol extension to send to the other side.
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
|
||||
}
|
||||
|
||||
pub mod udp {
|
||||
|
@ -108,7 +112,6 @@ pub mod udp {
|
|||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(vec![UdpProtocolExtension().into()]),
|
||||
//! Some(&[&UdpProtocolExtensionBuilder()])
|
||||
//! );
|
||||
//! ```
|
||||
|
@ -154,7 +157,6 @@ pub mod udp {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle receiving a packet.
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
|
@ -180,11 +182,294 @@ pub mod udp {
|
|||
|
||||
impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder {
|
||||
fn get_id(&self) -> u8 {
|
||||
0x01
|
||||
UdpProtocolExtension::ID
|
||||
}
|
||||
|
||||
fn build(&self, _: Bytes, _: crate::Role) -> AnyProtocolExtension {
|
||||
AnyProtocolExtension(Box::new(UdpProtocolExtension()))
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
_: Bytes,
|
||||
_: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
Ok(UdpProtocolExtension().into())
|
||||
}
|
||||
|
||||
fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension {
|
||||
UdpProtocolExtension().into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod password {
|
||||
//! Password protocol extension.
|
||||
//!
|
||||
//! # Example
|
||||
//! Server:
|
||||
//! ```
|
||||
//! let mut passwords = HashMap::new();
|
||||
//! passwords.insert("user1".to_string(), "pw".to_string());
|
||||
//! let (mux, fut) = ServerMux::new(
|
||||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(&[&PasswordProtocolExtensionBuilder::new_server(passwords)])
|
||||
//! );
|
||||
//! ```
|
||||
//!
|
||||
//! Client:
|
||||
//! ```
|
||||
//! let (mux, fut) = ClientMux::new(
|
||||
//! rx,
|
||||
//! tx,
|
||||
//! 128,
|
||||
//! Some(&[
|
||||
//! &PasswordProtocolExtensionBuilder::new_client(
|
||||
//! "user1".to_string(),
|
||||
//! "pw".to_string()
|
||||
//! )
|
||||
//! ])
|
||||
//! );
|
||||
//! ```
|
||||
//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication)
|
||||
|
||||
use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::{
|
||||
ws::{LockedWebSocketWrite, WebSocketRead},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Password protocol extension.
|
||||
///
|
||||
/// **This extension will panic when encoding if the username's length does not fit within a u8
|
||||
/// or the password's length does not fit within a u16.**
|
||||
pub struct PasswordProtocolExtension {
|
||||
/// The username to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u8.
|
||||
pub username: String,
|
||||
/// The password to log in with.
|
||||
///
|
||||
/// This string's length must fit within a u16.
|
||||
pub password: String,
|
||||
role: Role,
|
||||
}
|
||||
|
||||
impl PasswordProtocolExtension {
|
||||
/// Password protocol extension ID.
|
||||
pub const ID: u8 = 0x02;
|
||||
|
||||
/// Create a new password protocol extension for the server.
|
||||
///
|
||||
/// This signifies that the server requires a password.
|
||||
pub fn new_server() -> Self {
|
||||
Self {
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
role: Role::Server,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new password protocol extension for the client, with a username and password.
|
||||
///
|
||||
/// The username's length must fit within a u8. The password's length must fit within a
|
||||
/// u16.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
username,
|
||||
password,
|
||||
role: Role::Client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProtocolExtension for PasswordProtocolExtension {
|
||||
fn get_id(&self) -> u8 {
|
||||
Self::ID
|
||||
}
|
||||
|
||||
fn get_supported_packets(&self) -> &'static [u8] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn encode(&self) -> Bytes {
|
||||
match self.role {
|
||||
Role::Server => Bytes::new(),
|
||||
Role::Client => {
|
||||
let username = Bytes::from(self.username.clone().into_bytes());
|
||||
let password = Bytes::from(self.password.clone().into_bytes());
|
||||
let username_len = u8::try_from(username.len()).expect("username was too long");
|
||||
let password_len =
|
||||
u16::try_from(username.len()).expect("password was too long");
|
||||
|
||||
let mut bytes =
|
||||
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
||||
bytes.put_u8(username_len);
|
||||
bytes.put_u16_le(password_len);
|
||||
bytes.extend(username);
|
||||
bytes.extend(password);
|
||||
bytes.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_handshake(
|
||||
&mut self,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_packet(
|
||||
&mut self,
|
||||
_: Bytes,
|
||||
_: &mut dyn WebSocketRead,
|
||||
_: &LockedWebSocketWrite,
|
||||
) -> Result<(), WispError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PasswordProtocolExtensionError {
|
||||
Utf8Error(FromUtf8Error),
|
||||
InvalidUsername,
|
||||
InvalidPassword,
|
||||
}
|
||||
|
||||
impl Display for PasswordProtocolExtensionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use PasswordProtocolExtensionError as E;
|
||||
match self {
|
||||
E::Utf8Error(e) => write!(f, "{}", e),
|
||||
E::InvalidUsername => write!(f, "Invalid username"),
|
||||
E::InvalidPassword => write!(f, "Invalid password"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for PasswordProtocolExtensionError {}
|
||||
|
||||
impl From<PasswordProtocolExtensionError> for WispError {
|
||||
fn from(value: PasswordProtocolExtensionError) -> Self {
|
||||
WispError::ExtensionImplError(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FromUtf8Error> for PasswordProtocolExtensionError {
|
||||
fn from(value: FromUtf8Error) -> Self {
|
||||
PasswordProtocolExtensionError::Utf8Error(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PasswordProtocolExtension> for AnyProtocolExtension {
|
||||
fn from(value: PasswordProtocolExtension) -> Self {
|
||||
AnyProtocolExtension(Box::new(value))
|
||||
}
|
||||
}
|
||||
|
||||
/// Password protocol extension builder.
|
||||
pub struct PasswordProtocolExtensionBuilder {
|
||||
/// Map of users and their passwords to allow. Only used on server.
|
||||
pub users: HashMap<String, String>,
|
||||
/// Username to authenticate with. Only used on client.
|
||||
pub username: String,
|
||||
/// Password to authenticate with. Only used on client.
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl PasswordProtocolExtensionBuilder {
|
||||
/// Create a new password protocol extension builder for the server, with a map of users
|
||||
/// and passwords to allow.
|
||||
pub fn new_server(users: HashMap<String, String>) -> Self {
|
||||
Self {
|
||||
users,
|
||||
username: String::new(),
|
||||
password: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new password protocol extension builder for the client, with a username and
|
||||
/// password to authenticate with.
|
||||
pub fn new_client(username: String, password: String) -> Self {
|
||||
Self {
|
||||
users: HashMap::new(),
|
||||
username,
|
||||
password,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
|
||||
fn get_id(&self) -> u8 {
|
||||
PasswordProtocolExtension::ID
|
||||
}
|
||||
|
||||
fn build_from_bytes(
|
||||
&self,
|
||||
mut payload: Bytes,
|
||||
role: crate::Role,
|
||||
) -> Result<AnyProtocolExtension, WispError> {
|
||||
match role {
|
||||
Role::Server => {
|
||||
if payload.remaining() < 3 {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
|
||||
let username_len = payload.get_u8();
|
||||
let password_len = payload.get_u16_le();
|
||||
if payload.remaining() < (password_len + username_len as u16) as usize {
|
||||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
|
||||
use PasswordProtocolExtensionError as EError;
|
||||
let username =
|
||||
String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
|
||||
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||
let password =
|
||||
String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
|
||||
.map_err(|x| WispError::from(EError::from(x)))?;
|
||||
|
||||
let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
|
||||
return Err(EError::InvalidUsername.into());
|
||||
};
|
||||
|
||||
if *user.1 != password {
|
||||
return Err(EError::InvalidPassword.into());
|
||||
}
|
||||
Ok(PasswordProtocolExtension {
|
||||
username,
|
||||
password,
|
||||
role,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
Role::Client => {
|
||||
Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
|
||||
match role {
|
||||
Role::Server => PasswordProtocolExtension::new_server(),
|
||||
Role::Client => PasswordProtocolExtension::new_client(
|
||||
self.username.clone(),
|
||||
self.password.clone(),
|
||||
),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -458,13 +458,11 @@ pub struct ServerMux {
|
|||
impl ServerMux {
|
||||
/// Create a new server-side multiplexor.
|
||||
///
|
||||
/// If either extensions or extension_builders are None a Wisp v1 connection is created
|
||||
/// otherwise a Wisp v2 connection is created.
|
||||
/// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
pub async fn new<R, W>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
buffer_size: u32,
|
||||
extensions: Option<Vec<AnyProtocolExtension>>,
|
||||
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||
where
|
||||
|
@ -483,15 +481,17 @@ impl ServerMux {
|
|||
let mut extra_packet = Vec::with_capacity(1);
|
||||
let mut downgraded = true;
|
||||
|
||||
if let Some(extensions) = extensions {
|
||||
if let Some(builders) = extension_builders {
|
||||
let extensions: Vec<_> = builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Server))
|
||||
.collect();
|
||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||
write
|
||||
.write_frame(Packet::new_info(extensions).into())
|
||||
.await?;
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||
// TODO change this to correct timeout once draft 2 is out
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
||||
|
@ -507,7 +507,6 @@ impl ServerMux {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
|
@ -574,12 +573,10 @@ pub struct ClientMux {
|
|||
impl ClientMux {
|
||||
/// Create a new client side multiplexor.
|
||||
///
|
||||
/// If either extensions or extension_builders are None a Wisp v1 connection is created
|
||||
/// otherwise a Wisp v2 connection is created.
|
||||
/// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
pub async fn new<R, W>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
extensions: Option<Vec<AnyProtocolExtension>>,
|
||||
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||
where
|
||||
|
@ -596,12 +593,14 @@ impl ClientMux {
|
|||
let mut extra_packet = Vec::with_capacity(1);
|
||||
let mut downgraded = true;
|
||||
|
||||
if let Some(extensions) = extensions {
|
||||
if let Some(builders) = extension_builders {
|
||||
let extensions: Vec<_> = builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Client))
|
||||
.collect();
|
||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||
// TODO change this to correct timeout once draft 2 is out
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
||||
|
@ -620,7 +619,6 @@ impl ClientMux {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(&mut read, &write).await?;
|
||||
|
|
|
@ -444,7 +444,9 @@ impl Packet {
|
|||
return Err(WispError::PacketTooSmall);
|
||||
}
|
||||
if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) {
|
||||
extensions.push(builder.build(bytes.copy_to_bytes(length), role))
|
||||
if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) {
|
||||
extensions.push(extension)
|
||||
}
|
||||
} else {
|
||||
bytes.advance(length)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue