From d011d73a800870ce4057f39d052de752783d5367 Mon Sep 17 00:00:00 2001 From: velzie Date: Fri, 2 Aug 2024 22:44:16 -0400 Subject: [PATCH] fix postmessage shit --- src/client/client.ts | 4 ++-- src/client/dom/serviceworker.ts | 4 +++- src/client/index.ts | 4 +++- src/client/swruntime.ts | 38 ++++++++++++++++----------------- src/worker/fakesw.ts | 12 +++++++---- src/worker/fetch.ts | 18 +++++++++------- 6 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/client/client.ts b/src/client/client.ts index 3656b34..316d180 100644 --- a/src/client/client.ts +++ b/src/client/client.ts @@ -46,7 +46,7 @@ export class ScramjetClient { windowProxy: any; locationProxy: any; - eventcallbacks: WeakMap< + eventcallbacks: Map< any, [ { @@ -55,7 +55,7 @@ export class ScramjetClient { proxiedCallback: AnyFunction; }, ] - > = new WeakMap(); + > = new Map(); constructor(public global: typeof globalThis) { if ("document" in self) { diff --git a/src/client/dom/serviceworker.ts b/src/client/dom/serviceworker.ts index 1b3a133..6c334f5 100644 --- a/src/client/dom/serviceworker.ts +++ b/src/client/dom/serviceworker.ts @@ -1,5 +1,6 @@ import { encodeUrl } from "../shared"; import { ScramjetClient } from "../client"; +import { type MessageC2W } from "../../worker"; // we need a late order because we're mangling with addEventListener at a higher level export const order = 2; @@ -47,7 +48,8 @@ export default function (client: ScramjetClient, self: Self) { { scramjet$type: "registerServiceWorker", port: handle, - }, + origin: client.url.origin, + } as MessageC2W, [handle] ); diff --git a/src/client/index.ts b/src/client/index.ts index 532bc5e..38bd4bb 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -15,7 +15,9 @@ if (!(ScramjetClient.SCRAMJET in self)) { const client = new ScramjetClient(self); client.hook(); - if (issw) { + if ( + new URL(self.location.href).searchParams.get("dest") === "serviceworker" + ) { const runtime = new ScramjetServiceWorkerRuntime(client); runtime.hook(); } diff --git a/src/client/swruntime.ts b/src/client/swruntime.ts index 7c73c9a..75b6d8c 100644 --- a/src/client/swruntime.ts +++ b/src/client/swruntime.ts @@ -3,32 +3,30 @@ import { encodeUrl } from "./shared"; export class ScramjetServiceWorkerRuntime { constructor(public client: ScramjetClient) { - addEventListener("connect", (cevent: MessageEvent) => { + // @ts-ignore + self.onconnect = (cevent: MessageEvent) => { const port = cevent.ports[0]; port.addEventListener("message", (event) => { if ("scramjet$type" in event.data) { - handleMessage(client, event.data, port); + handleMessage(client, event.data); } }); port.start(); - }); + }; } hook() {} } -function handleMessage( - client: ScramjetClient, - data: MessageW2R, - port: MessagePort -) { +function handleMessage(client: ScramjetClient, data: MessageW2R) { + const port = data.scramjet$port; const type = data.scramjet$type; const token = data.scramjet$token; if (type === "fetch") { - const fetchhandlers = client.eventcallbacks.get("fetch"); + const fetchhandlers = client.eventcallbacks.get(self); if (!fetchhandlers) return; for (const handler of fetchhandlers) { @@ -37,24 +35,21 @@ function handleMessage( body: request.body, headers: new Headers(request.headers), method: request.method, - mode: request.mode, + mode: "same-origin", }); Object.defineProperty(fakeRequest, "destination", { value: request.destinitation, }); - const fakeFetchEvent = new FetchEvent("fetch", { - request: fakeRequest, - }); - + // TODO: clean up, maybe put into a class + const fakeFetchEvent: any = new Event("fetch"); + fakeFetchEvent.request = fakeRequest; fakeFetchEvent.respondWith = async ( response: Response | Promise ) => { response = await response; - - response.body; - port.postMessage({ + const message: MessageR2W = { scramjet$type: "fetch", scramjet$token: token, scramjet$response: { @@ -63,7 +58,9 @@ function handleMessage( status: response.status, statusText: response.statusText, }, - } as MessageR2W); + }; + + port.postMessage(message, [response.body]); }; handler.proxiedCallback(trustEvent(fakeFetchEvent)); @@ -76,7 +73,7 @@ function trustEvent(event: Event): Event { get(target, prop, reciever) { if (prop === "isTrusted") return true; - return Reflect.get(target, prop, reciever); + return Reflect.get(target, prop); }, }); } @@ -118,4 +115,5 @@ type MessageCommon = { }; export type MessageR2W = MessageCommon & MessageTypeR2W; -export type MessageW2R = MessageCommon & MessageTypeW2R; +export type MessageW2R = MessageCommon & + MessageTypeW2R & { scramjet$port: MessagePort }; diff --git a/src/worker/fakesw.ts b/src/worker/fakesw.ts index eabcb5c..ad23fc2 100644 --- a/src/worker/fakesw.ts +++ b/src/worker/fakesw.ts @@ -3,18 +3,18 @@ import { type MessageW2R, type MessageR2W } from "../client/swruntime"; export class FakeServiceWorker { syncToken = 0; promises: Record void> = {}; + messageChannel = new MessageChannel(); constructor( public handle: MessagePort, public origin: string ) { - this.handle.start(); - - this.handle.addEventListener("message", (event) => { + this.messageChannel.port1.addEventListener("message", (event) => { if ("scramjet$type" in event.data) { this.handleMessage(event.data); } }); + this.messageChannel.port1.start(); } handleMessage(data: MessageR2W) { @@ -31,6 +31,7 @@ export class FakeServiceWorker { const message: MessageW2R = { scramjet$type: "fetch", scramjet$token: token, + scramjet$port: this.messageChannel.port2, scramjet$request: { url: request.url, body: request.body, @@ -41,7 +42,10 @@ export class FakeServiceWorker { }, }; - this.handle.postMessage(message); + const transfer: any = request.body ? [request.body] : []; + transfer.push(this.messageChannel.port2); + + this.handle.postMessage(message, transfer); const { scramjet$response: r } = (await new Promise((resolve) => { this.promises[token] = resolve; diff --git a/src/worker/fetch.ts b/src/worker/fetch.ts index b55e57e..c696f91 100644 --- a/src/worker/fetch.ts +++ b/src/worker/fetch.ts @@ -3,6 +3,7 @@ import IDBMap from "@webreflection/idb-map"; import { ParseResultType } from "parse-domain"; import { ScramjetServiceWorker } from "."; import { renderError } from "./error"; +import { FakeServiceWorker } from "./fakesw"; const { encodeUrl, decodeUrl } = self.$scramjet.shared.url; const { rewriteHeaders, rewriteHtml, rewriteJs, rewriteCss, rewriteWorkers } = @@ -35,14 +36,6 @@ export async function swfetch( }); } - const activeWorker = this.serviceWorkers.find( - (w) => w.origin === new URL(request.url).origin - ); - if (activeWorker) { - // TODO: check scope - return await activeWorker.fetch(request); - } - const urlParam = new URLSearchParams(new URL(request.url).search); if (urlParam.has("url")) { @@ -53,6 +46,15 @@ export async function swfetch( try { const url = new URL(decodeUrl(request.url)); + + const activeWorker: FakeServiceWorker | null = this.serviceWorkers.find( + (w) => w.origin === url.origin + ); + + if (activeWorker) { + // TODO: check scope + return await activeWorker.fetch(request); + } if (url.origin == new URL(request.url).origin) { throw new Error( "attempted to fetch from same origin - this means the site has obtained a reference to the real origin, aborting"