mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
some optimizations and muxprotocolextensionstream for stream id 0
This commit is contained in:
parent
3b8dedeba2
commit
b3f35b232f
7 changed files with 237 additions and 170 deletions
245
wisp/src/lib.rs
245
wisp/src/lib.rs
|
@ -29,7 +29,7 @@ use std::{
|
|||
},
|
||||
time::Duration,
|
||||
};
|
||||
use ws::AppendingWebSocketRead;
|
||||
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
|
||||
|
||||
/// Wisp version supported by this crate.
|
||||
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||
|
@ -92,6 +92,8 @@ pub enum WispError {
|
|||
MuxMessageFailedToSend,
|
||||
/// Failed to receive message from multiplexor task.
|
||||
MuxMessageFailedToRecv,
|
||||
/// Multiplexor task ended.
|
||||
MuxTaskEnded,
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for WispError {
|
||||
|
@ -145,6 +147,7 @@ impl std::fmt::Display for WispError {
|
|||
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
||||
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
|
||||
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -164,6 +167,7 @@ struct MuxInner {
|
|||
tx: ws::LockedWebSocketWrite,
|
||||
stream_map: DashMap<u32, MuxMapValue>,
|
||||
buffer_size: u32,
|
||||
fut_exited: Arc<AtomicBool>
|
||||
}
|
||||
|
||||
impl MuxInner {
|
||||
|
@ -210,6 +214,7 @@ impl MuxInner {
|
|||
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
|
||||
x = wisp_fut.fuse() => x,
|
||||
};
|
||||
self.fut_exited.store(true, Ordering::Release);
|
||||
for x in self.stream_map.iter_mut() {
|
||||
x.is_closed.store(true, Ordering::Release);
|
||||
x.is_closed_event.notify(usize::MAX);
|
||||
|
@ -225,6 +230,7 @@ impl MuxInner {
|
|||
stream_type: StreamType,
|
||||
role: Role,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
tx: LockedWebSocketWrite,
|
||||
target_buffer_size: u32,
|
||||
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
||||
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
||||
|
@ -249,7 +255,8 @@ impl MuxInner {
|
|||
role,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
stream_tx.clone(),
|
||||
stream_tx,
|
||||
tx,
|
||||
is_closed,
|
||||
is_closed_event,
|
||||
flow_control,
|
||||
|
@ -267,16 +274,6 @@ impl MuxInner {
|
|||
let mut next_free_stream_id: u32 = 1;
|
||||
while let Ok(msg) = stream_rx.recv_async().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
if self.stream_map.get(&packet.stream_id).is_some() {
|
||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
||||
} else {
|
||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::SendBytes(packet, channel) => {
|
||||
let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await);
|
||||
}
|
||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||
let ret: Result<MuxStream, WispError> = async {
|
||||
let stream_id = next_free_stream_id;
|
||||
|
@ -290,6 +287,7 @@ impl MuxInner {
|
|||
stream_type,
|
||||
Role::Client,
|
||||
stream_tx.clone(),
|
||||
self.tx.clone(),
|
||||
0,
|
||||
)
|
||||
.await?;
|
||||
|
@ -330,6 +328,16 @@ impl MuxInner {
|
|||
}
|
||||
}
|
||||
|
||||
fn close_stream(&self, packet: Packet) {
|
||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.is_closed_event.notify(usize::MAX);
|
||||
stream.flow_control.store(u32::MAX, Ordering::Release);
|
||||
stream.flow_control_event.notify(usize::MAX);
|
||||
drop(stream.stream)
|
||||
}
|
||||
}
|
||||
|
||||
async fn server_loop<R>(
|
||||
&self,
|
||||
mut rx: R,
|
||||
|
@ -353,6 +361,7 @@ impl MuxInner {
|
|||
{
|
||||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||
Connect(inner_packet) => {
|
||||
let (map_value, stream) = self
|
||||
.create_new_stream(
|
||||
|
@ -360,6 +369,7 @@ impl MuxInner {
|
|||
inner_packet.stream_type,
|
||||
Role::Server,
|
||||
stream_tx.clone(),
|
||||
self.tx.clone(),
|
||||
target_buffer_size,
|
||||
)
|
||||
.await?;
|
||||
|
@ -383,16 +393,11 @@ impl MuxInner {
|
|||
}
|
||||
}
|
||||
}
|
||||
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||
Close(_) => {
|
||||
if packet.stream_id == 0 {
|
||||
break Ok(());
|
||||
}
|
||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.is_closed_event.notify(usize::MAX);
|
||||
drop(stream.stream)
|
||||
}
|
||||
self.close_stream(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -437,11 +442,7 @@ impl MuxInner {
|
|||
if packet.stream_id == 0 {
|
||||
break Ok(());
|
||||
}
|
||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
stream.is_closed_event.notify(usize::MAX);
|
||||
drop(stream.stream)
|
||||
}
|
||||
self.close_stream(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -449,6 +450,42 @@ impl MuxInner {
|
|||
}
|
||||
}
|
||||
|
||||
async fn maybe_wisp_v2<R>(
|
||||
read: &mut R,
|
||||
write: &LockedWebSocketWrite,
|
||||
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame>, bool), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead + Send,
|
||||
{
|
||||
let mut supported_extensions = Vec::new();
|
||||
let mut extra_packet = None;
|
||||
let mut downgraded = true;
|
||||
|
||||
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(write).fuse() => Some(x?),
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
supported_extensions = info
|
||||
.extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect();
|
||||
downgraded = false;
|
||||
} else {
|
||||
extra_packet.replace(packet.into());
|
||||
}
|
||||
}
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(read, write).await?;
|
||||
}
|
||||
Ok((supported_extensions, extra_packet, downgraded))
|
||||
}
|
||||
|
||||
/// Server-side multiplexor.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -477,6 +514,8 @@ pub struct ServerMux {
|
|||
pub supported_extension_ids: Vec<u8>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||
tx: ws::LockedWebSocketWrite,
|
||||
fut_exited: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl ServerMux {
|
||||
|
@ -498,41 +537,29 @@ impl ServerMux {
|
|||
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||
|
||||
write
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
let mut supported_extensions = Vec::new();
|
||||
let mut extra_packet = Vec::with_capacity(1);
|
||||
let mut downgraded = true;
|
||||
|
||||
if let Some(builders) = extension_builders {
|
||||
let extensions: Vec<_> = builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Server))
|
||||
.collect();
|
||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||
write
|
||||
.write_frame(Packet::new_info(extensions).into())
|
||||
.await?;
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
supported_extensions = info
|
||||
.extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect();
|
||||
downgraded = false;
|
||||
} else {
|
||||
extra_packet.push(packet.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
let (supported_extensions, extra_packet, downgraded) =
|
||||
if let Some(builders) = extension_builders {
|
||||
write
|
||||
.write_frame(
|
||||
Packet::new_info(
|
||||
builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Client))
|
||||
.collect(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
maybe_wisp_v2(&mut read, &write, builders).await?
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
|
||||
Ok(ServerMuxResult(
|
||||
Self {
|
||||
|
@ -540,11 +567,14 @@ impl ServerMux {
|
|||
close_tx: close_tx.clone(),
|
||||
downgraded,
|
||||
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
|
||||
tx: write.clone(),
|
||||
fut_exited: fut_exited.clone(),
|
||||
},
|
||||
MuxInner {
|
||||
tx: write,
|
||||
stream_map: DashMap::new(),
|
||||
buffer_size,
|
||||
fut_exited
|
||||
}
|
||||
.server_into_future(
|
||||
AppendingWebSocketRead(extra_packet, read),
|
||||
|
@ -558,10 +588,16 @@ impl ServerMux {
|
|||
|
||||
/// Wait for a stream to be created.
|
||||
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
|
||||
if self.fut_exited.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
self.muxstream_recv.recv_async().await.ok()
|
||||
}
|
||||
|
||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||
if self.fut_exited.load(Ordering::Acquire) {
|
||||
return Err(WispError::MuxTaskEnded);
|
||||
}
|
||||
self.close_tx
|
||||
.send_async(WsEvent::EndFut(reason))
|
||||
.await
|
||||
|
@ -570,20 +606,27 @@ impl ServerMux {
|
|||
|
||||
/// Close all streams.
|
||||
///
|
||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
||||
/// this function is called.
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.close_internal(None).await
|
||||
}
|
||||
|
||||
/// Close all streams and send an extension incompatibility error to the client.
|
||||
///
|
||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
|
||||
/// this function is called.
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
|
||||
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: 0,
|
||||
tx: self.tx.clone(),
|
||||
is_closed: self.fut_exited.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerMux {
|
||||
|
@ -656,6 +699,8 @@ pub struct ClientMux {
|
|||
/// Extensions that are supported by both sides.
|
||||
pub supported_extension_ids: Vec<u8>,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
tx: ws::LockedWebSocketWrite,
|
||||
fut_exited: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl ClientMux {
|
||||
|
@ -675,44 +720,30 @@ impl ClientMux {
|
|||
{
|
||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
||||
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||
|
||||
if first_packet.stream_id != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let mut supported_extensions = Vec::new();
|
||||
let mut extra_packet = Vec::with_capacity(1);
|
||||
let mut downgraded = true;
|
||||
|
||||
if let Some(builders) = extension_builders {
|
||||
let extensions: Vec<_> = builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Client))
|
||||
.collect();
|
||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
supported_extensions = info
|
||||
.extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect();
|
||||
write
|
||||
.write_frame(Packet::new_info(extensions).into())
|
||||
.await?;
|
||||
downgraded = false;
|
||||
} else {
|
||||
extra_packet.push(packet.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(&mut read, &write).await?;
|
||||
}
|
||||
let (supported_extensions, extra_packet, downgraded) =
|
||||
if let Some(builders) = extension_builders {
|
||||
let x = maybe_wisp_v2(&mut read, &write, builders).await?;
|
||||
write
|
||||
.write_frame(
|
||||
Packet::new_info(
|
||||
builders
|
||||
.iter()
|
||||
.map(|x| x.build_to_extension(Role::Client))
|
||||
.collect(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
x
|
||||
} else {
|
||||
(Vec::new(), None, true)
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
||||
Ok(ClientMuxResult(
|
||||
|
@ -723,11 +754,14 @@ impl ClientMux {
|
|||
.iter()
|
||||
.map(|x| x.get_id())
|
||||
.collect(),
|
||||
tx: write.clone(),
|
||||
fut_exited: fut_exited.clone(),
|
||||
},
|
||||
MuxInner {
|
||||
tx: write,
|
||||
stream_map: DashMap::new(),
|
||||
buffer_size: packet.buffer_remaining,
|
||||
fut_exited
|
||||
}
|
||||
.client_into_future(
|
||||
AppendingWebSocketRead(extra_packet, read),
|
||||
|
@ -748,6 +782,9 @@ impl ClientMux {
|
|||
host: String,
|
||||
port: u16,
|
||||
) -> Result<MuxStream, WispError> {
|
||||
if self.fut_exited.load(Ordering::Acquire) {
|
||||
return Err(WispError::MuxTaskEnded);
|
||||
}
|
||||
if stream_type == StreamType::Udp
|
||||
&& !self
|
||||
.supported_extension_ids
|
||||
|
@ -767,6 +804,9 @@ impl ClientMux {
|
|||
}
|
||||
|
||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||
if self.fut_exited.load(Ordering::Acquire) {
|
||||
return Err(WispError::MuxTaskEnded);
|
||||
}
|
||||
self.stream_tx
|
||||
.send_async(WsEvent::EndFut(reason))
|
||||
.await
|
||||
|
@ -775,20 +815,27 @@ impl ClientMux {
|
|||
|
||||
/// Close all streams.
|
||||
///
|
||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||
/// function.
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close(&self) -> Result<(), WispError> {
|
||||
self.close_internal(None).await
|
||||
}
|
||||
|
||||
/// Close all streams and send an extension incompatibility error to the client.
|
||||
///
|
||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||
/// function.
|
||||
/// Also terminates the multiplexor future.
|
||||
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
|
||||
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||
MuxProtocolExtensionStream {
|
||||
stream_id: 0,
|
||||
tx: self.tx.clone(),
|
||||
is_closed: self.fut_exited.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ClientMux {
|
||||
|
@ -812,7 +859,10 @@ where
|
|||
}
|
||||
|
||||
/// Require protocol extensions by their ID.
|
||||
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(ClientMux, F), WispError> {
|
||||
pub async fn with_required_extensions(
|
||||
self,
|
||||
extensions: &[u8],
|
||||
) -> Result<(ClientMux, F), WispError> {
|
||||
let mut unsupported_extensions = Vec::new();
|
||||
for extension in extensions {
|
||||
if !self.0.supported_extension_ids.contains(extension) {
|
||||
|
@ -830,6 +880,7 @@ where
|
|||
|
||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
||||
self.with_required_extensions(&[UdpProtocolExtension::ID]).await
|
||||
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue