clippy pedantic

This commit is contained in:
Toshit Chawda 2024-11-25 13:29:29 -08:00
parent 272610f904
commit 7efda6c533
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
14 changed files with 148 additions and 129 deletions

View file

@ -1,7 +1,6 @@
#![doc(html_no_source)] #![doc(html_no_source)]
#![deny(clippy::todo)] #![deny(clippy::todo)]
#![allow(unexpected_cfgs)] #![allow(unexpected_cfgs)]
#![warn(clippy::large_futures)]
use std::{collections::HashMap, fs::read_to_string, net::IpAddr}; use std::{collections::HashMap, fs::read_to_string, net::IpAddr};

View file

@ -29,8 +29,8 @@ pub enum CertAuthError {
impl std::fmt::Display for CertAuthError { impl std::fmt::Display for CertAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Self::Ed25519(x) => write!(f, "ED25519: {:?}", x), Self::Ed25519(x) => write!(f, "ED25519: {x:?}"),
Self::Getrandom(x) => write!(f, "getrandom: {:?}", x), Self::Getrandom(x) => write!(f, "getrandom: {x:?}"),
} }
} }
} }
@ -57,7 +57,7 @@ bitflags::bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SupportedCertificateTypes: u8 { pub struct SupportedCertificateTypes: u8 {
/// ED25519 certificate. /// ED25519 certificate.
const Ed25519 = 0b00000001; const Ed25519 = 0b0000_0001;
} }
} }
@ -160,7 +160,7 @@ impl ProtocolExtension for CertAuthProtocolExtension {
required, required,
} => { } => {
let mut out = BytesMut::with_capacity(2 + challenge.len()); let mut out = BytesMut::with_capacity(2 + challenge.len());
out.put_u8(*required as u8); out.put_u8(u8::from(*required));
out.put_u8(cert_types.bits()); out.put_u8(cert_types.bits());
out.extend_from_slice(challenge); out.extend_from_slice(challenge);
out.freeze() out.freeze()
@ -176,8 +176,7 @@ impl ProtocolExtension for CertAuthProtocolExtension {
out.extend_from_slice(signature); out.extend_from_slice(signature);
out.freeze() out.freeze()
} }
Self::ClientRecieved => Bytes::new(), Self::ServerVerified | Self::ClientRecieved => Bytes::new(),
Self::ServerVerified => Bytes::new(),
} }
} }
@ -262,10 +261,10 @@ impl CertAuthProtocolExtensionBuilder {
/// sent the certificate authentication protocol extension. /// sent the certificate authentication protocol extension.
pub fn is_required(&self) -> Option<bool> { pub fn is_required(&self) -> Option<bool> {
match self { match self {
Self::ServerBeforeChallenge { required, .. } => Some(*required), Self::ServerBeforeChallenge { required, .. }
Self::ServerAfterChallenge { required, .. } => Some(*required), | Self::ServerAfterChallenge { required, .. }
| Self::ClientAfterChallenge { required, .. } => Some(*required),
Self::ClientBeforeChallenge { .. } => None, Self::ClientBeforeChallenge { .. } => None,
Self::ClientAfterChallenge { required, .. } => Some(*required),
} }
} }
@ -294,8 +293,6 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
_: Role, _: Role,
) -> Result<AnyProtocolExtension, WispError> { ) -> Result<AnyProtocolExtension, WispError> {
match self { match self {
// server should have already sent the challenge before recieving a response to parse
Self::ServerBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ServerAfterChallenge { Self::ServerAfterChallenge {
verifiers, verifiers,
challenge, challenge,
@ -332,8 +329,12 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
Ok(CertAuthProtocolExtension::ClientRecieved.into()) Ok(CertAuthProtocolExtension::ClientRecieved.into())
} }
// client has already recieved a challenge
Self::ClientAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported), // client has already recieved a challenge or
// server should have already sent the challenge before recieving a response to parse
Self::ClientAfterChallenge { .. } | Self::ServerBeforeChallenge { .. } => {
Err(WispError::ExtensionImplNotSupported)
}
} }
} }
@ -352,7 +353,7 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
let required = *required; let required = *required;
*self = Self::ServerAfterChallenge { *self = Self::ServerAfterChallenge {
verifiers: verifiers.to_vec(), verifiers: verifiers.clone(),
challenge: challenge.clone(), challenge: challenge.clone(),
required, required,
}; };
@ -364,10 +365,6 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
} }
.into()) .into())
} }
// server has already sent a challenge
Self::ServerAfterChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
// client needs to recieve a challenge
Self::ClientBeforeChallenge { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ClientAfterChallenge { Self::ClientAfterChallenge {
signer, signer,
challenge, challenge,
@ -393,6 +390,12 @@ impl ProtocolExtensionBuilder for CertAuthProtocolExtensionBuilder {
} }
.into()) .into())
} }
// server has already sent a challenge or
// client needs to recieve a challenge
Self::ClientBeforeChallenge { .. } | Self::ServerAfterChallenge { .. } => {
Err(WispError::ExtensionImplNotSupported)
}
} }
} }
} }

View file

@ -8,6 +8,7 @@ pub mod udp;
use std::{ use std::{
any::TypeId, any::TypeId,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
ptr,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -47,13 +48,13 @@ impl AnyProtocolExtension {
impl Deref for AnyProtocolExtension { impl Deref for AnyProtocolExtension {
type Target = dyn ProtocolExtension; type Target = dyn ProtocolExtension;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.0.deref() &*self.0
} }
} }
impl DerefMut for AnyProtocolExtension { impl DerefMut for AnyProtocolExtension {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut() &mut *self.0
} }
} }
@ -137,7 +138,7 @@ impl dyn ProtocolExtension {
if self.__is::<T>() { if self.__is::<T>() {
unsafe { unsafe {
let raw: *mut dyn ProtocolExtension = Box::into_raw(self); let raw: *mut dyn ProtocolExtension = Box::into_raw(self);
Ok(Box::from_raw(raw as *mut T)) Ok(Box::from_raw(raw.cast::<T>()))
} }
} else { } else {
Err(self) Err(self)
@ -146,7 +147,7 @@ impl dyn ProtocolExtension {
fn __downcast_ref<T: ProtocolExtension>(&self) -> Option<&T> { fn __downcast_ref<T: ProtocolExtension>(&self) -> Option<&T> {
if self.__is::<T>() { if self.__is::<T>() {
unsafe { Some(&*(self as *const dyn ProtocolExtension as *const T)) } unsafe { Some(&*ptr::from_ref::<dyn ProtocolExtension>(self).cast::<T>()) }
} else { } else {
None None
} }
@ -154,7 +155,7 @@ impl dyn ProtocolExtension {
fn __downcast_mut<T: ProtocolExtension>(&mut self) -> Option<&mut T> { fn __downcast_mut<T: ProtocolExtension>(&mut self) -> Option<&mut T> {
if self.__is::<T>() { if self.__is::<T>() {
unsafe { Some(&mut *(self as *mut dyn ProtocolExtension as *mut T)) } unsafe { Some(&mut *ptr::from_mut::<dyn ProtocolExtension>(self).cast::<T>()) }
} else { } else {
None None
} }
@ -198,7 +199,7 @@ impl dyn ProtocolExtensionBuilder {
if self.__is::<T>() { if self.__is::<T>() {
unsafe { unsafe {
let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self); let raw: *mut dyn ProtocolExtensionBuilder = Box::into_raw(self);
Ok(Box::from_raw(raw as *mut T)) Ok(Box::from_raw(raw.cast::<T>()))
} }
} else { } else {
Err(self) Err(self)
@ -207,7 +208,7 @@ impl dyn ProtocolExtensionBuilder {
fn __downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> { fn __downcast_ref<T: ProtocolExtensionBuilder>(&self) -> Option<&T> {
if self.__is::<T>() { if self.__is::<T>() {
unsafe { Some(&*(self as *const dyn ProtocolExtensionBuilder as *const T)) } unsafe { Some(&*ptr::from_ref::<dyn ProtocolExtensionBuilder>(self).cast::<T>()) }
} else { } else {
None None
} }
@ -215,7 +216,7 @@ impl dyn ProtocolExtensionBuilder {
fn __downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> { fn __downcast_mut<T: ProtocolExtensionBuilder>(&mut self) -> Option<&mut T> {
if self.__is::<T>() { if self.__is::<T>() {
unsafe { Some(&mut *(self as *mut dyn ProtocolExtensionBuilder as *mut T)) } unsafe { Some(&mut *ptr::from_mut::<dyn ProtocolExtensionBuilder>(self).cast::<T>()) }
} else { } else {
None None
} }
@ -250,13 +251,13 @@ impl AnyProtocolExtensionBuilder {
impl Deref for AnyProtocolExtensionBuilder { impl Deref for AnyProtocolExtensionBuilder {
type Target = dyn ProtocolExtensionBuilder; type Target = dyn ProtocolExtensionBuilder;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
self.0.deref() &*self.0
} }
} }
impl DerefMut for AnyProtocolExtensionBuilder { impl DerefMut for AnyProtocolExtensionBuilder {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
self.0.deref_mut() &mut *self.0
} }
} }

View file

@ -76,11 +76,9 @@ impl ProtocolExtension for PasswordProtocolExtension {
match self { match self {
Self::ServerBeforeClientInfo { required } => { Self::ServerBeforeClientInfo { required } => {
let mut out = BytesMut::with_capacity(1); let mut out = BytesMut::with_capacity(1);
out.put_u8(*required as u8); out.put_u8(u8::from(*required));
out.freeze() out.freeze()
} }
Self::ServerAfterClientInfo { .. } => Bytes::new(),
Self::ClientBeforeServerInfo => Bytes::new(),
Self::ClientAfterServerInfo { user, password } => { Self::ClientAfterServerInfo { user, password } => {
let mut out = BytesMut::with_capacity(1 + 2 + user.len() + password.len()); let mut out = BytesMut::with_capacity(1 + 2 + user.len() + password.len());
out.put_u8(user.len().try_into().unwrap()); out.put_u8(user.len().try_into().unwrap());
@ -89,6 +87,8 @@ impl ProtocolExtension for PasswordProtocolExtension {
out.extend_from_slice(password.as_bytes()); out.extend_from_slice(password.as_bytes());
out.freeze() out.freeze()
} }
Self::ServerAfterClientInfo { .. } | Self::ClientBeforeServerInfo => Bytes::new(),
} }
} }
@ -164,10 +164,10 @@ impl PasswordProtocolExtensionBuilder {
/// sent the password protocol extension. /// sent the password protocol extension.
pub fn is_required(&self) -> Option<bool> { pub fn is_required(&self) -> Option<bool> {
match self { match self {
Self::ServerBeforeClientInfo { required, .. } => Some(*required), Self::ServerBeforeClientInfo { required, .. }
Self::ServerAfterClientInfo { required, .. } => Some(*required), | Self::ServerAfterClientInfo { required, .. }
| Self::ClientAfterServerInfo { required, .. } => Some(*required),
Self::ClientBeforeServerInfo { .. } => None, Self::ClientBeforeServerInfo { .. } => None,
Self::ClientAfterServerInfo { required, .. } => Some(*required),
} }
} }
@ -195,8 +195,9 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
} }
.into()) .into())
} }
Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported), Self::ServerAfterClientInfo { .. } | Self::ClientBeforeServerInfo { .. } => {
Self::ClientBeforeServerInfo { .. } => Err(WispError::ExtensionImplNotSupported), Err(WispError::ExtensionImplNotSupported)
}
Self::ClientAfterServerInfo { creds, .. } => { Self::ClientAfterServerInfo { creds, .. } => {
let (user, password) = creds.clone().ok_or(WispError::PasswordExtensionNoCreds)?; let (user, password) = creds.clone().ok_or(WispError::PasswordExtensionNoCreds)?;
Ok(PasswordProtocolExtension::ClientAfterServerInfo { user, password }.into()) Ok(PasswordProtocolExtension::ClientAfterServerInfo { user, password }.into())
@ -218,24 +219,23 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
let password = let password =
std::str::from_utf8(&bytes.split_to(password_len as usize))?.to_string(); std::str::from_utf8(&bytes.split_to(password_len as usize))?.to_string();
let valid = users.get(&user).map(|x| *x == password).unwrap_or(false); let valid = users.get(&user).is_some_and(|x| *x == password);
*self = Self::ServerAfterClientInfo { *self = Self::ServerAfterClientInfo {
users: users.clone(), users: users.clone(),
required: *required, required: *required,
}; };
if !valid { if valid {
Err(WispError::PasswordExtensionCredsInvalid)
} else {
Ok(PasswordProtocolExtension::ServerAfterClientInfo { Ok(PasswordProtocolExtension::ServerAfterClientInfo {
chosen_user: user, chosen_user: user,
chosen_password: password, chosen_password: password,
} }
.into()) .into())
} else {
Err(WispError::PasswordExtensionCredsInvalid)
} }
} }
Self::ServerAfterClientInfo { .. } => Err(WispError::ExtensionImplNotSupported),
Self::ClientBeforeServerInfo { creds } => { Self::ClientBeforeServerInfo { creds } => {
let required = bytes.get_u8() != 0; let required = bytes.get_u8() != 0;
@ -246,7 +246,9 @@ impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
Ok(PasswordProtocolExtension::ClientBeforeServerInfo.into()) Ok(PasswordProtocolExtension::ClientBeforeServerInfo.into())
} }
Self::ClientAfterServerInfo { .. } => Err(WispError::ExtensionImplNotSupported), Self::ClientAfterServerInfo { .. } | Self::ServerAfterClientInfo { .. } => {
Err(WispError::ExtensionImplNotSupported)
}
} }
} }
} }

View file

@ -1,4 +1,4 @@
//! WebSocketRead + WebSocketWrite implementation for generic `Stream + Sink`s. //! `WebSocketRead` and `WebSocketWrite` implementation for generic `Stream`s and `Sink`s.
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Sink, SinkExt, Stream, StreamExt}; use futures::{Sink, SinkExt, Stream, StreamExt};
@ -9,7 +9,7 @@ use crate::{
WispError, WispError,
}; };
/// WebSocketRead implementation for generic `Stream`s. /// `WebSocketRead` implementation for generic `Stream`s.
pub struct GenericWebSocketRead< pub struct GenericWebSocketRead<
T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, T: Stream<Item = Result<BytesMut, E>> + Send + Unpin,
E: Error + Sync + Send + 'static, E: Error + Sync + Send + 'static,
@ -18,12 +18,12 @@ pub struct GenericWebSocketRead<
impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Send + 'static> impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Send + 'static>
GenericWebSocketRead<T, E> GenericWebSocketRead<T, E>
{ {
/// Create a new wrapper WebSocketRead implementation. /// Create a new wrapper `WebSocketRead` implementation.
pub fn new(stream: T) -> Self { pub fn new(stream: T) -> Self {
Self(stream) Self(stream)
} }
/// Get the inner Stream from the wrapper. /// Get the inner `Stream` from the wrapper.
pub fn into_inner(self) -> T { pub fn into_inner(self) -> T {
self.0 self.0
} }
@ -45,7 +45,7 @@ impl<T: Stream<Item = Result<BytesMut, E>> + Send + Unpin, E: Error + Sync + Sen
} }
} }
/// WebSocketWrite implementation for generic `Sink`s. /// `WebSocketWrite` implementation for generic `Sink`s.
pub struct GenericWebSocketWrite< pub struct GenericWebSocketWrite<
T: Sink<Bytes, Error = E> + Send + Unpin, T: Sink<Bytes, Error = E> + Send + Unpin,
E: Error + Sync + Send + 'static, E: Error + Sync + Send + 'static,
@ -54,12 +54,12 @@ pub struct GenericWebSocketWrite<
impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static> impl<T: Sink<Bytes, Error = E> + Send + Unpin, E: Error + Sync + Send + 'static>
GenericWebSocketWrite<T, E> GenericWebSocketWrite<T, E>
{ {
/// Create a new wrapper WebSocketWrite implementation. /// Create a new wrapper `WebSocketWrite` implementation.
pub fn new(stream: T) -> Self { pub fn new(stream: T) -> Self {
Self(stream) Self(stream)
} }
/// Get the inner Sink from the wrapper. /// Get the inner `Sink` from the wrapper.
pub fn into_inner(self) -> T { pub fn into_inner(self) -> T {
self.0 self.0
} }

View file

@ -1,5 +1,12 @@
#![deny(missing_docs, clippy::todo)]
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(clippy::pedantic)]
#![deny(missing_docs, clippy::todo)]
#![allow(
clippy::must_use_candidate,
clippy::missing_errors_doc,
clippy::module_name_repetitions
)]
//! A library for easily creating [Wisp] clients and servers. //! A library for easily creating [Wisp] clients and servers.
//! //!
//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol //! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol

