mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 22:10:01 -04:00
wisp-mux... SEVEN!!
This commit is contained in:
parent
194ad4e5c8
commit
3f381d6b39
53 changed files with 3721 additions and 4821 deletions
|
@ -1,241 +1,196 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use flume as mpsc;
|
||||
use futures::channel::oneshot;
|
||||
use futures::SinkExt;
|
||||
|
||||
use crate::{
|
||||
extensions::AnyProtocolExtension,
|
||||
ws::{DynWebSocketRead, LockedWebSocketWrite, Payload, WebSocketRead, WebSocketWrite},
|
||||
CloseReason, ConnectPacket, MuxProtocolExtensionStream, MuxStream, Packet, PacketType, Role,
|
||||
WispError,
|
||||
locked_sink::LockedWebSocketWrite,
|
||||
packet::{CloseReason, ConnectPacket, MaybeInfoPacket, Packet, StreamType},
|
||||
stream::MuxStream,
|
||||
ws::{Payload, WebSocketRead, WebSocketReadExt, WebSocketWrite},
|
||||
Role, WispError,
|
||||
};
|
||||
|
||||
use super::{
|
||||
get_supported_extensions,
|
||||
inner::{MuxInner, WsEvent},
|
||||
send_info_packet, Multiplexor, MuxResult, WispHandshakeResult, WispHandshakeResultKind,
|
||||
WispV2Handshake,
|
||||
get_supported_extensions, handle_handshake,
|
||||
inner::{FlowControl, MultiplexorActor, StreamMap},
|
||||
send_info_packet, Multiplexor, MultiplexorImpl, MuxResult, WispHandshakeResult,
|
||||
WispHandshakeResultKind, WispV2Handshake,
|
||||
};
|
||||
|
||||
async fn handshake<R: WebSocketRead + 'static, W: WebSocketWrite>(
|
||||
rx: &mut R,
|
||||
tx: &LockedWebSocketWrite<W>,
|
||||
buffer_size: u32,
|
||||
v2_info: Option<WispV2Handshake>,
|
||||
) -> Result<WispHandshakeResult, WispError> {
|
||||
if let Some(WispV2Handshake {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = v2_info
|
||||
{
|
||||
send_info_packet(tx, &mut builders).await?;
|
||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
(closure)(&mut builders).await?;
|
||||
|
||||
let packet =
|
||||
Packet::maybe_parse_info(rx.wisp_read_frame(tx).await?, Role::Server, &mut builders)?;
|
||||
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
let mut supported_extensions = get_supported_extensions(info.extensions, &mut builders);
|
||||
|
||||
for extension in &mut supported_extensions {
|
||||
extension
|
||||
.handle_handshake(DynWebSocketRead::from_mut(rx), tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// v2 client
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V2 {
|
||||
extensions: supported_extensions,
|
||||
},
|
||||
downgraded: false,
|
||||
})
|
||||
} else {
|
||||
// downgrade to v1
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 {
|
||||
frame: Some(packet),
|
||||
},
|
||||
downgraded: true,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// user asked for v1 server
|
||||
tx.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 { frame: None },
|
||||
downgraded: false,
|
||||
})
|
||||
}
|
||||
pub(crate) struct ServerActor<W: WebSocketWrite> {
|
||||
stream_tx: flume::Sender<(ConnectPacket, MuxStream<W>)>,
|
||||
}
|
||||
|
||||
/// Server-side multiplexor.
|
||||
pub struct ServerMux<W: WebSocketWrite + 'static> {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
/// If this variable is true you must assume no extensions are supported.
|
||||
pub downgraded: bool,
|
||||
/// Extensions that are supported by both sides.
|
||||
pub supported_extensions: Vec<AnyProtocolExtension>,
|
||||
actor_tx: mpsc::Sender<WsEvent<W>>,
|
||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream<W>)>,
|
||||
tx: LockedWebSocketWrite<W>,
|
||||
actor_exited: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl<W: WebSocketWrite + 'static> ServerMux<W> {
|
||||
/// Create a new server-side multiplexor.
|
||||
///
|
||||
/// If `wisp_v2` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
pub async fn create<R>(
|
||||
mut rx: R,
|
||||
tx: W,
|
||||
buffer_size: u32,
|
||||
wisp_v2: Option<WispV2Handshake>,
|
||||
) -> Result<
|
||||
MuxResult<ServerMux<W>, impl Future<Output = Result<(), WispError>> + Send>,
|
||||
WispError,
|
||||
>
|
||||
where
|
||||
R: WebSocketRead + Send + 'static,
|
||||
{
|
||||
let tx = LockedWebSocketWrite::new(tx);
|
||||
let ret_tx = tx.clone();
|
||||
let ret = async {
|
||||
let handshake_result = handshake(&mut rx, &tx, buffer_size, wisp_v2).await?;
|
||||
let (extensions, extra_packet) = handshake_result.kind.into_parts();
|
||||
|
||||
let (mux_result, muxstream_recv) = MuxInner::new_server(
|
||||
rx,
|
||||
extra_packet,
|
||||
tx.clone(),
|
||||
extensions.clone(),
|
||||
buffer_size,
|
||||
);
|
||||
|
||||
Ok(MuxResult(
|
||||
Self {
|
||||
actor_tx: mux_result.actor_tx,
|
||||
actor_exited: mux_result.actor_exited,
|
||||
muxstream_recv,
|
||||
|
||||
tx,
|
||||
|
||||
downgraded: handshake_result.downgraded,
|
||||
supported_extensions: extensions,
|
||||
},
|
||||
mux_result.mux.into_future(),
|
||||
))
|
||||
}
|
||||
.await;
|
||||
|
||||
match ret {
|
||||
Ok(x) => Ok(x),
|
||||
Err(x) => match x {
|
||||
WispError::PasswordExtensionCredsInvalid => {
|
||||
ret_tx
|
||||
.write_frame(
|
||||
Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).into(),
|
||||
)
|
||||
.await?;
|
||||
ret_tx.close().await?;
|
||||
Err(x)
|
||||
}
|
||||
WispError::CertAuthExtensionSigInvalid => {
|
||||
ret_tx
|
||||
.write_frame(
|
||||
Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).into(),
|
||||
)
|
||||
.await?;
|
||||
ret_tx.close().await?;
|
||||
Err(x)
|
||||
}
|
||||
x => Err(x),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for a stream to be created.
|
||||
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream<W>)> {
|
||||
if self.actor_exited.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
self.muxstream_recv.recv_async().await.ok()
|
||||
}
|
||||
|
||||
/// Send a ping to the client.
|
||||
pub async fn send_ping(&self, payload: Payload<'static>) -> Result<(), WispError> {
|
||||
if self.actor_exited.load(Ordering::Acquire) {
|
||||
return Err(WispError::MuxTaskEnded);
|
||||
}
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.actor_tx
|
||||
.send_async(WsEvent::SendPing(payload, tx))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||
}
|
||||
|
||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||
if self.actor_exited.load(Ordering::Acquire) {
|
||||
return Err(WispError::MuxTaskEnded);
|
||||
}
|
||||
self.actor_tx
|
||||
.send_async(WsEvent::EndFut(reason))
|
||||
.await
|
||||
impl<W: WebSocketWrite> MultiplexorActor<W> for ServerActor<W> {
|
||||
fn handle_connect_packet(
|
||||
&mut self,
|
||||
stream: MuxStream<W>,
|
||||
pkt: ConnectPacket,
|
||||
) -> Result<(), WispError> {
|
||||
self.stream_tx
|
||||
.send((pkt, stream))
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||
}
|
||||
|
||||
/// Close all streams.
|
||||
///
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.close_internal(None).await
|
||||
fn handle_data_packet(
|
||||
&mut self,
|
||||
id: u32,
|
||||
pkt: Payload,
|
||||
streams: &mut StreamMap,
|
||||
) -> Result<(), WispError> {
|
||||
if let Some(stream) = streams.get(&id) {
|
||||
if stream.stream.try_send(pkt).is_ok() {
|
||||
stream.info.flow_dec();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Close all streams and send a close reason on stream ID 0.
|
||||
///
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close_with_reason(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
self.close_internal(Some(reason)).await
|
||||
fn handle_continue_packet(
|
||||
&mut self,
|
||||
_: u32,
|
||||
_: crate::packet::ContinuePacket,
|
||||
_: &mut StreamMap,
|
||||
) -> Result<(), WispError> {
|
||||
Err(WispError::InvalidPacketType(0x03))
|
||||
}
|
||||
|
||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream<W> {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: 0,
|
||||
tx: self.tx.clone(),
|
||||
is_closed: self.actor_exited.clone(),
|
||||
fn get_flow_control(ty: StreamType, flow_stream_types: &[u8]) -> FlowControl {
|
||||
if flow_stream_types.contains(&ty.into()) {
|
||||
FlowControl::EnabledSendMessages
|
||||
} else {
|
||||
FlowControl::Disabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: WebSocketWrite + 'static> Drop for ServerMux<W> {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.actor_tx.send(WsEvent::EndFut(None));
|
||||
pub struct ServerImpl<W: WebSocketWrite> {
|
||||
buffer_size: u32,
|
||||
stream_rx: flume::Receiver<(ConnectPacket, MuxStream<W>)>,
|
||||
}
|
||||
|
||||
impl<W: WebSocketWrite> MultiplexorImpl<W> for ServerImpl<W> {
|
||||
type Actor = ServerActor<W>;
|
||||
|
||||
async fn handshake<R: WebSocketRead>(
|
||||
&mut self,
|
||||
rx: &mut R,
|
||||
tx: &mut LockedWebSocketWrite<W>,
|
||||
v2: Option<WispV2Handshake>,
|
||||
) -> Result<WispHandshakeResult, WispError> {
|
||||
if let Some(WispV2Handshake {
|
||||
mut builders,
|
||||
closure,
|
||||
}) = v2
|
||||
{
|
||||
send_info_packet(tx, &mut builders, Role::Server).await?;
|
||||
tx.lock().await;
|
||||
tx.get()
|
||||
.send(Packet::new_continue(0, self.buffer_size).encode())
|
||||
.await?;
|
||||
tx.unlock();
|
||||
|
||||
(closure)(&mut builders).await?;
|
||||
|
||||
let packet =
|
||||
MaybeInfoPacket::decode(rx.next_erroring().await?, &mut builders, Role::Server)?;
|
||||
|
||||
match packet {
|
||||
MaybeInfoPacket::Info(info) => {
|
||||
let mut supported_extensions =
|
||||
get_supported_extensions(info.extensions, &mut builders);
|
||||
|
||||
handle_handshake(rx, tx, &mut supported_extensions).await?;
|
||||
|
||||
// v2 client
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V2 {
|
||||
extensions: supported_extensions,
|
||||
},
|
||||
downgraded: false,
|
||||
buffer_size: self.buffer_size,
|
||||
})
|
||||
}
|
||||
MaybeInfoPacket::Packet(packet) => {
|
||||
// downgrade to v1
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 {
|
||||
packet: Some(packet),
|
||||
},
|
||||
downgraded: true,
|
||||
buffer_size: self.buffer_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// user asked for v1 server
|
||||
tx.lock().await;
|
||||
tx.get()
|
||||
.send(Packet::new_continue(0, self.buffer_size).encode())
|
||||
.await?;
|
||||
tx.unlock();
|
||||
|
||||
Ok(WispHandshakeResult {
|
||||
kind: WispHandshakeResultKind::V1 { packet: None },
|
||||
downgraded: false,
|
||||
buffer_size: self.buffer_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_error(
|
||||
&mut self,
|
||||
err: WispError,
|
||||
tx: &mut LockedWebSocketWrite<W>,
|
||||
) -> Result<WispError, WispError> {
|
||||
match err {
|
||||
WispError::PasswordExtensionCredsInvalid => {
|
||||
tx.lock().await;
|
||||
tx.get()
|
||||
.send(Packet::new_close(0, CloseReason::ExtensionsPasswordAuthFailed).encode())
|
||||
.await?;
|
||||
tx.get().close().await?;
|
||||
tx.unlock();
|
||||
Ok(err)
|
||||
}
|
||||
WispError::CertAuthExtensionSigInvalid => {
|
||||
tx.lock().await;
|
||||
tx.get()
|
||||
.send(Packet::new_close(0, CloseReason::ExtensionsCertAuthFailed).encode())
|
||||
.await?;
|
||||
tx.get().close().await?;
|
||||
tx.unlock();
|
||||
Ok(err)
|
||||
}
|
||||
x => Ok(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: WebSocketWrite + 'static> Multiplexor for ServerMux<W> {
|
||||
fn has_extension(&self, extension_id: u8) -> bool {
|
||||
self.supported_extensions
|
||||
.iter()
|
||||
.any(|x| x.get_id() == extension_id)
|
||||
impl<W: WebSocketWrite> Multiplexor<ServerImpl<W>, W> {
|
||||
/// Create a new server-side multiplexor.
|
||||
///
|
||||
/// If `wisp_v2` is None a Wisp v1 connection is created, otherwise a Wisp v2 connection is created.
|
||||
/// **It is not guaranteed that all extensions you specify are available.** You must manually check
|
||||
/// if the extensions you need are available after the multiplexor has been created.
|
||||
#[expect(clippy::new_ret_no_self)]
|
||||
pub async fn new<R: WebSocketRead>(
|
||||
rx: R,
|
||||
tx: W,
|
||||
buffer_size: u32,
|
||||
wisp_v2: Option<WispV2Handshake>,
|
||||
) -> Result<MuxResult<ServerImpl<W>, W>, WispError> {
|
||||
let (stream_tx, stream_rx) = flume::unbounded();
|
||||
|
||||
let mux = ServerImpl {
|
||||
buffer_size,
|
||||
stream_rx,
|
||||
};
|
||||
let actor = ServerActor { stream_tx };
|
||||
|
||||
Self::create(rx, tx, wisp_v2, mux, actor).await
|
||||
}
|
||||
async fn exit(&self, reason: CloseReason) -> Result<(), WispError> {
|
||||
self.close_with_reason(reason).await
|
||||
|
||||
/// Wait for a stream to be created.
|
||||
pub async fn wait_for_stream(&self) -> Option<(ConnectPacket, MuxStream<W>)> {
|
||||
self.mux.stream_rx.recv_async().await.ok()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue