diff --git a/certs-grabber/src/main.rs b/certs-grabber/src/main.rs index 11574b8..c6dacf8 100644 --- a/certs-grabber/src/main.rs +++ b/certs-grabber/src/main.rs @@ -60,5 +60,5 @@ async fn main() { } code.pop(); code.push_str("];"); - println!("{}",code); + println!("{}", code); } diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 5746ac2..47ff2c3 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -123,7 +123,6 @@ impl tower_service::Service for TlsWispService { let stream = service.call(uri_parsed).await?.into_inner(); if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { let connector = TlsConnector::from(rustls_config); - log!("got stream"); Ok(TokioIo::new(Either::Left( connector .connect( @@ -216,10 +215,7 @@ impl WebSocketRead for WebSocketReader { } impl WebSocketWrapper { - pub fn connect( - url: &str, - protocols: Vec, - ) -> Result<(Self, WebSocketReader), JsValue> { + pub fn connect(url: &str, protocols: Vec) -> Result<(Self, WebSocketReader), JsValue> { let (read_tx, read_rx) = mpsc::unbounded_channel(); let closed = Arc::new(AtomicBool::new(false)); diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs deleted file mode 100644 index de472bd..0000000 --- a/wisp/src/extensions.rs +++ /dev/null @@ -1,481 +0,0 @@ -//! Wisp protocol extensions. - -use std::ops::{Deref, DerefMut}; - -use async_trait::async_trait; -use bytes::{BufMut, Bytes, BytesMut}; - -use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - Role, WispError, -}; - -/// Type-erased protocol extension that implements Clone. -#[derive(Debug)] -pub struct AnyProtocolExtension(Box); - -impl AnyProtocolExtension { - /// Create a new type-erased protocol extension. - pub fn new(extension: T) -> Self { - Self(Box::new(extension)) - } -} - -impl Deref for AnyProtocolExtension { - type Target = dyn ProtocolExtension; - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -impl DerefMut for AnyProtocolExtension { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.deref_mut() - } -} - -impl Clone for AnyProtocolExtension { - fn clone(&self) -> Self { - Self(self.0.box_clone()) - } -} - -impl From for Bytes { - fn from(value: AnyProtocolExtension) -> Self { - let mut bytes = BytesMut::with_capacity(5); - let payload = value.encode(); - bytes.put_u8(value.get_id()); - bytes.put_u32_le(payload.len() as u32); - bytes.extend(payload); - bytes.freeze() - } -} - -/// A Wisp protocol extension. -/// -/// See [the -/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). -#[async_trait] -pub trait ProtocolExtension: std::fmt::Debug { - /// Get the protocol extension ID. - fn get_id(&self) -> u8; - /// Get the protocol extension's supported packets. - /// - /// Used to decide whether to call the protocol extension's packet handler. - fn get_supported_packets(&self) -> &'static [u8]; - - /// Encode self into Bytes. - fn encode(&self) -> Bytes; - - /// Handle the handshake part of a Wisp connection. - /// - /// This should be used to send or receive data before any streams are created. - async fn handle_handshake( - &mut self, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, - ) -> Result<(), WispError>; - - /// Handle receiving a packet. - async fn handle_packet( - &mut self, - packet: Bytes, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, - ) -> Result<(), WispError>; - - /// Clone the protocol extension. - fn box_clone(&self) -> Box; -} - -/// Trait to build a Wisp protocol extension from a payload. -pub trait ProtocolExtensionBuilder { - /// Get the protocol extension ID. - /// - /// Used to decide whether this builder should be used. - fn get_id(&self) -> u8; - - /// Build a protocol extension from the extension's metadata. - fn build_from_bytes(&self, bytes: Bytes, role: Role) - -> Result; - - /// Build a protocol extension to send to the other side. - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; -} - -pub mod udp { - //! UDP protocol extension. - //! - //! # Example - //! ``` - //! let (mux, fut) = ServerMux::new( - //! rx, - //! tx, - //! 128, - //! Some(&[Box::new(UdpProtocolExtensionBuilder())]) - //! ); - //! ``` - //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) - use async_trait::async_trait; - use bytes::Bytes; - - use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - WispError, - }; - - use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; - - #[derive(Debug)] - /// UDP protocol extension. - pub struct UdpProtocolExtension(); - - impl UdpProtocolExtension { - /// UDP protocol extension ID. - pub const ID: u8 = 0x01; - } - - #[async_trait] - impl ProtocolExtension for UdpProtocolExtension { - fn get_id(&self) -> u8 { - Self::ID - } - - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn encode(&self) -> Bytes { - Bytes::new() - } - - async fn handle_handshake( - &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - fn box_clone(&self) -> Box { - Box::new(Self()) - } - } - - impl From for AnyProtocolExtension { - fn from(value: UdpProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } - } - - /// UDP protocol extension builder. - pub struct UdpProtocolExtensionBuilder(); - - impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { - fn get_id(&self) -> u8 { - UdpProtocolExtension::ID - } - - fn build_from_bytes( - &self, - _: Bytes, - _: crate::Role, - ) -> Result { - Ok(UdpProtocolExtension().into()) - } - - fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { - UdpProtocolExtension().into() - } - } -} - -pub mod password { - //! Password protocol extension. - //! - //! Passwords are sent in plain text!! - //! - //! # Example - //! Server: - //! ``` - //! let mut passwords = HashMap::new(); - //! passwords.insert("user1".to_string(), "pw".to_string()); - //! let (mux, fut) = ServerMux::new( - //! rx, - //! tx, - //! 128, - //! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))]) - //! ); - //! ``` - //! - //! Client: - //! ``` - //! let (mux, fut) = ClientMux::new( - //! rx, - //! tx, - //! 128, - //! Some(&[ - //! Box::new(PasswordProtocolExtensionBuilder::new_client( - //! "user1".to_string(), - //! "pw".to_string() - //! )) - //! ]) - //! ); - //! ``` - //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) - - use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; - - use async_trait::async_trait; - use bytes::{Buf, BufMut, Bytes, BytesMut}; - - use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - Role, WispError, - }; - - use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; - - #[derive(Debug, Clone)] - /// Password protocol extension. - /// - /// **Passwords are sent in plain text!!** - /// **This extension will panic when encoding if the username's length does not fit within a u8 - /// or the password's length does not fit within a u16.** - pub struct PasswordProtocolExtension { - /// The username to log in with. - /// - /// This string's length must fit within a u8. - pub username: String, - /// The password to log in with. - /// - /// This string's length must fit within a u16. - pub password: String, - role: Role, - } - - impl PasswordProtocolExtension { - /// Password protocol extension ID. - pub const ID: u8 = 0x02; - - /// Create a new password protocol extension for the server. - /// - /// This signifies that the server requires a password. - pub fn new_server() -> Self { - Self { - username: String::new(), - password: String::new(), - role: Role::Server, - } - } - - /// Create a new password protocol extension for the client, with a username and password. - /// - /// The username's length must fit within a u8. The password's length must fit within a - /// u16. - pub fn new_client(username: String, password: String) -> Self { - Self { - username, - password, - role: Role::Client, - } - } - } - - #[async_trait] - impl ProtocolExtension for PasswordProtocolExtension { - fn get_id(&self) -> u8 { - Self::ID - } - - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn encode(&self) -> Bytes { - match self.role { - Role::Server => Bytes::new(), - Role::Client => { - let username = Bytes::from(self.username.clone().into_bytes()); - let password = Bytes::from(self.password.clone().into_bytes()); - let username_len = u8::try_from(username.len()).expect("username was too long"); - let password_len = - u16::try_from(password.len()).expect("password was too long"); - - let mut bytes = - BytesMut::with_capacity(3 + username_len as usize + password_len as usize); - bytes.put_u8(username_len); - bytes.put_u16_le(password_len); - bytes.extend(username); - bytes.extend(password); - bytes.freeze() - } - } - } - - async fn handle_handshake( - &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - } - - #[derive(Debug)] - enum PasswordProtocolExtensionError { - Utf8Error(FromUtf8Error), - InvalidUsername, - InvalidPassword, - } - - impl Display for PasswordProtocolExtensionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use PasswordProtocolExtensionError as E; - match self { - E::Utf8Error(e) => write!(f, "{}", e), - E::InvalidUsername => write!(f, "Invalid username"), - E::InvalidPassword => write!(f, "Invalid password"), - } - } - } - - impl Error for PasswordProtocolExtensionError {} - - impl From for WispError { - fn from(value: PasswordProtocolExtensionError) -> Self { - WispError::ExtensionImplError(Box::new(value)) - } - } - - impl From for PasswordProtocolExtensionError { - fn from(value: FromUtf8Error) -> Self { - PasswordProtocolExtensionError::Utf8Error(value) - } - } - - impl From for AnyProtocolExtension { - fn from(value: PasswordProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } - } - - /// Password protocol extension builder. - /// - /// **Passwords are sent in plain text!!** - pub struct PasswordProtocolExtensionBuilder { - /// Map of users and their passwords to allow. Only used on server. - pub users: HashMap, - /// Username to authenticate with. Only used on client. - pub username: String, - /// Password to authenticate with. Only used on client. - pub password: String, - } - - impl PasswordProtocolExtensionBuilder { - /// Create a new password protocol extension builder for the server, with a map of users - /// and passwords to allow. - pub fn new_server(users: HashMap) -> Self { - Self { - users, - username: String::new(), - password: String::new(), - } - } - - /// Create a new password protocol extension builder for the client, with a username and - /// password to authenticate with. - pub fn new_client(username: String, password: String) -> Self { - Self { - users: HashMap::new(), - username, - password, - } - } - } - - impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { - fn get_id(&self) -> u8 { - PasswordProtocolExtension::ID - } - - fn build_from_bytes( - &self, - mut payload: Bytes, - role: crate::Role, - ) -> Result { - match role { - Role::Server => { - if payload.remaining() < 3 { - return Err(WispError::PacketTooSmall); - } - - let username_len = payload.get_u8(); - let password_len = payload.get_u16_le(); - if payload.remaining() < (password_len + username_len as u16) as usize { - return Err(WispError::PacketTooSmall); - } - - use PasswordProtocolExtensionError as EError; - let username = - String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; - let password = - String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; - - let Some(user) = self.users.iter().find(|x| *x.0 == username) else { - return Err(EError::InvalidUsername.into()); - }; - - if *user.1 != password { - return Err(EError::InvalidPassword.into()); - } - - Ok(PasswordProtocolExtension { - username, - password, - role, - } - .into()) - } - Role::Client => { - Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) - } - } - } - - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { - match role { - Role::Server => PasswordProtocolExtension::new_server(), - Role::Client => PasswordProtocolExtension::new_client( - self.username.clone(), - self.password.clone(), - ), - } - .into() - } - } -} diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs new file mode 100644 index 0000000..8c3ec12 --- /dev/null +++ b/wisp/src/extensions/mod.rs @@ -0,0 +1,106 @@ +//! Wisp protocol extensions. +pub mod password; +pub mod udp; + +use std::ops::{Deref, DerefMut}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +/// Type-erased protocol extension that implements Clone. +#[derive(Debug)] +pub struct AnyProtocolExtension(Box); + +impl AnyProtocolExtension { + /// Create a new type-erased protocol extension. + pub fn new(extension: T) -> Self { + Self(Box::new(extension)) + } +} + +impl Deref for AnyProtocolExtension { + type Target = dyn ProtocolExtension; + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl DerefMut for AnyProtocolExtension { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } +} + +impl Clone for AnyProtocolExtension { + fn clone(&self) -> Self { + Self(self.0.box_clone()) + } +} + +impl From for Bytes { + fn from(value: AnyProtocolExtension) -> Self { + let mut bytes = BytesMut::with_capacity(5); + let payload = value.encode(); + bytes.put_u8(value.get_id()); + bytes.put_u32_le(payload.len() as u32); + bytes.extend(payload); + bytes.freeze() + } +} + +/// A Wisp protocol extension. +/// +/// See [the +/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). +#[async_trait] +pub trait ProtocolExtension: std::fmt::Debug { + /// Get the protocol extension ID. + fn get_id(&self) -> u8; + /// Get the protocol extension's supported packets. + /// + /// Used to decide whether to call the protocol extension's packet handler. + fn get_supported_packets(&self) -> &'static [u8]; + + /// Encode self into Bytes. + fn encode(&self) -> Bytes; + + /// Handle the handshake part of a Wisp connection. + /// + /// This should be used to send or receive data before any streams are created. + async fn handle_handshake( + &mut self, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + packet: Bytes, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Clone the protocol extension. + fn box_clone(&self) -> Box; +} + +/// Trait to build a Wisp protocol extension from a payload. +pub trait ProtocolExtensionBuilder { + /// Get the protocol extension ID. + /// + /// Used to decide whether this builder should be used. + fn get_id(&self) -> u8; + + /// Build a protocol extension from the extension's metadata. + fn build_from_bytes(&self, bytes: Bytes, role: Role) + -> Result; + + /// Build a protocol extension to send to the other side. + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; +} diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs new file mode 100644 index 0000000..3fe15b3 --- /dev/null +++ b/wisp/src/extensions/password.rs @@ -0,0 +1,276 @@ +//! Password protocol extension. +//! +//! Passwords are sent in plain text!! +//! +//! # Example +//! Server: +//! ``` +//! let mut passwords = HashMap::new(); +//! passwords.insert("user1".to_string(), "pw".to_string()); +//! let (mux, fut) = ServerMux::new( +//! rx, +//! tx, +//! 128, +//! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))]) +//! ); +//! ``` +//! +//! Client: +//! ``` +//! let (mux, fut) = ClientMux::new( +//! rx, +//! tx, +//! 128, +//! Some(&[ +//! Box::new(PasswordProtocolExtensionBuilder::new_client( +//! "user1".to_string(), +//! "pw".to_string() +//! )) +//! ]) +//! ); +//! ``` +//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) + +use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; + +use async_trait::async_trait; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + +#[derive(Debug, Clone)] +/// Password protocol extension. +/// +/// **Passwords are sent in plain text!!** +/// **This extension will panic when encoding if the username's length does not fit within a u8 +/// or the password's length does not fit within a u16.** +pub struct PasswordProtocolExtension { + /// The username to log in with. + /// + /// This string's length must fit within a u8. + pub username: String, + /// The password to log in with. + /// + /// This string's length must fit within a u16. + pub password: String, + role: Role, +} + +impl PasswordProtocolExtension { + /// Password protocol extension ID. + pub const ID: u8 = 0x02; + + /// Create a new password protocol extension for the server. + /// + /// This signifies that the server requires a password. + pub fn new_server() -> Self { + Self { + username: String::new(), + password: String::new(), + role: Role::Server, + } + } + + /// Create a new password protocol extension for the client, with a username and password. + /// + /// The username's length must fit within a u8. The password's length must fit within a + /// u16. + pub fn new_client(username: String, password: String) -> Self { + Self { + username, + password, + role: Role::Client, + } + } +} + +#[async_trait] +impl ProtocolExtension for PasswordProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + match self.role { + Role::Server => Bytes::new(), + Role::Client => { + let username = Bytes::from(self.username.clone().into_bytes()); + let password = Bytes::from(self.password.clone().into_bytes()); + let username_len = u8::try_from(username.len()).expect("username was too long"); + let password_len = u16::try_from(password.len()).expect("password was too long"); + + let mut bytes = + BytesMut::with_capacity(3 + username_len as usize + password_len as usize); + bytes.put_u8(username_len); + bytes.put_u16_le(password_len); + bytes.extend(username); + bytes.extend(password); + bytes.freeze() + } + } + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +#[derive(Debug)] +enum PasswordProtocolExtensionError { + Utf8Error(FromUtf8Error), + InvalidUsername, + InvalidPassword, +} + +impl Display for PasswordProtocolExtensionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use PasswordProtocolExtensionError as E; + match self { + E::Utf8Error(e) => write!(f, "{}", e), + E::InvalidUsername => write!(f, "Invalid username"), + E::InvalidPassword => write!(f, "Invalid password"), + } + } +} + +impl Error for PasswordProtocolExtensionError {} + +impl From for WispError { + fn from(value: PasswordProtocolExtensionError) -> Self { + WispError::ExtensionImplError(Box::new(value)) + } +} + +impl From for PasswordProtocolExtensionError { + fn from(value: FromUtf8Error) -> Self { + PasswordProtocolExtensionError::Utf8Error(value) + } +} + +impl From for AnyProtocolExtension { + fn from(value: PasswordProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } +} + +/// Password protocol extension builder. +/// +/// **Passwords are sent in plain text!!** +pub struct PasswordProtocolExtensionBuilder { + /// Map of users and their passwords to allow. Only used on server. + pub users: HashMap, + /// Username to authenticate with. Only used on client. + pub username: String, + /// Password to authenticate with. Only used on client. + pub password: String, +} + +impl PasswordProtocolExtensionBuilder { + /// Create a new password protocol extension builder for the server, with a map of users + /// and passwords to allow. + pub fn new_server(users: HashMap) -> Self { + Self { + users, + username: String::new(), + password: String::new(), + } + } + + /// Create a new password protocol extension builder for the client, with a username and + /// password to authenticate with. + pub fn new_client(username: String, password: String) -> Self { + Self { + users: HashMap::new(), + username, + password, + } + } +} + +impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + PasswordProtocolExtension::ID + } + + fn build_from_bytes( + &self, + mut payload: Bytes, + role: crate::Role, + ) -> Result { + match role { + Role::Server => { + if payload.remaining() < 3 { + return Err(WispError::PacketTooSmall); + } + + let username_len = payload.get_u8(); + let password_len = payload.get_u16_le(); + if payload.remaining() < (password_len + username_len as u16) as usize { + return Err(WispError::PacketTooSmall); + } + + use PasswordProtocolExtensionError as EError; + let username = + String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + let password = + String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + + let Some(user) = self.users.iter().find(|x| *x.0 == username) else { + return Err(EError::InvalidUsername.into()); + }; + + if *user.1 != password { + return Err(EError::InvalidPassword.into()); + } + + Ok(PasswordProtocolExtension { + username, + password, + role, + } + .into()) + } + Role::Client => { + Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) + } + } + } + + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { + match role { + Role::Server => PasswordProtocolExtension::new_server(), + Role::Client => { + PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) + } + } + .into() + } +} diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs new file mode 100644 index 0000000..068b5eb --- /dev/null +++ b/wisp/src/extensions/udp.rs @@ -0,0 +1,93 @@ +//! UDP protocol extension. +//! +//! # Example +//! ``` +//! let (mux, fut) = ServerMux::new( +//! rx, +//! tx, +//! 128, +//! Some(&[Box::new(UdpProtocolExtensionBuilder())]) +//! ); +//! ``` +//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) +use async_trait::async_trait; +use bytes::Bytes; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + WispError, +}; + +use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + +#[derive(Debug)] +/// UDP protocol extension. +pub struct UdpProtocolExtension(); + +impl UdpProtocolExtension { + /// UDP protocol extension ID. + pub const ID: u8 = 0x01; +} + +#[async_trait] +impl ProtocolExtension for UdpProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + Bytes::new() + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(Self()) + } +} + +impl From for AnyProtocolExtension { + fn from(value: UdpProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } +} + +/// UDP protocol extension builder. +pub struct UdpProtocolExtensionBuilder(); + +impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + UdpProtocolExtension::ID + } + + fn build_from_bytes( + &self, + _: Bytes, + _: crate::Role, + ) -> Result { + Ok(UdpProtocolExtension().into()) + } + + fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { + UdpProtocolExtension().into() + } +} diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index a2c2f7d..e05de91 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use async_trait::async_trait; use bytes::BytesMut; use fastwebsockets::{ - CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite + CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -79,6 +79,8 @@ impl crate::ws::WebSocketWrite for WebSocketWrite< } async fn wisp_close(&mut self) -> Result<(), WispError> { - self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into()) + self.write_frame(Frame::close(CloseCode::Normal.into(), b"")) + .await + .map_err(|e| e.into()) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 8a074a7..e980bec 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -189,7 +189,10 @@ impl Drop for MuxStreamWrite { if !self.is_closed.load(Ordering::Acquire) { self.is_closed.store(true, Ordering::Release); let (tx, _) = oneshot::channel(); - let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx)); + let _ = self.tx.send(WsEvent::Close( + Packet::new_close(self.stream_id, CloseReason::Unknown), + tx, + )); } } }