View file

@ -39,14 +39,14 @@ async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
if let PacketType::Info(info) = packet.packet_type { if let PacketType::Info(info) = packet.packet_type {
// v2 server // v2 server
let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?; let buffer_size = validate_continue_packet(&rx.wisp_read_frame(tx).await?.try_into()?)?;
(closure)(&mut builders).await?; (closure)(&mut builders).await?;
send_info_packet(tx, &mut builders).await?; send_info_packet(tx, &mut builders).await?;
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders); let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
for extension in supported_extensions.iter_mut() { for extension in &mut supported_extensions {
extension extension
.handle_handshake(DynWebSocketRead::from_mut(rx), tx) .handle_handshake(DynWebSocketRead::from_mut(rx), tx)
.await?; .await?;
@ -63,7 +63,7 @@ async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
)) ))
} else { } else {
// downgrade to v1 // downgrade to v1
let buffer_size = validate_continue_packet(packet)?; let buffer_size = validate_continue_packet(&packet)?;
Ok(( Ok((
WispHandshakeResult { WispHandshakeResult {
@ -75,7 +75,7 @@ async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
} }
} else { } else {
// user asked for a v1 client // user asked for a v1 client
let buffer_size = validate_continue_packet(rx.wisp_read_frame(tx).await?.try_into()?)?; let buffer_size = validate_continue_packet(&rx.wisp_read_frame(tx).await?.try_into()?)?;
Ok(( Ok((
WispHandshakeResult { WispHandshakeResult {

View file

@ -43,7 +43,7 @@ struct MuxMapValue {
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
} }
pub struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> { pub(crate) struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
// gets taken by the mux task // gets taken by the mux task
rx: Option<R>, rx: Option<R>,
// gets taken by the mux task // gets taken by the mux task
@ -68,7 +68,7 @@ pub struct MuxInner<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
server_tx: mpsc::Sender<(ConnectPacket, MuxStream<W>)>, server_tx: mpsc::Sender<(ConnectPacket, MuxStream<W>)>,
} }
pub struct MuxInnerResult<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> { pub(crate) struct MuxInnerResult<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> {
pub mux: MuxInner<R, W>, pub mux: MuxInner<R, W>,
pub actor_exited: Arc<AtomicBool>, pub actor_exited: Arc<AtomicBool>,
pub actor_tx: mpsc::Sender<WsEvent<W>>, pub actor_tx: mpsc::Sender<WsEvent<W>>,
@ -84,7 +84,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
.collect() .collect()
} }
#[allow(clippy::type_complexity)] #[expect(clippy::type_complexity)]
pub fn new_server( pub fn new_server(
rx: R, rx: R,
maybe_downgrade_packet: Option<Packet<'static>>, maybe_downgrade_packet: Option<Packet<'static>>,
@ -100,6 +100,10 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
let ret_fut_tx = fut_tx.clone(); let ret_fut_tx = fut_tx.clone();
let fut_exited = Arc::new(AtomicBool::new(false)); let fut_exited = Arc::new(AtomicBool::new(false));
// 90% of the buffer size, not possible to overflow
#[expect(clippy::cast_possible_truncation)]
let target_buffer_size = ((u64::from(buffer_size) * 90) / 100) as u32;
( (
MuxInnerResult { MuxInnerResult {
mux: Self { mux: Self {
@ -114,7 +118,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
tcp_extensions: Self::get_tcp_extensions(&extensions), tcp_extensions: Self::get_tcp_extensions(&extensions),
extensions: Some(extensions), extensions: Some(extensions),
buffer_size, buffer_size,
target_buffer_size: ((buffer_size as u64 * 90) / 100) as u32, target_buffer_size,
role: Role::Server, role: Role::Server,
@ -172,8 +176,8 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
self.fut_exited.store(true, Ordering::Release); self.fut_exited.store(true, Ordering::Release);
for (_, stream) in self.stream_map.iter() { for stream in self.stream_map.values() {
self.close_stream(stream, ClosePacket::new(CloseReason::Unknown)); Self::close_stream(stream, ClosePacket::new(CloseReason::Unknown));
} }
self.stream_map.clear(); self.stream_map.clear();
@ -181,11 +185,11 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
ret ret
} }
async fn create_new_stream( fn create_new_stream(
&mut self, &mut self,
stream_id: u32, stream_id: u32,
stream_type: StreamType, stream_type: StreamType,
) -> Result<(MuxMapValue, MuxStream<W>), WispError> { ) -> (MuxMapValue, MuxStream<W>) {
let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server { let (ch_tx, ch_rx) = mpsc::bounded(if self.role == Role::Server {
self.buffer_size as usize self.buffer_size as usize
} else { } else {
@ -201,7 +205,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
AtomicCloseReason::new(CloseReason::Unknown).into(); AtomicCloseReason::new(CloseReason::Unknown).into();
let is_closed_event: Arc<Event> = Event::new().into(); let is_closed_event: Arc<Event> = Event::new().into();
Ok(( (
MuxMapValue { MuxMapValue {
stream: ch_tx, stream: ch_tx,
stream_type, stream_type,
@ -229,10 +233,10 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
flow_control_event, flow_control_event,
self.target_buffer_size, self.target_buffer_size,
), ),
)) )
} }
fn close_stream(&self, stream: &MuxMapValue, close_packet: ClosePacket) { fn close_stream(stream: &MuxMapValue, close_packet: ClosePacket) {
stream stream
.close_reason .close_reason
.store(close_packet.reason, Ordering::Release); .store(close_packet.reason, Ordering::Release);
@ -319,8 +323,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
.checked_add(1) .checked_add(1)
.ok_or(WispError::MaxStreamCountReached)?; .ok_or(WispError::MaxStreamCountReached)?;
let (map_value, stream) = let (map_value, stream) = self.create_new_stream(stream_id, stream_type);
self.create_new_stream(stream_id, stream_type).await?;
self.tx self.tx
.write_frame( .write_frame(
@ -340,7 +343,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
WsEvent::Close(packet, channel) => { WsEvent::Close(packet, channel) => {
if let Some(stream) = self.stream_map.remove(&packet.stream_id) { if let Some(stream) = self.stream_map.remove(&packet.stream_id) {
if let PacketType::Close(close) = packet.packet_type { if let PacketType::Close(close) = packet.packet_type {
self.close_stream(&stream, close); Self::close_stream(&stream, close);
} }
let _ = channel.send(self.tx.write_frame(packet.into()).await); let _ = channel.send(self.tx.write_frame(packet.into()).await);
} else { } else {
@ -383,20 +386,16 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
Ok(()) Ok(())
} }
fn handle_close_packet( fn handle_close_packet(&mut self, stream_id: u32, inner_packet: ClosePacket) -> bool {
&mut self,
stream_id: u32,
inner_packet: ClosePacket,
) -> Result<bool, WispError> {
if stream_id == 0 { if stream_id == 0 {
return Ok(true); return true;
} }
if let Some(stream) = self.stream_map.remove(&stream_id) { if let Some(stream) = self.stream_map.remove(&stream_id) {
self.close_stream(&stream, inner_packet); Self::close_stream(&stream, inner_packet);
} }
Ok(false) false
} }
fn handle_data_packet( fn handle_data_packet(
@ -404,7 +403,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
stream_id: u32, stream_id: u32,
optional_frame: Option<Frame<'static>>, optional_frame: Option<Frame<'static>>,
data: Payload<'static>, data: Payload<'static>,
) -> Result<bool, WispError> { ) -> bool {
let mut data = BytesMut::from(data); let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&stream_id) { if let Some(stream) = self.stream_map.get(&stream_id) {
@ -427,7 +426,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
} }
} }
Ok(false) false
} }
async fn handle_packet( async fn handle_packet(
@ -437,12 +436,12 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
) -> Result<bool, WispError> { ) -> Result<bool, WispError> {
use PacketType as P; use PacketType as P;
match packet.packet_type { match packet.packet_type {
P::Data(data) => self.handle_data_packet(packet.stream_id, optional_frame, data), P::Data(data) => Ok(self.handle_data_packet(packet.stream_id, optional_frame, data)),
P::Close(inner_packet) => self.handle_close_packet(packet.stream_id, inner_packet), P::Close(inner_packet) => Ok(self.handle_close_packet(packet.stream_id, inner_packet)),
_ => match self.role { _ => match self.role {
Role::Server => self.server_handle_packet(packet, optional_frame).await, Role::Server => self.server_handle_packet(packet, optional_frame).await,
Role::Client => self.client_handle_packet(packet, optional_frame).await, Role::Client => self.client_handle_packet(&packet),
}, },
} }
} }
@ -455,9 +454,8 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
use PacketType as P; use PacketType as P;
match packet.packet_type { match packet.packet_type {
P::Connect(inner_packet) => { P::Connect(inner_packet) => {
let (map_value, stream) = self let (map_value, stream) =
.create_new_stream(packet.stream_id, inner_packet.stream_type) self.create_new_stream(packet.stream_id, inner_packet.stream_type);
.await?;
self.server_tx self.server_tx
.send_async((inner_packet, stream)) .send_async((inner_packet, stream))
.await .await
@ -472,11 +470,7 @@ impl<R: WebSocketRead + 'static, W: WebSocketWrite + 'static> MuxInner<R, W> {
} }
} }
async fn client_handle_packet( fn client_handle_packet(&mut self, packet: &Packet<'static>) -> Result<bool, WispError> {
&mut self,
packet: Packet<'static>,
_optional_frame: Option<Frame<'static>>,
) -> Result<bool, WispError> {
use PacketType as P; use PacketType as P;
match packet.packet_type { match packet.packet_type {
P::Continue(inner_packet) => { P::Continue(inner_packet) => {

View file

@ -52,7 +52,7 @@ async fn send_info_packet<W: WebSocketWrite>(
.await .await
} }
fn validate_continue_packet(packet: Packet<'_>) -> Result<u32, WispError> { fn validate_continue_packet(packet: &Packet<'_>) -> Result<u32, WispError> {
if packet.stream_id != 0 { if packet.stream_id != 0 {
return Err(WispError::InvalidStreamId); return Err(WispError::InvalidStreamId);
} }

View file

@ -46,7 +46,7 @@ async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
if let PacketType::Info(info) = packet.packet_type { if let PacketType::Info(info) = packet.packet_type {
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders); let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
for extension in supported_extensions.iter_mut() { for extension in &mut supported_extensions {
extension extension
.handle_handshake(DynWebSocketRead::from_mut(rx), tx) .handle_handshake(DynWebSocketRead::from_mut(rx), tx)
.await?; .await?;

View file

@ -492,9 +492,9 @@ impl<'a> Packet<'a> {
return Err(WispError::PacketTooSmall); return Err(WispError::PacketTooSmall);
} }
if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) { if let Some(builder) = extension_builders.iter_mut().find(|x| x.get_id() == id) {
extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?) extensions.push(builder.build_from_bytes(bytes.copy_to_bytes(length), role)?);
} else { } else {
bytes.advance(length) bytes.advance(length);
} }
} }

View file

@ -205,7 +205,7 @@ impl AsyncBufRead for MuxStreamAsyncRW {
} }
fn consume(self: Pin<&mut Self>, amt: usize) { fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt) self.project().rx.consume(amt);
} }
} }
@ -270,7 +270,7 @@ impl AsyncBufRead for MuxStreamAsyncRead {
self.project().rx.poll_fill_buf(cx) self.project().rx.poll_fill_buf(cx)
} }
fn consume(self: Pin<&mut Self>, amt: usize) { fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().rx.consume(amt) self.project().rx.consume(amt);
} }
} }
@ -319,7 +319,7 @@ impl AsyncWrite for MuxStreamAsyncWrite {
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
self.error = Some(err); self.error = Some(err);
} }
Poll::Ready(Ok(_)) | Poll::Pending => {} Poll::Ready(Ok(())) | Poll::Pending => {}
} }
Poll::Ready(Ok(buf.len())) Poll::Ready(Ok(buf.len()))

View file

@ -50,7 +50,7 @@ impl<W: WebSocketWrite + 'static> MuxStreamRead<W> {
} }
let bytes = select! { let bytes = select! {
x = self.rx.recv_async() => x.map_err(|_| WispError::MuxMessageFailedToRecv)?, x = self.rx.recv_async() => x.map_err(|_| WispError::MuxMessageFailedToRecv)?,
_ = self.is_closed_event.listen().fuse() => return Ok(None) () = self.is_closed_event.listen().fuse() => return Ok(None)
}; };
if self.role == Role::Server && self.should_flow_control { if self.role == Role::Server && self.should_flow_control {
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
@ -288,11 +288,14 @@ impl<W: WebSocketWrite + 'static> MuxStream<W> {
stream_id, stream_id,
stream_type, stream_type,
role, role,
tx: tx.clone(), tx: tx.clone(),
rx, rx,
is_closed: is_closed.clone(), is_closed: is_closed.clone(),
is_closed_event: is_closed_event.clone(), is_closed_event,
close_reason: close_reason.clone(), close_reason: close_reason.clone(),
should_flow_control, should_flow_control,
flow_control: flow_control.clone(), flow_control: flow_control.clone(),
flow_control_read: AtomicU32::new(0), flow_control_read: AtomicU32::new(0),
@ -302,13 +305,16 @@ impl<W: WebSocketWrite + 'static> MuxStream<W> {
stream_id, stream_id,
stream_type, stream_type,
role, role,
mux_tx, mux_tx,
tx, tx,
is_closed: is_closed.clone(),
close_reason: close_reason.clone(), is_closed,
close_reason,
continue_recieved,
should_flow_control, should_flow_control,
flow_control: flow_control.clone(), flow_control,
continue_recieved: continue_recieved.clone(),
}, },
} }
} }

View file

@ -15,7 +15,7 @@ use futures::{lock::Mutex, TryFutureExt};
pub enum Payload<'a> { pub enum Payload<'a> {
/// Borrowed payload. Currently used when writing data. /// Borrowed payload. Currently used when writing data.
Borrowed(&'a [u8]), Borrowed(&'a [u8]),
/// BytesMut payload. Currently used when reading data. /// `BytesMut` payload. Currently used when reading data.
Bytes(BytesMut), Bytes(BytesMut),
} }
@ -33,6 +33,7 @@ impl<'a> From<&'a [u8]> for Payload<'a> {
impl Payload<'_> { impl Payload<'_> {
/// Turn a Payload<'a> into a Payload<'static> by copying the data. /// Turn a Payload<'a> into a Payload<'static> by copying the data.
#[must_use]
pub fn into_owned(self) -> Self { pub fn into_owned(self) -> Self {
match self { match self {
Self::Bytes(x) => Self::Bytes(x), Self::Bytes(x) => Self::Bytes(x),
@ -54,7 +55,7 @@ impl Deref for Payload<'_> {
type Target = [u8]; type Target = [u8];
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
match self { match self {
Self::Bytes(x) => x.deref(), Self::Bytes(x) => x,
Self::Borrowed(x) => x, Self::Borrowed(x) => x,
} }
} }
@ -175,7 +176,7 @@ pub trait WebSocketRead: Send {
// similar to what dynosaur does // similar to what dynosaur does
mod wsr_inner { mod wsr_inner {
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin, ptr};
use crate::WispError; use crate::WispError;
@ -187,7 +188,7 @@ mod wsr_inner {
tx: &'a dyn LockingWebSocketWrite, tx: &'a dyn LockingWebSocketWrite,
) -> Pin<Box<dyn Future<Output = Result<Frame<'static>, WispError>> + Send + 'a>>; ) -> Pin<Box<dyn Future<Output = Result<Frame<'static>, WispError>> + Send + 'a>>;
#[allow(clippy::type_complexity)] #[expect(clippy::type_complexity)]
fn wisp_read_split<'a>( fn wisp_read_split<'a>(
&'a mut self, &'a mut self,
tx: &'a dyn LockingWebSocketWrite, tx: &'a dyn LockingWebSocketWrite,
@ -222,7 +223,7 @@ mod wsr_inner {
} }
} }
/// WebSocketRead trait object. /// `WebSocketRead` trait object.
#[repr(transparent)] #[repr(transparent)]
pub struct DynWebSocketRead { pub struct DynWebSocketRead {
ptr: dyn ErasedWebSocketRead + 'static, ptr: dyn ErasedWebSocketRead + 'static,
@ -243,24 +244,26 @@ mod wsr_inner {
} }
} }
impl DynWebSocketRead { impl DynWebSocketRead {
/// Create a WebSocketRead trait object from a boxed WebSocketRead. /// Create a `WebSocketRead` trait object from a boxed `WebSocketRead`.
pub fn new(val: Box<impl WebSocketRead + 'static>) -> Box<Self> { pub fn new(val: Box<impl WebSocketRead + 'static>) -> Box<Self> {
let val: Box<dyn ErasedWebSocketRead + 'static> = val; let val: Box<dyn ErasedWebSocketRead + 'static> = val;
unsafe { std::mem::transmute(val) } unsafe { std::mem::transmute(val) }
} }
/// Create a WebSocketRead trait object from a WebSocketRead. /// Create a `WebSocketRead` trait object from a `WebSocketRead`.
pub fn boxed(val: impl WebSocketRead + 'static) -> Box<Self> { pub fn boxed(val: impl WebSocketRead + 'static) -> Box<Self> {
Self::new(Box::new(val)) Self::new(Box::new(val))
} }
/// Create a WebSocketRead trait object from a WebSocketRead reference. /// Create a `WebSocketRead` trait object from a `WebSocketRead` reference.
pub fn from_ref(val: &(impl WebSocketRead + 'static)) -> &Self { pub fn from_ref(val: &(impl WebSocketRead + 'static)) -> &Self {
let val: &(dyn ErasedWebSocketRead + 'static) = val; let val: &(dyn ErasedWebSocketRead + 'static) = val;
unsafe { std::mem::transmute(val) } unsafe { &*(ptr::from_ref::<dyn ErasedWebSocketRead>(val) as *const DynWebSocketRead) }
} }
/// Create a WebSocketRead trait object from a mutable WebSocketRead reference. /// Create a `WebSocketRead` trait object from a mutable `WebSocketRead` reference.
pub fn from_mut(val: &mut (impl WebSocketRead + 'static)) -> &mut Self { pub fn from_mut(val: &mut (impl WebSocketRead + 'static)) -> &mut Self {
let val: &mut (dyn ErasedWebSocketRead + 'static) = &mut *val; let val: &mut (dyn ErasedWebSocketRead + 'static) = &mut *val;
unsafe { std::mem::transmute(val) } unsafe {
&mut *(ptr::from_mut::<dyn ErasedWebSocketRead>(val) as *mut DynWebSocketRead)
}
} }
} }
} }
@ -294,7 +297,7 @@ pub trait WebSocketWrite: Send {
// similar to what dynosaur does // similar to what dynosaur does
mod wsw_inner { mod wsw_inner {
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin, ptr};
use crate::WispError; use crate::WispError;
@ -340,7 +343,7 @@ mod wsw_inner {
} }
} }
/// WebSocketWrite trait object. /// `WebSocketWrite` trait object.
#[repr(transparent)] #[repr(transparent)]
pub struct DynWebSocketWrite { pub struct DynWebSocketWrite {
ptr: dyn ErasedWebSocketWrite + 'static, ptr: dyn ErasedWebSocketWrite + 'static,
@ -363,24 +366,28 @@ mod wsw_inner {
} }
} }
impl DynWebSocketWrite { impl DynWebSocketWrite {
/// Create a new WebSocketWrite trait object from a boxed WebSocketWrite. /// Create a new `WebSocketWrite` trait object from a boxed `WebSocketWrite`.
pub fn new(val: Box<impl WebSocketWrite + 'static>) -> Box<Self> { pub fn new(val: Box<impl WebSocketWrite + 'static>) -> Box<Self> {
let val: Box<dyn ErasedWebSocketWrite + 'static> = val; let val: Box<dyn ErasedWebSocketWrite + 'static> = val;
unsafe { std::mem::transmute(val) } unsafe { std::mem::transmute(val) }
} }
/// Create a new WebSocketWrite trait object from a WebSocketWrite. /// Create a new `WebSocketWrite` trait object from a `WebSocketWrite`.
pub fn boxed(val: impl WebSocketWrite + 'static) -> Box<Self> { pub fn boxed(val: impl WebSocketWrite + 'static) -> Box<Self> {
Self::new(Box::new(val)) Self::new(Box::new(val))
} }
/// Create a new WebSocketWrite trait object from a WebSocketWrite reference. /// Create a new `WebSocketWrite` trait object from a `WebSocketWrite` reference.
pub fn from_ref(val: &(impl WebSocketWrite + 'static)) -> &Self { pub fn from_ref(val: &(impl WebSocketWrite + 'static)) -> &Self {
let val: &(dyn ErasedWebSocketWrite + 'static) = val; let val: &(dyn ErasedWebSocketWrite + 'static) = val;
unsafe { std::mem::transmute(val) } unsafe {
&*(ptr::from_ref::<dyn ErasedWebSocketWrite>(val) as *const DynWebSocketWrite)
}
} }
/// Create a new WebSocketWrite trait object from a mutable WebSocketWrite reference. /// Create a new `WebSocketWrite` trait object from a mutable `WebSocketWrite` reference.
pub fn from_mut(val: &mut (impl WebSocketWrite + 'static)) -> &mut Self { pub fn from_mut(val: &mut (impl WebSocketWrite + 'static)) -> &mut Self {
let val: &mut (dyn ErasedWebSocketWrite + 'static) = &mut *val; let val: &mut (dyn ErasedWebSocketWrite + 'static) = &mut *val;
unsafe { std::mem::transmute(val) } unsafe {
&mut *(ptr::from_mut::<dyn ErasedWebSocketWrite>(val) as *mut DynWebSocketWrite)
}
} }
} }
} }
@ -390,7 +397,7 @@ mod private {
pub trait Sealed {} pub trait Sealed {}
} }
/// Helper trait object for LockedWebSocketWrite. /// Helper trait object for `LockedWebSocketWrite`.
pub trait LockingWebSocketWrite: private::Sealed + Sync { pub trait LockingWebSocketWrite: private::Sealed + Sync {
/// Write a frame to the websocket. /// Write a frame to the websocket.
fn wisp_write_frame<'a>( fn wisp_write_frame<'a>(
@ -471,11 +478,11 @@ impl<T: WebSocketWrite> LockingWebSocketWrite for LockedWebSocketWrite<T> {
} }
} }
/// Combines two different WebSocketReads together. /// Combines two different `WebSocketRead`s together.
pub enum EitherWebSocketRead<A: WebSocketRead, B: WebSocketRead> { pub enum EitherWebSocketRead<A: WebSocketRead, B: WebSocketRead> {
/// First WebSocketRead variant. /// First `WebSocketRead` variant.
Left(A), Left(A),
/// Second WebSocketRead variant. /// Second `WebSocketRead` variant.
Right(B), Right(B),
} }
impl<A: WebSocketRead, B: WebSocketRead> WebSocketRead for EitherWebSocketRead<A, B> { impl<A: WebSocketRead, B: WebSocketRead> WebSocketRead for EitherWebSocketRead<A, B> {
@ -500,11 +507,11 @@ impl<A: WebSocketRead, B: WebSocketRead> WebSocketRead for EitherWebSocketRead<A
} }
} }
/// Combines two different WebSocketWrites together. /// Combines two different `WebSocketWrite`s together.
pub enum EitherWebSocketWrite<A: WebSocketWrite, B: WebSocketWrite> { pub enum EitherWebSocketWrite<A: WebSocketWrite, B: WebSocketWrite> {
/// First WebSocketWrite variant. /// First `WebSocketWrite` variant.
Left(A), Left(A),
/// Second WebSocketWrite variant. /// Second `WebSocketWrite` variant.
Right(B), Right(B),
} }
impl<A: WebSocketWrite, B: WebSocketWrite> WebSocketWrite for EitherWebSocketWrite<A, B> { impl<A: WebSocketWrite, B: WebSocketWrite> WebSocketWrite for EitherWebSocketWrite<A, B> {