mirror of
https://github.com/MercuryWorkshop/epoxy-tls.git
synced 2025-05-12 14:00:01 -04:00
fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets
This commit is contained in:
parent
4d433b60c4
commit
d10b7691e4
6 changed files with 220 additions and 50 deletions
|
@ -1,5 +1,5 @@
|
||||||
#![feature(let_chains, ip)]
|
#![feature(let_chains, ip)]
|
||||||
use std::io::Error;
|
use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc};
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
@ -20,8 +20,11 @@ use tokio_util::codec::{BytesCodec, Framed};
|
||||||
use tokio_util::either::Either;
|
use tokio_util::either::Either;
|
||||||
|
|
||||||
use wisp_mux::{
|
use wisp_mux::{
|
||||||
extensions::udp::UdpProtocolExtensionBuilder, CloseReason, ConnectPacket, MuxStream, ServerMux,
|
extensions::{
|
||||||
StreamType, WispError,
|
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||||
|
udp::UdpProtocolExtensionBuilder,
|
||||||
|
},
|
||||||
|
CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError,
|
||||||
};
|
};
|
||||||
|
|
||||||
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
|
type HttpBody = http_body_util::Full<hyper::body::Bytes>;
|
||||||
|
@ -56,6 +59,20 @@ struct Cli {
|
||||||
/// Whether the server should block ports other than 80 or 443
|
/// Whether the server should block ports other than 80 or 443
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
block_non_http: bool,
|
block_non_http: bool,
|
||||||
|
/// Path to a file containing `user:password` separated by newlines. This is plaintext!!!
|
||||||
|
///
|
||||||
|
/// `user` cannot contain `:`. Whitespace will be trimmed.
|
||||||
|
#[arg(long)]
|
||||||
|
auth: Option<PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MuxOptions {
|
||||||
|
pub block_local: bool,
|
||||||
|
pub block_udp: bool,
|
||||||
|
pub block_non_http: bool,
|
||||||
|
pub enforce_auth: bool,
|
||||||
|
pub auth: Arc<PasswordProtocolExtensionBuilder>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
|
@ -138,19 +155,44 @@ async fn main() -> Result<(), Error> {
|
||||||
"/".to_string()
|
"/".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut auth = HashMap::new();
|
||||||
|
let enforce_auth = opt.auth.is_some();
|
||||||
|
if let Some(file) = opt.auth {
|
||||||
|
let file = std::fs::read_to_string(file)?;
|
||||||
|
for entry in file.split('\n').filter_map(|x| {
|
||||||
|
if x.contains(':') {
|
||||||
|
Some(x.trim())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
let split: Vec<_> = entry.split(':').collect();
|
||||||
|
let username = split[0];
|
||||||
|
let password = split[1..].join(":");
|
||||||
|
println!(
|
||||||
|
"adding username {:?} password {:?} to allowed auth",
|
||||||
|
username, password
|
||||||
|
);
|
||||||
|
auth.insert(username.to_string(), password.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth));
|
||||||
|
|
||||||
|
let mux_options = MuxOptions {
|
||||||
|
block_local: opt.block_local,
|
||||||
|
block_non_http: opt.block_non_http,
|
||||||
|
block_udp: opt.block_udp,
|
||||||
|
auth: pw_ext,
|
||||||
|
enforce_auth,
|
||||||
|
};
|
||||||
|
|
||||||
println!("listening on `{}` with prefix `{}`", addr, prefix);
|
println!("listening on `{}` with prefix `{}`", addr, prefix);
|
||||||
while let Ok((stream, addr)) = socket.accept().await {
|
while let Ok((stream, addr)) = socket.accept().await {
|
||||||
let prefix = prefix.clone();
|
let prefix = prefix.clone();
|
||||||
|
let mux_options = mux_options.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let service = service_fn(move |res| {
|
let service = service_fn(move |res| {
|
||||||
accept_http(
|
accept_http(res, addr.clone(), prefix.clone(), mux_options.clone())
|
||||||
res,
|
|
||||||
addr.clone(),
|
|
||||||
prefix.clone(),
|
|
||||||
opt.block_local,
|
|
||||||
opt.block_udp,
|
|
||||||
opt.block_non_http,
|
|
||||||
)
|
|
||||||
});
|
});
|
||||||
let conn = http1::Builder::new()
|
let conn = http1::Builder::new()
|
||||||
.serve_connection(TokioIo::new(stream), service)
|
.serve_connection(TokioIo::new(stream), service)
|
||||||
|
@ -168,9 +210,7 @@ async fn accept_http(
|
||||||
mut req: Request<Incoming>,
|
mut req: Request<Incoming>,
|
||||||
addr: String,
|
addr: String,
|
||||||
prefix: String,
|
prefix: String,
|
||||||
block_local: bool,
|
mux_options: MuxOptions,
|
||||||
block_udp: bool,
|
|
||||||
block_non_http: bool,
|
|
||||||
) -> Result<Response<HttpBody>, WebSocketError> {
|
) -> Result<Response<HttpBody>, WebSocketError> {
|
||||||
let uri = req.uri().path().to_string();
|
let uri = req.uri().path().to_string();
|
||||||
if upgrade::is_upgrade_request(&req)
|
if upgrade::is_upgrade_request(&req)
|
||||||
|
@ -179,12 +219,17 @@ async fn accept_http(
|
||||||
let (res, fut) = upgrade::upgrade(&mut req)?;
|
let (res, fut) = upgrade::upgrade(&mut req)?;
|
||||||
|
|
||||||
if uri.is_empty() {
|
if uri.is_empty() {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await });
|
||||||
accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await
|
|
||||||
});
|
|
||||||
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
|
} else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await
|
accept_wsproxy(
|
||||||
|
fut,
|
||||||
|
uri,
|
||||||
|
addr.clone(),
|
||||||
|
mux_options.block_local,
|
||||||
|
mux_options.block_non_http,
|
||||||
|
)
|
||||||
|
.await
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,19 +303,41 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result<bool
|
||||||
async fn accept_ws(
|
async fn accept_ws(
|
||||||
ws: UpgradeFut,
|
ws: UpgradeFut,
|
||||||
addr: String,
|
addr: String,
|
||||||
block_local: bool,
|
mux_options: MuxOptions,
|
||||||
block_non_http: bool,
|
|
||||||
block_udp: bool,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
|
||||||
let (rx, tx) = ws.await?.split(tokio::io::split);
|
let (rx, tx) = ws.await?.split(tokio::io::split);
|
||||||
let rx = FragmentCollectorRead::new(rx);
|
let rx = FragmentCollectorRead::new(rx);
|
||||||
|
|
||||||
println!("{:?}: connected", addr);
|
println!("{:?}: connected", addr);
|
||||||
|
let (mut mux, fut) = if mux_options.enforce_auth {
|
||||||
|
let (mut mux, fut) = ServerMux::new(
|
||||||
|
rx,
|
||||||
|
tx,
|
||||||
|
u32::MAX,
|
||||||
|
Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
if !mux
|
||||||
|
.supported_extension_ids
|
||||||
|
.iter()
|
||||||
|
.any(|x| *x == PasswordProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"{:?}: client did not support auth or password was invalid",
|
||||||
|
addr
|
||||||
|
);
|
||||||
|
mux.close_extension_incompat().await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
(mux, fut)
|
||||||
|
} else {
|
||||||
|
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?
|
||||||
|
};
|
||||||
|
|
||||||
let (mut mux, fut) =
|
println!(
|
||||||
ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?;
|
"{:?}: downgraded: {} extensions supported: {:?}",
|
||||||
|
addr, mux.downgraded, mux.supported_extension_ids
|
||||||
println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids);
|
);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = fut.await {
|
if let Err(e) = fut.await {
|
||||||
|
@ -280,14 +347,14 @@ async fn accept_ws(
|
||||||
|
|
||||||
while let Some((packet, mut stream)) = mux.server_new_stream().await {
|
while let Some((packet, mut stream)) = mux.server_new_stream().await {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if (block_non_http
|
if (mux_options.block_non_http
|
||||||
&& !(packet.destination_port == 80 || packet.destination_port == 443))
|
&& !(packet.destination_port == 80 || packet.destination_port == 443))
|
||||||
|| (block_udp && packet.stream_type == StreamType::Udp)
|
|| (mux_options.block_udp && packet.stream_type == StreamType::Udp)
|
||||||
{
|
{
|
||||||
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
|
let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if block_local {
|
if mux_options.block_local {
|
||||||
match lookup_host(format!(
|
match lookup_host(format!(
|
||||||
"{}:{}",
|
"{}:{}",
|
||||||
packet.destination_hostname, packet.destination_port
|
packet.destination_hostname, packet.destination_port
|
||||||
|
|
|
@ -27,7 +27,14 @@ use tokio::{
|
||||||
};
|
};
|
||||||
use tokio_native_tls::{native_tls, TlsConnector};
|
use tokio_native_tls::{native_tls, TlsConnector};
|
||||||
use tokio_util::either::Either;
|
use tokio_util::either::Either;
|
||||||
use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError};
|
use wisp_mux::{
|
||||||
|
extensions::{
|
||||||
|
password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder},
|
||||||
|
udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder},
|
||||||
|
ProtocolExtensionBuilder,
|
||||||
|
},
|
||||||
|
ClientMux, StreamType, WispError,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum WispClientError {
|
enum WispClientError {
|
||||||
|
@ -80,6 +87,11 @@ struct Cli {
|
||||||
/// Ask for UDP
|
/// Ask for UDP
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
udp: bool,
|
udp: bool,
|
||||||
|
/// Enable auth: format is `username:password`
|
||||||
|
///
|
||||||
|
/// Usernames and passwords are sent in plaintext!!
|
||||||
|
#[arg(long)]
|
||||||
|
auth: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
#[tokio::main(flavor = "multi_thread")]
|
||||||
|
@ -103,6 +115,13 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
let addr_dest = opts.tcp.ip().to_string();
|
let addr_dest = opts.tcp.ip().to_string();
|
||||||
let addr_dest_port = opts.tcp.port();
|
let addr_dest_port = opts.tcp.port();
|
||||||
|
|
||||||
|
let auth = opts.auth.map(|auth| {
|
||||||
|
let split: Vec<_> = auth.split(':').collect();
|
||||||
|
let username = split[0].to_string();
|
||||||
|
let password = split[1..].join(":");
|
||||||
|
PasswordProtocolExtensionBuilder::new_client(username, password)
|
||||||
|
});
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
|
"connecting to {} and sending &[0; 1024 * {}] to {} with threads {}",
|
||||||
opts.wisp, opts.packet_size, opts.tcp, opts.streams,
|
opts.wisp, opts.packet_size, opts.tcp, opts.streams,
|
||||||
|
@ -133,18 +152,49 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||||
let (rx, tx) = ws.split(tokio::io::split);
|
let (rx, tx) = ws.split(tokio::io::split);
|
||||||
let rx = FragmentCollectorRead::new(rx);
|
let rx = FragmentCollectorRead::new(rx);
|
||||||
|
|
||||||
let (mut mux, fut) = if opts.udp {
|
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Sync)>> = Vec::new();
|
||||||
let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?;
|
if opts.udp {
|
||||||
if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) {
|
extensions.push(Box::new(UdpProtocolExtensionBuilder()));
|
||||||
println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
|
}
|
||||||
exit(1);
|
let enforce_auth = auth.is_some();
|
||||||
}
|
if let Some(auth) = auth {
|
||||||
(mux, fut)
|
extensions.push(Box::new(auth));
|
||||||
} else {
|
}
|
||||||
ClientMux::new(rx, tx, Some(&[])).await?
|
let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> =
|
||||||
};
|
extensions.iter().map(|x| x.as_ref()).collect();
|
||||||
|
|
||||||
println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids);
|
let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?;
|
||||||
|
if opts.udp
|
||||||
|
&& !mux
|
||||||
|
.supported_extension_ids
|
||||||
|
.iter()
|
||||||
|
.any(|x| *x == UdpProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"server did not support udp, was downgraded {}, extensions supported {:?}",
|
||||||
|
mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
mux.close_extension_incompat().await?;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
if enforce_auth
|
||||||
|
&& !mux
|
||||||
|
.supported_extension_ids
|
||||||
|
.iter()
|
||||||
|
.any(|x| *x == PasswordProtocolExtension::ID)
|
||||||
|
{
|
||||||
|
println!(
|
||||||
|
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
|
||||||
|
mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
mux.close_extension_incompat().await?;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"connected and created ClientMux, was downgraded {}, extensions supported {:?}",
|
||||||
|
mux.downgraded, mux.supported_extension_ids
|
||||||
|
);
|
||||||
|
|
||||||
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
|
let mut threads = Vec::with_capacity(opts.streams * 2 + 3);
|
||||||
|
|
||||||
|
|
|
@ -202,6 +202,8 @@ pub mod udp {
|
||||||
pub mod password {
|
pub mod password {
|
||||||
//! Password protocol extension.
|
//! Password protocol extension.
|
||||||
//!
|
//!
|
||||||
|
//! Passwords are sent in plain text!!
|
||||||
|
//!
|
||||||
//! # Example
|
//! # Example
|
||||||
//! Server:
|
//! Server:
|
||||||
//! ```
|
//! ```
|
||||||
|
@ -246,6 +248,7 @@ pub mod password {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Password protocol extension.
|
/// Password protocol extension.
|
||||||
///
|
///
|
||||||
|
/// **Passwords are sent in plain text!!**
|
||||||
/// **This extension will panic when encoding if the username's length does not fit within a u8
|
/// **This extension will panic when encoding if the username's length does not fit within a u8
|
||||||
/// or the password's length does not fit within a u16.**
|
/// or the password's length does not fit within a u16.**
|
||||||
pub struct PasswordProtocolExtension {
|
pub struct PasswordProtocolExtension {
|
||||||
|
@ -306,7 +309,7 @@ pub mod password {
|
||||||
let password = Bytes::from(self.password.clone().into_bytes());
|
let password = Bytes::from(self.password.clone().into_bytes());
|
||||||
let username_len = u8::try_from(username.len()).expect("username was too long");
|
let username_len = u8::try_from(username.len()).expect("username was too long");
|
||||||
let password_len =
|
let password_len =
|
||||||
u16::try_from(username.len()).expect("password was too long");
|
u16::try_from(password.len()).expect("password was too long");
|
||||||
|
|
||||||
let mut bytes =
|
let mut bytes =
|
||||||
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
|
||||||
|
@ -380,6 +383,8 @@ pub mod password {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Password protocol extension builder.
|
/// Password protocol extension builder.
|
||||||
|
///
|
||||||
|
/// **Passwords are sent in plain text!!**
|
||||||
pub struct PasswordProtocolExtensionBuilder {
|
pub struct PasswordProtocolExtensionBuilder {
|
||||||
/// Map of users and their passwords to allow. Only used on server.
|
/// Map of users and their passwords to allow. Only used on server.
|
||||||
pub users: HashMap<String, String>,
|
pub users: HashMap<String, String>,
|
||||||
|
@ -448,6 +453,7 @@ pub mod password {
|
||||||
if *user.1 != password {
|
if *user.1 != password {
|
||||||
return Err(EError::InvalidPassword.into());
|
return Err(EError::InvalidPassword.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(PasswordProtocolExtension {
|
Ok(PasswordProtocolExtension {
|
||||||
username,
|
username,
|
||||||
password,
|
password,
|
||||||
|
|
|
@ -64,6 +64,8 @@ pub enum WispError {
|
||||||
UriHasNoPort,
|
UriHasNoPort,
|
||||||
/// The max stream count was reached.
|
/// The max stream count was reached.
|
||||||
MaxStreamCountReached,
|
MaxStreamCountReached,
|
||||||
|
/// The Wisp protocol version was incompatible.
|
||||||
|
IncompatibleProtocolVersion,
|
||||||
/// The stream had already been closed.
|
/// The stream had already been closed.
|
||||||
StreamAlreadyClosed,
|
StreamAlreadyClosed,
|
||||||
/// The websocket frame received had an invalid type.
|
/// The websocket frame received had an invalid type.
|
||||||
|
@ -117,6 +119,7 @@ impl std::fmt::Display for WispError {
|
||||||
Self::UriHasNoHost => write!(f, "URI has no host"),
|
Self::UriHasNoHost => write!(f, "URI has no host"),
|
||||||
Self::UriHasNoPort => write!(f, "URI has no port"),
|
Self::UriHasNoPort => write!(f, "URI has no port"),
|
||||||
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"),
|
||||||
|
Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"),
|
||||||
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
|
Self::StreamAlreadyClosed => write!(f, "Stream already closed"),
|
||||||
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"),
|
||||||
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"),
|
||||||
|
@ -286,7 +289,15 @@ impl MuxInner {
|
||||||
let _ = channel.send(Err(WispError::InvalidStreamId));
|
let _ = channel.send(Err(WispError::InvalidStreamId));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
WsEvent::EndFut => break,
|
WsEvent::EndFut(x) => {
|
||||||
|
if let Some(reason) = x {
|
||||||
|
let _ = self
|
||||||
|
.tx
|
||||||
|
.write_frame(Packet::new_close(0, reason).into())
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -364,6 +375,9 @@ impl MuxInner {
|
||||||
}
|
}
|
||||||
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
Continue(_) | Info(_) => break Err(WispError::InvalidPacketType),
|
||||||
Close(_) => {
|
Close(_) => {
|
||||||
|
if packet.stream_id == 0 {
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
stream.is_closed.store(true, Ordering::Release);
|
||||||
stream.stream.disconnect();
|
stream.stream.disconnect();
|
||||||
|
@ -410,6 +424,9 @@ impl MuxInner {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Close(_) => {
|
Close(_) => {
|
||||||
|
if packet.stream_id == 0 {
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) {
|
||||||
stream.is_closed.store(true, Ordering::Release);
|
stream.is_closed.store(true, Ordering::Release);
|
||||||
stream.stream.disconnect();
|
stream.stream.disconnect();
|
||||||
|
@ -532,15 +549,28 @@ impl ServerMux {
|
||||||
self.muxstream_recv.next().await
|
self.muxstream_recv.next().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
self.close_tx
|
||||||
|
.send(WsEvent::EndFut(reason))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
|
}
|
||||||
|
|
||||||
/// Close all streams.
|
/// Close all streams.
|
||||||
///
|
///
|
||||||
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
/// Also terminates the multiplexor future. Waiting for a new stream will never succeed after
|
||||||
/// this function is called.
|
/// this function is called.
|
||||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||||
self.close_tx
|
self.close_internal(None).await
|
||||||
.send(WsEvent::EndFut)
|
}
|
||||||
|
|
||||||
|
/// Close all streams and send an extension incompatibility error to the client.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future. Waiting for a new stream will never succed after
|
||||||
|
/// this function is called.
|
||||||
|
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
|
||||||
|
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Client side multiplexor.
|
/// Client side multiplexor.
|
||||||
|
@ -600,7 +630,7 @@ impl ClientMux {
|
||||||
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
x = read.wisp_read_frame(&write).fuse() => Some(x?),
|
||||||
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
_ = Delay::new(Duration::from_secs(5)).fuse() => None
|
||||||
} {
|
} {
|
||||||
let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?;
|
let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?;
|
||||||
if let PacketType::Info(info) = packet.packet_type {
|
if let PacketType::Info(info) = packet.packet_type {
|
||||||
supported_extensions = info
|
supported_extensions = info
|
||||||
.extensions
|
.extensions
|
||||||
|
@ -671,14 +701,27 @@ impl ClientMux {
|
||||||
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn close_internal(&mut self, reason: Option<CloseReason>) -> Result<(), WispError> {
|
||||||
|
self.close_tx
|
||||||
|
.send(WsEvent::EndFut(reason))
|
||||||
|
.await
|
||||||
|
.map_err(|_| WispError::MuxMessageFailedToSend)
|
||||||
|
}
|
||||||
|
|
||||||
/// Close all streams.
|
/// Close all streams.
|
||||||
///
|
///
|
||||||
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||||
/// function.
|
/// function.
|
||||||
pub async fn close(&mut self) -> Result<(), WispError> {
|
pub async fn close(&mut self) -> Result<(), WispError> {
|
||||||
self.close_tx
|
self.close_internal(None).await
|
||||||
.send(WsEvent::EndFut)
|
}
|
||||||
|
|
||||||
|
/// Close all streams and send an extension incompatibility error to the client.
|
||||||
|
///
|
||||||
|
/// Also terminates the multiplexor future. Creating a stream is UB after calling this
|
||||||
|
/// function.
|
||||||
|
pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> {
|
||||||
|
self.close_internal(Some(CloseReason::IncompatibleExtensions))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| WispError::MuxMessageFailedToSend)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -446,6 +446,10 @@ impl Packet {
|
||||||
minor: bytes.get_u8(),
|
minor: bytes.get_u8(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if version.major != WISP_VERSION.major {
|
||||||
|
return Err(WispError::IncompatibleProtocolVersion);
|
||||||
|
}
|
||||||
|
|
||||||
let mut extensions = Vec::new();
|
let mut extensions = Vec::new();
|
||||||
|
|
||||||
while bytes.remaining() > 4 {
|
while bytes.remaining() > 4 {
|
||||||
|
|
|
@ -27,7 +27,7 @@ pub(crate) enum WsEvent {
|
||||||
u16,
|
u16,
|
||||||
oneshot::Sender<Result<MuxStream, WispError>>,
|
oneshot::Sender<Result<MuxStream, WispError>>,
|
||||||
),
|
),
|
||||||
EndFut,
|
EndFut(Option<CloseReason>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read side of a multiplexor stream.
|
/// Read side of a multiplexor stream.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue