From e2b85c3d9962d0c3c922ed68fdd9311a0cdc9bcd Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 7 Jul 2024 14:10:19 -0700 Subject: [PATCH] websocket support --- src/client.ts | 209 ++++++++++++++++++++++++++++++++++++++++++++-- src/connection.ts | 55 ++++++++++-- src/worker.ts | 53 ++++++++++-- 3 files changed, 298 insertions(+), 19 deletions(-) diff --git a/src/client.ts b/src/client.ts index 4301fd5..33149d2 100644 --- a/src/client.ts +++ b/src/client.ts @@ -1,7 +1,21 @@ import { BareHeaders, BareTransport, maxRedirects } from './baretypes'; -import { WorkerConnection } from './connection'; +import { WorkerConnection, WorkerMessage, WorkerResponse } from './connection'; import { WebSocketFields } from './snapshot'; +const validChars = + "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"; + +export function validProtocol(protocol: string): boolean { + for (let i = 0; i < protocol.length; i++) { + const char = protocol[i]; + + if (!validChars.includes(char)) { + return false; + } + } + + return true; +} // get the unhooked value const getRealReadyState = Object.getOwnPropertyDescriptor( @@ -118,7 +132,190 @@ export class BareClient { requestHeaders: BareHeaders, arrayBufferImpl: typeof ArrayBuffer, ): WebSocket { - throw new Error("todo"); + try { + remote = new URL(remote); + } catch (err) { + throw new DOMException( + `Faiiled to construct 'WebSocket': The URL '${remote}' is invalid.` + ); + } + + if (!wsProtocols.includes(remote.protocol)) + throw new DOMException( + `Failed to construct 'WebSocket': The URL's scheme must be either 'ws' or 'wss'. '${remote.protocol}' is not allowed.` + ); + + if (!Array.isArray(protocols)) protocols = [protocols]; + + protocols = protocols.map(String); + + for (const proto of protocols) + if (!validProtocol(proto)) + throw new DOMException( + `Failed to construct 'WebSocket': The subprotocol '${proto}' is invalid.` + ); + + + let wsImpl = (webSocketImpl || WebSocket) as WebSocketImpl; + const socket = new wsImpl("ws://127.0.0.1:1", protocols); + + let fakeProtocol = ''; + + let fakeReadyState: number = WebSocketFields.CONNECTING; + + let initialErrorHappened = false; + socket.addEventListener("error", (e) => { + if (!initialErrorHappened) { + fakeReadyState = WebSocket.CONNECTING; + e.stopImmediatePropagation(); + initialErrorHappened = true; + } + }); + let initialCloseHappened = false; + socket.addEventListener("close", (e) => { + if (!initialCloseHappened) { + e.stopImmediatePropagation(); + initialCloseHappened = true; + } + }); + // TODO socket onerror will be broken + + arrayBufferImpl = arrayBufferImpl || wsImpl.constructor.constructor("return ArrayBuffer")().prototype; + requestHeaders = requestHeaders || {}; + requestHeaders['Host'] = (new URL(remote)).host; + // requestHeaders['Origin'] = origin; + requestHeaders['Pragma'] = 'no-cache'; + requestHeaders['Cache-Control'] = 'no-cache'; + requestHeaders['Upgrade'] = 'websocket'; + // requestHeaders['User-Agent'] = navigator.userAgent; + requestHeaders['Connection'] = 'Upgrade'; + + const onopen = (protocol: string) => { + fakeReadyState = WebSocketFields.OPEN; + fakeProtocol = protocol; + + (socket as any).meta = { + headers: { + "sec-websocket-protocol": protocol, + } + }; // what the fuck is a meta + socket.dispatchEvent(new Event("open")); + }; + + const onmessage = async (payload) => { + if (typeof payload === "string") { + socket.dispatchEvent(new MessageEvent("message", { data: payload })); + } else if ("byteLength" in payload) { + if (socket.binaryType === "blob") { + payload = new Blob([payload]); + } else { + Object.setPrototypeOf(payload, arrayBufferImpl); + } + + socket.dispatchEvent(new MessageEvent("message", { data: payload })); + } else if ("arrayBuffer" in payload) { + if (socket.binaryType === "arraybuffer") { + payload = await payload.arrayBuffer() + Object.setPrototypeOf(payload, arrayBufferImpl); + } + + socket.dispatchEvent(new MessageEvent("message", { data: payload })); + } + }; + + const onclose = (code, reason) => { + fakeReadyState = WebSocketFields.CLOSED; + socket.dispatchEvent(new CloseEvent("close", { code, reason })); + }; + + const onerror = () => { + fakeReadyState = WebSocketFields.CLOSED; + socket.dispatchEvent(new Event("error")) + }; + + const channel = new MessageChannel(); + + channel.port1.onmessage = event => { + if (event.data.type === "open") { + onopen(event.data.args[0]); + } else if (event.data.type === "message") { + onmessage(event.data.args[0]); + } else if (event.data.type === "close") { + onclose(event.data.args[0], event.data.args[1]); + } else if (event.data.type === "error") { + onerror(/* event.data.args[0] */); + } + } + + this.worker.sendMessage({ + type: "websocket", + websocket: { + url: remote.toString(), + origin: origin, + protocols: protocols, + requestHeaders: requestHeaders, + }, + websocketChannel: channel.port2, + }, [channel.port2]) + + // protocol is always an empty before connecting + // updated when we receive the metadata + // this value doesn't change when it's CLOSING or CLOSED etc + const getReadyState = () => fakeReadyState; + + // we have to hook .readyState ourselves + + Object.defineProperty(socket, 'readyState', { + get: getReadyState, + configurable: true, + enumerable: true, + }); + + /** + * @returns The error that should be thrown if send() were to be called on this socket according to the fake readyState value + */ + const getSendError = () => { + const readyState = getReadyState(); + + if (readyState === WebSocketFields.CONNECTING) + return new DOMException( + "Failed to execute 'send' on 'WebSocket': Still in CONNECTING state." + ); + }; + + // we have to hook .send ourselves + // use ...args to avoid giving the number of args a quantity + // no arguments will trip the following error: TypeError: Failed to execute 'send' on 'WebSocket': 1 argument required, but only 0 present. + socket.send = function(...args) { + const error = getSendError(); + + if (error) throw error; + let data = args[0]; + // @ts-expect-error idk why it errors? + if (data.buffer) data = data.buffer; + + channel.port1.postMessage({ type: "data", data: data }, data instanceof ArrayBuffer ? [data] : []); + }; + + socket.close = function(code: number, reason: string) { + channel.port1.postMessage({ type: "close", closeCode: code, closeReason: reason }); + } + + Object.defineProperty(socket, 'url', { + get: () => remote.toString(), + configurable: true, + enumerable: true, + }); + + const getProtocol = () => fakeProtocol; + + Object.defineProperty(socket, 'protocol', { + get: getProtocol, + configurable: true, + enumerable: true, + }); + + return socket; } async fetch( @@ -159,15 +356,15 @@ export class BareClient { if ('host' in headers) headers.host = urlO.host; else headers.Host = urlO.host; - let resp = (await this.worker.sendMessage({ + const message = Object.assign({ type: "fetch", fetch: { remote: urlO.toString(), method: req.method, - body: body, headers: headers, - } - })).fetch; + }, + }, body ? { fetchBody: body } : {}); + let resp = (await this.worker.sendMessage(message as WorkerMessage, body ? [body] : [])).fetch; let responseobj: BareResponse & Partial = new Response( statusEmpty.includes(resp.status) ? undefined : resp.body, { diff --git a/src/connection.ts b/src/connection.ts index c74b2fc..486624d 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -13,13 +13,20 @@ function tryGetPort(client: SWClient): Promise { } export type WorkerMessage = { - type: "fetch" | "set", + type: "fetch" | "websocket" | "set", fetch?: { remote: string, method: string, - body: ReadableStream | null, headers: BareHeaders, }, + fetchBody?: ReadableStream, + websocket?: { + url: string, + origin: string, + protocols: string[], + requestHeaders: BareHeaders, + }, + websocketChannel?: MessagePort, client?: string, }; @@ -29,21 +36,32 @@ export type WorkerRequest = { } export type WorkerResponse = { - type: "fetch" | "set" | "error", + type: "fetch" | "websocket" | "set" | "error", fetch?: TransferrableResponse, error?: Error, } +type BroadcastMessage = { + type: "getPath" | "path", + path?: string, +} + export class WorkerConnection { + channel: BroadcastChannel; port: MessagePort | Promise; constructor(workerPath?: string) { + this.channel = new BroadcastChannel("bare-mux"); // @ts-expect-error if (self.clients) { + // running in a ServiceWorker + // ask a window for the worker port // @ts-expect-error const clients: Promise = self.clients.matchAll({ type: "window", includeUncontrolled: true }); this.port = clients.then(clients => Promise.any(clients.map((x: SWClient) => tryGetPort(x)))); } else if (workerPath && SharedWorker) { + // running in a window, was passed a workerPath + // create the SharedWorker and help other bare-mux clients get the workerPath navigator.serviceWorker.addEventListener("message", event => { if (event.data.type === "getPort" && event.data.port) { const worker = new SharedWorker(workerPath); @@ -51,18 +69,41 @@ export class WorkerConnection { } }); + this.channel.onmessage = (event: MessageEvent) => { + if (event.data.type === "getPath") { + this.channel.postMessage({ type: "path", path: workerPath }); + } + } + const worker = new SharedWorker(workerPath, "bare-mux-worker"); this.port = worker.port; + } else if (SharedWorker) { + // running in a window, was not passed a workerPath + // ask other bare-mux clients for the workerPath + this.port = new Promise(resolve => { + this.channel.onmessage = (event: MessageEvent) => { + if (event.data.type === "path") { + const worker = new SharedWorker(event.data.path, "bare-mux-worker"); + this.channel.onmessage = (event: MessageEvent) => { + if (event.data.type === "getPath") { + this.channel.postMessage({ type: "path", path: event.data.path }); + } + } + resolve(worker.port); + } + } + this.channel.postMessage({ type: "getPath" }); + }); } else { - throw new Error("workerPath was not passed or SharedWorker does not exist and am not running in a Service Worker."); + // SharedWorker does not exist + throw new Error("Unable to get a channel to the SharedWorker."); } } - async sendMessage(message: WorkerMessage): Promise { + async sendMessage(message: WorkerMessage, transferable?: Transferable[]): Promise { if (this.port instanceof Promise) this.port = await this.port; let channel = new MessageChannel(); - let toTransfer: Transferable[] = [channel.port2]; - if (message.fetch && message.fetch.body) toTransfer.push(message.fetch.body); + let toTransfer: Transferable[] = [channel.port2, ...(transferable || [])]; this.port.postMessage({ message: message, port: channel.port2 }, toTransfer); diff --git a/src/worker.ts b/src/worker.ts index 3e6c99d..8508d71 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -1,5 +1,5 @@ import { BareTransport } from "./baretypes"; -import { WorkerMessage } from "./connection" +import { WorkerMessage, WorkerResponse } from "./connection" let currentTransport: BareTransport | null = null; @@ -11,7 +11,7 @@ function handleConnection(port: MessagePort) { const func = new Function(message.client); currentTransport = await func(); console.log("set transport to ", currentTransport); - port.postMessage({ type: "set" }); + port.postMessage({ type: "set" }); } else if (message.type === "fetch") { try { if (!currentTransport) throw new Error("No BareTransport was set. Try creating a BareMuxConnection and calling set() on it."); @@ -19,18 +19,59 @@ function handleConnection(port: MessagePort) { const resp = await currentTransport.request( new URL(message.fetch.remote), message.fetch.method, - message.fetch.body, + message.fetchBody, message.fetch.headers, null ); if (resp.body instanceof ReadableStream || resp.body instanceof ArrayBuffer) { - port.postMessage({ type: "fetch", fetch: resp }, [resp.body]); + port.postMessage({ type: "fetch", fetch: resp }, [resp.body]); } else { - port.postMessage({ type: "fetch", fetch: resp }); + port.postMessage({ type: "fetch", fetch: resp }); } } catch (err) { - port.postMessage({ type: "error", error: err }); + port.postMessage({ type: "error", error: err }); + } + } else if (message.type === "websocket") { + try { + if (!currentTransport) throw new Error("No BareTransport was set. Try creating a BareMuxConnection and calling set() on it."); + if (!currentTransport.ready) await currentTransport.init(); + const onopen = (protocol: string) => { + message.websocketChannel.postMessage({ type: "open", args: [protocol] }); + }; + const onclose = (code: number, reason: string) => { + message.websocketChannel.postMessage({ type: "close", args: [code, reason] }); + }; + const onerror = (error: string) => { + message.websocketChannel.postMessage({ type: "error", args: [error] }); + }; + const onmessage = (data: Blob | ArrayBuffer | string) => { + if (data instanceof ArrayBuffer) { + message.websocketChannel.postMessage({ type: "message", args: [data] }, [data]); + } else { + message.websocketChannel.postMessage({ type: "message", args: [data] }); + } + } + const [data, close] = currentTransport.connect( + new URL(message.websocket.url), + message.websocket.origin, + message.websocket.protocols, + message.websocket.requestHeaders, + onopen, + onmessage, + onclose, + onerror, + ); + message.websocketChannel.onmessage = (event: MessageEvent) => { + if (event.data.type === "data") { + data(event.data.data); + } else if (event.data.type === "close") { + close(event.data.closeCode, event.data.closeReason); + } + } + port.postMessage({ type: "websocket" }); + } catch (err) { + port.postMessage({ type: "error", error: err }); } } }