diff --git a/Cargo.lock b/Cargo.lock index 8fa37cd..33b2300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -861,6 +861,18 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1487,6 +1499,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -2273,6 +2294,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "strsim" @@ -3203,6 +3227,7 @@ dependencies = [ "dashmap", "event-listener", "fastwebsockets 0.7.1", + "flume", "futures", "futures-timer", "futures-util", diff --git a/client/demo.js b/client/demo.js index a5b35a2..298e2f2 100644 --- a/client/demo.js +++ b/client/demo.js @@ -238,9 +238,9 @@ onmessage = async (msg) => { log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); } else { - let resp = await epoxy_client.fetch("https://httpbin.org/get"); + let resp = await epoxy_client.fetch("https://www.example.com/"); console.log(resp, Object.fromEntries(resp.headers)); - plog(await resp.json()); + log(await resp.text()); } log("done"); }; diff --git a/client/src/utils.rs b/client/src/utils.rs index 3f6dc9b..98717a5 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -200,13 +200,10 @@ pub async fn make_mux( ), WispError, > { - let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) - .await - .map_err(|_| WispError::WsImplSocketClosed)?; + let (wtx, wrx) = + WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - let mux = ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?; - - Ok(mux) + ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await } pub fn spawn_mux_fut( @@ -215,6 +212,7 @@ pub fn spawn_mux_fut( url: String, ) { wasm_bindgen_futures::spawn_local(async move { + debug!("epoxy: mux future started"); if let Err(e) = fut.await { log!("epoxy: error in mux future, restarting: {:?}", e); while let Err(e) = replace_mux(mux.clone(), &url).await { @@ -229,7 +227,7 @@ pub fn spawn_mux_fut( pub async fn replace_mux(mux: Arc>, url: &str) -> Result<(), WispError> { let (mux_replace, fut) = make_mux(url).await?; let mut mux_write = mux.write().await; - mux_write.close().await?; + let _ = mux_write.close().await; *mux_write = mux_replace; drop(mux_write); spawn_mux_fut(mux, fut, url.into()); diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index e67779e..5746ac2 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -123,6 +123,7 @@ impl tower_service::Service for TlsWispService { let stream = service.call(uri_parsed).await?.into_inner(); if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { let connector = TlsConnector::from(rustls_config); + log!("got stream"); Ok(TokioIo::new(Either::Left( connector .connect( @@ -143,6 +144,7 @@ impl tower_service::Service for TlsWispService { pub enum WebSocketError { Unknown, SendFailed, + CloseFailed, } impl std::fmt::Display for WebSocketError { @@ -151,6 +153,7 @@ impl std::fmt::Display for WebSocketError { match self { Unknown => write!(f, "Unknown error"), SendFailed => write!(f, "Send failed"), + CloseFailed => write!(f, "Close failed"), } } } @@ -213,7 +216,7 @@ impl WebSocketRead for WebSocketReader { } impl WebSocketWrapper { - pub async fn connect( + pub fn connect( url: &str, protocols: Vec, ) -> Result<(Self, WebSocketReader), JsValue> { @@ -327,6 +330,12 @@ impl WebSocketWrite for WebSocketWrapper { _ => Err(WispError::WsImplNotSupported), } } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.inner + .close() + .map_err(|_| WebSocketError::CloseFailed.into()) + } } impl Drop for WebSocketWrapper { diff --git a/server/src/main.rs b/server/src/main.rs index 7e0e581..a6c8b8c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -12,9 +12,13 @@ use hyper::{ body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; +use tokio::{ + io::{copy_bidirectional, split, BufReader, BufWriter}, + net::{lookup_host, TcpListener, TcpStream, UdpSocket}, + select, +}; use tokio_util::codec::{BytesCodec, Framed}; #[cfg(unix)] use tokio_util::either::Either; @@ -22,9 +26,10 @@ use tokio_util::either::Either; use wisp_mux::{ extensions::{ password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, - udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, + udp::UdpProtocolExtensionBuilder, + ProtocolExtensionBuilder, }, - CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, + CloseReason, ConnectPacket, IoStream, MuxStream, MuxStreamIo, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -182,7 +187,10 @@ async fn main() -> Result<(), Error> { block_local: opt.block_local, block_non_http: opt.block_non_http, block_udp: opt.block_udp, - auth: Arc::new(vec![Box::new(UdpProtocolExtensionBuilder()), Box::new(pw_ext)]), + auth: Arc::new(vec![ + Box::new(UdpProtocolExtensionBuilder()), + Box::new(pw_ext), + ]), enforce_auth, }; @@ -257,7 +265,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result Result<(), Box> { avg.get_average() * opts.packet_size, ); if is_term { - print!("\x1b[2K{}\r", stat); + println!("\x1b[1A\x1b[2K{}\r", stat); } else { println!("{}", stat); } @@ -284,6 +284,8 @@ async fn main() -> Result<(), Box> { let out = select_all(threads.into_iter()).await; + let duration_since = Instant::now().duration_since(start_time); + if let Err(err) = out.0? { println!("\n\nerr: {:?}", err); exit(1); @@ -291,10 +293,10 @@ async fn main() -> Result<(), Box> { out.2.into_iter().for_each(|x| x.abort()); - let duration_since = Instant::now().duration_since(start_time); + mux.close().await?; println!( - "\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", + "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", cnt.get(), opts.packet_size, cnt.get() * opts.packet_size, diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 8cf2cba..795a1c6 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -15,6 +15,7 @@ bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } +flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" futures-util = "0.3.30" diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index cd0199c..a2c2f7d 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use async_trait::async_trait; use bytes::BytesMut; use fastwebsockets::{ - FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, + CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -77,4 +77,8 @@ impl crate::ws::WebSocketWrite for WebSocketWrite< async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { self.write_frame(frame.into()).await.map_err(|e| e.into()) } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into()) + } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 1454145..d68edf0 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,4 +1,4 @@ -#![deny(missing_docs)] +#![deny(missing_docs, warnings)] #![cfg_attr(docsrs, feature(doc_cfg))] //! A library for easily creating [Wisp] clients and servers. //! @@ -19,9 +19,8 @@ use bytes::Bytes; use dashmap::DashMap; use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; -use futures::{ - channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt -}; +use flume as mpsc; +use futures::{channel::oneshot, select, Future, FutureExt}; use futures_timer::Delay; use std::{ sync::{ @@ -151,11 +150,12 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} struct MuxMapValue { - stream: Mutex>, + stream: mpsc::Sender, stream_type: StreamType, flow_control: Arc, flow_control_event: Arc, is_closed: Arc, + is_closed_event: Arc, } struct MuxInner { @@ -170,7 +170,7 @@ impl MuxInner { rx: R, extensions: Vec, close_rx: mpsc::Receiver, - muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, + muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, close_tx: mpsc::Sender, ) -> Result<(), WispError> where @@ -210,20 +210,60 @@ impl MuxInner { }; for x in self.stream_map.iter_mut() { x.is_closed.store(true, Ordering::Release); - x.stream.lock().await.disconnect(); - x.stream.lock().await.close_channel(); + x.is_closed_event.notify(usize::MAX); } self.stream_map.clear(); + let _ = self.tx.close().await; ret } + async fn create_new_stream( + &self, + stream_id: u32, + stream_type: StreamType, + role: Role, + stream_tx: mpsc::Sender, + target_buffer_size: u32, + ) -> Result<(MuxMapValue, MuxStream), WispError> { + let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); + + let flow_control_event: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + + let is_closed: Arc = AtomicBool::new(false).into(); + let is_closed_event: Arc = Event::new().into(); + + Ok(( + MuxMapValue { + stream: ch_tx, + stream_type, + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + is_closed_event: is_closed_event.clone(), + }, + MuxStream::new( + stream_id, + role, + stream_type, + ch_rx, + stream_tx.clone(), + is_closed, + is_closed_event, + flow_control, + flow_control_event, + target_buffer_size, + ), + )) + } + async fn stream_loop( &self, - mut stream_rx: mpsc::Receiver, + stream_rx: mpsc::Receiver, stream_tx: mpsc::Sender, ) { let mut next_free_stream_id: u32 = 1; - while let Some(msg) = stream_rx.next().await { + while let Ok(msg) = stream_rx.recv_async().await { match msg { WsEvent::SendPacket(packet, channel) => { if self.stream_map.get(&packet.stream_id).is_some() { @@ -234,16 +274,20 @@ impl MuxInner { } WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { - let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let stream_id = next_free_stream_id; let next_stream_id = next_free_stream_id .checked_add(1) .ok_or(WispError::MaxStreamCountReached)?; - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - - let is_closed: Arc = AtomicBool::new(false).into(); + let (map_value, stream) = self + .create_new_stream( + stream_id, + stream_type, + Role::Client, + stream_tx.clone(), + 0, + ) + .await?; self.tx .write_frame( @@ -251,39 +295,19 @@ impl MuxInner { ) .await?; + self.stream_map.insert(stream_id, map_value); + next_free_stream_id = next_stream_id; - self.stream_map.insert( - stream_id, - MuxMapValue { - stream: ch_tx.into(), - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); - - Ok(MuxStream::new( - stream_id, - Role::Client, - stream_type, - ch_rx, - stream_tx.clone(), - is_closed, - flow_control, - flow_control_event, - 0, - )) + Ok(stream) } .await; let _ = channel.send(ret); } WsEvent::Close(packet, channel) => { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { - stream.stream.lock().await.disconnect(); - stream.stream.lock().await.close_channel(); let _ = channel.send(self.tx.write_frame(packet.into()).await); + drop(stream.stream) } else { let _ = channel.send(Err(WispError::InvalidStreamId)); } @@ -305,8 +329,8 @@ impl MuxInner { &self, mut rx: R, mut extensions: Vec, - muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, - close_tx: mpsc::Sender, + muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, + stream_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead + Send, @@ -325,42 +349,24 @@ impl MuxInner { use PacketType::*; match packet.packet_type { Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); - let stream_type = inner_packet.stream_type; - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - let flow_control_event: Arc = Event::new().into(); - let is_closed: Arc = AtomicBool::new(false).into(); - - self.stream_map.insert( - packet.stream_id, - MuxMapValue { - stream: ch_tx.into(), - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); + let (map_value, stream) = self + .create_new_stream( + packet.stream_id, + inner_packet.stream_type, + Role::Server, + stream_tx.clone(), + target_buffer_size, + ) + .await?; muxstream_sender - .unbounded_send(( - inner_packet, - MuxStream::new( - packet.stream_id, - Role::Server, - stream_type, - ch_rx, - close_tx.clone(), - is_closed, - flow_control, - flow_control_event, - target_buffer_size, - ), - )) - .map_err(|x| WispError::Other(Box::new(x)))?; + .send_async((inner_packet, stream)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + self.stream_map.insert(packet.stream_id, map_value); } Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.lock().await.send(data).await; + let _ = stream.stream.send_async(data).await; if stream.stream_type == StreamType::Tcp { stream.flow_control.store( stream @@ -379,8 +385,8 @@ impl MuxInner { } if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.lock().await.disconnect(); - stream.stream.lock().await.close_channel(); + stream.is_closed_event.notify(usize::MAX); + drop(stream.stream) } } } @@ -409,7 +415,7 @@ impl MuxInner { Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.lock().await.send(data).await; + let _ = stream.stream.send_async(data).await; } } Continue(inner_packet) => { @@ -428,8 +434,8 @@ impl MuxInner { } if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.lock().await.disconnect(); - stream.stream.lock().await.close_channel(); + stream.is_closed_event.notify(usize::MAX); + drop(stream.stream) } } } @@ -465,7 +471,7 @@ pub struct ServerMux { /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, close_tx: mpsc::Sender, - muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, + muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, } impl ServerMux { @@ -484,7 +490,7 @@ impl ServerMux { R: ws::WebSocketRead + Send, W: ws::WebSocketWrite + Send + 'static, { - let (close_tx, close_rx) = mpsc::channel::(256); + let (close_tx, close_rx) = mpsc::bounded::(256); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let write = ws::LockedWebSocketWrite::new(Box::new(write)); @@ -547,12 +553,12 @@ impl ServerMux { /// Wait for a stream to be created. pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { - self.muxstream_recv.next().await + self.muxstream_recv.recv_async().await.ok() } async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { self.close_tx - .send(WsEvent::EndFut(reason)) + .send_async(WsEvent::EndFut(reason)) .await .map_err(|_| WispError::MuxMessageFailedToSend) } @@ -574,6 +580,13 @@ impl ServerMux { .await } } + +impl Drop for ServerMux { + fn drop(&mut self) { + let _ = self.close_tx.send(WsEvent::EndFut(None)); + } +} + /// Client side multiplexor. /// /// # Example @@ -595,7 +608,7 @@ pub struct ClientMux { pub downgraded: bool, /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, - close_tx: mpsc::Sender, + stream_tx: mpsc::Sender, } impl ClientMux { @@ -654,10 +667,10 @@ impl ClientMux { extension.handle_handshake(&mut read, &write).await?; } - let (tx, rx) = mpsc::channel::(256); + let (tx, rx) = mpsc::bounded::(256); Ok(( Self { - close_tx: tx.clone(), + stream_tx: tx.clone(), downgraded, supported_extension_ids: supported_extensions .iter() @@ -697,16 +710,16 @@ impl ClientMux { return Err(WispError::UdpExtensionNotSupported); } let (tx, rx) = oneshot::channel(); - self.close_tx - .send(WsEvent::CreateStream(stream_type, host, port, tx)) + self.stream_tx + .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { - self.close_tx - .send(WsEvent::EndFut(reason)) + self.stream_tx + .send_async(WsEvent::EndFut(reason)) .await .map_err(|_| WispError::MuxMessageFailedToSend) } @@ -728,3 +741,9 @@ impl ClientMux { .await } } + +impl Drop for ClientMux { + fn drop(&mut self) { + let _ = self.stream_tx.send(WsEvent::EndFut(None)); + } +} diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 1a8c2da..8a074a7 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,13 +1,14 @@ use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError}; -use async_io_stream::IoStream; +pub use async_io_stream::IoStream; use bytes::Bytes; use event_listener::Event; +use flume as mpsc; use futures::{ - channel::{mpsc, oneshot}, - stream, + channel::oneshot, + select, stream, task::{Context, Poll}, - Sink, SinkExt, Stream, StreamExt, + FutureExt, Sink, Stream, }; use pin_project_lite::pin_project; use std::{ @@ -40,6 +41,7 @@ pub struct MuxStreamRead { tx: mpsc::Sender, rx: mpsc::Receiver, is_closed: Arc, + is_closed_event: Arc, flow_control: Arc, flow_control_read: AtomicU32, target_flow_control: u32, @@ -51,13 +53,16 @@ impl MuxStreamRead { if self.is_closed.load(Ordering::Acquire) { return None; } - let bytes = self.rx.next().await?; + let bytes = select! { + x = self.rx.recv_async() => x.ok()?, + _ = self.is_closed_event.listen().fuse() => return None + }; if self.role == Role::Server && self.stream_type == StreamType::Tcp { let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; if val > self.target_flow_control { let (tx, rx) = oneshot::channel::>(); self.tx - .send(WsEvent::SendPacket( + .send_async(WsEvent::SendPacket( Packet::new_continue( self.stream_id, self.flow_control.fetch_add(val, Ordering::AcqRel) + val, @@ -107,13 +112,13 @@ impl MuxStreamWrite { } let (tx, rx) = oneshot::channel::>(); self.tx - .send(WsEvent::SendPacket( + .send_async(WsEvent::SendPacket( Packet::new_data(self.stream_id, data), tx, )) .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; if self.role == Role::Client && self.stream_type == StreamType::Tcp { self.flow_control.store( self.flow_control.load(Ordering::Acquire).saturating_sub(1), @@ -151,13 +156,13 @@ impl MuxStreamWrite { let (tx, rx) = oneshot::channel::>(); self.tx - .send(WsEvent::Close( + .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; Ok(()) } @@ -179,6 +184,16 @@ impl MuxStreamWrite { } } +impl Drop for MuxStreamWrite { + fn drop(&mut self) { + if !self.is_closed.load(Ordering::Acquire) { + self.is_closed.store(true, Ordering::Release); + let (tx, _) = oneshot::channel(); + let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx)); + } + } +} + /// Multiplexor stream. pub struct MuxStream { /// ID of the stream. @@ -196,6 +211,7 @@ impl MuxStream { rx: mpsc::Receiver, tx: mpsc::Sender, is_closed: Arc, + is_closed_event: Arc, flow_control: Arc, continue_recieved: Arc, target_flow_control: u32, @@ -209,6 +225,7 @@ impl MuxStream { tx: tx.clone(), rx, is_closed: is_closed.clone(), + is_closed_event: is_closed_event.clone(), flow_control: flow_control.clone(), flow_control_read: AtomicU32::new(0), target_flow_control, @@ -288,13 +305,13 @@ impl MuxStreamCloser { let (tx, rx) = oneshot::channel::>(); self.close_channel - .send(WsEvent::Close( + .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; Ok(()) } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 7348bb8..258a5d1 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -76,6 +76,9 @@ pub trait WebSocketRead { pub trait WebSocketWrite { /// Write a frame to the socket. async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; + + /// Close the socket. + async fn wisp_close(&mut self) -> Result<(), WispError>; } /// Locked WebSocket. @@ -88,9 +91,14 @@ impl LockedWebSocketWrite { } /// Write a frame to the websocket. - pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> { + pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> { self.0.lock().await.wisp_write_frame(frame).await } + + /// Close the websocket. + pub async fn close(&self) -> Result<(), WispError> { + self.0.lock().await.wisp_close().await + } } pub(crate) struct AppendingWebSocketRead(pub Vec, pub R)