use async hashmap

This commit is contained in:
Toshit Chawda 2024-08-31 10:40:40 -07:00
parent b42cf07a24
commit b1f56c1dae
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 59 additions and 32 deletions

17
Cargo.lock generated
View file

@ -1485,12 +1485,27 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]]
name = "scc"
version = "2.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37"
dependencies = [
"sdd",
]
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.2.0" version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdd"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f"
[[package]] [[package]]
name = "send_wrapper" name = "send_wrapper"
version = "0.4.0" version = "0.4.0"
@ -2253,13 +2268,13 @@ dependencies = [
"async-trait", "async-trait",
"atomic_enum", "atomic_enum",
"bytes", "bytes",
"dashmap",
"event-listener", "event-listener",
"fastwebsockets", "fastwebsockets",
"flume", "flume",
"futures", "futures",
"futures-timer", "futures-timer",
"pin-project-lite", "pin-project-lite",
"scc",
"tokio", "tokio",
] ]

View file

@ -12,13 +12,13 @@ edition = "2021"
async-trait = "0.1.81" async-trait = "0.1.81"
atomic_enum = "0.3.0" atomic_enum = "0.3.0"
bytes = "1.7.1" bytes = "1.7.1"
dashmap = { version = "6.0.1", features = ["inline"] }
event-listener = "5.3.1" event-listener = "5.3.1"
fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true }
flume = "0.11.0" flume = "0.11.0"
futures = "0.3.30" futures = "0.3.30"
futures-timer = "3.0.3" futures-timer = "3.0.3"
pin-project-lite = "0.2.14" pin-project-lite = "0.2.14"
scc = "2.1.16"
tokio = { version = "1.39.3", optional = true, default-features = false } tokio = { version = "1.39.3", optional = true, default-features = false }
[features] [features]

View file

@ -19,12 +19,12 @@ pub mod ws;
pub use crate::{packet::*, stream::*}; pub use crate::{packet::*, stream::*};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use dashmap::DashMap;
use event_listener::Event; use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use flume as mpsc; use flume as mpsc;
use futures::{channel::oneshot, select, Future, FutureExt}; use futures::{channel::oneshot, select, Future, FutureExt};
use futures_timer::Delay; use futures_timer::Delay;
use scc::HashMap;
use std::{ use std::{
sync::{ sync::{
atomic::{AtomicBool, AtomicU32, Ordering}, atomic::{AtomicBool, AtomicU32, Ordering},
@ -169,9 +169,16 @@ struct MuxMapValue {
is_closed_event: Arc<Event>, is_closed_event: Arc<Event>,
} }
impl Drop for MuxMapValue {
fn drop(&mut self) {
self.is_closed.store(true, Ordering::Release);
self.is_closed_event.notify(usize::MAX);
}
}
struct MuxInner { struct MuxInner {
tx: ws::LockedWebSocketWrite, tx: ws::LockedWebSocketWrite,
stream_map: DashMap<u32, MuxMapValue>, stream_map: HashMap<u32, MuxMapValue>,
buffer_size: u32, buffer_size: u32,
fut_exited: Arc<AtomicBool>, fut_exited: Arc<AtomicBool>,
} }
@ -221,11 +228,7 @@ impl MuxInner {
x = wisp_fut.fuse() => x, x = wisp_fut.fuse() => x,
}; };
self.fut_exited.store(true, Ordering::Release); self.fut_exited.store(true, Ordering::Release);
for x in self.stream_map.iter_mut() { self.stream_map.clear_async().await;
x.is_closed.store(true, Ordering::Release);
x.is_closed_event.notify(usize::MAX);
}
self.stream_map.clear();
let _ = self.tx.close().await; let _ = self.tx.close().await;
ret ret
} }
@ -310,7 +313,7 @@ impl MuxInner {
) )
.await?; .await?;
self.stream_map.insert(stream_id, map_value); self.stream_map.upsert_async(stream_id, map_value).await;
next_free_stream_id = next_stream_id; next_free_stream_id = next_stream_id;
@ -320,12 +323,12 @@ impl MuxInner {
let _ = channel.send(ret); let _ = channel.send(ret);
} }
WsEvent::Close(packet, channel) => { WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { if let Some((_, stream)) = self.stream_map.remove_async(&packet.stream_id).await
{
if let PacketType::Close(close) = packet.packet_type { if let PacketType::Close(close) = packet.packet_type {
self.close_stream(packet.stream_id, close); self.close_stream(stream, close);
} }
let _ = channel.send(self.tx.write_frame(packet.into()).await); let _ = channel.send(self.tx.write_frame(packet.into()).await);
drop(stream.stream)
} else { } else {
let _ = channel.send(Err(WispError::InvalidStreamId)); let _ = channel.send(Err(WispError::InvalidStreamId));
} }
@ -343,17 +346,14 @@ impl MuxInner {
} }
} }
fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) { fn close_stream(&self, stream: MuxMapValue, close_packet: ClosePacket) {
if let Some((_, stream)) = self.stream_map.remove(&stream_id) { stream
stream .close_reason
.close_reason .store(close_packet.reason, Ordering::Release);
.store(close_packet.reason, Ordering::Release); stream.is_closed.store(true, Ordering::Release);
stream.is_closed.store(true, Ordering::Release); stream.is_closed_event.notify(usize::MAX);
stream.is_closed_event.notify(usize::MAX); stream.flow_control.store(u32::MAX, Ordering::Release);
stream.flow_control.store(u32::MAX, Ordering::Release); stream.flow_control_event.notify(usize::MAX);
stream.flow_control_event.notify(usize::MAX);
drop(stream.stream)
}
} }
async fn server_loop<R>( async fn server_loop<R>(
@ -404,11 +404,13 @@ impl MuxInner {
.send_async((inner_packet, stream)) .send_async((inner_packet, stream))
.await .await
.map_err(|_| WispError::MuxMessageFailedToSend)?; .map_err(|_| WispError::MuxMessageFailedToSend)?;
self.stream_map.insert(packet.stream_id, map_value); self.stream_map
.upsert_async(packet.stream_id, map_value)
.await;
} }
Data(data) => { Data(data) => {
let mut data = BytesMut::from(data); let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if let Some(extra_frame) = optional_frame { if let Some(extra_frame) = optional_frame {
if data.is_empty() { if data.is_empty() {
data = extra_frame.payload.into(); data = extra_frame.payload.into();
@ -432,7 +434,12 @@ impl MuxInner {
if packet.stream_id == 0 { if packet.stream_id == 0 {
break Ok(()); break Ok(());
} }
self.close_stream(packet.stream_id, inner_packet)
if let Some((_, stream)) =
self.stream_map.remove_async(&packet.stream_id).await
{
self.close_stream(stream, inner_packet)
}
} }
} }
} }
@ -469,7 +476,7 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => { Data(data) => {
let mut data = BytesMut::from(data); let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if let Some(extra_frame) = optional_frame { if let Some(extra_frame) = optional_frame {
if data.is_empty() { if data.is_empty() {
data = extra_frame.payload.into(); data = extra_frame.payload.into();
@ -481,7 +488,7 @@ impl MuxInner {
} }
} }
Continue(inner_packet) => { Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) { if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if stream.stream_type == StreamType::Tcp { if stream.stream_type == StreamType::Tcp {
stream stream
.flow_control .flow_control
@ -494,7 +501,12 @@ impl MuxInner {
if packet.stream_id == 0 { if packet.stream_id == 0 {
break Ok(()); break Ok(());
} }
self.close_stream(packet.stream_id, inner_packet);
if let Some((_, stream)) =
self.stream_map.remove_async(&packet.stream_id).await
{
self.close_stream(stream, inner_packet)
}
} }
} }
} }
@ -624,7 +636,7 @@ impl ServerMux {
}, },
MuxInner { MuxInner {
tx: write, tx: write,
stream_map: DashMap::new(), stream_map: HashMap::new(),
buffer_size, buffer_size,
fut_exited, fut_exited,
} }
@ -814,7 +826,7 @@ impl ClientMux {
}, },
MuxInner { MuxInner {
tx: write, tx: write,
stream_map: DashMap::new(), stream_map: HashMap::new(),
buffer_size: packet.buffer_remaining, buffer_size: packet.buffer_remaining,
fut_exited, fut_exited,
} }