impl client/server version handshake

This commit is contained in:
Spencer Pogorzelski 2023-08-18 14:50:15 -07:00
parent edd83f54f9
commit ba77d77036
9 changed files with 114 additions and 10 deletions

View file

@ -2,11 +2,15 @@ import {
C2SRequestType,
C2SRequestTypes,
C2SWSOpenPayload,
C2S_HELLO,
HTTPRequestPayload,
HTTPResponsePayload,
PROTOCOL_VERSION,
ProtoBareHeaders,
S2CRequestType,
S2CRequestTypes,
S2C_HELLO_ERR,
S2C_HELLO_OK,
Transport,
WSClosePayload,
WSErrorPayload,
@ -33,6 +37,7 @@ type OpenWSMeta = {
};
export class Connection {
initialized = false;
requestCallbacks: Record<number, Function> = {};
openRequestStreams: Record<number, ReadableStreamDefaultController<any>> = {};
openingSockets: Record<number, OpenWSMeta> = {};
@ -40,8 +45,37 @@ export class Connection {
counter: number = 0;
static uninitializedError() {
throw new Error("Connection not initialized");
}
constructor(public transport: Transport) {
transport.ondata = this.ondata.bind(this);
transport.ondata = Connection.uninitializedError;
}
async initialize(): Promise<void> {
const onDataPromise = (): Promise<ArrayBuffer> => {
return new Promise((res) => {
this.transport.ondata = res;
});
};
// maybe some sort of timeout here?
// this code is not the best tbh
this.transport.send(new TextEncoder().encode(C2S_HELLO + PROTOCOL_VERSION));
const msg = await onDataPromise();
const msgText = new TextDecoder().decode(msg);
if (msgText === S2C_HELLO_OK) {
this.transport.ondata = this.ondata.bind(this);
this.initialized = true;
} else if (msgText.startsWith(S2C_HELLO_ERR)) {
const expectedVersion = msgText.slice(S2C_HELLO_ERR.length);
throw new Error(
`We are running protocol version ${PROTOCOL_VERSION}, ` +
`but server expected ${expectedVersion}`
);
} else {
throw new Error("Unexpected server hello response");
}
}
nextSeq() {
@ -49,6 +83,8 @@ export class Connection {
}
ondata(data: ArrayBuffer) {
if (!this.initialized) return;
let cursor = 0;
const view = new DataView(data);
@ -148,6 +184,10 @@ export class Connection {
type: C2SRequestType,
data?: ArrayBuffer | Blob
): Promise<void> {
if (!this.initialized) {
Connection.uninitializedError();
}
let header = new window.ArrayBuffer(2 + 1);
let view = new DataView(header);
@ -174,6 +214,10 @@ export class Connection {
},
body: ReadableStream<ArrayBuffer | Uint8Array> | null
): Promise<{ payload: HTTPResponsePayload; body: ArrayBuffer }> {
if (!this.initialized) {
Connection.uninitializedError();
}
const payload: HTTPRequestPayload = { ...data, hasBody: Boolean(body) };
let json = JSON.stringify(payload);
@ -202,11 +246,15 @@ export class Connection {
onclose: (code: number, reason: string, wasClean: boolean) => void,
onmessage: (data: any) => void,
onerror: (message: string) => void,
arrayBufferImpl: ArrayBufferConstructor,
arrayBufferImpl: ArrayBufferConstructor
): {
send: (data: any) => void;
close: (code?: number, reason?: string) => void;
} {
if (!this.initialized) {
Connection.uninitializedError();
}
const payload: C2SWSOpenPayload = { url: url.toString(), protocols };
const payloadJSON = JSON.stringify(payload);
let seq = this.nextSeq();

View file

@ -24,4 +24,8 @@ export class DevWsTransport extends Transport {
send(data: ArrayBuffer) {
this.ws.send(data);
}
close() {
this.ws.close();
}
}

View file

@ -8,11 +8,9 @@ const rtcConf = {
],
};
export type Offer = { offer: any; localCandidates: any };
export type Answer = { answer: any; candidates: any };
export class RTCTransport extends Transport {
peer: RTCPeerConnection;
@ -57,6 +55,10 @@ export class RTCTransport extends Transport {
this.dataChannel.send(data);
}
close() {
this.dataChannel.close();
}
async createOffer(): Promise<Promise<Offer>> {
const localCandidates: RTCIceCandidate[] = [];

View file

@ -52,10 +52,12 @@
let showTrackerList = false;
function onTransportOpen() {
async function onTransportOpen() {
console.log("Transport opened");
let connection = new Connection(transport);
// TODO: error handling here
await connection.initialize();
let bare = new AdriftBareClient(connection);
console.log(setBareClientImplementation);
setBareClientImplementation(bare);

View file

@ -3,4 +3,5 @@ export abstract class Transport {
constructor(public onopen: () => void, public onclose: () => void) {}
abstract send(data: ArrayBuffer): void;
abstract close(): void;
}

View file

@ -55,4 +55,11 @@ export type WSErrorPayload = {
// WebRTC max is 16K, let's say 12K to be safe
export const MAX_CHUNK_SIZE = 12 * 1024;
export const S2C_HELLO_OK = ":3";
// these two end with a version string
export const C2S_HELLO = "haiii ";
export const S2C_HELLO_ERR = ":< ";
export const PROTOCOL_VERSION = "1.0";
export { Transport } from "./Transport";

View file

@ -31,7 +31,10 @@ app.post("/connect", (req, res) => {
app.ws("/dev-ws", (ws, _req) => {
console.log("ws connect");
const client = new AdriftServer((msg) => ws.send(msg));
const client = new AdriftServer(
(msg) => ws.send(msg),
() => ws.close()
);
ws.on("message", (msg) => {
if (typeof msg === "string") {

View file

@ -75,9 +75,12 @@ export async function answerRtc(data: any, onrespond: (answer: any) => void) {
dataChannel.onopen = () => {
console.log("opened");
server = new AdriftServer((msg) => {
if (dataChannel.readyState === "open") dataChannel.send(msg);
});
server = new AdriftServer(
(msg) => {
if (dataChannel.readyState === "open") dataChannel.send(msg);
},
() => dataChannel.close()
);
};
dataChannel.onclose = () => {
console.log("closed");

View file

@ -3,12 +3,16 @@ import { IncomingMessage, STATUS_CODES } from "http";
import { WebSocket } from "isomorphic-ws";
import {
C2SRequestTypes,
C2S_HELLO,
HTTPRequestPayload,
HTTPResponsePayload,
MAX_CHUNK_SIZE,
PROTOCOL_VERSION,
ProtoBareHeaders,
S2CRequestType,
S2CRequestTypes,
S2C_HELLO_ERR,
S2C_HELLO_OK,
WSClosePayload,
WSErrorPayload,
} from "protocol";
@ -32,16 +36,41 @@ function bareErrorToResponse(e: BareError): {
}
export class AdriftServer {
initialized: boolean = false;
send: (msg: ArrayBuffer) => void;
close: () => void;
requestStreams: Record<number, Promise<Writable>> = {};
sockets: Record<number, WebSocket> = {};
events: EventEmitter;
constructor(send: (msg: ArrayBuffer) => void) {
constructor(send: (msg: ArrayBuffer) => void, close: () => void) {
this.send = send;
this.close = close;
this.events = new EventEmitter();
}
handleHello(msg: ArrayBuffer) {
try {
const text = new TextDecoder().decode(msg);
if (!text.startsWith(C2S_HELLO)) {
this.close();
return;
}
// later if we want we can supported multiple versions and run different behavior based
// on which we are talking to, might be too much effort idk
const version = text.slice(C2S_HELLO.length);
if (version === PROTOCOL_VERSION) {
this.send(new TextEncoder().encode(S2C_HELLO_OK));
this.initialized = true;
} else {
this.send(new TextEncoder().encode(S2C_HELLO_ERR + PROTOCOL_VERSION));
this.close();
}
} catch (_) {
this.close();
}
}
static parseMsgInit(
msg: ArrayBuffer
): { cursor: number; seq: number; op: number } | undefined {
@ -207,6 +236,11 @@ export class AdriftServer {
}
async onMsg(msg: ArrayBuffer) {
if (!this.initialized) {
this.handleHello(msg);
return;
}
const init = AdriftServer.parseMsgInit(msg);
if (!init) return;
const { cursor, seq, op } = init;