From 3a43dfcc15d90f8d7b2ea0ac115d95c7d81086fa Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Tue, 23 Jul 2024 18:56:21 -0700 Subject: [PATCH] remote transport --- src/client.ts | 37 ++++++++++++++- src/worker.ts | 102 +++++++++++++++--------------------------- src/workerHandlers.ts | 66 +++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 68 deletions(-) create mode 100644 src/workerHandlers.ts diff --git a/src/client.ts b/src/client.ts index fa7096a..544a298 100644 --- a/src/client.ts +++ b/src/client.ts @@ -1,6 +1,7 @@ -import { BareHeaders, maxRedirects } from './baretypes'; +import { BareHeaders, BareTransport, maxRedirects } from './baretypes'; import { WorkerConnection, WorkerMessage } from './connection'; import { WebSocketFields } from './snapshot'; +import { handleFetch, handleWebsocket, sendError } from './workerHandlers'; const validChars = "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"; @@ -113,6 +114,7 @@ export class BareMuxConnection { } async setManualTransport(functionBody: string, options: any[], transferables?: Transferable[]) { + if (functionBody === "bare-mux-remote") throw new Error("Use setRemoteTransport."); await this.worker.sendMessage({ type: "set", client: { @@ -122,6 +124,39 @@ export class BareMuxConnection { }, transferables); } + async setRemoteTransport(transport: BareTransport, name: string) { + const channel = new MessageChannel(); + + channel.port1.onmessage = async (event: MessageEvent) => { + const port = event.data.port; + const message: WorkerMessage = event.data.message; + + if (message.type === "fetch") { + try { + if (!transport.ready) await transport.init(); + await handleFetch(message, port, transport); + } catch (err) { + sendError(port, err, "fetch"); + } + } else if (message.type === "websocket") { + try { + if (!transport.ready) await transport.init(); + await handleWebsocket(message, port, transport); + } catch (err) { + sendError(port, err, "websocket"); + } + } + } + + await this.worker.sendMessage({ + type: "set", + client: { + function: "bare-mux-remote", + args: [channel.port2, name] + }, + }, [channel.port2]); + } + getInnerPort(): MessagePort | Promise { return this.worker.port; } diff --git a/src/worker.ts b/src/worker.ts index d894fbf..d542184 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -1,7 +1,8 @@ import { BareTransport } from "./baretypes"; -import { BroadcastMessage, WorkerMessage, WorkerResponse, browserSupportsTransferringStreams } from "./connection" +import { BroadcastMessage, WorkerMessage, WorkerRequest, WorkerResponse } from "./connection" +import { handleFetch, handleWebsocket, sendError } from "./workerHandlers"; -let currentTransport: BareTransport | null = null; +let currentTransport: BareTransport | MessagePort | null = null; let currentTransportName: string = ""; const channel = new BroadcastChannel("bare-mux"); @@ -14,6 +15,14 @@ function noClients(): Error { }); } +function handleRemoteClient(message: WorkerMessage, port: MessagePort) { + const remote = currentTransport as MessagePort; + let transferables: Transferable[] = [port]; + if (message.fetch?.body) transferables.push(message.fetch.body); + if (message.websocket?.channel) transferables.push(message.websocket.channel); + remote.postMessage({ message, port }, transferables); +} + function handleConnection(port: MessagePort) { port.onmessage = async (event: MessageEvent) => { const port = event.data.port; @@ -24,90 +33,49 @@ function handleConnection(port: MessagePort) { try { const AsyncFunction = (async function() { }).constructor; - // @ts-expect-error - const func = new AsyncFunction(message.client.function); - const [newTransport, name] = await func(); - currentTransport = new newTransport(...message.client.args); - currentTransportName = name; - console.log("set transport to ", currentTransport, name); + if (message.client.function === "bare-mux-remote") { + currentTransport = message.client.args[0] as MessagePort; + currentTransportName = `bare-mux-remote (${message.client.args[1]})`; + } else { + // @ts-expect-error + const func = new AsyncFunction(message.client.function); + const [newTransport, name] = await func(); + currentTransport = new newTransport(...message.client.args); + currentTransportName = name; + } + console.log("set transport to ", currentTransport, currentTransportName); port.postMessage({ type: "set" }); } catch (err) { - console.error("error while processing 'set': ", err); - port.postMessage({ type: "error", error: err }); + sendError(port, err, 'set'); } } else if (message.type === "get") { port.postMessage({ type: "get", name: currentTransportName }); } else if (message.type === "fetch") { try { if (!currentTransport) throw noClients(); + if (currentTransport instanceof MessagePort) { + handleRemoteClient(message, port); + return; + } if (!currentTransport.ready) await currentTransport.init(); - const resp = await currentTransport.request( - new URL(message.fetch.remote), - message.fetch.method, - message.fetch.body, - message.fetch.headers, - null - ); - - if (!browserSupportsTransferringStreams() && resp.body instanceof ReadableStream) { - const conversionResp = new Response(resp.body); - resp.body = await conversionResp.arrayBuffer(); - } - - if (resp.body instanceof ReadableStream || resp.body instanceof ArrayBuffer) { - port.postMessage({ type: "fetch", fetch: resp }, [resp.body]); - } else { - port.postMessage({ type: "fetch", fetch: resp }); - } + await handleFetch(message, port, currentTransport); } catch (err) { - console.error("error while processing 'fetch': ", err); - port.postMessage({ type: "error", error: err }); + sendError(port, err, 'fetch'); } } else if (message.type === "websocket") { try { if (!currentTransport) throw noClients(); + if (currentTransport instanceof MessagePort) { + handleRemoteClient(message, port); + return; + } if (!currentTransport.ready) await currentTransport.init(); - const onopen = (protocol: string) => { - message.websocket.channel.postMessage({ type: "open", args: [protocol] }); - }; - const onclose = (code: number, reason: string) => { - message.websocket.channel.postMessage({ type: "close", args: [code, reason] }); - }; - const onerror = (error: string) => { - message.websocket.channel.postMessage({ type: "error", args: [error] }); - }; - const onmessage = (data: Blob | ArrayBuffer | string) => { - if (data instanceof ArrayBuffer) { - message.websocket.channel.postMessage({ type: "message", args: [data] }, [data]); - } else { - message.websocket.channel.postMessage({ type: "message", args: [data] }); - } - } - const [data, close] = currentTransport.connect( - new URL(message.websocket.url), - message.websocket.origin, - message.websocket.protocols, - message.websocket.requestHeaders, - onopen, - onmessage, - onclose, - onerror, - ); - message.websocket.channel.onmessage = (event: MessageEvent) => { - if (event.data.type === "data") { - data(event.data.data); - } else if (event.data.type === "close") { - close(event.data.closeCode, event.data.closeReason); - } - } - - port.postMessage({ type: "websocket" }); + await handleWebsocket(message, port, currentTransport); } catch (err) { - console.error("error while processing 'websocket': ", err); - port.postMessage({ type: "error", error: err }); + sendError(port, err, 'websocket'); } } } diff --git a/src/workerHandlers.ts b/src/workerHandlers.ts new file mode 100644 index 0000000..0f0f845 --- /dev/null +++ b/src/workerHandlers.ts @@ -0,0 +1,66 @@ +import { BareTransport } from "./baretypes"; +import { browserSupportsTransferringStreams, WorkerMessage, WorkerResponse } from "./connection"; + +export function sendError(port: MessagePort, err: Error, name: string) { + console.error(`error while processing '${name}': `, err); + port.postMessage({ type: "error", error: err }); +} + +export async function handleFetch(message: WorkerMessage, port: MessagePort, transport: BareTransport) { + const resp = await transport.request( + new URL(message.fetch.remote), + message.fetch.method, + message.fetch.body, + message.fetch.headers, + null + ); + + if (!browserSupportsTransferringStreams() && resp.body instanceof ReadableStream) { + const conversionResp = new Response(resp.body); + resp.body = await conversionResp.arrayBuffer(); + } + + if (resp.body instanceof ReadableStream || resp.body instanceof ArrayBuffer) { + port.postMessage({ type: "fetch", fetch: resp }, [resp.body]); + } else { + port.postMessage({ type: "fetch", fetch: resp }); + } +} + +export async function handleWebsocket(message: WorkerMessage, port: MessagePort, transport: BareTransport) { + const onopen = (protocol: string) => { + message.websocket.channel.postMessage({ type: "open", args: [protocol] }); + }; + const onclose = (code: number, reason: string) => { + message.websocket.channel.postMessage({ type: "close", args: [code, reason] }); + }; + const onerror = (error: string) => { + message.websocket.channel.postMessage({ type: "error", args: [error] }); + }; + const onmessage = (data: Blob | ArrayBuffer | string) => { + if (data instanceof ArrayBuffer) { + message.websocket.channel.postMessage({ type: "message", args: [data] }, [data]); + } else { + message.websocket.channel.postMessage({ type: "message", args: [data] }); + } + } + const [data, close] = transport.connect( + new URL(message.websocket.url), + message.websocket.origin, + message.websocket.protocols, + message.websocket.requestHeaders, + onopen, + onmessage, + onclose, + onerror, + ); + message.websocket.channel.onmessage = (event: MessageEvent) => { + if (event.data.type === "data") { + data(event.data.data); + } else if (event.data.type === "close") { + close(event.data.closeCode, event.data.closeReason); + } + } + + port.postMessage({ type: "websocket" }); +}