impl client/server version handshake

This commit is contained in:
Spencer Pogorzelski 2023-08-18 14:50:15 -07:00
parent edd83f54f9
commit ba77d77036
9 changed files with 114 additions and 10 deletions

View file

@ -2,11 +2,15 @@ import {
C2SRequestType, C2SRequestType,
C2SRequestTypes, C2SRequestTypes,
C2SWSOpenPayload, C2SWSOpenPayload,
C2S_HELLO,
HTTPRequestPayload, HTTPRequestPayload,
HTTPResponsePayload, HTTPResponsePayload,
PROTOCOL_VERSION,
ProtoBareHeaders, ProtoBareHeaders,
S2CRequestType, S2CRequestType,
S2CRequestTypes, S2CRequestTypes,
S2C_HELLO_ERR,
S2C_HELLO_OK,
Transport, Transport,
WSClosePayload, WSClosePayload,
WSErrorPayload, WSErrorPayload,
@ -33,6 +37,7 @@ type OpenWSMeta = {
}; };
export class Connection { export class Connection {
initialized = false;
requestCallbacks: Record<number, Function> = {}; requestCallbacks: Record<number, Function> = {};
openRequestStreams: Record<number, ReadableStreamDefaultController<any>> = {}; openRequestStreams: Record<number, ReadableStreamDefaultController<any>> = {};
openingSockets: Record<number, OpenWSMeta> = {}; openingSockets: Record<number, OpenWSMeta> = {};
@ -40,8 +45,37 @@ export class Connection {
counter: number = 0; counter: number = 0;
static uninitializedError() {
throw new Error("Connection not initialized");
}
constructor(public transport: Transport) { constructor(public transport: Transport) {
transport.ondata = this.ondata.bind(this); transport.ondata = Connection.uninitializedError;
}
async initialize(): Promise<void> {
const onDataPromise = (): Promise<ArrayBuffer> => {
return new Promise((res) => {
this.transport.ondata = res;
});
};
// maybe some sort of timeout here?
// this code is not the best tbh
this.transport.send(new TextEncoder().encode(C2S_HELLO + PROTOCOL_VERSION));
const msg = await onDataPromise();
const msgText = new TextDecoder().decode(msg);
if (msgText === S2C_HELLO_OK) {
this.transport.ondata = this.ondata.bind(this);
this.initialized = true;
} else if (msgText.startsWith(S2C_HELLO_ERR)) {
const expectedVersion = msgText.slice(S2C_HELLO_ERR.length);
throw new Error(
`We are running protocol version ${PROTOCOL_VERSION}, ` +
`but server expected ${expectedVersion}`
);
} else {
throw new Error("Unexpected server hello response");
}
} }
nextSeq() { nextSeq() {
@ -49,6 +83,8 @@ export class Connection {
} }
ondata(data: ArrayBuffer) { ondata(data: ArrayBuffer) {
if (!this.initialized) return;
let cursor = 0; let cursor = 0;
const view = new DataView(data); const view = new DataView(data);
@ -148,6 +184,10 @@ export class Connection {
type: C2SRequestType, type: C2SRequestType,
data?: ArrayBuffer | Blob data?: ArrayBuffer | Blob
): Promise<void> { ): Promise<void> {
if (!this.initialized) {
Connection.uninitializedError();
}
let header = new window.ArrayBuffer(2 + 1); let header = new window.ArrayBuffer(2 + 1);
let view = new DataView(header); let view = new DataView(header);
@ -174,6 +214,10 @@ export class Connection {
}, },
body: ReadableStream<ArrayBuffer | Uint8Array> | null body: ReadableStream<ArrayBuffer | Uint8Array> | null
): Promise<{ payload: HTTPResponsePayload; body: ArrayBuffer }> { ): Promise<{ payload: HTTPResponsePayload; body: ArrayBuffer }> {
if (!this.initialized) {
Connection.uninitializedError();
}
const payload: HTTPRequestPayload = { ...data, hasBody: Boolean(body) }; const payload: HTTPRequestPayload = { ...data, hasBody: Boolean(body) };
let json = JSON.stringify(payload); let json = JSON.stringify(payload);
@ -202,11 +246,15 @@ export class Connection {
onclose: (code: number, reason: string, wasClean: boolean) => void, onclose: (code: number, reason: string, wasClean: boolean) => void,
onmessage: (data: any) => void, onmessage: (data: any) => void,
onerror: (message: string) => void, onerror: (message: string) => void,
arrayBufferImpl: ArrayBufferConstructor, arrayBufferImpl: ArrayBufferConstructor
): { ): {
send: (data: any) => void; send: (data: any) => void;
close: (code?: number, reason?: string) => void; close: (code?: number, reason?: string) => void;
} { } {
if (!this.initialized) {
Connection.uninitializedError();
}
const payload: C2SWSOpenPayload = { url: url.toString(), protocols }; const payload: C2SWSOpenPayload = { url: url.toString(), protocols };
const payloadJSON = JSON.stringify(payload); const payloadJSON = JSON.stringify(payload);
let seq = this.nextSeq(); let seq = this.nextSeq();

View file

@ -24,4 +24,8 @@ export class DevWsTransport extends Transport {
send(data: ArrayBuffer) { send(data: ArrayBuffer) {
this.ws.send(data); this.ws.send(data);
} }
close() {
this.ws.close();
}
} }

View file

@ -8,11 +8,9 @@ const rtcConf = {
], ],
}; };
export type Offer = { offer: any; localCandidates: any }; export type Offer = { offer: any; localCandidates: any };
export type Answer = { answer: any; candidates: any }; export type Answer = { answer: any; candidates: any };
export class RTCTransport extends Transport { export class RTCTransport extends Transport {
peer: RTCPeerConnection; peer: RTCPeerConnection;
@ -57,6 +55,10 @@ export class RTCTransport extends Transport {
this.dataChannel.send(data); this.dataChannel.send(data);
} }
close() {
this.dataChannel.close();
}
async createOffer(): Promise<Promise<Offer>> { async createOffer(): Promise<Promise<Offer>> {
const localCandidates: RTCIceCandidate[] = []; const localCandidates: RTCIceCandidate[] = [];

View file

@ -52,10 +52,12 @@
let showTrackerList = false; let showTrackerList = false;
function onTransportOpen() { async function onTransportOpen() {
console.log("Transport opened"); console.log("Transport opened");
let connection = new Connection(transport); let connection = new Connection(transport);
// TODO: error handling here
await connection.initialize();
let bare = new AdriftBareClient(connection); let bare = new AdriftBareClient(connection);
console.log(setBareClientImplementation); console.log(setBareClientImplementation);
setBareClientImplementation(bare); setBareClientImplementation(bare);

View file

@ -3,4 +3,5 @@ export abstract class Transport {
constructor(public onopen: () => void, public onclose: () => void) {} constructor(public onopen: () => void, public onclose: () => void) {}
abstract send(data: ArrayBuffer): void; abstract send(data: ArrayBuffer): void;
abstract close(): void;
} }

View file

@ -55,4 +55,11 @@ export type WSErrorPayload = {
// WebRTC max is 16K, let's say 12K to be safe // WebRTC max is 16K, let's say 12K to be safe
export const MAX_CHUNK_SIZE = 12 * 1024; export const MAX_CHUNK_SIZE = 12 * 1024;
export const S2C_HELLO_OK = ":3";
// these two end with a version string
export const C2S_HELLO = "haiii ";
export const S2C_HELLO_ERR = ":< ";
export const PROTOCOL_VERSION = "1.0";
export { Transport } from "./Transport"; export { Transport } from "./Transport";

View file

@ -31,7 +31,10 @@ app.post("/connect", (req, res) => {
app.ws("/dev-ws", (ws, _req) => { app.ws("/dev-ws", (ws, _req) => {
console.log("ws connect"); console.log("ws connect");
const client = new AdriftServer((msg) => ws.send(msg)); const client = new AdriftServer(
(msg) => ws.send(msg),
() => ws.close()
);
ws.on("message", (msg) => { ws.on("message", (msg) => {
if (typeof msg === "string") { if (typeof msg === "string") {

View file

@ -75,9 +75,12 @@ export async function answerRtc(data: any, onrespond: (answer: any) => void) {
dataChannel.onopen = () => { dataChannel.onopen = () => {
console.log("opened"); console.log("opened");
server = new AdriftServer((msg) => { server = new AdriftServer(
if (dataChannel.readyState === "open") dataChannel.send(msg); (msg) => {
}); if (dataChannel.readyState === "open") dataChannel.send(msg);
},
() => dataChannel.close()
);
}; };
dataChannel.onclose = () => { dataChannel.onclose = () => {
console.log("closed"); console.log("closed");

View file

@ -3,12 +3,16 @@ import { IncomingMessage, STATUS_CODES } from "http";
import { WebSocket } from "isomorphic-ws"; import { WebSocket } from "isomorphic-ws";
import { import {
C2SRequestTypes, C2SRequestTypes,
C2S_HELLO,
HTTPRequestPayload, HTTPRequestPayload,
HTTPResponsePayload, HTTPResponsePayload,
MAX_CHUNK_SIZE, MAX_CHUNK_SIZE,
PROTOCOL_VERSION,
ProtoBareHeaders, ProtoBareHeaders,
S2CRequestType, S2CRequestType,
S2CRequestTypes, S2CRequestTypes,
S2C_HELLO_ERR,
S2C_HELLO_OK,
WSClosePayload, WSClosePayload,
WSErrorPayload, WSErrorPayload,
} from "protocol"; } from "protocol";
@ -32,16 +36,41 @@ function bareErrorToResponse(e: BareError): {
} }
export class AdriftServer { export class AdriftServer {
initialized: boolean = false;
send: (msg: ArrayBuffer) => void; send: (msg: ArrayBuffer) => void;
close: () => void;
requestStreams: Record<number, Promise<Writable>> = {}; requestStreams: Record<number, Promise<Writable>> = {};
sockets: Record<number, WebSocket> = {}; sockets: Record<number, WebSocket> = {};
events: EventEmitter; events: EventEmitter;
constructor(send: (msg: ArrayBuffer) => void) { constructor(send: (msg: ArrayBuffer) => void, close: () => void) {
this.send = send; this.send = send;
this.close = close;
this.events = new EventEmitter(); this.events = new EventEmitter();
} }
handleHello(msg: ArrayBuffer) {
try {
const text = new TextDecoder().decode(msg);
if (!text.startsWith(C2S_HELLO)) {
this.close();
return;
}
// later if we want we can supported multiple versions and run different behavior based
// on which we are talking to, might be too much effort idk
const version = text.slice(C2S_HELLO.length);
if (version === PROTOCOL_VERSION) {
this.send(new TextEncoder().encode(S2C_HELLO_OK));
this.initialized = true;
} else {
this.send(new TextEncoder().encode(S2C_HELLO_ERR + PROTOCOL_VERSION));
this.close();
}
} catch (_) {
this.close();
}
}
static parseMsgInit( static parseMsgInit(
msg: ArrayBuffer msg: ArrayBuffer
): { cursor: number; seq: number; op: number } | undefined { ): { cursor: number; seq: number; op: number } | undefined {
@ -207,6 +236,11 @@ export class AdriftServer {
} }
async onMsg(msg: ArrayBuffer) { async onMsg(msg: ArrayBuffer) {
if (!this.initialized) {
this.handleHello(msg);
return;
}
const init = AdriftServer.parseMsgInit(msg); const init = AdriftServer.parseMsgInit(msg);
if (!init) return; if (!init) return;
const { cursor, seq, op } = init; const { cursor, seq, op } = init;