diff --git a/Cargo.lock b/Cargo.lock index b6e0b7d..a3ba6a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1485,12 +1485,27 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "scc" +version = "2.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37" +dependencies = [ + "sdd", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sdd" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f" + [[package]] name = "send_wrapper" version = "0.4.0" @@ -2253,13 +2268,13 @@ dependencies = [ "async-trait", "atomic_enum", "bytes", - "dashmap", "event-listener", "fastwebsockets", "flume", "futures", "futures-timer", "pin-project-lite", + "scc", "tokio", ] diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 0d9ffac..28b8b76 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -12,13 +12,13 @@ edition = "2021" async-trait = "0.1.81" atomic_enum = "0.3.0" bytes = "1.7.1" -dashmap = { version = "6.0.1", features = ["inline"] } event-listener = "5.3.1" fastwebsockets = { version = "0.8.0", features = ["unstable-split"], optional = true } flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" pin-project-lite = "0.2.14" +scc = "2.1.16" tokio = { version = "1.39.3", optional = true, default-features = false } [features] diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 1ba04e8..4f77e77 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -19,12 +19,12 @@ pub mod ws; pub use crate::{packet::*, stream::*}; use bytes::{Bytes, BytesMut}; -use dashmap::DashMap; use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use flume as mpsc; use futures::{channel::oneshot, select, Future, FutureExt}; use futures_timer::Delay; +use scc::HashMap; use std::{ sync::{ atomic::{AtomicBool, AtomicU32, Ordering}, @@ -169,9 +169,16 @@ struct MuxMapValue { is_closed_event: Arc, } +impl Drop for MuxMapValue { + fn drop(&mut self) { + self.is_closed.store(true, Ordering::Release); + self.is_closed_event.notify(usize::MAX); + } +} + struct MuxInner { tx: ws::LockedWebSocketWrite, - stream_map: DashMap, + stream_map: HashMap, buffer_size: u32, fut_exited: Arc, } @@ -221,11 +228,7 @@ impl MuxInner { x = wisp_fut.fuse() => x, }; self.fut_exited.store(true, Ordering::Release); - for x in self.stream_map.iter_mut() { - x.is_closed.store(true, Ordering::Release); - x.is_closed_event.notify(usize::MAX); - } - self.stream_map.clear(); + self.stream_map.clear_async().await; let _ = self.tx.close().await; ret } @@ -310,7 +313,7 @@ impl MuxInner { ) .await?; - self.stream_map.insert(stream_id, map_value); + self.stream_map.upsert_async(stream_id, map_value).await; next_free_stream_id = next_stream_id; @@ -320,12 +323,12 @@ impl MuxInner { let _ = channel.send(ret); } WsEvent::Close(packet, channel) => { - if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + if let Some((_, stream)) = self.stream_map.remove_async(&packet.stream_id).await + { if let PacketType::Close(close) = packet.packet_type { - self.close_stream(packet.stream_id, close); + self.close_stream(stream, close); } let _ = channel.send(self.tx.write_frame(packet.into()).await); - drop(stream.stream) } else { let _ = channel.send(Err(WispError::InvalidStreamId)); } @@ -343,17 +346,14 @@ impl MuxInner { } } - fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) { - if let Some((_, stream)) = self.stream_map.remove(&stream_id) { - stream - .close_reason - .store(close_packet.reason, Ordering::Release); - stream.is_closed.store(true, Ordering::Release); - stream.is_closed_event.notify(usize::MAX); - stream.flow_control.store(u32::MAX, Ordering::Release); - stream.flow_control_event.notify(usize::MAX); - drop(stream.stream) - } + fn close_stream(&self, stream: MuxMapValue, close_packet: ClosePacket) { + stream + .close_reason + .store(close_packet.reason, Ordering::Release); + stream.is_closed.store(true, Ordering::Release); + stream.is_closed_event.notify(usize::MAX); + stream.flow_control.store(u32::MAX, Ordering::Release); + stream.flow_control_event.notify(usize::MAX); } async fn server_loop( @@ -404,11 +404,13 @@ impl MuxInner { .send_async((inner_packet, stream)) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; - self.stream_map.insert(packet.stream_id, map_value); + self.stream_map + .upsert_async(packet.stream_id, map_value) + .await; } Data(data) => { let mut data = BytesMut::from(data); - if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await { if let Some(extra_frame) = optional_frame { if data.is_empty() { data = extra_frame.payload.into(); @@ -432,7 +434,12 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - self.close_stream(packet.stream_id, inner_packet) + + if let Some((_, stream)) = + self.stream_map.remove_async(&packet.stream_id).await + { + self.close_stream(stream, inner_packet) + } } } } @@ -469,7 +476,7 @@ impl MuxInner { Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { let mut data = BytesMut::from(data); - if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await { if let Some(extra_frame) = optional_frame { if data.is_empty() { data = extra_frame.payload.into(); @@ -481,7 +488,7 @@ impl MuxInner { } } Continue(inner_packet) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await { if stream.stream_type == StreamType::Tcp { stream .flow_control @@ -494,7 +501,12 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - self.close_stream(packet.stream_id, inner_packet); + + if let Some((_, stream)) = + self.stream_map.remove_async(&packet.stream_id).await + { + self.close_stream(stream, inner_packet) + } } } } @@ -624,7 +636,7 @@ impl ServerMux { }, MuxInner { tx: write, - stream_map: DashMap::new(), + stream_map: HashMap::new(), buffer_size, fut_exited, } @@ -814,7 +826,7 @@ impl ClientMux { }, MuxInner { tx: write, - stream_map: DashMap::new(), + stream_map: HashMap::new(), buffer_size: packet.buffer_remaining, fut_exited, }