make requiring protocol extensions easy

This commit is contained in:
Toshit Chawda 2024-04-20 18:38:38 -07:00
parent 063b527914
commit 01d7ac5002
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 143 additions and 90 deletions

1
Cargo.lock generated
View file

@ -3230,7 +3230,6 @@ dependencies = [
"flume", "flume",
"futures", "futures",
"futures-timer", "futures-timer",
"futures-util",
"pin-project-lite", "pin-project-lite",
"tokio", "tokio",
] ]

View file

@ -203,7 +203,11 @@ pub async fn make_mux(
let (wtx, wrx) = let (wtx, wrx) =
WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?; WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?;
wtx.wait_for_open().await; wtx.wait_for_open().await;
ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await Ok(
ClientMux::create(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())]))
.await?
.with_no_required_extensions(),
)
} }
pub fn spawn_mux_fut( pub fn spawn_mux_fut(

View file

@ -253,45 +253,39 @@ async fn accept_http(
} }
} }
async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result<bool, WispError> { async fn handle_mux(
packet: ConnectPacket,
stream: MuxStream,
) -> Result<bool, Box<dyn std::error::Error + Sync + Send>> {
let uri = format!( let uri = format!(
"{}:{}", "{}:{}",
packet.destination_hostname, packet.destination_port packet.destination_hostname, packet.destination_port
); );
match packet.stream_type { match packet.stream_type {
StreamType::Tcp => { StreamType::Tcp => {
let mut tcp_stream = TcpStream::connect(uri) let mut tcp_stream = TcpStream::connect(uri).await?;
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut mux_stream = stream.into_io().into_asyncrw(); let mut mux_stream = stream.into_io().into_asyncrw();
copy_bidirectional(&mut mux_stream, &mut tcp_stream) copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?;
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
} }
StreamType::Udp => { StreamType::Udp => {
let uri = lookup_host(uri) let uri = lookup_host(uri)
.await .await?
.map_err(|x| WispError::Other(Box::new(x)))?
.next() .next()
.ok_or(WispError::InvalidUri)?; .ok_or(WispError::InvalidUri)?;
let udp_socket = UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }) let udp_socket =
.await UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }).await?;
.map_err(|x| WispError::Other(Box::new(x)))?; udp_socket.connect(uri).await?;
udp_socket
.connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut data = vec![0u8; 65507]; // udp standard max datagram size let mut data = vec![0u8; 65507]; // udp standard max datagram size
loop { loop {
tokio::select! { tokio::select! {
size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => { size = udp_socket.recv(&mut data) => {
let size = size?; let size = size?;
stream.write(Bytes::copy_from_slice(&data[..size])).await? stream.write(Bytes::copy_from_slice(&data[..size])).await?
}, },
event = stream.read() => { event = stream.read() => {
match event { match event {
Some(event) => { Some(event) => {
let _ = udp_socket.send(&event).await.map_err(|x| WispError::Other(Box::new(x)))?; let _ = udp_socket.send(&event).await?;
} }
None => break, None => break,
} }
@ -319,28 +313,18 @@ async fn accept_ws(
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128 // size is set to 128
let (mux, fut) = if mux_options.enforce_auth { let (mux, fut) = if mux_options.enforce_auth {
let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?; ServerMux::create(rx, tx, 128, Some(mux_options.auth.as_slice()))
if !mux .await?
.supported_extension_ids .with_required_extensions(&[PasswordProtocolExtension::ID]).await?
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"{:?}: client did not support auth or password was invalid",
addr
);
mux.close_extension_incompat().await?;
return Ok(());
}
(mux, fut)
} else { } else {
ServerMux::new( ServerMux::create(
rx, rx,
tx, tx,
128, 128,
Some(&[Box::new(UdpProtocolExtensionBuilder())]), Some(&[Box::new(UdpProtocolExtensionBuilder())]),
) )
.await? .await?
.with_no_required_extensions()
}; };
println!( println!(

View file

@ -156,53 +156,33 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
let rx = FragmentCollectorRead::new(rx); let rx = FragmentCollectorRead::new(rx);
let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new(); let mut extensions: Vec<Box<(dyn ProtocolExtensionBuilder + Send + Sync)>> = Vec::new();
let mut extension_ids: Vec<u8> = Vec::new();
if opts.udp { if opts.udp {
extensions.push(Box::new(UdpProtocolExtensionBuilder())); extensions.push(Box::new(UdpProtocolExtensionBuilder()));
extension_ids.push(UdpProtocolExtension::ID);
} }
let enforce_auth = auth.is_some();
if let Some(auth) = auth { if let Some(auth) = auth {
extensions.push(Box::new(auth)); extensions.push(Box::new(auth));
extension_ids.push(PasswordProtocolExtension::ID);
} }
let (mux, fut) = if opts.wisp_v1 { let (mux, fut) = if opts.wisp_v1 {
ClientMux::new(rx, tx, None).await? ClientMux::create(rx, tx, None)
.await?
.with_no_required_extensions()
} else { } else {
ClientMux::new(rx, tx, Some(extensions.as_slice())).await? ClientMux::create(rx, tx, Some(extensions.as_slice()))
.await?
.with_required_extensions(extension_ids.as_slice()).await?
}; };
if opts.udp
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == UdpProtocolExtension::ID)
{
println!(
"server did not support udp, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
if enforce_auth
&& !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids
);
mux.close_extension_incompat().await?;
exit(1);
}
println!( println!(
"connected and created ClientMux, was downgraded {}, extensions supported {:?}", "connected and created ClientMux, was downgraded {}, extensions supported {:?}",
mux.downgraded, mux.supported_extension_ids mux.downgraded, mux.supported_extension_ids
); );
let mut threads = Vec::with_capacity(opts.streams * 2 + 3); let mut threads = Vec::with_capacity(opts.streams + 4);
let mut reads = Vec::with_capacity(opts.streams);
threads.push(tokio::spawn(fut)); threads.push(tokio::spawn(fut));
@ -226,13 +206,15 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
#[allow(unreachable_code)] #[allow(unreachable_code)]
Ok::<(), WispError>(()) Ok::<(), WispError>(())
})); }));
threads.push(tokio::spawn(async move { reads.push(cr);
loop {
cr.read().await;
}
}));
} }
threads.push(tokio::spawn(async move {
loop {
select_all(reads.iter().map(|x| Box::pin(x.read()))).await;
}
}));
let cnt_avg = cnt.clone(); let cnt_avg = cnt.clone();
threads.push(tokio::spawn(async move { threads.push(tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100)); let mut interval = interval(Duration::from_millis(100));
@ -295,14 +277,16 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
mux.close().await?; mux.close().await?;
println!( if duration_since.as_secs() != 0 {
"\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", println!(
cnt.get(), "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)",
opts.packet_size, cnt.get(),
cnt.get() * opts.packet_size, opts.packet_size,
format_duration(duration_since), cnt.get() * opts.packet_size,
(cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(), format_duration(duration_since),
); (cnt.get() * opts.packet_size) as u64 / duration_since.as_secs(),
);
}
Ok(()) Ok(())
} }

View file

@ -18,7 +18,6 @@ fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional =
flume = "0.11.0" flume = "0.11.0"
futures = "0.3.30" futures = "0.3.30"
futures-timer = "3.0.3" futures-timer = "3.0.3"
futures-util = "0.3.30"
pin-project-lite = "0.2.13" pin-project-lite = "0.2.13"
tokio = { version = "1.35.1", optional = true, default-features = false } tokio = { version = "1.35.1", optional = true, default-features = false }

View file

@ -80,8 +80,8 @@ pub enum WispError {
ExtensionImplError(Box<dyn std::error::Error + Sync + Send>), ExtensionImplError(Box<dyn std::error::Error + Sync + Send>),
/// The protocol extension implementation did not support the action. /// The protocol extension implementation did not support the action.
ExtensionImplNotSupported, ExtensionImplNotSupported,
/// The UDP protocol extension is not supported by the server. /// The specified protocol extensions are not supported by the server.
UdpExtensionNotSupported, ExtensionsNotSupported(Vec<u8>),
/// The string was invalid UTF-8. /// The string was invalid UTF-8.
Utf8Error(std::str::Utf8Error), Utf8Error(std::str::Utf8Error),
/// The integer failed to convert. /// The integer failed to convert.
@ -137,7 +137,9 @@ impl std::fmt::Display for WispError {
"Protocol extension implementation error: unsupported feature" "Protocol extension implementation error: unsupported feature"
) )
} }
Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"), Self::ExtensionsNotSupported(list) => {
write!(f, "Protocol extensions {:?} not supported", list)
}
Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err),
Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err), Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err),
Self::Other(err) => write!(f, "Other error: {}", err), Self::Other(err) => write!(f, "Other error: {}", err),
@ -483,12 +485,12 @@ impl ServerMux {
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check /// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created. /// if the extensions you need are available after the multiplexor has been created.
pub async fn new<R, W>( pub async fn create<R, W>(
mut read: R, mut read: R,
write: W, write: W,
buffer_size: u32, buffer_size: u32,
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>, extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError> ) -> Result<ServerMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
@ -532,7 +534,7 @@ impl ServerMux {
} }
} }
Ok(( Ok(ServerMuxResult(
Self { Self {
muxstream_recv: rx, muxstream_recv: rx,
close_tx: close_tx.clone(), close_tx: close_tx.clone(),
@ -590,6 +592,48 @@ impl Drop for ServerMux {
} }
} }
/// Result of `ServerMux::new`.
pub struct ServerMuxResult<F>(ServerMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ServerMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ServerMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID. Will close the multiplexor connection if
/// extensions are not supported.
pub async fn with_required_extensions(
self,
extensions: &[u8],
) -> Result<(ServerMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extension_ids.contains(extension) {
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ServerMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID])
.await
}
}
/// Client side multiplexor. /// Client side multiplexor.
/// ///
/// # Example /// # Example
@ -620,11 +664,11 @@ impl ClientMux {
/// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created.
/// **It is not guaranteed that all extensions you specify are available.** You must manually check /// **It is not guaranteed that all extensions you specify are available.** You must manually check
/// if the extensions you need are available after the multiplexor has been created. /// if the extensions you need are available after the multiplexor has been created.
pub async fn new<R, W>( pub async fn create<R, W>(
mut read: R, mut read: R,
write: W, write: W,
extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>, extension_builders: Option<&[Box<dyn ProtocolExtensionBuilder + Send + Sync>]>,
) -> Result<(Self, impl Future<Output = Result<(), WispError>> + Send), WispError> ) -> Result<ClientMuxResult<impl Future<Output = Result<(), WispError>> + Send>, WispError>
where where
R: ws::WebSocketRead + Send, R: ws::WebSocketRead + Send,
W: ws::WebSocketWrite + Send + 'static, W: ws::WebSocketWrite + Send + 'static,
@ -671,7 +715,7 @@ impl ClientMux {
} }
let (tx, rx) = mpsc::bounded::<WsEvent>(256); let (tx, rx) = mpsc::bounded::<WsEvent>(256);
Ok(( Ok(ClientMuxResult(
Self { Self {
stream_tx: tx.clone(), stream_tx: tx.clone(),
downgraded, downgraded,
@ -710,7 +754,9 @@ impl ClientMux {
.iter() .iter()
.any(|x| *x == UdpProtocolExtension::ID) .any(|x| *x == UdpProtocolExtension::ID)
{ {
return Err(WispError::UdpExtensionNotSupported); return Err(WispError::ExtensionsNotSupported(vec![
UdpProtocolExtension::ID,
]));
} }
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
self.stream_tx self.stream_tx
@ -750,3 +796,40 @@ impl Drop for ClientMux {
let _ = self.stream_tx.send(WsEvent::EndFut(None)); let _ = self.stream_tx.send(WsEvent::EndFut(None));
} }
} }
/// Result of `ClientMux::new`.
pub struct ClientMuxResult<F>(ClientMux, F)
where
F: Future<Output = Result<(), WispError>> + Send;
impl<F> ClientMuxResult<F>
where
F: Future<Output = Result<(), WispError>> + Send,
{
/// Require no protocol extensions.
pub fn with_no_required_extensions(self) -> (ClientMux, F) {
(self.0, self.1)
}
/// Require protocol extensions by their ID.
pub async fn with_required_extensions(self, extensions: &[u8]) -> Result<(ClientMux, F), WispError> {
let mut unsupported_extensions = Vec::new();
for extension in extensions {
if !self.0.supported_extension_ids.contains(extension) {
unsupported_extensions.push(*extension);
}
}
if unsupported_extensions.is_empty() {
Ok((self.0, self.1))
} else {
self.0.close_extension_incompat().await?;
self.1.await?;
Err(WispError::ExtensionsNotSupported(unsupported_extensions))
}
}
/// Shorthand for `with_required_extensions(&[UdpProtocolExtension::ID])`
pub async fn with_udp_extension_required(self) -> Result<(ClientMux, F), WispError> {
self.with_required_extensions(&[UdpProtocolExtension::ID]).await
}
}