make requiring protocol extensions easy

This commit is contained in:
Toshit Chawda 2024-04-20 18:38:38 -07:00
parent 063b527914
commit 01d7ac5002
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
6 changed files with 143 additions and 90 deletions

View file

@ -253,45 +253,39 @@ async fn accept_http(
}
}
async fn handle_mux(packet: ConnectPacket, stream: MuxStream) -> Result<bool, WispError> {
async fn handle_mux(
packet: ConnectPacket,
stream: MuxStream,
) -> Result<bool, Box<dyn std::error::Error + Sync + Send>> {
let uri = format!(
"{}:{}",
packet.destination_hostname, packet.destination_port
);
match packet.stream_type {
StreamType::Tcp => {
let mut tcp_stream = TcpStream::connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let mut tcp_stream = TcpStream::connect(uri).await?;
let mut mux_stream = stream.into_io().into_asyncrw();
copy_bidirectional(&mut mux_stream, &mut tcp_stream)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
copy_bidirectional(&mut mux_stream, &mut tcp_stream).await?;
}
StreamType::Udp => {
let uri = lookup_host(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?
.await?
.next()
.ok_or(WispError::InvalidUri)?;
let udp_socket = UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" })
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
udp_socket
.connect(uri)
.await
.map_err(|x| WispError::Other(Box::new(x)))?;
let udp_socket =
UdpSocket::bind(if uri.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }).await?;
udp_socket.connect(uri).await?;
let mut data = vec![0u8; 65507]; // udp standard max datagram size
loop {
tokio::select! {
size = udp_socket.recv(&mut data).map_err(|x| WispError::Other(Box::new(x))) => {
size = udp_socket.recv(&mut data) => {
let size = size?;
stream.write(Bytes::copy_from_slice(&data[..size])).await?
},
event = stream.read() => {
match event {
Some(event) => {
let _ = udp_socket.send(&event).await.map_err(|x| WispError::Other(Box::new(x)))?;
let _ = udp_socket.send(&event).await?;
}
None => break,
}
@ -319,28 +313,18 @@ async fn accept_ws(
// to prevent memory ""leaks"" because users are sending in packets way too fast the buffer
// size is set to 128
let (mux, fut) = if mux_options.enforce_auth {
let (mux, fut) = ServerMux::new(rx, tx, 128, Some(mux_options.auth.as_slice())).await?;
if !mux
.supported_extension_ids
.iter()
.any(|x| *x == PasswordProtocolExtension::ID)
{
println!(
"{:?}: client did not support auth or password was invalid",
addr
);
mux.close_extension_incompat().await?;
return Ok(());
}
(mux, fut)
ServerMux::create(rx, tx, 128, Some(mux_options.auth.as_slice()))
.await?
.with_required_extensions(&[PasswordProtocolExtension::ID]).await?
} else {
ServerMux::new(
ServerMux::create(
rx,
tx,
128,
Some(&[Box::new(UdpProtocolExtensionBuilder())]),
)
.await?
.with_no_required_extensions()
};
println!(