mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
some optimizations and muxprotocolextensionstream for stream id 0
This commit is contained in:
parent
3b8dedeba2
commit
b3f35b232f
7 changed files with 237 additions and 170 deletions
|
@ -1,10 +1,9 @@
|
||||||
use crate::*;
|
use crate::*;
|
||||||
use std::{
|
use std::{
|
||||||
pin::Pin,
|
ops::Deref, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::{Context, Poll}
|
||||||
sync::atomic::{AtomicBool, Ordering},
|
|
||||||
task::{Context, Poll},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
use event_listener::Event;
|
use event_listener::Event;
|
||||||
use futures_util::{FutureExt, Stream};
|
use futures_util::{FutureExt, Stream};
|
||||||
use hyper::body::Body;
|
use hyper::body::Body;
|
||||||
|
@ -207,7 +206,7 @@ impl WebSocketRead for WebSocketReader {
|
||||||
_ = self.close_event.listen().fuse() => Some(Closed),
|
_ = self.close_event.listen().fuse() => Some(Closed),
|
||||||
};
|
};
|
||||||
match res.ok_or(WispError::WsImplSocketClosed)? {
|
match res.ok_or(WispError::WsImplSocketClosed)? {
|
||||||
Message(bin) => Ok(Frame::binary(bin.into())),
|
Message(bin) => Ok(Frame::binary(BytesMut::from(bin.deref()))),
|
||||||
Error => Err(WebSocketError::Unknown.into()),
|
Error => Err(WebSocketError::Unknown.into()),
|
||||||
Closed => Err(WispError::WsImplSocketClosed),
|
Closed => Err(WispError::WsImplSocketClosed),
|
||||||
}
|
}
|
||||||
|
|
|
@ -225,14 +225,14 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
interval.tick().await;
|
interval.tick().await;
|
||||||
let now = cnt_avg.get();
|
let now = cnt_avg.get();
|
||||||
let stat = format!(
|
let stat = format!(
|
||||||
"sent &[0; 1024 * {}] cnt: {:?} ({} KiB), +{:?} ({} KiB / 100ms), moving average (10 s): {:?} ({} KiB / 10 s)",
|
"sent &[0; 1024 * {}] cnt: {:?} ({} KiB), +{:?} / 100ms ({} KiB / 1s), moving average (10 s): {:?} / 100ms ({} KiB / 1s)",
|
||||||
opts.packet_size,
|
opts.packet_size,
|
||||||
now,
|
now,
|
||||||
now * opts.packet_size,
|
now * opts.packet_size,
|
||||||
now - last_time,
|
now - last_time,
|
||||||
(now - last_time) * opts.packet_size,
|
(now - last_time) * opts.packet_size * 10,
|
||||||
avg.get_average(),
|
avg.get_average(),
|
||||||
avg.get_average() * opts.packet_size,
|
avg.get_average() * opts.packet_size * 10,
|
||||||
);
|
);
|
||||||
if is_term {
|
if is_term {
|
||||||
println!("\x1b[1A\x1b[2K{}\r", stat);
|
println!("\x1b[1A\x1b[2K{}\r", stat);
|
||||||
|
|
|
@ -30,7 +30,7 @@ impl From<Frame<'_>> for crate::ws::Frame {
|
||||||
Self {
|
Self {
|
||||||
finished: frame.fin,
|
finished: frame.fin,
|
||||||
opcode: frame.opcode.into(),
|
opcode: frame.opcode.into(),
|
||||||
payload: BytesMut::from(frame.payload.deref()).freeze(),
|
payload: BytesMut::from(frame.payload.deref()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ impl From<Frame<'_>> for crate::ws::Frame {
|
||||||
impl<'a> From<crate::ws::Frame> for Frame<'a> {
|
impl<'a> From<crate::ws::Frame> for Frame<'a> {
|
||||||
fn from(frame: crate::ws::Frame) -> Self {
|
fn from(frame: crate::ws::Frame) -> Self {
|
||||||
use crate::ws::OpCode::*;
|
use crate::ws::OpCode::*;
|
||||||
let payload = Payload::Owned(frame.payload.into());
|
let payload = Payload::Bytes(frame.payload);
|
||||||
match frame.opcode {
|
match frame.opcode {
|
||||||
Text => Self::text(payload),
|
Text => Self::text(payload),
|
||||||
Binary => Self::binary(payload),
|
Binary => Self::binary(payload),
|
||||||
|
|
245
wisp/src/lib.rs
245
wisp/src/lib.rs
|
@ -29,7 +29,7 @@ use std::{
|
||||||
},
|
},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use ws::AppendingWebSocketRead;
|
use ws::{AppendingWebSocketRead, LockedWebSocketWrite};
|
||||||
|
|
||||||
/// Wisp version supported by this crate.
|
/// Wisp version supported by this crate.
|
||||||
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||||
|
@ -92,6 +92,8 @@ pub enum WispError {
|
||||||
MuxMessageFailedToSend,
|
MuxMessageFailedToSend,
|
||||||
/// Failed to receive message from multiplexor task.
|
/// Failed to receive message from multiplexor task.
|
||||||
MuxMessageFailedToRecv,
|
MuxMessageFailedToRecv,
|
||||||
|
/// Multiplexor task ended.
|
||||||
|
MuxTaskEnded,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<std::str::Utf8Error> for WispError {
|
impl From<std::str::Utf8Error> for WispError {
|
||||||
|
@ -145,6 +147,7 @@ impl std::fmt::Display for WispError {
|
||||||
Self::Other(err) => write!(f, "Other error: {}", err),
|
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||||
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
||||||
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
|
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
|
||||||
|
Self::MuxTaskEnded => write!(f, "Multiplexor task ended"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -164,6 +167,7 @@ struct MuxInner {
|
||||||
tx: ws::LockedWebSocketWrite,
|
tx: ws::LockedWebSocketWrite,
|
||||||
stream_map: DashMap<u32, MuxMapValue>,
|
stream_map: DashMap<u32, MuxMapValue>,
|
||||||
buffer_size: u32,
|
buffer_size: u32,
|
||||||
|
fut_exited: Arc<AtomicBool>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MuxInner {
|
impl MuxInner {
|
||||||
|
@ -210,6 +214,7 @@ impl MuxInner {
|
||||||
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
|
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
|
||||||
x = wisp_fut.fuse() => x,
|
x = wisp_fut.fuse() => x,
|
||||||
};
|
};
|
||||||
|
self.fut_exited.store(true, Ordering::Release);
|
||||||
for x in self.stream_map.iter_mut() {
|
for x in self.stream_map.iter_mut() {
|
||||||
x.is_closed.store(true, Ordering::Release);
|
x.is_closed.store(true, Ordering::Release);
|
||||||
x.is_closed_event.notify(usize::MAX);
|
x.is_closed_event.notify(usize::MAX);
|
||||||
|
@ -225,6 +230,7 @@ impl MuxInner {
|
||||||
stream_type: StreamType,
|
stream_type: StreamType,
|
||||||
role: Role,
|
role: Role,
|
||||||
stream_tx: mpsc::Sender<WsEvent>,
|
stream_tx: mpsc::Sender<WsEvent>,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
target_buffer_size: u32,
|
target_buffer_size: u32,
|
||||||
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
) -> Result<(MuxMapValue, MuxStream), WispError> {
|
||||||
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize);
|
||||||
|
@ -249,7 +255,8 @@ impl MuxInner {
|
||||||
role,
|
role,
|
||||||
stream_type,
|
stream_type,
|
||||||
ch_rx,
|
ch_rx,
|
||||||
stream_tx.clone(),
|
stream_tx,
|
||||||
|
tx,
|
||||||
is_closed,
|
is_closed,
|
||||||
is_closed_event,
|
is_closed_event,
|
||||||
flow_control,
|
flow_control,
|
||||||
|
@ -267,16 +274,6 @@ impl MuxInner {
|
||||||
let mut next_free_stream_id: u32 = 1;
|
let mut next_free_stream_id: u32 = 1;
|
||||||
while let Ok(msg) = stream_rx.recv_async().await {
|
while let Ok(msg) = stream_rx.recv_async().await {
|
||||||
match msg {
|
match msg {
|
||||||
WsEvent::SendPacket(packet, channel) => {
|
|
||||||
if self.stream_map.get(&packet.stream_id).is_some() {
|
|
||||||
let _ = channel.send(self.tx.write_frame(packet.into()).await);
|
|
||||||
} else {
|
|
||||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
WsEvent::SendBytes(packet, channel) => {
|
|
||||||
let _ = channel.send(self.tx.write_frame(ws::Frame::binary(packet)).await);
|
|
||||||
}
|
|
||||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||||
let ret: Result<MuxStream, WispError> = async {
|
let ret: Result<MuxStream, WispError> = async {
|
||||||
let stream_id = next_free_stream_id;
|
let stream_id = next_free_stream_id;
|
||||||
|
@ -290,6 +287,7 @@ impl MuxInner {
|
||||||
stream_type,
|
stream_type,
|
||||||
Role::Client,
|
Role::Client,
|
||||||
stream_tx.clone(),
|
stream_tx.clone(),
|
||||||
|
self.tx.clone(),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -330,6 +328,16 @@ impl MuxInner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn close_stream(&self, packet: Packet) {
|
||||||
|
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||||
|
stream.is_closed.store(true, Ordering::Release);
|
||||||
|
stream.is_closed_event.notify(usize::MAX);
|
||||||
|
stream.flow_control.store(u32::MAX, Ordering::Release);
|
||||||
|
stream.flow_control_event.notify(usize::MAX);
|
||||||
|
drop(stream.stream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn server_loop<R>(
|
async fn server_loop<R>(
|
||||||
&self,
|
&self,
|
||||||
mut rx: R,
|
mut rx: R,
|
||||||
|
@ -353,6 +361,7 @@ impl MuxInner {
|
||||||
{
|
{
|
||||||
use PacketType::*;
|
use PacketType::*;
|
||||||
match packet.packet_type {
|
match packet.packet_type {
|
||||||
|
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||||
Connect(inner_packet) => {
|
Connect(inner_packet) => {
|
||||||
let (map_value, stream) = self
|
let (map_value, stream) = self
|
||||||
.create_new_stream(
|
.create_new_stream(
|
||||||
|
@ -360,6 +369,7 @@ impl MuxInner {
|
||||||
inner_packet.stream_type,
|
inner_packet.stream_type,
|
||||||
Role::Server,
|
Role::Server,
|
||||||
stream_tx.clone(),
|
stream_tx.clone(),
|
||||||
|
self.tx.clone(),
|
||||||
target_buffer_size,
|
target_buffer_size,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -383,16 +393,11 @@ impl MuxInner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
|
||||||
Close(_) => {
|
Close(_) => {
|
||||||
if packet.stream_id == 0 {
|
if packet.stream_id == 0 {
|
||||||
break Ok(());
|
break Ok(());
|
||||||
}
|
}
|
||||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
self.close_stream(packet)
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
|
||||||
stream.is_closed_event.notify(usize::MAX);
|
|
||||||
drop(stream.stream)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -437,11 +442,7 @@ impl MuxInner {
|
||||||
if packet.stream_id == 0 {
|
if packet.stream_id == 0 {
|
||||||
break Ok(());
|
break Ok(());
|
||||||
}
|
}
|
||||||
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
|
self.close_stream(packet)
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
|
||||||
stream.is_closed_event.notify(usize::MAX);
|
|
||||||
drop(stream.stream)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -449,6 +450,42 @@ impl MuxInner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn maybe_wisp_v2<R>(
|
||||||
|
read: &mut R,
|
||||||
|
write: &LockedWebSocketWrite,
|
||||||
|
builders: &[Box<dyn ProtocolExtensionBuilder + Sync + Send>],
|
||||||
|
) -> Result<(Vec<AnyProtocolExtension>, Option<ws::Frame>, bool), WispError>
|
||||||
|
where
|
||||||
|
R: ws::WebSocketRead + Send,
|
||||||
|
{
|
||||||
|
let mut supported_extensions = Vec::new();
|
||||||
|
let mut extra_packet = None;
|
||||||
|
let mut downgraded = true;
|
||||||
|
|
||||||
|
let extension_ids: Vec<_> = builders.iter().map(|x| x.get_id()).collect();
|
||||||
|
if let Some(frame) = select! {
|
||||||
|
x = read.wisp_read_frame(write).fuse() => Some(x?),
|
||||||
|
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||||
|
} {
|
||||||
|
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
|
||||||
|
if let PacketType::Info(info) = packet.packet_type {
|
||||||
|
supported_extensions = info
|
||||||
|
.extensions
|
||||||
|
.into_iter()
|
||||||
|
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||||
|
.collect();
|
||||||
|
downgraded = false;
|
||||||
|
} else {
|
||||||
|
extra_packet.replace(packet.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for extension in supported_extensions.iter_mut() {
|
||||||
|
extension.handle_handshake(read, write).await?;
|
||||||
|
}
|
||||||
|
Ok((supported_extensions, extra_packet, downgraded))
|
||||||
|
}
|
||||||
|
|
||||||
/// Server-side multiplexor.
|
/// Server-side multiplexor.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
|
@ -477,6 +514,8 @@ pub struct ServerMux {
|
||||||
pub supported_extension_ids: Vec<u8>,
|
pub supported_extension_ids: Vec<u8>,
|
||||||
close_tx: mpsc::Sender<WsEvent>,
|
close_tx: mpsc::Sender<WsEvent>,
|
||||||
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>,
|
||||||
|
tx: ws::LockedWebSocketWrite,
|
||||||
|
fut_exited: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerMux {
|
impl ServerMux {
|
||||||
|
@ -498,41 +537,29 @@ impl ServerMux {
|
||||||
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
|
let (close_tx, close_rx) = mpsc::bounded::<WsEvent>(256);
|
||||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||||
|
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
write
|
write
|
||||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut supported_extensions = Vec::new();
|
let (supported_extensions, extra_packet, downgraded) =
|
||||||
let mut extra_packet = Vec::with_capacity(1);
|
if let Some(builders) = extension_builders {
|
||||||
let mut downgraded = true;
|
write
|
||||||
|
.write_frame(
|
||||||
if let Some(builders) = extension_builders {
|
Packet::new_info(
|
||||||
let extensions: Vec<_> = builders
|
builders
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.build_to_extension(Role::Server))
|
.map(|x| x.build_to_extension(Role::Client))
|
||||||
.collect();
|
.collect(),
|
||||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
)
|
||||||
write
|
.into(),
|
||||||
.write_frame(Packet::new_info(extensions).into())
|
)
|
||||||
.await?;
|
.await?;
|
||||||
if let Some(frame) = select! {
|
maybe_wisp_v2(&mut read, &write, builders).await?
|
||||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
} else {
|
||||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
(Vec::new(), None, true)
|
||||||
} {
|
};
|
||||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
|
||||||
if let PacketType::Info(info) = packet.packet_type {
|
|
||||||
supported_extensions = info
|
|
||||||
.extensions
|
|
||||||
.into_iter()
|
|
||||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
|
||||||
.collect();
|
|
||||||
downgraded = false;
|
|
||||||
} else {
|
|
||||||
extra_packet.push(packet.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ServerMuxResult(
|
Ok(ServerMuxResult(
|
||||||
Self {
|
Self {
|
||||||
|
@ -540,11 +567,14 @@ impl ServerMux {
|
||||||
close_tx: close_tx.clone(),
|
close_tx: close_tx.clone(),
|
||||||
downgraded,
|
downgraded,
|
||||||
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
|
supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(),
|
||||||
|
tx: write.clone(),
|
||||||
|
fut_exited: fut_exited.clone(),
|
||||||
},
|
},
|
||||||
MuxInner {
|
MuxInner {
|
||||||
tx: write,
|
tx: write,
|
||||||
stream_map: DashMap::new(),
|
stream_map: DashMap::new(),
|
||||||
buffer_size,
|
buffer_size,
|
||||||
|
fut_exited
|
||||||
}
|
}
|
||||||
.server_into_future(
|
.server_into_future(
|
||||||
AppendingWebSocketRead(extra_packet, read),
|
AppendingWebSocketRead(extra_packet, read),
|
||||||
|
@ -558,10 +588,16 @@ impl ServerMux {
|
||||||
|
|
||||||
/// Wait for a stream to be created.
|
/// Wait for a stream to be created.
|
||||||
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
|
pub async fn server_new_stream(&self) -> Option<(ConnectPacket, MuxStream)> {
|
||||||
|
if self.fut_exited.load(Ordering::Acquire) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
self.muxstream_recv.recv_async().await.ok()
|
self.muxstream_recv.recv_async().await.ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
if self.fut_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
self.close_tx
|
self.close_tx
|
||||||
.send_async(WsEvent::EndFut(reason))
|
.send_async(WsEvent::EndFut(reason))
|
||||||
.await
|
.await
|
||||||
|
@ -570,20 +606,27 @@ impl ServerMux {
|
||||||
|
|
||||||
/// Close all streams.
|
/// Close all streams.
|
||||||
///
|
///
|
||||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
/// Also terminates the multiplexor future.
|
||||||
/// this function is called.
|
|
||||||
pub async fn close(&self) -> Result<(), WispError> {
|
pub async fn close(&self) -> Result<(), WispError> {
|
||||||
self.close_internal(None).await
|
self.close_internal(None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Close all streams and send an extension incompatibility error to the client.
|
/// Close all streams and send an extension incompatibility error to the client.
|
||||||
///
|
///
|
||||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
|
/// Also terminates the multiplexor future.
|
||||||
/// this function is called.
|
|
||||||
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
|
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
|
||||||
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||||
|
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||||
|
MuxProtocolExtensionStream {
|
||||||
|
stream_id: 0,
|
||||||
|
tx: self.tx.clone(),
|
||||||
|
is_closed: self.fut_exited.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for ServerMux {
|
impl Drop for ServerMux {
|
||||||
|
@ -656,6 +699,8 @@ pub struct ClientMux {
|
||||||
/// Extensions that are supported by both sides.
|
/// Extensions that are supported by both sides.
|
||||||
pub supported_extension_ids: Vec<u8>,
|
pub supported_extension_ids: Vec<u8>,
|
||||||
stream_tx: mpsc::Sender<WsEvent>,
|
stream_tx: mpsc::Sender<WsEvent>,
|
||||||
|
tx: ws::LockedWebSocketWrite,
|
||||||
|
fut_exited: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientMux {
|
impl ClientMux {
|
||||||
|
@ -675,44 +720,30 @@ impl ClientMux {
|
||||||
{
|
{
|
||||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||||
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
||||||
|
let fut_exited = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
if first_packet.stream_id != 0 {
|
if first_packet.stream_id != 0 {
|
||||||
return Err(WispError::InvalidStreamId);
|
return Err(WispError::InvalidStreamId);
|
||||||
}
|
}
|
||||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||||
let mut supported_extensions = Vec::new();
|
let (supported_extensions, extra_packet, downgraded) =
|
||||||
let mut extra_packet = Vec::with_capacity(1);
|
if let Some(builders) = extension_builders {
|
||||||
let mut downgraded = true;
|
let x = maybe_wisp_v2(&mut read, &write, builders).await?;
|
||||||
|
write
|
||||||
if let Some(builders) = extension_builders {
|
.write_frame(
|
||||||
let extensions: Vec<_> = builders
|
Packet::new_info(
|
||||||
.iter()
|
builders
|
||||||
.map(|x| x.build_to_extension(Role::Client))
|
.iter()
|
||||||
.collect();
|
.map(|x| x.build_to_extension(Role::Client))
|
||||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
.collect(),
|
||||||
if let Some(frame) = select! {
|
)
|
||||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
.into(),
|
||||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
)
|
||||||
} {
|
.await?;
|
||||||
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
|
x
|
||||||
if let PacketType::Info(info) = packet.packet_type {
|
} else {
|
||||||
supported_extensions = info
|
(Vec::new(), None, true)
|
||||||
.extensions
|
};
|
||||||
.into_iter()
|
|
||||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
|
||||||
.collect();
|
|
||||||
write
|
|
||||||
.write_frame(Packet::new_info(extensions).into())
|
|
||||||
.await?;
|
|
||||||
downgraded = false;
|
|
||||||
} else {
|
|
||||||
extra_packet.push(packet.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for extension in supported_extensions.iter_mut() {
|
|
||||||
extension.handle_handshake(&mut read, &write).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
let (tx, rx) = mpsc::bounded::<WsEvent>(256);
|
||||||
Ok(ClientMuxResult(
|
Ok(ClientMuxResult(
|
||||||
|
@ -723,11 +754,14 @@ impl ClientMux {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.get_id())
|
.map(|x| x.get_id())
|
||||||
.collect(),
|
.collect(),
|
||||||
|
tx: write.clone(),
|
||||||
|
fut_exited: fut_exited.clone(),
|
||||||
},
|
},
|
||||||
MuxInner {
|
MuxInner {
|
||||||
tx: write,
|
tx: write,
|
||||||
stream_map: DashMap::new(),
|
stream_map: DashMap::new(),
|
||||||
buffer_size: packet.buffer_remaining,
|
buffer_size: packet.buffer_remaining,
|
||||||
|
fut_exited
|
||||||
}
|
}
|
||||||
.client_into_future(
|
.client_into_future(
|
||||||
AppendingWebSocketRead(extra_packet, read),
|
AppendingWebSocketRead(extra_packet, read),
|
||||||
|
@ -748,6 +782,9 @@ impl ClientMux {
|
||||||
host: String,
|
host: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
) -> Result<MuxStream, WispError> {
|
) -> Result<MuxStream, WispError> {
|
||||||
|
if self.fut_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
if stream_type == StreamType::Udp
|
if stream_type == StreamType::Udp
|
||||||
&& !self
|
&& !self
|
||||||
.supported_extension_ids
|
.supported_extension_ids
|
||||||
|
@ -767,6 +804,9 @@ impl ClientMux {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
async fn close_internal(&self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
if self.fut_exited.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::MuxTaskEnded);
|
||||||
|
}
|
||||||
self.stream_tx
|
self.stream_tx
|
||||||
.send_async(WsEvent::EndFut(reason))
|
.send_async(WsEvent::EndFut(reason))
|
||||||
.await
|
.await
|
||||||
|
@ -775,20 +815,27 @@ impl ClientMux {
|
||||||
|
|
||||||
/// Close all streams.
|
/// Close all streams.
|
||||||
///
|
///
|
||||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
/// Also terminates the multiplexor future.
|
||||||
/// function.
|
|
||||||
pub async fn close(&self) -> Result<(), WispError> {
|
pub async fn close(&self) -> Result<(), WispError> {
|
||||||
self.close_internal(None).await
|
self.close_internal(None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Close all streams and send an extension incompatibility error to the client.
|
/// Close all streams and send an extension incompatibility error to the client.
|
||||||
///
|
///
|
||||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
/// Also terminates the multiplexor future.
|
||||||
/// function.
|
|
||||||
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
|
pub async fn close_extension_incompat(&self) -> Result<(), WispError> {
|
||||||
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get a protocol extension stream for sending packets with stream id 0.
|
||||||
|
pub fn get_protocol_extension_stream(&self) -> MuxProtocolExtensionStream {
|
||||||
|
MuxProtocolExtensionStream {
|
||||||
|
stream_id: 0,
|
||||||
|
tx: self.tx.clone(),
|
||||||
|
is_closed: self.fut_exited.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for ClientMux {
|
impl Drop for ClientMux {
|
||||||
|
@ -812,7 +859,10 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Require protocol extensions by their ID.
|
/// Require protocol extensions by their ID.
|
||||||
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(ClientMux, F), WispError> {
|
pub async fn with_required_extensions(
|
||||||
|
self,
|
||||||
|
extensions: &[u8],
|
||||||
|
) -> Result<(ClientMux, F), WispError> {
|
||||||
let mut unsupported_extensions = Vec::new();
|
let mut unsupported_extensions = Vec::new();
|
||||||
for extension in extensions {
|
for extension in extensions {
|
||||||
if !self.0.supported_extension_ids.contains(extension) {
|
if !self.0.supported_extension_ids.contains(extension) {
|
||||||
|
@ -830,6 +880,7 @@ where
|
||||||
|
|
||||||
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
|
||||||
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
|
||||||
self.with_required_extensions(&[UdpProtocolExtension::ID]).await
|
self.with_required_extensions(&[UdpProtocolExtension::ID])
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -362,12 +362,12 @@ impl Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> Bytes {
|
pub(crate) fn raw_encode(packet_type: u8, stream_id: u32, bytes: Bytes) -> BytesMut {
|
||||||
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
|
let mut encoded = BytesMut::with_capacity(1 + 4 + bytes.len());
|
||||||
encoded.put_u8(packet_type);
|
encoded.put_u8(packet_type);
|
||||||
encoded.put_u32_le(stream_id);
|
encoded.put_u32_le(stream_id);
|
||||||
encoded.extend(bytes);
|
encoded.extend(bytes);
|
||||||
encoded.freeze()
|
encoded
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
|
fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result<Self, WispError> {
|
||||||
|
@ -396,7 +396,7 @@ impl Packet {
|
||||||
if frame.opcode != OpCode::Binary {
|
if frame.opcode != OpCode::Binary {
|
||||||
return Err(WispError::WsFrameInvalidType);
|
return Err(WispError::WsFrameInvalidType);
|
||||||
}
|
}
|
||||||
let mut bytes = frame.payload;
|
let mut bytes = frame.payload.freeze();
|
||||||
if bytes.remaining() < 1 {
|
if bytes.remaining() < 1 {
|
||||||
return Err(WispError::PacketTooSmall);
|
return Err(WispError::PacketTooSmall);
|
||||||
}
|
}
|
||||||
|
@ -420,22 +420,40 @@ impl Packet {
|
||||||
if frame.opcode != OpCode::Binary {
|
if frame.opcode != OpCode::Binary {
|
||||||
return Err(WispError::WsFrameInvalidType);
|
return Err(WispError::WsFrameInvalidType);
|
||||||
}
|
}
|
||||||
let mut bytes = frame.payload;
|
let mut bytes = frame.payload.freeze();
|
||||||
if bytes.remaining() < 1 {
|
if bytes.remaining() < 1 {
|
||||||
return Err(WispError::PacketTooSmall);
|
return Err(WispError::PacketTooSmall);
|
||||||
}
|
}
|
||||||
let packet_type = bytes.get_u8();
|
let packet_type = bytes.get_u8();
|
||||||
if let Some(extension) = extensions
|
match packet_type {
|
||||||
.iter_mut()
|
0x01 => Ok(Some(Self {
|
||||||
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
stream_id: bytes.get_u32_le(),
|
||||||
{
|
packet_type: PacketType::Connect(bytes.try_into()?),
|
||||||
extension.handle_packet(bytes, read, write).await?;
|
})),
|
||||||
Ok(None)
|
0x02 => Ok(Some(Self {
|
||||||
} else if packet_type == 0x05 {
|
stream_id: bytes.get_u32_le(),
|
||||||
// Server may send a 0x05 in handshake since it's Wisp v2 but we may be Wisp v1 so we need to ignore 0x05
|
packet_type: PacketType::Data(bytes),
|
||||||
Ok(None)
|
})),
|
||||||
} else {
|
0x03 => Ok(Some(Self {
|
||||||
Ok(Some(Self::parse_packet(packet_type, bytes)?))
|
stream_id: bytes.get_u32_le(),
|
||||||
|
packet_type: PacketType::Continue(bytes.try_into()?),
|
||||||
|
})),
|
||||||
|
0x04 => Ok(Some(Self {
|
||||||
|
stream_id: bytes.get_u32_le(),
|
||||||
|
packet_type: PacketType::Close(bytes.try_into()?),
|
||||||
|
})),
|
||||||
|
0x05 => Ok(None),
|
||||||
|
packet_type => {
|
||||||
|
if let Some(extension) = extensions
|
||||||
|
.iter_mut()
|
||||||
|
.find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type))
|
||||||
|
{
|
||||||
|
extension.handle_packet(bytes, read, write).await?;
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Err(WispError::InvalidPacketType)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -500,7 +518,7 @@ impl TryFrom<Bytes> for Packet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Packet> for Bytes {
|
impl From<Packet> for BytesMut {
|
||||||
fn from(packet: Packet) -> Self {
|
fn from(packet: Packet) -> Self {
|
||||||
Packet::raw_encode(
|
Packet::raw_encode(
|
||||||
packet.packet_type.as_u8(),
|
packet.packet_type.as_u8(),
|
||||||
|
@ -519,7 +537,7 @@ impl TryFrom<ws::Frame> for Packet {
|
||||||
if frame.opcode != ws::OpCode::Binary {
|
if frame.opcode != ws::OpCode::Binary {
|
||||||
return Err(Self::Error::WsFrameInvalidType);
|
return Err(Self::Error::WsFrameInvalidType);
|
||||||
}
|
}
|
||||||
frame.payload.try_into()
|
frame.payload.freeze().try_into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError};
|
use crate::{
|
||||||
|
sink_unfold,
|
||||||
|
ws::{Frame, LockedWebSocketWrite},
|
||||||
|
CloseReason, Packet, Role, StreamType, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
pub use async_io_stream::IoStream;
|
pub use async_io_stream::IoStream;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
@ -20,8 +24,6 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) enum WsEvent {
|
pub(crate) enum WsEvent {
|
||||||
SendPacket(Packet, oneshot::Sender<Result<(), WispError>>),
|
|
||||||
SendBytes(Bytes, oneshot::Sender<Result<(), WispError>>),
|
|
||||||
Close(Packet, oneshot::Sender<Result<(), WispError>>),
|
Close(Packet, oneshot::Sender<Result<(), WispError>>),
|
||||||
CreateStream(
|
CreateStream(
|
||||||
StreamType,
|
StreamType,
|
||||||
|
@ -39,7 +41,7 @@ pub struct MuxStreamRead {
|
||||||
/// Type of the stream.
|
/// Type of the stream.
|
||||||
pub stream_type: StreamType,
|
pub stream_type: StreamType,
|
||||||
role: Role,
|
role: Role,
|
||||||
tx: mpsc::Sender<WsEvent>,
|
tx: LockedWebSocketWrite,
|
||||||
rx: mpsc::Receiver<Bytes>,
|
rx: mpsc::Receiver<Bytes>,
|
||||||
is_closed: Arc<AtomicBool>,
|
is_closed: Arc<AtomicBool>,
|
||||||
is_closed_event: Arc<Event>,
|
is_closed_event: Arc<Event>,
|
||||||
|
@ -60,19 +62,17 @@ impl MuxStreamRead {
|
||||||
};
|
};
|
||||||
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
|
if self.role == Role::Server && self.stream_type == StreamType::Tcp {
|
||||||
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
|
let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1;
|
||||||
if val > self.target_flow_control {
|
if val > self.target_flow_control && !self.is_closed.load(Ordering::Acquire) {
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
|
||||||
self.tx
|
self.tx
|
||||||
.send_async(WsEvent::SendPacket(
|
.write_frame(
|
||||||
Packet::new_continue(
|
Packet::new_continue(
|
||||||
self.stream_id,
|
self.stream_id,
|
||||||
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
|
self.flow_control.fetch_add(val, Ordering::AcqRel) + val,
|
||||||
),
|
)
|
||||||
tx,
|
.into(),
|
||||||
))
|
)
|
||||||
.await
|
.await
|
||||||
.ok()?;
|
.ok()?;
|
||||||
rx.await.ok()?.ok()?;
|
|
||||||
self.flow_control_read.store(0, Ordering::Release);
|
self.flow_control_read.store(0, Ordering::Release);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -93,7 +93,8 @@ pub struct MuxStreamWrite {
|
||||||
/// Type of the stream.
|
/// Type of the stream.
|
||||||
pub stream_type: StreamType,
|
pub stream_type: StreamType,
|
||||||
role: Role,
|
role: Role,
|
||||||
tx: mpsc::Sender<WsEvent>,
|
mux_tx: mpsc::Sender<WsEvent>,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
is_closed: Arc<AtomicBool>,
|
is_closed: Arc<AtomicBool>,
|
||||||
continue_recieved: Arc<Event>,
|
continue_recieved: Arc<Event>,
|
||||||
flow_control: Arc<AtomicU32>,
|
flow_control: Arc<AtomicU32>,
|
||||||
|
@ -102,24 +103,20 @@ pub struct MuxStreamWrite {
|
||||||
impl MuxStreamWrite {
|
impl MuxStreamWrite {
|
||||||
/// Write data to the stream.
|
/// Write data to the stream.
|
||||||
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
pub async fn write(&self, data: Bytes) -> Result<(), WispError> {
|
||||||
if self.is_closed.load(Ordering::Acquire) {
|
|
||||||
return Err(WispError::StreamAlreadyClosed);
|
|
||||||
}
|
|
||||||
if self.role == Role::Client
|
if self.role == Role::Client
|
||||||
&& self.stream_type == StreamType::Tcp
|
&& self.stream_type == StreamType::Tcp
|
||||||
&& self.flow_control.load(Ordering::Acquire) == 0
|
&& self.flow_control.load(Ordering::Acquire) == 0
|
||||||
{
|
{
|
||||||
self.continue_recieved.listen().await;
|
self.continue_recieved.listen().await;
|
||||||
}
|
}
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
if self.is_closed.load(Ordering::Acquire) {
|
||||||
|
return Err(WispError::StreamAlreadyClosed);
|
||||||
|
}
|
||||||
|
|
||||||
self.tx
|
self.tx
|
||||||
.send_async(WsEvent::SendPacket(
|
.write_frame(Packet::new_data(self.stream_id, data).into())
|
||||||
Packet::new_data(self.stream_id, data),
|
.await?;
|
||||||
tx,
|
|
||||||
))
|
|
||||||
.await
|
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
|
||||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
|
||||||
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
if self.role == Role::Client && self.stream_type == StreamType::Tcp {
|
||||||
self.flow_control.store(
|
self.flow_control.store(
|
||||||
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
|
self.flow_control.load(Ordering::Acquire).saturating_sub(1),
|
||||||
|
@ -143,7 +140,7 @@ impl MuxStreamWrite {
|
||||||
pub fn get_close_handle(&self) -> MuxStreamCloser {
|
pub fn get_close_handle(&self) -> MuxStreamCloser {
|
||||||
MuxStreamCloser {
|
MuxStreamCloser {
|
||||||
stream_id: self.stream_id,
|
stream_id: self.stream_id,
|
||||||
close_channel: self.tx.clone(),
|
close_channel: self.mux_tx.clone(),
|
||||||
is_closed: self.is_closed.clone(),
|
is_closed: self.is_closed.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,7 +162,7 @@ impl MuxStreamWrite {
|
||||||
self.is_closed.store(true, Ordering::Release);
|
self.is_closed.store(true, Ordering::Release);
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
||||||
self.tx
|
self.mux_tx
|
||||||
.send_async(WsEvent::Close(
|
.send_async(WsEvent::Close(
|
||||||
Packet::new_close(self.stream_id, reason),
|
Packet::new_close(self.stream_id, reason),
|
||||||
tx,
|
tx,
|
||||||
|
@ -199,7 +196,7 @@ impl Drop for MuxStreamWrite {
|
||||||
if !self.is_closed.load(Ordering::Acquire) {
|
if !self.is_closed.load(Ordering::Acquire) {
|
||||||
self.is_closed.store(true, Ordering::Release);
|
self.is_closed.store(true, Ordering::Release);
|
||||||
let (tx, _) = oneshot::channel();
|
let (tx, _) = oneshot::channel();
|
||||||
let _ = self.tx.send(WsEvent::Close(
|
let _ = self.mux_tx.send(WsEvent::Close(
|
||||||
Packet::new_close(self.stream_id, CloseReason::Unknown),
|
Packet::new_close(self.stream_id, CloseReason::Unknown),
|
||||||
tx,
|
tx,
|
||||||
));
|
));
|
||||||
|
@ -222,7 +219,8 @@ impl MuxStream {
|
||||||
role: Role,
|
role: Role,
|
||||||
stream_type: StreamType,
|
stream_type: StreamType,
|
||||||
rx: mpsc::Receiver<Bytes>,
|
rx: mpsc::Receiver<Bytes>,
|
||||||
tx: mpsc::Sender<WsEvent>,
|
mux_tx: mpsc::Sender<WsEvent>,
|
||||||
|
tx: LockedWebSocketWrite,
|
||||||
is_closed: Arc<AtomicBool>,
|
is_closed: Arc<AtomicBool>,
|
||||||
is_closed_event: Arc<Event>,
|
is_closed_event: Arc<Event>,
|
||||||
flow_control: Arc<AtomicU32>,
|
flow_control: Arc<AtomicU32>,
|
||||||
|
@ -247,6 +245,7 @@ impl MuxStream {
|
||||||
stream_id,
|
stream_id,
|
||||||
stream_type,
|
stream_type,
|
||||||
role,
|
role,
|
||||||
|
mux_tx,
|
||||||
tx,
|
tx,
|
||||||
is_closed: is_closed.clone(),
|
is_closed: is_closed.clone(),
|
||||||
flow_control: flow_control.clone(),
|
flow_control: flow_control.clone(),
|
||||||
|
@ -339,26 +338,23 @@ impl MuxStreamCloser {
|
||||||
pub struct MuxProtocolExtensionStream {
|
pub struct MuxProtocolExtensionStream {
|
||||||
/// ID of the stream.
|
/// ID of the stream.
|
||||||
pub stream_id: u32,
|
pub stream_id: u32,
|
||||||
tx: mpsc::Sender<WsEvent>,
|
pub(crate) tx: LockedWebSocketWrite,
|
||||||
is_closed: Arc<AtomicBool>,
|
pub(crate) is_closed: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MuxProtocolExtensionStream {
|
impl MuxProtocolExtensionStream {
|
||||||
/// Send a protocol extension packet.
|
/// Send a protocol extension packet with this stream's ID.
|
||||||
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
|
pub async fn send(&self, packet_type: u8, data: Bytes) -> Result<(), WispError> {
|
||||||
if self.is_closed.load(Ordering::Acquire) {
|
if self.is_closed.load(Ordering::Acquire) {
|
||||||
return Err(WispError::StreamAlreadyClosed);
|
return Err(WispError::StreamAlreadyClosed);
|
||||||
}
|
}
|
||||||
let (tx, rx) = oneshot::channel::<Result<(), WispError>>();
|
|
||||||
self.tx
|
self.tx
|
||||||
.send_async(WsEvent::SendBytes(
|
.write_frame(Frame::binary(Packet::raw_encode(
|
||||||
Packet::raw_encode(packet_type, self.stream_id, data),
|
packet_type,
|
||||||
tx,
|
self.stream_id,
|
||||||
))
|
data,
|
||||||
|
)))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
|
||||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,11 @@
|
||||||
//! for other WebSocket implementations.
|
//! for other WebSocket implementations.
|
||||||
//!
|
//!
|
||||||
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
//! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::WispError;
|
use crate::WispError;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::BytesMut;
|
||||||
use futures::lock::Mutex;
|
use futures::lock::Mutex;
|
||||||
|
|
||||||
/// Opcode of the WebSocket frame.
|
/// Opcode of the WebSocket frame.
|
||||||
|
@ -32,12 +34,12 @@ pub struct Frame {
|
||||||
/// Opcode of the WebSocket frame.
|
/// Opcode of the WebSocket frame.
|
||||||
pub opcode: OpCode,
|
pub opcode: OpCode,
|
||||||
/// Payload of the WebSocket frame.
|
/// Payload of the WebSocket frame.
|
||||||
pub payload: Bytes,
|
pub payload: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Frame {
|
impl Frame {
|
||||||
/// Create a new text frame.
|
/// Create a new text frame.
|
||||||
pub fn text(payload: Bytes) -> Self {
|
pub fn text(payload: BytesMut) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: true,
|
finished: true,
|
||||||
opcode: OpCode::Text,
|
opcode: OpCode::Text,
|
||||||
|
@ -46,7 +48,7 @@ impl Frame {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new binary frame.
|
/// Create a new binary frame.
|
||||||
pub fn binary(payload: Bytes) -> Self {
|
pub fn binary(payload: BytesMut) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: true,
|
finished: true,
|
||||||
opcode: OpCode::Binary,
|
opcode: OpCode::Binary,
|
||||||
|
@ -55,7 +57,7 @@ impl Frame {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new close frame.
|
/// Create a new close frame.
|
||||||
pub fn close(payload: Bytes) -> Self {
|
pub fn close(payload: BytesMut) -> Self {
|
||||||
Self {
|
Self {
|
||||||
finished: true,
|
finished: true,
|
||||||
opcode: OpCode::Close,
|
opcode: OpCode::Close,
|
||||||
|
@ -82,12 +84,13 @@ pub trait WebSocketWrite {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Locked WebSocket.
|
/// Locked WebSocket.
|
||||||
pub struct LockedWebSocketWrite(Mutex<Box<dyn WebSocketWrite + Send>>);
|
#[derive(Clone)]
|
||||||
|
pub struct LockedWebSocketWrite(Arc<Mutex<Box<dyn WebSocketWrite + Send>>>);
|
||||||
|
|
||||||
impl LockedWebSocketWrite {
|
impl LockedWebSocketWrite {
|
||||||
/// Create a new locked websocket.
|
/// Create a new locked websocket.
|
||||||
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
|
pub fn new(ws: Box<dyn WebSocketWrite + Send>) -> Self {
|
||||||
Self(Mutex::new(ws))
|
Self(Mutex::new(ws).into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write a frame to the websocket.
|
/// Write a frame to the websocket.
|
||||||
|
@ -101,7 +104,7 @@ impl LockedWebSocketWrite {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct AppendingWebSocketRead<R>(pub Vec<Frame>, pub R)
|
pub(crate) struct AppendingWebSocketRead<R>(pub Option<Frame>, pub R)
|
||||||
where
|
where
|
||||||
R: WebSocketRead + Send;
|
R: WebSocketRead + Send;
|
||||||
|
|
||||||
|
@ -111,7 +114,7 @@ where
|
||||||
R: WebSocketRead + Send,
|
R: WebSocketRead + Send,
|
||||||
{
|
{
|
||||||
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result<Frame, WispError> {
|
||||||
if let Some(x) = self.0.pop() {
|
if let Some(x) = self.0.take() {
|
||||||
return Ok(x);
|
return Ok(x);
|
||||||
}
|
}
|
||||||
return self.1.wisp_read_frame(tx).await;
|
return self.1.wisp_read_frame(tx).await;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue