use async hashmap

This commit is contained in:
Toshit Chawda 2024-08-31 10:40:40 -07:00
parent b42cf07a24
commit b1f56c1dae
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
3 changed files with 59 additions and 32 deletions

View file

@ -19,12 +19,12 @@ pub mod ws;
pub use crate::{packet::*, stream::*};
use bytes::{Bytes, BytesMut};
use dashmap::DashMap;
use event_listener::Event;
use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder};
use flume as mpsc;
use futures::{channel::oneshot, select, Future, FutureExt};
use futures_timer::Delay;
use scc::HashMap;
use std::{
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
@ -169,9 +169,16 @@ struct MuxMapValue {
is_closed_event: Arc<Event>,
}
impl Drop for MuxMapValue {
fn drop(&mut self) {
self.is_closed.store(true, Ordering::Release);
self.is_closed_event.notify(usize::MAX);
}
}
struct MuxInner {
tx: ws::LockedWebSocketWrite,
stream_map: DashMap<u32, MuxMapValue>,
stream_map: HashMap<u32, MuxMapValue>,
buffer_size: u32,
fut_exited: Arc<AtomicBool>,
}
@ -221,11 +228,7 @@ impl MuxInner {
x = wisp_fut.fuse() => x,
};
self.fut_exited.store(true, Ordering::Release);
for x in self.stream_map.iter_mut() {
x.is_closed.store(true, Ordering::Release);
x.is_closed_event.notify(usize::MAX);
}
self.stream_map.clear();
self.stream_map.clear_async().await;
let _ = self.tx.close().await;
ret
}
@ -310,7 +313,7 @@ impl MuxInner {
)
.await?;
self.stream_map.insert(stream_id, map_value);
self.stream_map.upsert_async(stream_id, map_value).await;
next_free_stream_id = next_stream_id;
@ -320,12 +323,12 @@ impl MuxInner {
let _ = channel.send(ret);
}
WsEvent::Close(packet, channel) => {
if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) {
if let Some((_, stream)) = self.stream_map.remove_async(&packet.stream_id).await
{
if let PacketType::Close(close) = packet.packet_type {
self.close_stream(packet.stream_id, close);
self.close_stream(stream, close);
}
let _ = channel.send(self.tx.write_frame(packet.into()).await);
drop(stream.stream)
} else {
let _ = channel.send(Err(WispError::InvalidStreamId));
}
@ -343,17 +346,14 @@ impl MuxInner {
}
}
fn close_stream(&self, stream_id: u32, close_packet: ClosePacket) {
if let Some((_, stream)) = self.stream_map.remove(&stream_id) {
stream
.close_reason
.store(close_packet.reason, Ordering::Release);
stream.is_closed.store(true, Ordering::Release);
stream.is_closed_event.notify(usize::MAX);
stream.flow_control.store(u32::MAX, Ordering::Release);
stream.flow_control_event.notify(usize::MAX);
drop(stream.stream)
}
fn close_stream(&self, stream: MuxMapValue, close_packet: ClosePacket) {
stream
.close_reason
.store(close_packet.reason, Ordering::Release);
stream.is_closed.store(true, Ordering::Release);
stream.is_closed_event.notify(usize::MAX);
stream.flow_control.store(u32::MAX, Ordering::Release);
stream.flow_control_event.notify(usize::MAX);
}
async fn server_loop<R>(
@ -404,11 +404,13 @@ impl MuxInner {
.send_async((inner_packet, stream))
.await
.map_err(|_| WispError::MuxMessageFailedToSend)?;
self.stream_map.insert(packet.stream_id, map_value);
self.stream_map
.upsert_async(packet.stream_id, map_value)
.await;
}
Data(data) => {
let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if let Some(extra_frame) = optional_frame {
if data.is_empty() {
data = extra_frame.payload.into();
@ -432,7 +434,12 @@ impl MuxInner {
if packet.stream_id == 0 {
break Ok(());
}
self.close_stream(packet.stream_id, inner_packet)
if let Some((_, stream)) =
self.stream_map.remove_async(&packet.stream_id).await
{
self.close_stream(stream, inner_packet)
}
}
}
}
@ -469,7 +476,7 @@ impl MuxInner {
Connect(_) | Info(_) => break Err(WispError::InvalidPacketType),
Data(data) => {
let mut data = BytesMut::from(data);
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if let Some(extra_frame) = optional_frame {
if data.is_empty() {
data = extra_frame.payload.into();
@ -481,7 +488,7 @@ impl MuxInner {
}
}
Continue(inner_packet) => {
if let Some(stream) = self.stream_map.get(&packet.stream_id) {
if let Some(stream) = self.stream_map.get_async(&packet.stream_id).await {
if stream.stream_type == StreamType::Tcp {
stream
.flow_control
@ -494,7 +501,12 @@ impl MuxInner {
if packet.stream_id == 0 {
break Ok(());
}
self.close_stream(packet.stream_id, inner_packet);
if let Some((_, stream)) =
self.stream_map.remove_async(&packet.stream_id).await
{
self.close_stream(stream, inner_packet)
}
}
}
}
@ -624,7 +636,7 @@ impl ServerMux {
},
MuxInner {
tx: write,
stream_map: DashMap::new(),
stream_map: HashMap::new(),
buffer_size,
fut_exited,
}
@ -814,7 +826,7 @@ impl ClientMux {
},
MuxInner {
tx: write,
stream_map: DashMap::new(),
stream_map: HashMap::new(),
buffer_size: packet.buffer_remaining,
fut_exited,
}