diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs index 9358c4a..dfefae6 100644 --- a/wisp/src/extensions.rs +++ b/wisp/src/extensions.rs @@ -88,7 +88,7 @@ pub trait ProtocolExtension: std::fmt::Debug { fn box_clone(&self) -> Box; } -/// Trait to build a Wisp protocol extension for the client. +/// Trait to build a Wisp protocol extension from a payload. pub trait ProtocolExtensionBuilder { /// Get the protocol extension ID. /// @@ -96,7 +96,11 @@ pub trait ProtocolExtensionBuilder { fn get_id(&self) -> u8; /// Build a protocol extension from the extension's metadata. - fn build(&self, bytes: Bytes, role: Role) -> AnyProtocolExtension; + fn build_from_bytes(&self, bytes: Bytes, role: Role) + -> Result; + + /// Build a protocol extension to send to the other side. + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; } pub mod udp { @@ -108,7 +112,6 @@ pub mod udp { //! rx, //! tx, //! 128, - //! Some(vec![UdpProtocolExtension().into()]), //! Some(&[&UdpProtocolExtensionBuilder()]) //! ); //! ``` @@ -154,7 +157,6 @@ pub mod udp { Ok(()) } - /// Handle receiving a packet. async fn handle_packet( &mut self, _: Bytes, @@ -180,11 +182,294 @@ pub mod udp { impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { fn get_id(&self) -> u8 { - 0x01 + UdpProtocolExtension::ID } - fn build(&self, _: Bytes, _: crate::Role) -> AnyProtocolExtension { - AnyProtocolExtension(Box::new(UdpProtocolExtension())) + fn build_from_bytes( + &self, + _: Bytes, + _: crate::Role, + ) -> Result { + Ok(UdpProtocolExtension().into()) + } + + fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { + UdpProtocolExtension().into() + } + } +} + +pub mod password { + //! Password protocol extension. + //! + //! # Example + //! Server: + //! ``` + //! let mut passwords = HashMap::new(); + //! passwords.insert("user1".to_string(), "pw".to_string()); + //! let (mux, fut) = ServerMux::new( + //! rx, + //! tx, + //! 128, + //! Some(&[&PasswordProtocolExtensionBuilder::new_server(passwords)]) + //! ); + //! ``` + //! + //! Client: + //! ``` + //! let (mux, fut) = ClientMux::new( + //! rx, + //! tx, + //! 128, + //! Some(&[ + //! &PasswordProtocolExtensionBuilder::new_client( + //! "user1".to_string(), + //! "pw".to_string() + //! ) + //! ]) + //! ); + //! ``` + //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) + + use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; + + use async_trait::async_trait; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, + }; + + use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + + #[derive(Debug, Clone)] + /// Password protocol extension. + /// + /// **This extension will panic when encoding if the username's length does not fit within a u8 + /// or the password's length does not fit within a u16.** + pub struct PasswordProtocolExtension { + /// The username to log in with. + /// + /// This string's length must fit within a u8. + pub username: String, + /// The password to log in with. + /// + /// This string's length must fit within a u16. + pub password: String, + role: Role, + } + + impl PasswordProtocolExtension { + /// Password protocol extension ID. + pub const ID: u8 = 0x02; + + /// Create a new password protocol extension for the server. + /// + /// This signifies that the server requires a password. + pub fn new_server() -> Self { + Self { + username: String::new(), + password: String::new(), + role: Role::Server, + } + } + + /// Create a new password protocol extension for the client, with a username and password. + /// + /// The username's length must fit within a u8. The password's length must fit within a + /// u16. + pub fn new_client(username: String, password: String) -> Self { + Self { + username, + password, + role: Role::Client, + } + } + } + + #[async_trait] + impl ProtocolExtension for PasswordProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + match self.role { + Role::Server => Bytes::new(), + Role::Client => { + let username = Bytes::from(self.username.clone().into_bytes()); + let password = Bytes::from(self.password.clone().into_bytes()); + let username_len = u8::try_from(username.len()).expect("username was too long"); + let password_len = + u16::try_from(username.len()).expect("password was too long"); + + let mut bytes = + BytesMut::with_capacity(3 + username_len as usize + password_len as usize); + bytes.put_u8(username_len); + bytes.put_u16_le(password_len); + bytes.extend(username); + bytes.extend(password); + bytes.freeze() + } + } + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } + } + + #[derive(Debug)] + enum PasswordProtocolExtensionError { + Utf8Error(FromUtf8Error), + InvalidUsername, + InvalidPassword, + } + + impl Display for PasswordProtocolExtensionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use PasswordProtocolExtensionError as E; + match self { + E::Utf8Error(e) => write!(f, "{}", e), + E::InvalidUsername => write!(f, "Invalid username"), + E::InvalidPassword => write!(f, "Invalid password"), + } + } + } + + impl Error for PasswordProtocolExtensionError {} + + impl From for WispError { + fn from(value: PasswordProtocolExtensionError) -> Self { + WispError::ExtensionImplError(Box::new(value)) + } + } + + impl From for PasswordProtocolExtensionError { + fn from(value: FromUtf8Error) -> Self { + PasswordProtocolExtensionError::Utf8Error(value) + } + } + + impl From for AnyProtocolExtension { + fn from(value: PasswordProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } + } + + /// Password protocol extension builder. + pub struct PasswordProtocolExtensionBuilder { + /// Map of users and their passwords to allow. Only used on server. + pub users: HashMap, + /// Username to authenticate with. Only used on client. + pub username: String, + /// Password to authenticate with. Only used on client. + pub password: String, + } + + impl PasswordProtocolExtensionBuilder { + /// Create a new password protocol extension builder for the server, with a map of users + /// and passwords to allow. + pub fn new_server(users: HashMap) -> Self { + Self { + users, + username: String::new(), + password: String::new(), + } + } + + /// Create a new password protocol extension builder for the client, with a username and + /// password to authenticate with. + pub fn new_client(username: String, password: String) -> Self { + Self { + users: HashMap::new(), + username, + password, + } + } + } + + impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + PasswordProtocolExtension::ID + } + + fn build_from_bytes( + &self, + mut payload: Bytes, + role: crate::Role, + ) -> Result { + match role { + Role::Server => { + if payload.remaining() < 3 { + return Err(WispError::PacketTooSmall); + } + + let username_len = payload.get_u8(); + let password_len = payload.get_u16_le(); + if payload.remaining() < (password_len + username_len as u16) as usize { + return Err(WispError::PacketTooSmall); + } + + use PasswordProtocolExtensionError as EError; + let username = + String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + let password = + String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + + let Some(user) = self.users.iter().find(|x| *x.0 == username) else { + return Err(EError::InvalidUsername.into()); + }; + + if *user.1 != password { + return Err(EError::InvalidPassword.into()); + } + Ok(PasswordProtocolExtension { + username, + password, + role, + } + .into()) + } + Role::Client => { + Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) + } + } + } + + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { + match role { + Role::Server => PasswordProtocolExtension::new_server(), + Role::Client => PasswordProtocolExtension::new_client( + self.username.clone(), + self.password.clone(), + ), + } + .into() } } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 7458bf4..40b21f7 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -458,13 +458,11 @@ pub struct ServerMux { impl ServerMux { /// Create a new server-side multiplexor. /// - /// If either extensions or extension_builders are None a Wisp v1 connection is created - /// otherwise a Wisp v2 connection is created. + /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. pub async fn new( mut read: R, write: W, buffer_size: u32, - extensions: Option>, extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, ) -> Result<(Self, impl Future> + Send), WispError> where @@ -483,28 +481,29 @@ impl ServerMux { let mut extra_packet = Vec::with_capacity(1); let mut downgraded = true; - if let Some(extensions) = extensions { - if let Some(builders) = extension_builders { - let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); - write - .write_frame(Packet::new_info(extensions).into()) - .await?; - if let Some(frame) = select! { - x = read.wisp_read_frame(&write).fuse() => Some(x?), - // TODO change this to correct timeout once draft 2 is out - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - downgraded = false; - } else { - extra_packet.push(packet.into()); - } + if let Some(builders) = extension_builders { + let extensions: Vec<_> = builders + .iter() + .map(|x| x.build_to_extension(Role::Server)) + .collect(); + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.push(packet.into()); } } } @@ -574,12 +573,10 @@ pub struct ClientMux { impl ClientMux { /// Create a new client side multiplexor. /// - /// If either extensions or extension_builders are None a Wisp v1 connection is created - /// otherwise a Wisp v2 connection is created. + /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. pub async fn new( mut read: R, write: W, - extensions: Option>, extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, ) -> Result<(Self, impl Future> + Send), WispError> where @@ -596,28 +593,29 @@ impl ClientMux { let mut extra_packet = Vec::with_capacity(1); let mut downgraded = true; - if let Some(extensions) = extensions { - if let Some(builders) = extension_builders { - let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); - if let Some(frame) = select! { - x = read.wisp_read_frame(&write).fuse() => Some(x?), - // TODO change this to correct timeout once draft 2 is out - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - write - .write_frame(Packet::new_info(extensions).into()) - .await?; - downgraded = false; - } else { - extra_packet.push(packet.into()); - } + if let Some(builders) = extension_builders { + let extensions: Vec<_> = builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(); + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + downgraded = false; + } else { + extra_packet.push(packet.into()); } } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 388fae7..0017307 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -444,7 +444,9 @@ impl Packet { return Err(WispError::PacketTooSmall); } if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { - extensions.push(builder.build(bytes.copy_to_bytes(length), role)) + if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { + extensions.push(extension) + } } else { bytes.advance(length) }