mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-13 06:20:02 -04:00
preliminary support for wisp v2
This commit is contained in:
parent
98072be3d4
commit
ef5ed52e71
18 changed files with 772 additions and 206 deletions
383
wisp/src/lib.rs
383
wisp/src/lib.rs
|
@ -4,6 +4,7 @@
|
|||
//!
|
||||
//! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol
|
||||
|
||||
pub mod extensions;
|
||||
#[cfg(feature = "fastwebsockets")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))]
|
||||
mod fastwebsockets;
|
||||
|
@ -12,18 +13,28 @@ mod sink_unfold;
|
|||
mod stream;
|
||||
pub mod ws;
|
||||
|
||||
pub use crate::packet::*;
|
||||
pub use crate::stream::*;
|
||||
pub use crate::{packet::*, stream::*};
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use event_listener::Event;
|
||||
use futures::SinkExt;
|
||||
use futures::{channel::mpsc, Future, FutureExt, StreamExt};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
select, Future, FutureExt, SinkExt, StreamExt,
|
||||
};
|
||||
use futures_timer::Delay;
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use ws::AppendingWebSocketRead;
|
||||
|
||||
/// Wisp version supported by this crate.
|
||||
pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 };
|
||||
|
||||
/// The role of the multiplexor.
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
|
@ -37,9 +48,9 @@ pub enum Role {
|
|||
/// Errors the Wisp implementation can return.
|
||||
#[derive(Debug)]
|
||||
pub enum WispError {
|
||||
/// The packet recieved did not have enough data.
|
||||
/// The packet received did not have enough data.
|
||||
PacketTooSmall,
|
||||
/// The packet recieved had an invalid type.
|
||||
/// The packet received had an invalid type.
|
||||
InvalidPacketType,
|
||||
/// The stream had an invalid type.
|
||||
InvalidStreamType,
|
||||
|
@ -47,19 +58,19 @@ pub enum WispError {
|
|||
InvalidStreamId,
|
||||
/// The close packet had an invalid reason.
|
||||
InvalidCloseReason,
|
||||
/// The URI recieved was invalid.
|
||||
/// The URI received was invalid.
|
||||
InvalidUri,
|
||||
/// The URI recieved had no host.
|
||||
/// The URI received had no host.
|
||||
UriHasNoHost,
|
||||
/// The URI recieved had no port.
|
||||
/// The URI received had no port.
|
||||
UriHasNoPort,
|
||||
/// The max stream count was reached.
|
||||
MaxStreamCountReached,
|
||||
/// The stream had already been closed.
|
||||
StreamAlreadyClosed,
|
||||
/// The websocket frame recieved had an invalid type.
|
||||
/// The websocket frame received had an invalid type.
|
||||
WsFrameInvalidType,
|
||||
/// The websocket frame recieved was not finished.
|
||||
/// The websocket frame received was not finished.
|
||||
WsFrameNotFinished,
|
||||
/// Error specific to the websocket implementation.
|
||||
WsImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||
|
@ -67,17 +78,33 @@ pub enum WispError {
|
|||
WsImplSocketClosed,
|
||||
/// The websocket implementation did not support the action.
|
||||
WsImplNotSupported,
|
||||
/// Error specific to the protocol extension implementation.
|
||||
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
|
||||
/// The protocol extension implementation did not support the action.
|
||||
ExtensionImplNotSupported,
|
||||
/// The UDP protocol extension is not supported by the server.
|
||||
UdpExtensionNotSupported,
|
||||
/// The string was invalid UTF-8.
|
||||
Utf8Error(std::str::Utf8Error),
|
||||
/// The integer failed to convert.
|
||||
TryFromIntError(std::num::TryFromIntError),
|
||||
/// Other error.
|
||||
Other(Box<dyn std::error::Error + Sync + Send>),
|
||||
/// Failed to send message to multiplexor task.
|
||||
MuxMessageFailedToSend,
|
||||
/// Failed to receive message from multiplexor task.
|
||||
MuxMessageFailedToRecv,
|
||||
}
|
||||
|
||||
impl From<std::str::Utf8Error> for WispError {
|
||||
fn from(err: std::str::Utf8Error) -> WispError {
|
||||
WispError::Utf8Error(err)
|
||||
fn from(err: std::str::Utf8Error) -> Self {
|
||||
Self::Utf8Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::num::TryFromIntError> for WispError {
|
||||
fn from(value: std::num::TryFromIntError) -> Self {
|
||||
Self::TryFromIntError(value)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,9 +130,21 @@ impl std::fmt::Display for WispError {
|
|||
Self::WsImplNotSupported => {
|
||||
write!(f, "Websocket implementation error: unsupported feature")
|
||||
}
|
||||
Self::ExtensionImplError(err) => {
|
||||
write!(f, "Protocol extension implementation error: {}", err)
|
||||
}
|
||||
Self::ExtensionImplNotSupported => {
|
||||
write!(
|
||||
f,
|
||||
"Protocol extension implementation error: unsupported feature"
|
||||
)
|
||||
}
|
||||
Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"),
|
||||
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
|
||||
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
|
||||
Self::Other(err) => write!(f, "Other error: {}", err),
|
||||
Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"),
|
||||
Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -120,29 +159,27 @@ struct MuxMapValue {
|
|||
is_closed: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
struct MuxInner<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
||||
struct MuxInner {
|
||||
tx: ws::LockedWebSocketWrite,
|
||||
stream_map: DashMap<u32, MuxMapValue>,
|
||||
buffer_size: u32,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> MuxInner<W> {
|
||||
impl MuxInner {
|
||||
pub async fn server_into_future<R>(
|
||||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.into_future(
|
||||
self.as_future(
|
||||
close_rx,
|
||||
self.server_loop(rx, muxstream_sender, buffer_size, close_tx),
|
||||
close_tx.clone(),
|
||||
self.server_loop(rx, muxstream_sender, close_tx),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
@ -151,20 +188,23 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
self,
|
||||
rx: R,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
self.into_future(close_rx, self.client_loop(rx)).await
|
||||
self.as_future(close_rx, close_tx, self.client_loop(rx))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn into_future(
|
||||
async fn as_future(
|
||||
&self,
|
||||
close_rx: mpsc::Receiver<WsEvent>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
wisp_fut: impl Future<Output = Result<(), WispError>>,
|
||||
) -> Result<(), WispError> {
|
||||
let ret = futures::select! {
|
||||
_ = self.stream_loop(close_rx).fuse() => Ok(()),
|
||||
_ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()),
|
||||
x = wisp_fut.fuse() => x,
|
||||
};
|
||||
self.stream_map.iter_mut().for_each(|mut x| {
|
||||
|
@ -176,7 +216,12 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
ret
|
||||
}
|
||||
|
||||
async fn stream_loop(&self, mut stream_rx: mpsc::Receiver<WsEvent>) {
|
||||
async fn stream_loop(
|
||||
&self,
|
||||
mut stream_rx: mpsc::Receiver<WsEvent>,
|
||||
stream_tx: mpsc::Sender<WsEvent>,
|
||||
) {
|
||||
let mut next_free_stream_id: u32 = 1;
|
||||
while let Some(msg) = stream_rx.next().await {
|
||||
match msg {
|
||||
WsEvent::SendPacket(packet, channel) => {
|
||||
|
@ -186,6 +231,53 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||
}
|
||||
}
|
||||
WsEvent::CreateStream(stream_type, host, port, channel) => {
|
||||
let ret: Result<MuxStream, WispError> = async {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
let stream_id = next_free_stream_id;
|
||||
let next_stream_id = next_free_stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?;
|
||||
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.tx
|
||||
.write_frame(
|
||||
Packet::new_connect(stream_id, stream_type, port, host).into(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
next_free_stream_id = next_stream_id;
|
||||
|
||||
self.stream_map.insert(
|
||||
stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(MuxStream::new(
|
||||
stream_id,
|
||||
Role::Client,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
stream_tx.clone(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
0,
|
||||
))
|
||||
}
|
||||
.await;
|
||||
let _ = channel.send(ret);
|
||||
}
|
||||
WsEvent::Close(packet, channel) => {
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.stream.disconnect();
|
||||
|
@ -204,17 +296,13 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
&self,
|
||||
mut rx: R,
|
||||
muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>,
|
||||
buffer_size: u32,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
) -> Result<(), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
{
|
||||
// will send continues once flow_control is at 10% of max
|
||||
let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32;
|
||||
self.tx
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32;
|
||||
|
||||
loop {
|
||||
let frame = rx.wisp_read_frame(&self.tx).await?;
|
||||
|
@ -228,7 +316,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
Connect(inner_packet) => {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
let stream_type = inner_packet.stream_type;
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(buffer_size).into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buffer_size).into();
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
|
@ -273,7 +361,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
}
|
||||
}
|
||||
}
|
||||
Continue(_) => break Err(WispError::InvalidPacketType),
|
||||
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||
Close(_) => {
|
||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||
stream.is_closed.store(true, Ordering::Release);
|
||||
|
@ -298,7 +386,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
|
||||
use PacketType::*;
|
||||
match packet.packet_type {
|
||||
Connect(_) => break Err(WispError::InvalidPacketType),
|
||||
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||
Data(data) => {
|
||||
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
|
||||
let _ = stream.stream.unbounded_send(data);
|
||||
|
@ -332,7 +420,7 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
/// ```
|
||||
/// use wisp_mux::ServerMux;
|
||||
///
|
||||
/// let (mux, fut) = ServerMux::new(rx, tx, 128);
|
||||
/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some(vec![]), Some([]));
|
||||
/// tokio::spawn(async move {
|
||||
/// if let Err(e) = fut.await {
|
||||
/// println!("error in multiplexor: {:?}", e);
|
||||
|
@ -346,34 +434,89 @@ impl<W: ws::WebSocketWrite> MuxInner<W> {
|
|||
/// }
|
||||
/// ```
|
||||
pub struct ServerMux {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
/// If this variable is true you must assume no extensions are supported.
|
||||
pub downgraded: bool,
|
||||
/// Extensions that are supported by both sides.
|
||||
pub supported_extensions: Arc<[AnyProtocolExtension]>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>,
|
||||
}
|
||||
|
||||
impl ServerMux {
|
||||
/// Create a new server-side multiplexor.
|
||||
pub fn new<R, W: ws::WebSocketWrite>(
|
||||
read: R,
|
||||
///
|
||||
/// If either extensions or extension_builders are None a Wisp v1 connection is created
|
||||
/// otherwise a Wisp v2 connection is created.
|
||||
pub async fn new<R, W>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
buffer_size: u32,
|
||||
) -> (Self, impl Future<Output = Result<(), WispError>>)
|
||||
extensions: Option<Vec<AnyProtocolExtension>>,
|
||||
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
R: ws::WebSocketRead + Send,
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
let (close_tx, close_rx) = mpsc::channel::<WsEvent>(256);
|
||||
let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>();
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
(
|
||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||
|
||||
write
|
||||
.write_frame(Packet::new_continue(0, buffer_size).into())
|
||||
.await?;
|
||||
|
||||
let mut supported_extensions = Vec::new();
|
||||
let mut extra_packet = Vec::with_capacity(1);
|
||||
let mut downgraded = true;
|
||||
|
||||
if let Some(extensions) = extensions {
|
||||
if let Some(builders) = extension_builders {
|
||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||
write
|
||||
.write_frame(Packet::new_info(extensions).into())
|
||||
.await?;
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||
// TODO change this to correct timeout once draft 2 is out
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
supported_extensions = info
|
||||
.extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect();
|
||||
downgraded = false;
|
||||
} else {
|
||||
extra_packet.push(packet.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
muxstream_recv: rx,
|
||||
close_tx: close_tx.clone(),
|
||||
downgraded,
|
||||
supported_extensions: supported_extensions.into(),
|
||||
},
|
||||
MuxInner {
|
||||
tx: write,
|
||||
stream_map: DashMap::new().into(),
|
||||
stream_map: DashMap::new(),
|
||||
buffer_size,
|
||||
}
|
||||
.server_into_future(read, close_rx, tx, buffer_size, close_tx),
|
||||
)
|
||||
.server_into_future(
|
||||
AppendingWebSocketRead(extra_packet, read),
|
||||
close_rx,
|
||||
tx,
|
||||
close_tx,
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
/// Wait for a stream to be created.
|
||||
|
@ -398,7 +541,7 @@ impl ServerMux {
|
|||
/// ```
|
||||
/// use wisp_mux::{ClientMux, StreamType};
|
||||
///
|
||||
/// let (mux, fut) = ClientMux::new(rx, tx).await?;
|
||||
/// let (mux, fut) = ClientMux::new(rx, tx, Some(vec![]), []).await?;
|
||||
/// tokio::spawn(async move {
|
||||
/// if let Err(e) = fut.await {
|
||||
/// println!("error in multiplexor: {:?}", e);
|
||||
|
@ -406,50 +549,88 @@ impl ServerMux {
|
|||
/// });
|
||||
/// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80);
|
||||
/// ```
|
||||
pub struct ClientMux<W>
|
||||
where
|
||||
W: ws::WebSocketWrite,
|
||||
{
|
||||
tx: ws::LockedWebSocketWrite<W>,
|
||||
stream_map: Arc<DashMap<u32, MuxMapValue>>,
|
||||
next_free_stream_id: AtomicU32,
|
||||
pub struct ClientMux {
|
||||
/// Whether the connection was downgraded to Wisp v1.
|
||||
///
|
||||
/// If this variable is true you must assume no extensions are supported.
|
||||
pub downgraded: bool,
|
||||
/// Extensions that are supported by both sides.
|
||||
pub supported_extensions: Arc<[AnyProtocolExtension]>,
|
||||
close_tx: mpsc::Sender<WsEvent>,
|
||||
buf_size: u32,
|
||||
target_buf_size: u32,
|
||||
}
|
||||
|
||||
impl<W: ws::WebSocketWrite> ClientMux<W> {
|
||||
impl ClientMux {
|
||||
/// Create a new client side multiplexor.
|
||||
pub async fn new<R>(
|
||||
///
|
||||
/// If either extensions or extension_builders are None a Wisp v1 connection is created
|
||||
/// otherwise a Wisp v2 connection is created.
|
||||
pub async fn new<R, W>(
|
||||
mut read: R,
|
||||
write: W,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>>), WispError>
|
||||
extensions: Option<Vec<AnyProtocolExtension>>,
|
||||
extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>,
|
||||
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError>
|
||||
where
|
||||
R: ws::WebSocketRead,
|
||||
R: ws::WebSocketRead + Send,
|
||||
W: ws::WebSocketWrite + Send + 'static,
|
||||
{
|
||||
let write = ws::LockedWebSocketWrite::new(write);
|
||||
let write = ws::LockedWebSocketWrite::new(Box::new(write));
|
||||
let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?;
|
||||
if first_packet.stream_id != 0 {
|
||||
return Err(WispError::InvalidStreamId);
|
||||
}
|
||||
if let PacketType::Continue(packet) = first_packet.packet_type {
|
||||
let mut supported_extensions = Vec::new();
|
||||
let mut extra_packet = Vec::with_capacity(1);
|
||||
let mut downgraded = true;
|
||||
|
||||
if let Some(extensions) = extensions {
|
||||
if let Some(builders) = extension_builders {
|
||||
let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect();
|
||||
if let Some(frame) = select! {
|
||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||
// TODO change this to correct timeout once draft 2 is out
|
||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||
} {
|
||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
||||
if let PacketType::Info(info) = packet.packet_type {
|
||||
supported_extensions = info
|
||||
.extensions
|
||||
.into_iter()
|
||||
.filter(|x| extension_ids.contains(&x.get_id()))
|
||||
.collect();
|
||||
write
|
||||
.write_frame(Packet::new_info(extensions).into())
|
||||
.await?;
|
||||
downgraded = false;
|
||||
} else {
|
||||
extra_packet.push(packet.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for extension in supported_extensions.iter_mut() {
|
||||
extension.handle_handshake(&mut read, &write).await?;
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel::<WsEvent>(256);
|
||||
let map = Arc::new(DashMap::new());
|
||||
Ok((
|
||||
Self {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
next_free_stream_id: AtomicU32::new(1),
|
||||
close_tx: tx.clone(),
|
||||
buf_size: packet.buffer_remaining,
|
||||
// server-only
|
||||
target_buf_size: 0,
|
||||
downgraded,
|
||||
supported_extensions: supported_extensions.into(),
|
||||
},
|
||||
MuxInner {
|
||||
tx: write.clone(),
|
||||
stream_map: map.clone(),
|
||||
tx: write,
|
||||
stream_map: DashMap::new(),
|
||||
buffer_size: packet.buffer_remaining,
|
||||
}
|
||||
.client_into_future(read, rx),
|
||||
.client_into_future(
|
||||
AppendingWebSocketRead(extra_packet, read),
|
||||
rx,
|
||||
tx,
|
||||
),
|
||||
))
|
||||
} else {
|
||||
Err(WispError::InvalidPacketType)
|
||||
|
@ -458,51 +639,25 @@ impl<W: ws::WebSocketWrite> ClientMux<W> {
|
|||
|
||||
/// Create a new stream, multiplexed through Wisp.
|
||||
pub async fn client_new_stream(
|
||||
&self,
|
||||
&mut self,
|
||||
stream_type: StreamType,
|
||||
host: String,
|
||||
port: u16,
|
||||
) -> Result<MuxStream, WispError> {
|
||||
let (ch_tx, ch_rx) = mpsc::unbounded();
|
||||
let stream_id = self.next_free_stream_id.load(Ordering::Acquire);
|
||||
let next_stream_id = stream_id
|
||||
.checked_add(1)
|
||||
.ok_or(WispError::MaxStreamCountReached)?;
|
||||
|
||||
let flow_control_event: Arc<Event> = Event::new().into();
|
||||
let flow_control: Arc<AtomicU32> = AtomicU32::new(self.buf_size).into();
|
||||
|
||||
let is_closed: Arc<AtomicBool> = AtomicBool::new(false).into();
|
||||
|
||||
self.tx
|
||||
.write_frame(Packet::new_connect(stream_id, stream_type, port, host).into())
|
||||
.await?;
|
||||
|
||||
self.next_free_stream_id
|
||||
.store(next_stream_id, Ordering::Release);
|
||||
|
||||
self.stream_map.insert(
|
||||
stream_id,
|
||||
MuxMapValue {
|
||||
stream: ch_tx,
|
||||
stream_type,
|
||||
flow_control: flow_control.clone(),
|
||||
flow_control_event: flow_control_event.clone(),
|
||||
is_closed: is_closed.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(MuxStream::new(
|
||||
stream_id,
|
||||
Role::Client,
|
||||
stream_type,
|
||||
ch_rx,
|
||||
self.close_tx.clone(),
|
||||
is_closed,
|
||||
flow_control,
|
||||
flow_control_event,
|
||||
self.target_buf_size,
|
||||
))
|
||||
if stream_type == StreamType::Udp
|
||||
&& !self
|
||||
.supported_extensions
|
||||
.iter()
|
||||
.any(|x| x.get_id() == UdpProtocolExtension::ID)
|
||||
{
|
||||
return Err(WispError::UdpExtensionNotSupported);
|
||||
}
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.close_tx
|
||||
.send(WsEvent::CreateStream(stream_type, host, port, tx))
|
||||
.await
|
||||
.map_err(|_| WispError::MuxMessageFailedToSend)?;
|
||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||
}
|
||||
|
||||
/// Close all streams.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue