From ba77d77036430a8c2376ed9f06014ae1f9ba0c33 Mon Sep 17 00:00:00 2001 From: Spencer Pogorzelski <34356756+Scoder12@users.noreply.github.com> Date: Fri, 18 Aug 2023 14:50:15 -0700 Subject: [PATCH] impl client/server version handshake --- client/src/Connection.ts | 52 ++++++++++++++++++++++++++++++++++-- client/src/DevWsTransport.ts | 4 +++ client/src/RTCTransport.ts | 6 +++-- frontend/src/App.svelte | 4 ++- protocol/src/Transport.ts | 1 + protocol/src/index.ts | 7 +++++ server/src/dev.ts | 5 +++- server/src/rtc.ts | 9 ++++--- server/src/server.ts | 36 ++++++++++++++++++++++++- 9 files changed, 114 insertions(+), 10 deletions(-) diff --git a/client/src/Connection.ts b/client/src/Connection.ts index bac6529..3e3bceb 100644 --- a/client/src/Connection.ts +++ b/client/src/Connection.ts @@ -2,11 +2,15 @@ import { C2SRequestType, C2SRequestTypes, C2SWSOpenPayload, + C2S_HELLO, HTTPRequestPayload, HTTPResponsePayload, + PROTOCOL_VERSION, ProtoBareHeaders, S2CRequestType, S2CRequestTypes, + S2C_HELLO_ERR, + S2C_HELLO_OK, Transport, WSClosePayload, WSErrorPayload, @@ -33,6 +37,7 @@ type OpenWSMeta = { }; export class Connection { + initialized = false; requestCallbacks: Record = {}; openRequestStreams: Record> = {}; openingSockets: Record = {}; @@ -40,8 +45,37 @@ export class Connection { counter: number = 0; + static uninitializedError() { + throw new Error("Connection not initialized"); + } + constructor(public transport: Transport) { - transport.ondata = this.ondata.bind(this); + transport.ondata = Connection.uninitializedError; + } + + async initialize(): Promise { + const onDataPromise = (): Promise => { + return new Promise((res) => { + this.transport.ondata = res; + }); + }; + // maybe some sort of timeout here? + // this code is not the best tbh + this.transport.send(new TextEncoder().encode(C2S_HELLO + PROTOCOL_VERSION)); + const msg = await onDataPromise(); + const msgText = new TextDecoder().decode(msg); + if (msgText === S2C_HELLO_OK) { + this.transport.ondata = this.ondata.bind(this); + this.initialized = true; + } else if (msgText.startsWith(S2C_HELLO_ERR)) { + const expectedVersion = msgText.slice(S2C_HELLO_ERR.length); + throw new Error( + `We are running protocol version ${PROTOCOL_VERSION}, ` + + `but server expected ${expectedVersion}` + ); + } else { + throw new Error("Unexpected server hello response"); + } } nextSeq() { @@ -49,6 +83,8 @@ export class Connection { } ondata(data: ArrayBuffer) { + if (!this.initialized) return; + let cursor = 0; const view = new DataView(data); @@ -148,6 +184,10 @@ export class Connection { type: C2SRequestType, data?: ArrayBuffer | Blob ): Promise { + if (!this.initialized) { + Connection.uninitializedError(); + } + let header = new window.ArrayBuffer(2 + 1); let view = new DataView(header); @@ -174,6 +214,10 @@ export class Connection { }, body: ReadableStream | null ): Promise<{ payload: HTTPResponsePayload; body: ArrayBuffer }> { + if (!this.initialized) { + Connection.uninitializedError(); + } + const payload: HTTPRequestPayload = { ...data, hasBody: Boolean(body) }; let json = JSON.stringify(payload); @@ -202,11 +246,15 @@ export class Connection { onclose: (code: number, reason: string, wasClean: boolean) => void, onmessage: (data: any) => void, onerror: (message: string) => void, - arrayBufferImpl: ArrayBufferConstructor, + arrayBufferImpl: ArrayBufferConstructor ): { send: (data: any) => void; close: (code?: number, reason?: string) => void; } { + if (!this.initialized) { + Connection.uninitializedError(); + } + const payload: C2SWSOpenPayload = { url: url.toString(), protocols }; const payloadJSON = JSON.stringify(payload); let seq = this.nextSeq(); diff --git a/client/src/DevWsTransport.ts b/client/src/DevWsTransport.ts index 458683c..3f705a1 100644 --- a/client/src/DevWsTransport.ts +++ b/client/src/DevWsTransport.ts @@ -24,4 +24,8 @@ export class DevWsTransport extends Transport { send(data: ArrayBuffer) { this.ws.send(data); } + + close() { + this.ws.close(); + } } diff --git a/client/src/RTCTransport.ts b/client/src/RTCTransport.ts index 9b714ce..d9273dd 100644 --- a/client/src/RTCTransport.ts +++ b/client/src/RTCTransport.ts @@ -8,11 +8,9 @@ const rtcConf = { ], }; - export type Offer = { offer: any; localCandidates: any }; export type Answer = { answer: any; candidates: any }; - export class RTCTransport extends Transport { peer: RTCPeerConnection; @@ -57,6 +55,10 @@ export class RTCTransport extends Transport { this.dataChannel.send(data); } + close() { + this.dataChannel.close(); + } + async createOffer(): Promise> { const localCandidates: RTCIceCandidate[] = []; diff --git a/frontend/src/App.svelte b/frontend/src/App.svelte index 7ff0473..75f8e71 100644 --- a/frontend/src/App.svelte +++ b/frontend/src/App.svelte @@ -52,10 +52,12 @@ let showTrackerList = false; - function onTransportOpen() { + async function onTransportOpen() { console.log("Transport opened"); let connection = new Connection(transport); + // TODO: error handling here + await connection.initialize(); let bare = new AdriftBareClient(connection); console.log(setBareClientImplementation); setBareClientImplementation(bare); diff --git a/protocol/src/Transport.ts b/protocol/src/Transport.ts index 03d3fe1..f685292 100644 --- a/protocol/src/Transport.ts +++ b/protocol/src/Transport.ts @@ -3,4 +3,5 @@ export abstract class Transport { constructor(public onopen: () => void, public onclose: () => void) {} abstract send(data: ArrayBuffer): void; + abstract close(): void; } diff --git a/protocol/src/index.ts b/protocol/src/index.ts index 80e1c2c..0247292 100644 --- a/protocol/src/index.ts +++ b/protocol/src/index.ts @@ -55,4 +55,11 @@ export type WSErrorPayload = { // WebRTC max is 16K, let's say 12K to be safe export const MAX_CHUNK_SIZE = 12 * 1024; +export const S2C_HELLO_OK = ":3"; +// these two end with a version string +export const C2S_HELLO = "haiii "; +export const S2C_HELLO_ERR = ":< "; + +export const PROTOCOL_VERSION = "1.0"; + export { Transport } from "./Transport"; diff --git a/server/src/dev.ts b/server/src/dev.ts index eb836c6..6f589c3 100644 --- a/server/src/dev.ts +++ b/server/src/dev.ts @@ -31,7 +31,10 @@ app.post("/connect", (req, res) => { app.ws("/dev-ws", (ws, _req) => { console.log("ws connect"); - const client = new AdriftServer((msg) => ws.send(msg)); + const client = new AdriftServer( + (msg) => ws.send(msg), + () => ws.close() + ); ws.on("message", (msg) => { if (typeof msg === "string") { diff --git a/server/src/rtc.ts b/server/src/rtc.ts index 2e2bbf8..d26ca5e 100644 --- a/server/src/rtc.ts +++ b/server/src/rtc.ts @@ -75,9 +75,12 @@ export async function answerRtc(data: any, onrespond: (answer: any) => void) { dataChannel.onopen = () => { console.log("opened"); - server = new AdriftServer((msg) => { - if (dataChannel.readyState === "open") dataChannel.send(msg); - }); + server = new AdriftServer( + (msg) => { + if (dataChannel.readyState === "open") dataChannel.send(msg); + }, + () => dataChannel.close() + ); }; dataChannel.onclose = () => { console.log("closed"); diff --git a/server/src/server.ts b/server/src/server.ts index 2ab9a04..a2c15eb 100644 --- a/server/src/server.ts +++ b/server/src/server.ts @@ -3,12 +3,16 @@ import { IncomingMessage, STATUS_CODES } from "http"; import { WebSocket } from "isomorphic-ws"; import { C2SRequestTypes, + C2S_HELLO, HTTPRequestPayload, HTTPResponsePayload, MAX_CHUNK_SIZE, + PROTOCOL_VERSION, ProtoBareHeaders, S2CRequestType, S2CRequestTypes, + S2C_HELLO_ERR, + S2C_HELLO_OK, WSClosePayload, WSErrorPayload, } from "protocol"; @@ -32,16 +36,41 @@ function bareErrorToResponse(e: BareError): { } export class AdriftServer { + initialized: boolean = false; send: (msg: ArrayBuffer) => void; + close: () => void; requestStreams: Record> = {}; sockets: Record = {}; events: EventEmitter; - constructor(send: (msg: ArrayBuffer) => void) { + constructor(send: (msg: ArrayBuffer) => void, close: () => void) { this.send = send; + this.close = close; this.events = new EventEmitter(); } + handleHello(msg: ArrayBuffer) { + try { + const text = new TextDecoder().decode(msg); + if (!text.startsWith(C2S_HELLO)) { + this.close(); + return; + } + // later if we want we can supported multiple versions and run different behavior based + // on which we are talking to, might be too much effort idk + const version = text.slice(C2S_HELLO.length); + if (version === PROTOCOL_VERSION) { + this.send(new TextEncoder().encode(S2C_HELLO_OK)); + this.initialized = true; + } else { + this.send(new TextEncoder().encode(S2C_HELLO_ERR + PROTOCOL_VERSION)); + this.close(); + } + } catch (_) { + this.close(); + } + } + static parseMsgInit( msg: ArrayBuffer ): { cursor: number; seq: number; op: number } | undefined { @@ -207,6 +236,11 @@ export class AdriftServer { } async onMsg(msg: ArrayBuffer) { + if (!this.initialized) { + this.handleHello(msg); + return; + } + const init = AdriftServer.parseMsgInit(msg); if (!init) return; const { cursor, seq, op } = init;