add a new Payload struct to allow for one-copy writes and cargo fmt

This commit is contained in:
Toshit Chawda 2024-07-17 16:23:58 -07:00
parent 314c1bfa75
commit d6353bd5a9
No known key found for this signature in database
GPG key ID: 91480ED99E2B3D9D
18 changed files with 3533 additions and 3395 deletions

View file

@ -3,73 +3,78 @@ use std::{str::from_utf8, sync::Arc};
use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::Bytes;
use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
FragmentCollectorRead, Frame, OpCode, Payload, Role, WebSocket, WebSocketWrite,
};
use futures_util::lock::Mutex;
use getrandom::getrandom;
use http::{
header::{
CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, USER_AGENT,
},
Method, Request, Response, StatusCode, Uri,
header::{
CONNECTION, HOST, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION,
UPGRADE, USER_AGENT,
},
Method, Request, Response, StatusCode, Uri,
};
use hyper::{
body::Incoming,
upgrade::{self, Upgraded},
body::Incoming,
upgrade::{self, Upgraded},
};
use js_sys::{ArrayBuffer, Function, Object, Uint8Array};
use tokio::io::WriteHalf;
use wasm_bindgen::{prelude::*, JsError, JsValue};
use wasm_bindgen_futures::spawn_local;
use crate::{tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody};
use crate::{
tokioio::TokioIo, utils::entries_of_object, EpoxyClient, EpoxyError, EpoxyHandlers, HttpBody,
};
#[wasm_bindgen]
pub struct EpoxyWebSocket {
tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
onerror: Function,
tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
onerror: Function,
}
#[wasm_bindgen]
impl EpoxyWebSocket {
pub(crate) async fn connect(
client: &EpoxyClient,
handlers: EpoxyHandlers,
url: String,
protocols: Vec<String>,
pub(crate) async fn connect(
client: &EpoxyClient,
handlers: EpoxyHandlers,
url: String,
protocols: Vec<String>,
headers: JsValue,
user_agent: &str,
) -> Result<Self, EpoxyError> {
let EpoxyHandlers {
onopen,
onclose,
onerror,
onmessage,
} = handlers;
let onerror_cloned = onerror.clone();
let ret: Result<EpoxyWebSocket, EpoxyError> = async move {
let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
) -> Result<Self, EpoxyError> {
let EpoxyHandlers {
onopen,
onclose,
onerror,
onmessage,
} = handlers;
let onerror_cloned = onerror.clone();
let ret: Result<EpoxyWebSocket, EpoxyError> = async move {
let url: Uri = url.try_into()?;
let host = url.host().ok_or(EpoxyError::NoUrlHost)?;
let mut rand = [0u8; 16];
getrandom(&mut rand)?;
let key = BASE64_STANDARD.encode(rand);
let mut rand = [0u8; 16];
getrandom(&mut rand)?;
let key = BASE64_STANDARD.encode(rand);
let mut request = Request::builder()
.method(Method::GET)
.uri(url.clone())
.header(HOST, host)
.header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_KEY, key)
.header(SEC_WEBSOCKET_VERSION, "13")
let mut request = Request::builder()
.method(Method::GET)
.uri(url.clone())
.header(HOST, host)
.header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_KEY, key)
.header(SEC_WEBSOCKET_VERSION, "13")
.header(USER_AGENT, user_agent);
if !protocols.is_empty() {
request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
}
if !protocols.is_empty() {
request = request.header(SEC_WEBSOCKET_PROTOCOL, protocols.join(","));
}
if web_sys::Headers::instanceof(&headers) && let Ok(entries) = Object::from_entries(&headers) {
if web_sys::Headers::instanceof(&headers)
&& let Ok(entries) = Object::from_entries(&headers)
{
for header in entries_of_object(&entries) {
request = request.header(&header[0], &header[1]);
}
@ -79,153 +84,153 @@ impl EpoxyWebSocket {
}
}
let request = request.body(HttpBody::new(Bytes::new()))?;
let request = request.body(HttpBody::new(Bytes::new()))?;
let mut response = client.client.request(request).await?;
verify(&response)?;
let mut response = client.client.request(request).await?;
verify(&response)?;
let websocket = WebSocket::after_handshake(
TokioIo::new(upgrade::on(&mut response).await?),
Role::Client,
);
let websocket = WebSocket::after_handshake(
TokioIo::new(upgrade::on(&mut response).await?),
Role::Client,
);
let (rx, tx) = websocket.split(tokio::io::split);
let (rx, tx) = websocket.split(tokio::io::split);
let mut rx = FragmentCollectorRead::new(rx);
let tx = Arc::new(Mutex::new(tx));
let mut rx = FragmentCollectorRead::new(rx);
let tx = Arc::new(Mutex::new(tx));
let read_tx = tx.clone();
let onerror_cloned = onerror.clone();
let read_tx = tx.clone();
let onerror_cloned = onerror.clone();
spawn_local(async move {
loop {
match rx
.read_frame(&mut |arg| async {
read_tx.lock().await.write_frame(arg).await
})
.await
{
Ok(frame) => match frame.opcode {
OpCode::Text => {
if let Ok(str) = from_utf8(&frame.payload) {
let _ = onmessage.call1(&JsValue::null(), &str.into());
}
}
OpCode::Binary => {
let _ = onmessage.call1(
&JsValue::null(),
&Uint8Array::from(frame.payload.to_vec().as_slice()).into(),
);
}
OpCode::Close => {
break;
}
// ping/pong/continue
_ => {}
},
Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break;
}
}
}
let _ = onclose.call0(&JsValue::null());
});
spawn_local(async move {
loop {
match rx
.read_frame(&mut |arg| async {
read_tx.lock().await.write_frame(arg).await
})
.await
{
Ok(frame) => match frame.opcode {
OpCode::Text => {
if let Ok(str) = from_utf8(&frame.payload) {
let _ = onmessage.call1(&JsValue::null(), &str.into());
}
}
OpCode::Binary => {
let _ = onmessage.call1(
&JsValue::null(),
&Uint8Array::from(frame.payload.to_vec().as_slice()).into(),
);
}
OpCode::Close => {
break;
}
// ping/pong/continue
_ => {}
},
Err(err) => {
let _ = onerror.call1(&JsValue::null(), &JsError::from(err).into());
break;
}
}
}
let _ = onclose.call0(&JsValue::null());
});
let _ = onopen.call0(&JsValue::null());
let _ = onopen.call0(&JsValue::null());
Ok(Self {
tx,
onerror: onerror_cloned,
})
}
.await;
Ok(Self {
tx,
onerror: onerror_cloned,
})
}
.await;
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = onerror_cloned.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret = if let Some(str) = payload.as_string() {
self.tx
.lock()
.await
.write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec())))
.await
.map_err(EpoxyError::from)
} else if let Ok(binary) = payload.dyn_into::<ArrayBuffer>() {
self.tx
.lock()
.await
.write_frame(Frame::binary(Payload::Owned(
Uint8Array::new(&binary).to_vec(),
)))
.await
.map_err(EpoxyError::from)
} else {
Err(EpoxyError::WsInvalidPayload)
};
pub async fn send(&self, payload: JsValue) -> Result<(), EpoxyError> {
let ret = if let Some(str) = payload.as_string() {
self.tx
.lock()
.await
.write_frame(Frame::text(Payload::Owned(str.as_bytes().to_vec())))
.await
.map_err(EpoxyError::from)
} else if let Ok(binary) = payload.dyn_into::<ArrayBuffer>() {
self.tx
.lock()
.await
.write_frame(Frame::binary(Payload::Owned(
Uint8Array::new(&binary).to_vec(),
)))
.await
.map_err(EpoxyError::from)
} else {
Err(EpoxyError::WsInvalidPayload)
};
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err)
}
}
}
pub async fn close(&self, code: u16, reason: String) -> Result<(), EpoxyError> {
let ret = self
.tx
.lock()
.await
.write_frame(Frame::close(code, reason.as_bytes()))
.await;
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err.into())
}
}
}
pub async fn close(&self, code: u16, reason: String) -> Result<(), EpoxyError> {
let ret = self
.tx
.lock()
.await
.write_frame(Frame::close(code, reason.as_bytes()))
.await;
match ret {
Ok(ok) => Ok(ok),
Err(err) => {
let _ = self
.onerror
.call1(&JsValue::null(), &err.to_string().into());
Err(err.into())
}
}
}
}
// https://github.com/snapview/tungstenite-rs/blob/314feea3055a93e585882fb769854a912a7e6dae/src/handshake/client.rs#L189
fn verify(response: &Response<Incoming>) -> Result<(), EpoxyError> {
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(EpoxyError::WsInvalidStatusCode);
}
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(EpoxyError::WsInvalidStatusCode);
}
let headers = response.headers();
let headers = response.headers();
if !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(EpoxyError::WsInvalidUpgradeHeader);
}
if !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(EpoxyError::WsInvalidUpgradeHeader);
}
if !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(EpoxyError::WsInvalidConnectionHeader);
}
if !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(EpoxyError::WsInvalidConnectionHeader);
}
Ok(())
Ok(())
}