separate clientmux and servermux into new files

This commit is contained in:
Toshit Chawda 2024-10-23 23:00:23 -07:00
parent 65a7904437
commit 2efb641228
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
4 changed files with 559 additions and 519 deletions

View file

@ -12,29 +12,15 @@ mod fastwebsockets;
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
pub mod generic;
mod inner;
mod mux;
mod packet;
mod sink_unfold;
mod stream;
pub mod ws;
pub use crate::{packet::*, stream::*};
pub use crate::{mux::*, packet::*, stream::*};
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
use flume as mpsc;
use futures::{channel::oneshot, select, Future, FutureExt};
use futures_timer::Delay;
use inner::{MuxInner, WsEvent};
use std::{
ops::DerefMut,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use thiserror::Error;
use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
/// Wisp version supported by this crate.
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
@ -128,506 +114,3 @@ pub enum WispError {
#[error("Certificate authentication protocol extension: Invalid signature")]
CertAuthExtensionSigInvalid,
}
async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
role: Role,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
where
R: ws::WebSocketRead + Send,
{
let mut supported_extensions = Vec::new();
let mut extra_packet: Option<ws::Frame<'static>> = None;
let mut downgraded = true;
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, role, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.replace(ws::Frame::from(packet).clone());
}
}
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(read, write).await?;
}
Ok((supported_extensions, extra_packet, downgraded))
}
async fn send_info_packet(
write: &LockedWebSocketWrite,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(), WispError> {
write
.write_frame(
Packet::new_info(
builders
.iter_mut()
.map(|x| x.build_to_extension(Role::Server))
.collect::<Result<Vec<_>, _>>()?,
)
.into(),
)
.await
}
/// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<
dyn Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
}
impl WispV2Extensions {
/// Create a Wisp V2 settings struct with no middleware.
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
Self {
builders,
closure: Box::new(|_| Box::pin(async { Ok(()) })),
}
}
/// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where
C: Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{
Self {
builders,
closure: Box::new(closure),
}
}
/// Add a Wisp V2 extension builder to the settings struct.
pub fn add_extension(&mut self, extension: AnyProtocolExtensionBuilder) {
self.builders.push(extension);
}
}
/// Server-side multiplexor.
pub struct ServerMux {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: ws::LockedWebSocketWrite,
actor_exited: Arc<AtomicBool>,
}
impl ServerMux {
/// Create a new server-side multiplexor.
///
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
buffer_size: u32,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
let ret_tx = tx.clone();
let ret = async {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
send_info_packet(&tx, builders.deref_mut()).await?;
(closure)(builders.deref_mut()).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
buffer_size,
);
Ok(ServerMuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
}
.await;
match ret {
Ok(x) => Ok(x),
Err(x) => match x {
WispError::PasswordExtensionCredsInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
WispError::CertAuthExtensionSigInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
x => Err(x),
},
}
}
/// Wait for a stream to be created.
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
if self.actor_exited.load(Ordering::Acquire) {
return None;
}
self.muxstream_recv.recv_async().await.ok()
}
/// Send a ping to the client.
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
let (tx, rx) = oneshot::channel();
self.actor_tx
.send_async(WsEvent::SendPing(payload, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.actor_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future.
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
is_closed: self.actor_exited.clone(),
}
}
}
impl Drop for ServerMux {
fn drop(&mut self) {
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
/// Result of `ServerMux::new`.
pub struct ServerMuxResult<F>(ServerMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ServerMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID. Will close the multiplexor connection if
/// extensions are not supported.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}
/// Client side multiplexor.
pub struct ClientMux {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>,
tx: ws::LockedWebSocketWrite,
actor_exited: Arc<AtomicBool>,
}
impl ClientMux {
/// Create a new client side multiplexor.
///
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static,
{
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
// if not downgraded
if !res.2 {
(closure)(&mut builders).await?;
send_info_packet(&tx, &mut builders).await?;
}
res
} else {
(Vec::new(), None, true)
};
let mux_result = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
packet.buffer_remaining,
);
Ok(ClientMuxResult(
Self {
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
} else {
Err(WispError::InvalidPacketType)
}
}
/// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream(
&self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream, WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
if stream_type == StreamType::Udp
&& !self
.supported_extensions
.iter()
.any(|x| x.get_id() == UdpProtocolExtension::ID)
{
return Err(WispError::ExtensionsNotSupported(vec![
UdpProtocolExtension::ID,
]));
}
let (tx, rx) = oneshot::channel();
self.actor_tx
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
/// Send a ping to the server.
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
let (tx, rx) = oneshot::channel();
self.actor_tx
.send_async(WsEvent::SendPing(payload, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.actor_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future.
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
is_closed: self.actor_exited.clone(),
}
}
}
impl Drop for ClientMux {
fn drop(&mut self) {
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
/// Result of `ClientMux::new`.
pub struct ClientMuxResult<F>(ClientMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ClientMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}

223
wisp/src/mux/client.rs Normal file
View file

@ -0,0 +1,223 @@
use std::{
future::Future,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use flume as mpsc;
use futures::channel::oneshot;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
inner::{MuxInner, WsEvent},
ws::{AppendingWebSocketRead, LockedWebSocketWrite, WebSocketRead, WebSocketWrite, Payload},
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
WispError,
};
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
/// Client side multiplexor.
pub struct ClientMux {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>,
tx: LockedWebSocketWrite,
actor_exited: Arc<AtomicBool>,
}
impl ClientMux {
/// Create a new client side multiplexor.
///
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static,
{
let tx = LockedWebSocketWrite::new(Box::new(tx));
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
if first_packet.stream_id != 0 {
return Err(WispError::InvalidStreamId);
}
if let PacketType::Continue(packet) = first_packet.packet_type {
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
// if not downgraded
if !res.2 {
(closure)(&mut builders).await?;
send_info_packet(&tx, &mut builders).await?;
}
res
} else {
(Vec::new(), None, true)
};
let mux_result = MuxInner::new_client(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
packet.buffer_remaining,
);
Ok(ClientMuxResult(
Self {
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
} else {
Err(WispError::InvalidPacketType)
}
}
/// Create a new stream, multiplexed through Wisp.
pub async fn client_new_stream(
&self,
stream_type: StreamType,
host: String,
port: u16,
) -> Result<MuxStream, WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
if stream_type == StreamType::Udp
&& !self
.supported_extensions
.iter()
.any(|x| x.get_id() == UdpProtocolExtension::ID)
{
return Err(WispError::ExtensionsNotSupported(vec![
UdpProtocolExtension::ID,
]));
}
let (tx, rx) = oneshot::channel();
self.actor_tx
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
/// Send a ping to the server.
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
let (tx, rx) = oneshot::channel();
self.actor_tx
.send_async(WsEvent::SendPing(payload, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.actor_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future.
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
is_closed: self.actor_exited.clone(),
}
}
}
impl Drop for ClientMux {
fn drop(&mut self) {
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
/// Result of `ClientMux::new`.
pub struct ClientMuxResult<F>(ClientMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ClientMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}

109
wisp/src/mux/mod.rs Normal file
View file

@ -0,0 +1,109 @@
mod client;
mod server;
use std::{future::Future, pin::Pin, time::Duration};
pub use client::ClientMux;
use futures::{select, FutureExt};
use futures_timer::Delay;
pub use server::{ServerMux, ServerMuxResult};
use crate::{
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
ws::{Frame, LockedWebSocketWrite, WebSocketRead},
Packet, PacketType, Role, WispError,
};
async fn maybe_wisp_v2<R>(
read: &mut R,
write: &LockedWebSocketWrite,
role: Role,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(Vec<AnyProtocolExtension>, Option<Frame<'static>>, bool), WispError>
where
R: WebSocketRead + Send,
{
let mut supported_extensions = Vec::new();
let mut extra_packet: Option<Frame<'static>> = None;
let mut downgraded = true;
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
if let Some(frame) = select! {
x = read.wisp_read_frame(write).fuse() => Some(x?),
_ = Delay::new(Duration::from_secs(5)).fuse() => None
} {
let packet = Packet::maybe_parse_info(frame, role, builders)?;
if let PacketType::Info(info) = packet.packet_type {
supported_extensions = info
.extensions
.into_iter()
.filter(|x| extension_ids.contains(&x.get_id()))
.collect();
downgraded = false;
} else {
extra_packet.replace(Frame::from(packet).clone());
}
}
for extension in supported_extensions.iter_mut() {
extension.handle_handshake(read, write).await?;
}
Ok((supported_extensions, extra_packet, downgraded))
}
async fn send_info_packet(
write: &LockedWebSocketWrite,
builders: &mut [AnyProtocolExtensionBuilder],
) -> Result<(), WispError> {
write
.write_frame(
Packet::new_info(
builders
.iter_mut()
.map(|x| x.build_to_extension(Role::Server))
.collect::<Result<Vec<_>, _>>()?,
)
.into(),
)
.await
}
/// Wisp V2 handshake and protocol extension settings wrapper struct.
pub struct WispV2Extensions {
builders: Vec<AnyProtocolExtensionBuilder>,
closure: Box<
dyn Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send,
>,
}
impl WispV2Extensions {
/// Create a Wisp V2 settings struct with no middleware.
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
Self {
builders,
closure: Box::new(|_| Box::pin(async { Ok(()) })),
}
}
/// Create a Wisp V2 settings struct with some middleware.
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
where
C: Fn(
&mut [AnyProtocolExtensionBuilder],
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
+ Send
+ 'static,
{
Self {
builders,
closure: Box::new(closure),
}
}
/// Add a Wisp V2 extension builder to the settings struct.
pub fn add_extension(&mut self, extension: AnyProtocolExtensionBuilder) {
self.builders.push(extension);
}
}

225
wisp/src/mux/server.rs Normal file
View file

@ -0,0 +1,225 @@
use std::{
future::Future,
ops::DerefMut,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use flume as mpsc;
use futures::channel::oneshot;
use crate::{
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
inner::{MuxInner, WsEvent},
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
};
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
/// Server-side multiplexor.
pub struct ServerMux {
/// Whether the connection was downgraded to Wisp v1.
///
/// If this variable is true you must assume no extensions are supported.
pub downgraded: bool,
/// Extensions that are supported by both sides.
pub supported_extensions: Vec<AnyProtocolExtension>,
actor_tx: mpsc::Sender<WsEvent>,
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
tx: LockedWebSocketWrite,
actor_exited: Arc<AtomicBool>,
}
impl ServerMux {
/// Create a new server-side multiplexor.
///
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created.
pub async fn create<R, W>(
mut rx: R,
tx: W,
buffer_size: u32,
wisp_v2: Option<WispV2Extensions>,
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where
R: WebSocketRead + Send,
W: WebSocketWrite + Send + 'static,
{
let tx = LockedWebSocketWrite::new(Box::new(tx));
let ret_tx = tx.clone();
let ret = async {
tx.write_frame(Packet::new_continue(0, buffer_size).into())
.await?;
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
mut builders,
closure,
}) = wisp_v2
{
send_info_packet(&tx, builders.deref_mut()).await?;
(closure)(builders.deref_mut()).await?;
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
} else {
(Vec::new(), None, true)
};
let (mux_result, muxstream_recv) = MuxInner::new_server(
AppendingWebSocketRead(extra_packet, rx),
tx.clone(),
supported_extensions.clone(),
buffer_size,
);
Ok(ServerMuxResult(
Self {
muxstream_recv,
actor_tx: mux_result.actor_tx,
downgraded,
supported_extensions,
tx,
actor_exited: mux_result.actor_exited,
},
mux_result.mux.into_future(),
))
}
.await;
match ret {
Ok(x) => Ok(x),
Err(x) => match x {
WispError::PasswordExtensionCredsInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
WispError::CertAuthExtensionSigInvalid => {
ret_tx
.write_frame(
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
)
.await?;
ret_tx.close().await?;
Err(x)
}
x => Err(x),
},
}
}
/// Wait for a stream to be created.
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
if self.actor_exited.load(Ordering::Acquire) {
return None;
}
self.muxstream_recv.recv_async().await.ok()
}
/// Send a ping to the client.
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
let (tx, rx) = oneshot::channel();
self.actor_tx
.send_async(WsEvent::SendPing(payload, tx))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
}
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
if self.actor_exited.load(Ordering::Acquire) {
return Err(WispError::MuxTaskEnded);
}
self.actor_tx
.send_async(WsEvent::EndFut(reason))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)
}
/// Close all streams.
///
/// Also terminates the multiplexor future.
pub async fn close(&self) -> Result<(), WispError> {
self.close_internal(None).await
}
/// Close all streams and send a close reason on stream ID 0.
///
/// Also terminates the multiplexor future.
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
self.close_internal(Some(reason)).await
}
/// Get a protocol extension stream for sending packets with stream id 0.
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
MuxProtocolExtensionStream {
stream_id: 0,
tx: self.tx.clone(),
is_closed: self.actor_exited.clone(),
}
}
}
impl Drop for ServerMux {
fn drop(&mut self) {
let _ = self.actor_tx.send(WsEvent::EndFut(None));
}
}
/// Result of `ServerMux::new`.
pub struct ServerMuxResult<F>(ServerMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ServerMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID. Will close the multiplexor connection if
/// extensions are not supported.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self
.0
.supported_extensions
.iter()
.any(|x| x.get_id() == *extension)
{
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0
.close_with_reason(CloseReason::ExtensionsIncompatible)
.await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}