mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
separate clientmux and servermux into new files
This commit is contained in:
parent
65a7904437
commit
2efb641228
4 changed files with 559 additions and 519 deletions
521
wisp/src/lib.rs
521
wisp/src/lib.rs
|
@ -12,29 +12,15 @@ mod fastwebsockets;
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "generic_stream")))]
|
||||||
pub mod generic;
|
pub mod generic;
|
||||||
mod inner;
|
mod inner;
|
||||||
|
mod mux;
|
||||||
mod packet;
|
mod packet;
|
||||||
mod sink_unfold;
|
mod sink_unfold;
|
||||||
mod stream;
|
mod stream;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
pub use crate::{packet::*, stream::*};
|
pub use crate::{mux::*, packet::*, stream::*};
|
||||||
|
|
||||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, AnyProtocolExtensionBuilder};
|
|
||||||
use flume as mpsc;
|
|
||||||
use futures::{channel::oneshot, select, Future, FutureExt};
|
|
||||||
use futures_timer::Delay;
|
|
||||||
use inner::{MuxInner, WsEvent};
|
|
||||||
use std::{
|
|
||||||
ops::DerefMut,
|
|
||||||
pin::Pin,
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload};
|
|
||||||
|
|
||||||
/// Wisp version supported by this crate.
|
/// Wisp version supported by this crate.
|
||||||
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||||
|
@ -128,506 +114,3 @@ pub enum WispError {
|
||||||
#[error("Certificate authentication protocol extension: Invalid signature")]
|
#[error("Certificate authentication protocol extension: Invalid signature")]
|
||||||
CertAuthExtensionSigInvalid,
|
CertAuthExtensionSigInvalid,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn maybe_wisp_v2<R>(
|
|
||||||
read: &mut R,
|
|
||||||
write: &LockedWebSocketWrite,
|
|
||||||
role: Role,
|
|
||||||
builders: &mut [AnyProtocolExtensionBuilder],
|
|
||||||
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame<'static>>, bool), WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
{
|
|
||||||
let mut supported_extensions = Vec::new();
|
|
||||||
let mut extra_packet: Option<ws::Frame<'static>> = None;
|
|
||||||
let mut downgraded = true;
|
|
||||||
|
|
||||||
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
|
||||||
if let Some(frame) = select! {
|
|
||||||
x = read.wisp_read_frame(write).fuse() => Some(x?),
|
|
||||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
|
||||||
} {
|
|
||||||
let packet = Packet::maybe_parse_info(frame, role, builders)?;
|
|
||||||
if let PacketType::Info(info) = packet.packet_type {
|
|
||||||
supported_extensions = info
|
|
||||||
.extensions
|
|
||||||
.into_iter()
|
|
||||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
|
||||||
.collect();
|
|
||||||
downgraded = false;
|
|
||||||
} else {
|
|
||||||
extra_packet.replace(ws::Frame::from(packet).clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for extension in supported_extensions.iter_mut() {
|
|
||||||
extension.handle_handshake(read, write).await?;
|
|
||||||
}
|
|
||||||
Ok((supported_extensions, extra_packet, downgraded))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_info_packet(
|
|
||||||
write: &LockedWebSocketWrite,
|
|
||||||
builders: &mut [AnyProtocolExtensionBuilder],
|
|
||||||
) -> Result<(), WispError> {
|
|
||||||
write
|
|
||||||
.write_frame(
|
|
||||||
Packet::new_info(
|
|
||||||
builders
|
|
||||||
.iter_mut()
|
|
||||||
.map(|x| x.build_to_extension(Role::Server))
|
|
||||||
.collect::<Result<Vec<_>, _>>()?,
|
|
||||||
)
|
|
||||||
.into(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
|
||||||
pub struct WispV2Extensions {
|
|
||||||
builders: Vec<AnyProtocolExtensionBuilder>,
|
|
||||||
closure: Box<
|
|
||||||
dyn Fn(
|
|
||||||
&mut [AnyProtocolExtensionBuilder],
|
|
||||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
|
||||||
+ Send,
|
|
||||||
>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WispV2Extensions {
|
|
||||||
/// Create a Wisp V2 settings struct with no middleware.
|
|
||||||
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
|
|
||||||
Self {
|
|
||||||
builders,
|
|
||||||
closure: Box::new(|_| Box::pin(async { Ok(()) })),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a Wisp V2 settings struct with some middleware.
|
|
||||||
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
|
||||||
where
|
|
||||||
C: Fn(
|
|
||||||
&mut [AnyProtocolExtensionBuilder],
|
|
||||||
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
|
||||||
+ Send
|
|
||||||
+ 'static,
|
|
||||||
{
|
|
||||||
Self {
|
|
||||||
builders,
|
|
||||||
closure: Box::new(closure),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a Wisp V2 extension builder to the settings struct.
|
|
||||||
pub fn add_extension(&mut self, extension: AnyProtocolExtensionBuilder) {
|
|
||||||
self.builders.push(extension);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Server-side multiplexor.
|
|
||||||
pub struct ServerMux {
|
|
||||||
/// Whether the connection was downgraded to Wisp v1.
|
|
||||||
///
|
|
||||||
/// If this variable is true you must assume no extensions are supported.
|
|
||||||
pub downgraded: bool,
|
|
||||||
/// Extensions that are supported by both sides.
|
|
||||||
pub supported_extensions: Vec<AnyProtocolExtension>,
|
|
||||||
actor_tx: mpsc::Sender<WsEvent>,
|
|
||||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
|
||||||
tx: ws::LockedWebSocketWrite,
|
|
||||||
actor_exited: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ServerMux {
|
|
||||||
/// Create a new server-side multiplexor.
|
|
||||||
///
|
|
||||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
|
||||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
|
||||||
/// if the extensions you need are available after the multiplexor has been created.
|
|
||||||
pub async fn create<R, W>(
|
|
||||||
mut rx: R,
|
|
||||||
tx: W,
|
|
||||||
buffer_size: u32,
|
|
||||||
wisp_v2: Option<WispV2Extensions>,
|
|
||||||
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
W: ws::WebSocketWrite + Send + 'static,
|
|
||||||
{
|
|
||||||
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
|
|
||||||
let ret_tx = tx.clone();
|
|
||||||
let ret = async {
|
|
||||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
|
||||||
mut builders,
|
|
||||||
closure,
|
|
||||||
}) = wisp_v2
|
|
||||||
{
|
|
||||||
send_info_packet(&tx, builders.deref_mut()).await?;
|
|
||||||
(closure)(builders.deref_mut()).await?;
|
|
||||||
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
|
||||||
} else {
|
|
||||||
(Vec::new(), None, true)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
|
||||||
tx.clone(),
|
|
||||||
supported_extensions.clone(),
|
|
||||||
buffer_size,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(ServerMuxResult(
|
|
||||||
Self {
|
|
||||||
muxstream_recv,
|
|
||||||
actor_tx: mux_result.actor_tx,
|
|
||||||
downgraded,
|
|
||||||
supported_extensions,
|
|
||||||
tx,
|
|
||||||
actor_exited: mux_result.actor_exited,
|
|
||||||
},
|
|
||||||
mux_result.mux.into_future(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match ret {
|
|
||||||
Ok(x) => Ok(x),
|
|
||||||
Err(x) => match x {
|
|
||||||
WispError::PasswordExtensionCredsInvalid => {
|
|
||||||
ret_tx
|
|
||||||
.write_frame(
|
|
||||||
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
ret_tx.close().await?;
|
|
||||||
Err(x)
|
|
||||||
}
|
|
||||||
WispError::CertAuthExtensionSigInvalid => {
|
|
||||||
ret_tx
|
|
||||||
.write_frame(
|
|
||||||
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
ret_tx.close().await?;
|
|
||||||
Err(x)
|
|
||||||
}
|
|
||||||
x => Err(x),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wait for a stream to be created.
|
|
||||||
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
|
|
||||||
if self.actor_exited.load(Ordering::Acquire) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
self.muxstream_recv.recv_async().await.ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send a ping to the client.
|
|
||||||
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
|
|
||||||
if self.actor_exited.load(Ordering::Acquire) {
|
|
||||||
return Err(WispError::MuxTaskEnded);
|
|
||||||
}
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
self.actor_tx
|
|
||||||
.send_async(WsEvent::SendPing(payload, tx))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
|
||||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
|
||||||
if self.actor_exited.load(Ordering::Acquire) {
|
|
||||||
return Err(WispError::MuxTaskEnded);
|
|
||||||
}
|
|
||||||
self.actor_tx
|
|
||||||
.send_async(WsEvent::EndFut(reason))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Close all streams.
|
|
||||||
///
|
|
||||||
/// Also terminates the multiplexor future.
|
|
||||||
pub async fn close(&self) -> Result<(), WispError> {
|
|
||||||
self.close_internal(None).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Close all streams and send a close reason on stream ID 0.
|
|
||||||
///
|
|
||||||
/// Also terminates the multiplexor future.
|
|
||||||
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
|
|
||||||
self.close_internal(Some(reason)).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
|
||||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
|
||||||
MuxProtocolExtensionStream {
|
|
||||||
stream_id: 0,
|
|
||||||
tx: self.tx.clone(),
|
|
||||||
is_closed: self.actor_exited.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for ServerMux {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Result of `ServerMux::new`.
|
|
||||||
pub struct ServerMuxResult<F>(ServerMux, F)
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<(), WispError>> + Send;
|
|
||||||
|
|
||||||
impl<F> ServerMuxResult<F>
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<(), WispError>> + Send,
|
|
||||||
{
|
|
||||||
/// Require no protocol extensions.
|
|
||||||
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
|
|
||||||
(self.0, self.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Require protocol extensions by their ID. Will close the multiplexor connection if
|
|
||||||
/// extensions are not supported.
|
|
||||||
pub async fn with_required_extensions(
|
|
||||||
self,
|
|
||||||
extensions: &[u8],
|
|
||||||
) -> Result<(ServerMux, F), WispError> {
|
|
||||||
let mut unsupported_extensions = Vec::new();
|
|
||||||
for extension in extensions {
|
|
||||||
if !self
|
|
||||||
.0
|
|
||||||
.supported_extensions
|
|
||||||
.iter()
|
|
||||||
.any(|x| x.get_id() == *extension)
|
|
||||||
{
|
|
||||||
unsupported_extensions.push(*extension);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if unsupported_extensions.is_empty() {
|
|
||||||
Ok((self.0, self.1))
|
|
||||||
} else {
|
|
||||||
self.0
|
|
||||||
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
|
||||||
.await?;
|
|
||||||
self.1.await?;
|
|
||||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
|
||||||
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
|
|
||||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Client side multiplexor.
|
|
||||||
pub struct ClientMux {
|
|
||||||
/// Whether the connection was downgraded to Wisp v1.
|
|
||||||
///
|
|
||||||
/// If this variable is true you must assume no extensions are supported.
|
|
||||||
pub downgraded: bool,
|
|
||||||
/// Extensions that are supported by both sides.
|
|
||||||
pub supported_extensions: Vec<AnyProtocolExtension>,
|
|
||||||
actor_tx: mpsc::Sender<WsEvent>,
|
|
||||||
tx: ws::LockedWebSocketWrite,
|
|
||||||
actor_exited: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ClientMux {
|
|
||||||
/// Create a new client side multiplexor.
|
|
||||||
///
|
|
||||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
|
||||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
|
||||||
/// if the extensions you need are available after the multiplexor has been created.
|
|
||||||
pub async fn create<R, W>(
|
|
||||||
mut rx: R,
|
|
||||||
tx: W,
|
|
||||||
wisp_v2: Option<WispV2Extensions>,
|
|
||||||
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
|
||||||
where
|
|
||||||
R: ws::WebSocketRead + Send,
|
|
||||||
W: ws::WebSocketWrite + Send + 'static,
|
|
||||||
{
|
|
||||||
let tx = ws::LockedWebSocketWrite::new(Box::new(tx));
|
|
||||||
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
|
|
||||||
|
|
||||||
if first_packet.stream_id != 0 {
|
|
||||||
return Err(WispError::InvalidStreamId);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
|
||||||
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
|
||||||
mut builders,
|
|
||||||
closure,
|
|
||||||
}) = wisp_v2
|
|
||||||
{
|
|
||||||
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
|
||||||
// if not downgraded
|
|
||||||
if !res.2 {
|
|
||||||
(closure)(&mut builders).await?;
|
|
||||||
send_info_packet(&tx, &mut builders).await?;
|
|
||||||
}
|
|
||||||
res
|
|
||||||
} else {
|
|
||||||
(Vec::new(), None, true)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mux_result = MuxInner::new_client(
|
|
||||||
AppendingWebSocketRead(extra_packet, rx),
|
|
||||||
tx.clone(),
|
|
||||||
supported_extensions.clone(),
|
|
||||||
packet.buffer_remaining,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(ClientMuxResult(
|
|
||||||
Self {
|
|
||||||
actor_tx: mux_result.actor_tx,
|
|
||||||
downgraded,
|
|
||||||
supported_extensions,
|
|
||||||
tx,
|
|
||||||
actor_exited: mux_result.actor_exited,
|
|
||||||
},
|
|
||||||
mux_result.mux.into_future(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
Err(WispError::InvalidPacketType)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new stream, multiplexed through Wisp.
|
|
||||||
pub async fn client_new_stream(
|
|
||||||
&self,
|
|
||||||
stream_type: StreamType,
|
|
||||||
host: String,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<MuxStream, WispError> {
|
|
||||||
if self.actor_exited.load(Ordering::Acquire) {
|
|
||||||
return Err(WispError::MuxTaskEnded);
|
|
||||||
}
|
|
||||||
if stream_type == StreamType::Udp
|
|
||||||
&& !self
|
|
||||||
.supported_extensions
|
|
||||||
.iter()
|
|
||||||
.any(|x| x.get_id() == UdpProtocolExtension::ID)
|
|
||||||
{
|
|
||||||
return Err(WispError::ExtensionsNotSupported(vec![
|
|
||||||
UdpProtocolExtension::ID,
|
|
||||||
]));
|
|
||||||
}
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
self.actor_tx
|
|
||||||
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
|
||||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send a ping to the server.
|
|
||||||
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
|
|
||||||
if self.actor_exited.load(Ordering::Acquire) {
|
|
||||||
return Err(WispError::MuxTaskEnded);
|
|
||||||
}
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
self.actor_tx
|
|
||||||
.send_async(WsEvent::SendPing(payload, tx))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
|
||||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
|
||||||
if self.actor_exited.load(Ordering::Acquire) {
|
|
||||||
return Err(WispError::MuxTaskEnded);
|
|
||||||
}
|
|
||||||
self.actor_tx
|
|
||||||
.send_async(WsEvent::EndFut(reason))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Close all streams.
|
|
||||||
///
|
|
||||||
/// Also terminates the multiplexor future.
|
|
||||||
pub async fn close(&self) -> Result<(), WispError> {
|
|
||||||
self.close_internal(None).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Close all streams and send a close reason on stream ID 0.
|
|
||||||
///
|
|
||||||
/// Also terminates the multiplexor future.
|
|
||||||
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
|
|
||||||
self.close_internal(Some(reason)).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
|
||||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
|
||||||
MuxProtocolExtensionStream {
|
|
||||||
stream_id: 0,
|
|
||||||
tx: self.tx.clone(),
|
|
||||||
is_closed: self.actor_exited.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for ClientMux {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Result of `ClientMux::new`.
|
|
||||||
pub struct ClientMuxResult<F>(ClientMux, F)
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<(), WispError>> + Send;
|
|
||||||
|
|
||||||
impl<F> ClientMuxResult<F>
|
|
||||||
where
|
|
||||||
F: Future<Output = Result<(), WispError>> + Send,
|
|
||||||
{
|
|
||||||
/// Require no protocol extensions.
|
|
||||||
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
|
|
||||||
(self.0, self.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Require protocol extensions by their ID.
|
|
||||||
pub async fn with_required_extensions(
|
|
||||||
self,
|
|
||||||
extensions: &[u8],
|
|
||||||
) -> Result<(ClientMux, F), WispError> {
|
|
||||||
let mut unsupported_extensions = Vec::new();
|
|
||||||
for extension in extensions {
|
|
||||||
if !self
|
|
||||||
.0
|
|
||||||
.supported_extensions
|
|
||||||
.iter()
|
|
||||||
.any(|x| x.get_id() == *extension)
|
|
||||||
{
|
|
||||||
unsupported_extensions.push(*extension);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if unsupported_extensions.is_empty() {
|
|
||||||
Ok((self.0, self.1))
|
|
||||||
} else {
|
|
||||||
self.0
|
|
||||||
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
|
||||||
.await?;
|
|
||||||
self.1.await?;
|
|
||||||
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
|
||||||
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
|
||||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
223
wisp/src/mux/client.rs
Normal file
223
wisp/src/mux/client.rs
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use flume as mpsc;
|
||||||
|
use futures::channel::oneshot;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
||||||
|
inner::{MuxInner, WsEvent},
|
||||||
|
ws::{AppendingWebSocketRead, LockedWebSocketWrite, WebSocketRead, WebSocketWrite, Payload},
|
||||||
|
CloseReason, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role, StreamType,
|
||||||
|
WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
|
||||||
|
|
||||||
|
/// Client side multiplexor.
|
||||||
|
pub struct ClientMux {
|
||||||
|
/// Whether the connection was downgraded to Wisp v1.
|
||||||
|
///
|
||||||
|
/// If this variable is true you must assume no extensions are supported.
|
||||||
|
pub downgraded: bool,
|
||||||
|
/// Extensions that are supported by both sides.
|
||||||
|
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||||
|
actor_tx: mpsc::Sender<WsEvent>,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
|
actor_exited: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientMux {
|
||||||
|
/// Create a new client side multiplexor.
|
||||||
|
///
|
||||||
|
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||||
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
|
pub async fn create<R, W>(
|
||||||
|
mut rx: R,
|
||||||
|
tx: W,
|
||||||
|
wisp_v2: Option<WispV2Extensions>,
|
||||||
|
) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
|
where
|
||||||
|
R: WebSocketRead + Send,
|
||||||
|
W: WebSocketWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||||
|
let first_packet = Packet::try_from(rx.wisp_read_frame(&tx).await?)?;
|
||||||
|
|
||||||
|
if first_packet.stream_id != 0 {
|
||||||
|
return Err(WispError::InvalidStreamId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||||
|
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||||
|
mut builders,
|
||||||
|
closure,
|
||||||
|
}) = wisp_v2
|
||||||
|
{
|
||||||
|
let res = maybe_wisp_v2(&mut rx, &tx, Role::Client, &mut builders).await?;
|
||||||
|
// if not downgraded
|
||||||
|
if !res.2 {
|
||||||
|
(closure)(&mut builders).await?;
|
||||||
|
send_info_packet(&tx, &mut builders).await?;
|
||||||
|
}
|
||||||
|
res
|
||||||
|
} else {
|
||||||
|
(Vec::new(), None, true)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mux_result = MuxInner::new_client(
|
||||||
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
|
tx.clone(),
|
||||||
|
supported_extensions.clone(),
|
||||||
|
packet.buffer_remaining,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(ClientMuxResult(
|
||||||
|
Self {
|
||||||
|
actor_tx: mux_result.actor_tx,
|
||||||
|
downgraded,
|
||||||
|
supported_extensions,
|
||||||
|
tx,
|
||||||
|
actor_exited: mux_result.actor_exited,
|
||||||
|
},
|
||||||
|
mux_result.mux.into_future(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Err(WispError::InvalidPacketType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new stream, multiplexed through Wisp.
|
||||||
|
pub async fn client_new_stream(
|
||||||
|
&self,
|
||||||
|
stream_type: StreamType,
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
) -> Result<MuxStream, WispError> {
|
||||||
|
if self.actor_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
|
if stream_type == StreamType::Udp
|
||||||
|
&& !self
|
||||||
|
.supported_extensions
|
||||||
|
.iter()
|
||||||
|
.any(|x| x.get_id() == UdpProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
return Err(WispError::ExtensionsNotSupported(vec![
|
||||||
|
UdpProtocolExtension::ID,
|
||||||
|
]));
|
||||||
|
}
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
self.actor_tx
|
||||||
|
.send_async(WsEvent::CreateStream(stream_type, host, port, tx))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a ping to the server.
|
||||||
|
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
|
||||||
|
if self.actor_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
self.actor_tx
|
||||||
|
.send_async(WsEvent::SendPing(payload, tx))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
if self.actor_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
|
self.actor_tx
|
||||||
|
.send_async(WsEvent::EndFut(reason))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Close all streams.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future.
|
||||||
|
pub async fn close(&self) -> Result<(), WispError> {
|
||||||
|
self.close_internal(None).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Close all streams and send a close reason on stream ID 0.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future.
|
||||||
|
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||||
|
self.close_internal(Some(reason)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||||
|
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||||
|
MuxProtocolExtensionStream {
|
||||||
|
stream_id: 0,
|
||||||
|
tx: self.tx.clone(),
|
||||||
|
is_closed: self.actor_exited.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ClientMux {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of `ClientMux::new`.
|
||||||
|
pub struct ClientMuxResult<F>(ClientMux, F)
|
||||||
|
where
|
||||||
|
F: Future<Output = Result<(), WispError>> + Send;
|
||||||
|
|
||||||
|
impl<F> ClientMuxResult<F>
|
||||||
|
where
|
||||||
|
F: Future<Output = Result<(), WispError>> + Send,
|
||||||
|
{
|
||||||
|
/// Require no protocol extensions.
|
||||||
|
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
|
||||||
|
(self.0, self.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Require protocol extensions by their ID.
|
||||||
|
pub async fn with_required_extensions(
|
||||||
|
self,
|
||||||
|
extensions: &[u8],
|
||||||
|
) -> Result<(ClientMux, F), WispError> {
|
||||||
|
let mut unsupported_extensions = Vec::new();
|
||||||
|
for extension in extensions {
|
||||||
|
if !self
|
||||||
|
.0
|
||||||
|
.supported_extensions
|
||||||
|
.iter()
|
||||||
|
.any(|x| x.get_id() == *extension)
|
||||||
|
{
|
||||||
|
unsupported_extensions.push(*extension);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if unsupported_extensions.is_empty() {
|
||||||
|
Ok((self.0, self.1))
|
||||||
|
} else {
|
||||||
|
self.0
|
||||||
|
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
||||||
|
.await?;
|
||||||
|
self.1.await?;
|
||||||
|
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||||
|
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
||||||
|
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
109
wisp/src/mux/mod.rs
Normal file
109
wisp/src/mux/mod.rs
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
mod client;
|
||||||
|
mod server;
|
||||||
|
use std::{future::Future, pin::Pin, time::Duration};
|
||||||
|
|
||||||
|
pub use client::ClientMux;
|
||||||
|
use futures::{select, FutureExt};
|
||||||
|
use futures_timer::Delay;
|
||||||
|
pub use server::{ServerMux, ServerMuxResult};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
extensions::{AnyProtocolExtension, AnyProtocolExtensionBuilder},
|
||||||
|
ws::{Frame, LockedWebSocketWrite, WebSocketRead},
|
||||||
|
Packet, PacketType, Role, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
async fn maybe_wisp_v2<R>(
|
||||||
|
read: &mut R,
|
||||||
|
write: &LockedWebSocketWrite,
|
||||||
|
role: Role,
|
||||||
|
builders: &mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Result<(Vec<AnyProtocolExtension>, Option<Frame<'static>>, bool), WispError>
|
||||||
|
where
|
||||||
|
R: WebSocketRead + Send,
|
||||||
|
{
|
||||||
|
let mut supported_extensions = Vec::new();
|
||||||
|
let mut extra_packet: Option<Frame<'static>> = None;
|
||||||
|
let mut downgraded = true;
|
||||||
|
|
||||||
|
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
||||||
|
if let Some(frame) = select! {
|
||||||
|
x = read.wisp_read_frame(write).fuse() => Some(x?),
|
||||||
|
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||||
|
} {
|
||||||
|
let packet = Packet::maybe_parse_info(frame, role, builders)?;
|
||||||
|
if let PacketType::Info(info) = packet.packet_type {
|
||||||
|
supported_extensions = info
|
||||||
|
.extensions
|
||||||
|
.into_iter()
|
||||||
|
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||||
|
.collect();
|
||||||
|
downgraded = false;
|
||||||
|
} else {
|
||||||
|
extra_packet.replace(Frame::from(packet).clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for extension in supported_extensions.iter_mut() {
|
||||||
|
extension.handle_handshake(read, write).await?;
|
||||||
|
}
|
||||||
|
Ok((supported_extensions, extra_packet, downgraded))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_info_packet(
|
||||||
|
write: &LockedWebSocketWrite,
|
||||||
|
builders: &mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Result<(), WispError> {
|
||||||
|
write
|
||||||
|
.write_frame(
|
||||||
|
Packet::new_info(
|
||||||
|
builders
|
||||||
|
.iter_mut()
|
||||||
|
.map(|x| x.build_to_extension(Role::Server))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wisp V2 handshake and protocol extension settings wrapper struct.
|
||||||
|
pub struct WispV2Extensions {
|
||||||
|
builders: Vec<AnyProtocolExtensionBuilder>,
|
||||||
|
closure: Box<
|
||||||
|
dyn Fn(
|
||||||
|
&mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
||||||
|
+ Send,
|
||||||
|
>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WispV2Extensions {
|
||||||
|
/// Create a Wisp V2 settings struct with no middleware.
|
||||||
|
pub fn new(builders: Vec<AnyProtocolExtensionBuilder>) -> Self {
|
||||||
|
Self {
|
||||||
|
builders,
|
||||||
|
closure: Box::new(|_| Box::pin(async { Ok(()) })),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a Wisp V2 settings struct with some middleware.
|
||||||
|
pub fn new_with_middleware<C>(builders: Vec<AnyProtocolExtensionBuilder>, closure: C) -> Self
|
||||||
|
where
|
||||||
|
C: Fn(
|
||||||
|
&mut [AnyProtocolExtensionBuilder],
|
||||||
|
) -> Pin<Box<dyn Future<Output = Result<(), WispError>> + Sync + Send>>
|
||||||
|
+ Send
|
||||||
|
+ 'static,
|
||||||
|
{
|
||||||
|
Self {
|
||||||
|
builders,
|
||||||
|
closure: Box::new(closure),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a Wisp V2 extension builder to the settings struct.
|
||||||
|
pub fn add_extension(&mut self, extension: AnyProtocolExtensionBuilder) {
|
||||||
|
self.builders.push(extension);
|
||||||
|
}
|
||||||
|
}
|
225
wisp/src/mux/server.rs
Normal file
225
wisp/src/mux/server.rs
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
ops::DerefMut,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use flume as mpsc;
|
||||||
|
use futures::channel::oneshot;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
extensions::{udp::UdpProtocolExtension, AnyProtocolExtension},
|
||||||
|
inner::{MuxInner, WsEvent},
|
||||||
|
ws::{AppendingWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||||
|
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, Role, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{maybe_wisp_v2, send_info_packet, WispV2Extensions};
|
||||||
|
|
||||||
|
/// Server-side multiplexor.
|
||||||
|
pub struct ServerMux {
|
||||||
|
/// Whether the connection was downgraded to Wisp v1.
|
||||||
|
///
|
||||||
|
/// If this variable is true you must assume no extensions are supported.
|
||||||
|
pub downgraded: bool,
|
||||||
|
/// Extensions that are supported by both sides.
|
||||||
|
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||||
|
actor_tx: mpsc::Sender<WsEvent>,
|
||||||
|
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
|
actor_exited: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerMux {
|
||||||
|
/// Create a new server-side multiplexor.
|
||||||
|
///
|
||||||
|
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||||
|
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||||
|
/// if the extensions you need are available after the multiplexor has been created.
|
||||||
|
pub async fn create<R, W>(
|
||||||
|
mut rx: R,
|
||||||
|
tx: W,
|
||||||
|
buffer_size: u32,
|
||||||
|
wisp_v2: Option<WispV2Extensions>,
|
||||||
|
) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
|
||||||
|
where
|
||||||
|
R: WebSocketRead + Send,
|
||||||
|
W: WebSocketWrite + Send + 'static,
|
||||||
|
{
|
||||||
|
let tx = LockedWebSocketWrite::new(Box::new(tx));
|
||||||
|
let ret_tx = tx.clone();
|
||||||
|
let ret = async {
|
||||||
|
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (supported_extensions, extra_packet, downgraded) = if let Some(WispV2Extensions {
|
||||||
|
mut builders,
|
||||||
|
closure,
|
||||||
|
}) = wisp_v2
|
||||||
|
{
|
||||||
|
send_info_packet(&tx, builders.deref_mut()).await?;
|
||||||
|
(closure)(builders.deref_mut()).await?;
|
||||||
|
maybe_wisp_v2(&mut rx, &tx, Role::Server, &mut builders).await?
|
||||||
|
} else {
|
||||||
|
(Vec::new(), None, true)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||||
|
AppendingWebSocketRead(extra_packet, rx),
|
||||||
|
tx.clone(),
|
||||||
|
supported_extensions.clone(),
|
||||||
|
buffer_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(ServerMuxResult(
|
||||||
|
Self {
|
||||||
|
muxstream_recv,
|
||||||
|
actor_tx: mux_result.actor_tx,
|
||||||
|
downgraded,
|
||||||
|
supported_extensions,
|
||||||
|
tx,
|
||||||
|
actor_exited: mux_result.actor_exited,
|
||||||
|
},
|
||||||
|
mux_result.mux.into_future(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match ret {
|
||||||
|
Ok(x) => Ok(x),
|
||||||
|
Err(x) => match x {
|
||||||
|
WispError::PasswordExtensionCredsInvalid => {
|
||||||
|
ret_tx
|
||||||
|
.write_frame(
|
||||||
|
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
ret_tx.close().await?;
|
||||||
|
Err(x)
|
||||||
|
}
|
||||||
|
WispError::CertAuthExtensionSigInvalid => {
|
||||||
|
ret_tx
|
||||||
|
.write_frame(
|
||||||
|
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
ret_tx.close().await?;
|
||||||
|
Err(x)
|
||||||
|
}
|
||||||
|
x => Err(x),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wait for a stream to be created.
|
||||||
|
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
|
||||||
|
if self.actor_exited.load(Ordering::Acquire) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
self.muxstream_recv.recv_async().await.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a ping to the client.
|
||||||
|
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
|
||||||
|
if self.actor_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
self.actor_tx
|
||||||
|
.send_async(WsEvent::SendPing(payload, tx))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||||
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
if self.actor_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
|
self.actor_tx
|
||||||
|
.send_async(WsEvent::EndFut(reason))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Close all streams.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future.
|
||||||
|
pub async fn close(&self) -> Result<(), WispError> {
|
||||||
|
self.close_internal(None).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Close all streams and send a close reason on stream ID 0.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future.
|
||||||
|
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||||
|
self.close_internal(Some(reason)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||||
|
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||||
|
MuxProtocolExtensionStream {
|
||||||
|
stream_id: 0,
|
||||||
|
tx: self.tx.clone(),
|
||||||
|
is_closed: self.actor_exited.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ServerMux {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of `ServerMux::new`.
|
||||||
|
pub struct ServerMuxResult<F>(ServerMux, F)
|
||||||
|
where
|
||||||
|
F: Future<Output = Result<(), WispError>> + Send;
|
||||||
|
|
||||||
|
impl<F> ServerMuxResult<F>
|
||||||
|
where
|
||||||
|
F: Future<Output = Result<(), WispError>> + Send,
|
||||||
|
{
|
||||||
|
/// Require no protocol extensions.
|
||||||
|
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
|
||||||
|
(self.0, self.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Require protocol extensions by their ID. Will close the multiplexor connection if
|
||||||
|
/// extensions are not supported.
|
||||||
|
pub async fn with_required_extensions(
|
||||||
|
self,
|
||||||
|
extensions: &[u8],
|
||||||
|
) -> Result<(ServerMux, F), WispError> {
|
||||||
|
let mut unsupported_extensions = Vec::new();
|
||||||
|
for extension in extensions {
|
||||||
|
if !self
|
||||||
|
.0
|
||||||
|
.supported_extensions
|
||||||
|
.iter()
|
||||||
|
.any(|x| x.get_id() == *extension)
|
||||||
|
{
|
||||||
|
unsupported_extensions.push(*extension);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if unsupported_extensions.is_empty() {
|
||||||
|
Ok((self.0, self.1))
|
||||||
|
} else {
|
||||||
|
self.0
|
||||||
|
.close_with_reason(CloseReason::ExtensionsIncompatible)
|
||||||
|
.await?;
|
||||||
|
self.1.await?;
|
||||||
|
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||||
|
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
|
||||||
|
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue