clean up muxresult types

This commit is contained in:
Toshit Chawda 2024-10-23 23:23:55 -07:00
parent fc63298397
commit c8de5524b4
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 81 additions and 114 deletions

View file

@ -11,13 +11,13 @@ use flume as mpsc;
use futures::channel::oneshot;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
extensions::AnyProtocolExtension,
inner::{MuxInner, WsEvent},
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
};
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
use super::{maybe_wisp_v2, send_info_packet, Multiplexor, MuxResult, WispV2Extensions};
/// Server-side multiplexor.
pub struct ServerMux {
@ -44,7 +44,7 @@ impl ServerMux {
tx: W,
buffer_size: u32,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
) -> Result<MuxResult<ServerMux, impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static,
@ -74,7 +74,7 @@ impl ServerMux {
buffer_size,
);
Ok(ServerMuxResult(
Ok(MuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
@ -175,51 +175,13 @@ impl Drop for ServerMux {
}
}
/// Result of `ServerMux::new`.
pub struct ServerMuxResult<F>(ServerMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ServerMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
(self.0, self.1)
impl Multiplexor for ServerMux {
fn has_extension(&self, extension_id: u8) -> bool {
self.supported_extensions
.iter()
.any(|x| x.get_id() == extension_id)
}
/// Require protocol extensions by their ID. Will close the multiplexor connection if
/// extensions are not supported.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
async fn exit(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_with_reason(reason).await
}
}