diff --git a/crates/bindings-typescript/src/lib/binary_writer.ts b/crates/bindings-typescript/src/lib/binary_writer.ts index a66310745ce..6370e48153c 100644 --- a/crates/bindings-typescript/src/lib/binary_writer.ts +++ b/crates/bindings-typescript/src/lib/binary_writer.ts @@ -93,6 +93,12 @@ export default class BinaryWriter { this.offset += 1; } + writeBytes(value: Uint8Array): void { + this.expandBuffer(value.length); + new Uint8Array(this.buffer.buffer, this.offset, value.length).set(value); + this.offset += value.length; + } + writeI8(value: number): void { this.expandBuffer(1); this.view.setInt8(this.offset, value); diff --git a/crates/bindings-typescript/src/sdk/db_connection_impl.ts b/crates/bindings-typescript/src/sdk/db_connection_impl.ts index b873b30bff6..954954040c8 100644 --- a/crates/bindings-typescript/src/sdk/db_connection_impl.ts +++ b/crates/bindings-typescript/src/sdk/db_connection_impl.ts @@ -1,7 +1,7 @@ import { ConnectionId, ProductBuilder, ProductType } from '../'; import { AlgebraicType, type ComparablePrimitive } from '../'; -import { BinaryReader } from '../'; -import { BinaryWriter } from '../'; +import BinaryReader from '../lib/binary_reader.ts'; +import BinaryWriter from '../lib/binary_writer.ts'; import { BsatnRowList, ClientMessage, @@ -60,6 +60,18 @@ import type { ProceduresView } from './procedures.ts'; import type { Values } from '../lib/type_util.ts'; import type { TransactionUpdate } from './client_api/types.ts'; import { InternalError, SenderError } from '../lib/errors.ts'; +import { + normalizeWsProtocol, + PREFERRED_WS_PROTOCOLS, + V2_WS_PROTOCOL, + V3_WS_PROTOCOL, + type NegotiatedWsProtocol, +} from './websocket_protocols'; +import { + countClientMessagesForV3Frame, + encodeClientMessagesV3, + forEachServerMessageV3, +} from './websocket_v3_frames.ts'; export { DbConnectionBuilder, @@ -117,6 +129,9 @@ const CLIENT_MESSAGE_CALL_REDUCER_TAG = getClientMessageVariantTag('CallReducer'); const CLIENT_MESSAGE_CALL_PROCEDURE_TAG = getClientMessageVariantTag('CallProcedure'); +// Keep individual v3 frames bounded so one burst does not monopolize the send +// path or create very large websocket writes. +const MAX_V3_OUTBOUND_FRAME_BYTES = 256 * 1024; export class DbConnectionImpl implements DbContext @@ -172,6 +187,8 @@ export class DbConnectionImpl #inboundQueueOffset = 0; #isDrainingInboundQueue = false; #outboundQueue: Uint8Array[] = []; + #isOutboundFlushScheduled = false; + #negotiatedWsProtocol: NegotiatedWsProtocol = V2_WS_PROTOCOL; #subscriptionManager = new SubscriptionManager(); #remoteModule: RemoteModule; #reducerCallbacks = new Map< @@ -198,6 +215,7 @@ export class DbConnectionImpl #sourceNameToTableDef: Record>; #messageReader = new BinaryReader(new Uint8Array()); #rowListReader = new BinaryReader(new Uint8Array()); + #clientFrameEncoder = new BinaryWriter(1024); #boundSubscriptionBuilder!: () => SubscriptionBuilderImpl; #boundDisconnect!: () => void; @@ -296,7 +314,7 @@ export class DbConnectionImpl this.wsPromise = createWSFn({ url, nameOrAddress, - wsProtocol: 'v2.bsatn.spacetimedb', + wsProtocol: [...PREFERRED_WS_PROTOCOLS], authToken: token, compression: compression, lightMode: lightMode, @@ -595,23 +613,99 @@ export class DbConnectionImpl } #flushOutboundQueue(wsResolved: WebsocketAdapter): void { + if (this.#negotiatedWsProtocol === V3_WS_PROTOCOL) { + this.#flushOutboundQueueV3(wsResolved); + return; + } + this.#flushOutboundQueueV2(wsResolved); + } + + #flushOutboundQueueV2(wsResolved: WebsocketAdapter): void { const pending = this.#outboundQueue.splice(0); for (const message of pending) { wsResolved.send(message); } } + #flushOutboundQueueV3(wsResolved: WebsocketAdapter): void { + if (this.#outboundQueue.length === 0) { + return; + } + + // Emit at most one bounded frame per flush. If more encoded v2 messages + // remain in the queue, they are sent by a later scheduled flush so inbound + // traffic and other tasks get a chance to run between websocket writes. + const batchSize = countClientMessagesForV3Frame( + this.#outboundQueue, + MAX_V3_OUTBOUND_FRAME_BYTES + ); + wsResolved.send( + encodeClientMessagesV3( + this.#clientFrameEncoder, + this.#outboundQueue, + batchSize + ) + ); + + if (batchSize === this.#outboundQueue.length) { + this.#outboundQueue.length = 0; + return; + } + + this.#outboundQueue.copyWithin(0, batchSize); + this.#outboundQueue.length -= batchSize; + if (this.#outboundQueue.length > 0) { + this.#scheduleDeferredOutboundFlush(); + } + } + + #scheduleOutboundFlush(): void { + this.#scheduleOutboundFlushWith('microtask'); + } + + #scheduleDeferredOutboundFlush(): void { + this.#scheduleOutboundFlushWith('next-task'); + } + + #scheduleOutboundFlushWith(schedule: 'microtask' | 'next-task'): void { + if (this.#isOutboundFlushScheduled) { + return; + } + + this.#isOutboundFlushScheduled = true; + const flush = () => { + this.#isOutboundFlushScheduled = false; + if (this.ws && this.isActive) { + this.#flushOutboundQueue(this.ws); + } + }; + + // The first v3 flush stays on the current turn so same-tick sends coalesce. + // Follow-up flushes after a size-capped frame yield to the next task so we + // do not sit in a tight send loop while inbound websocket work is waiting. + if (schedule === 'next-task') { + setTimeout(flush, 0); + } else { + queueMicrotask(flush); + } + } + #reducerArgsEncoder = new BinaryWriter(1024); #clientMessageEncoder = new BinaryWriter(1024); #sendEncodedMessage(encoded: Uint8Array, describe: () => string): void { + stdbLogger('trace', describe); if (this.ws && this.isActive) { - if (this.#outboundQueue.length) this.#flushOutboundQueue(this.ws); + if (this.#negotiatedWsProtocol === V2_WS_PROTOCOL) { + if (this.#outboundQueue.length) this.#flushOutboundQueue(this.ws); + this.ws.send(encoded); + return; + } - stdbLogger('trace', describe); - this.ws.send(encoded); + this.#outboundQueue.push(encoded.slice()); + this.#scheduleOutboundFlush(); } else { - stdbLogger('trace', describe); - // use slice() to copy, in case the clientMessageEncoder's buffer gets used + // Use slice() to copy, in case the clientMessageEncoder's buffer gets reused + // before the connection opens or before a v3 microbatch flush runs. this.#outboundQueue.push(encoded.slice()); } } @@ -681,6 +775,9 @@ export class DbConnectionImpl * Handles WebSocket onOpen event. */ #handleOnOpen(): void { + if (this.ws) { + this.#negotiatedWsProtocol = normalizeWsProtocol(this.ws.protocol); + } this.isActive = true; if (this.ws) { this.#flushOutboundQueue(this.ws); @@ -728,10 +825,17 @@ export class DbConnectionImpl ); } - #processMessage(data: Uint8Array): void { - const reader = this.#messageReader; - reader.reset(data); - const serverMessage = ServerMessage.deserialize(reader); + #dispatchPendingCallbacks(callbacks: readonly PendingCallback[]): void { + stdbLogger( + 'trace', + () => `Calling ${callbacks.length} triggered row callbacks` + ); + for (const callback of callbacks) { + callback.cb(); + } + } + + #processServerMessage(serverMessage: ServerMessage): void { stdbLogger( 'trace', () => `Processing server message: ${stringify(serverMessage)}` @@ -769,13 +873,7 @@ export class DbConnectionImpl const callbacks = this.#applyTableUpdates(tableUpdates, eventContext); const { event: _, ...subscriptionEventContext } = eventContext; subscription.emitter.emit('applied', subscriptionEventContext); - stdbLogger( - 'trace', - () => `Calling ${callbacks.length} triggered row callbacks` - ); - for (const callback of callbacks) { - callback.cb(); - } + this.#dispatchPendingCallbacks(callbacks); break; } case 'UnsubscribeApplied': { @@ -801,13 +899,7 @@ export class DbConnectionImpl const { event: _, ...subscriptionEventContext } = eventContext; subscription.emitter.emit('end', subscriptionEventContext); this.#subscriptionManager.subscriptions.delete(querySetId); - stdbLogger( - 'trace', - () => `Calling ${callbacks.length} triggered row callbacks` - ); - for (const callback of callbacks) { - callback.cb(); - } + this.#dispatchPendingCallbacks(callbacks); break; } case 'SubscriptionError': { @@ -861,13 +953,7 @@ export class DbConnectionImpl eventContext, serverMessage.value ); - stdbLogger( - 'trace', - () => `Calling ${callbacks.length} triggered row callbacks` - ); - for (const callback of callbacks) { - callback.cb(); - } + this.#dispatchPendingCallbacks(callbacks); break; } case 'ReducerResult': { @@ -899,13 +985,7 @@ export class DbConnectionImpl eventContext, result.value.transactionUpdate ); - stdbLogger( - 'trace', - () => `Calling ${callbacks.length} triggered row callbacks` - ); - for (const callback of callbacks) { - callback.cb(); - } + this.#dispatchPendingCallbacks(callbacks); } this.#reducerCallInfo.delete(requestId); const cb = this.#reducerCallbacks.get(requestId); @@ -934,6 +1014,31 @@ export class DbConnectionImpl } } + #processV2Message(data: Uint8Array): void { + const reader = this.#messageReader; + reader.reset(data); + this.#processServerMessage(ServerMessage.deserialize(reader)); + } + + #processMessage(data: Uint8Array): void { + if (this.#negotiatedWsProtocol !== V3_WS_PROTOCOL) { + this.#processV2Message(data); + return; + } + + const messageCount = forEachServerMessageV3( + this.#messageReader, + data, + serverMessage => { + this.#processServerMessage(serverMessage); + } + ); + stdbLogger( + 'trace', + () => `Processing server v3 payload with ${messageCount} message(s)` + ); + } + /** * Handles WebSocket onMessage event. * @param wsMessage MessageEvent object. diff --git a/crates/bindings-typescript/src/sdk/websocket_decompress_adapter.ts b/crates/bindings-typescript/src/sdk/websocket_decompress_adapter.ts index 40157393dd1..b5db13a8149 100644 --- a/crates/bindings-typescript/src/sdk/websocket_decompress_adapter.ts +++ b/crates/bindings-typescript/src/sdk/websocket_decompress_adapter.ts @@ -2,6 +2,7 @@ import { decompress } from './decompress'; import { resolveWS } from './ws'; export interface WebsocketAdapter { + readonly protocol: string; send(msg: Uint8Array): void; close(): void; @@ -12,6 +13,10 @@ export interface WebsocketAdapter { } export class WebsocketDecompressAdapter implements WebsocketAdapter { + get protocol(): string { + return this.#ws.protocol; + } + set onclose(handler: (ev: CloseEvent) => void) { this.#ws.onclose = handler; } @@ -73,7 +78,7 @@ export class WebsocketDecompressAdapter implements WebsocketAdapter { confirmedReads, }: { url: URL; - wsProtocol: string; + wsProtocol: string | string[]; nameOrAddress: string; authToken?: string; compression: 'gzip' | 'none'; diff --git a/crates/bindings-typescript/src/sdk/websocket_protocols.ts b/crates/bindings-typescript/src/sdk/websocket_protocols.ts new file mode 100644 index 00000000000..2d6598143e9 --- /dev/null +++ b/crates/bindings-typescript/src/sdk/websocket_protocols.ts @@ -0,0 +1,25 @@ +import { stdbLogger } from './logger.ts'; + +export const V2_WS_PROTOCOL = 'v2.bsatn.spacetimedb'; +export const V3_WS_PROTOCOL = 'v3.bsatn.spacetimedb'; +export const PREFERRED_WS_PROTOCOLS = [V3_WS_PROTOCOL, V2_WS_PROTOCOL] as const; + +export type NegotiatedWsProtocol = + | typeof V2_WS_PROTOCOL + | typeof V3_WS_PROTOCOL; + +export function normalizeWsProtocol(protocol: string): NegotiatedWsProtocol { + if (protocol === V3_WS_PROTOCOL) { + return V3_WS_PROTOCOL; + } + // We treat an empty negotiated subprotocol as legacy v2 for compatibility. + if (protocol === '' || protocol === V2_WS_PROTOCOL) { + return V2_WS_PROTOCOL; + } + + stdbLogger( + 'warn', + `Unexpected websocket subprotocol "${protocol}", falling back to ${V2_WS_PROTOCOL}.` + ); + return V2_WS_PROTOCOL; +} diff --git a/crates/bindings-typescript/src/sdk/websocket_test_adapter.ts b/crates/bindings-typescript/src/sdk/websocket_test_adapter.ts index 6ac15f0e7fe..4c2c48b42f0 100644 --- a/crates/bindings-typescript/src/sdk/websocket_test_adapter.ts +++ b/crates/bindings-typescript/src/sdk/websocket_test_adapter.ts @@ -1,61 +1,109 @@ -import { BinaryReader, BinaryWriter } from '../'; +import BinaryReader from '../lib/binary_reader.ts'; +import BinaryWriter from '../lib/binary_writer.ts'; import { ClientMessage, ServerMessage } from './client_api/types'; import type { WebsocketAdapter } from './websocket_decompress_adapter'; +import { PREFERRED_WS_PROTOCOLS, V3_WS_PROTOCOL } from './websocket_protocols'; +import { + decodeClientMessagesV3, + encodeServerMessagesV3, +} from './websocket_v3_frames.ts'; class WebsocketTestAdapter implements WebsocketAdapter { - onclose: any; - // eslint-disable-next-line @typescript-eslint/no-unsafe-function-type - onopen!: () => void; - onmessage: any; - onerror: any; + protocol: string = ''; - messageQueue: any[]; + messageQueue: Uint8Array[]; outgoingMessages: ClientMessage[]; closed: boolean; + supportedProtocols: string[]; + + #onclose: (ev: CloseEvent) => void = () => {}; + #onopen: () => void = () => {}; + #onmessage: (msg: { data: Uint8Array }) => void = () => {}; constructor() { this.messageQueue = []; this.outgoingMessages = []; this.closed = false; + this.supportedProtocols = [...PREFERRED_WS_PROTOCOLS]; + } + + set onclose(handler: (ev: CloseEvent) => void) { + this.#onclose = handler; + } + + set onopen(handler: () => void) { + this.#onopen = handler; } - send(message: any): void { - const parsedMessage = ClientMessage.deserialize(new BinaryReader(message)); - this.outgoingMessages.push(parsedMessage); - // console.ClientMessageSerde.deserialize(message); - this.messageQueue.push(message); + set onmessage(handler: (msg: { data: Uint8Array }) => void) { + this.#onmessage = handler; + } + + set onerror(_handler: (msg: ErrorEvent) => void) {} + + send(message: Uint8Array): void { + const rawMessage = message.slice(); + const outgoingMessages = + this.protocol === V3_WS_PROTOCOL + ? decodeClientMessagesV3(rawMessage) + : [rawMessage]; + + for (const outgoingMessage of outgoingMessages) { + this.outgoingMessages.push( + ClientMessage.deserialize(new BinaryReader(outgoingMessage)) + ); + } + this.messageQueue.push(rawMessage); } close(): void { this.closed = true; - this.onclose?.({ code: 1000, reason: 'normal closure', wasClean: true }); + this.#onclose({ + code: 1000, + reason: 'normal closure', + wasClean: true, + } as CloseEvent); } acceptConnection(): void { - this.onopen(); + this.#onopen(); } sendToClient(message: ServerMessage): void { const writer = new BinaryWriter(1024); ServerMessage.serialize(writer, message); - const rawBytes = writer.getBuffer(); + const rawBytes = writer.getBuffer().slice(); // The brotli library's `compress` is somehow broken: it returns `null` for some inputs. // See https://github.com/foliojs/brotli.js/issues/36, which is closed but not actually fixed. // So we send the uncompressed data here, and in `spacetimedb.ts`, // if compression fails, we treat the raw message as having been uncompressed all along. // const data = compress(rawBytes); - this.onmessage({ data: rawBytes }); + const outboundData = + this.protocol === V3_WS_PROTOCOL + ? encodeServerMessagesV3(writer, [rawBytes]).slice() + : rawBytes; + this.#onmessage({ data: outboundData }); } async createWebSocketFn(_args: { url: URL; - wsProtocol: string; + wsProtocol: string | string[]; nameOrAddress: string; authToken?: string; compression: 'gzip' | 'none'; lightMode: boolean; confirmedReads?: boolean; }): Promise { + const requestedProtocols = Array.isArray(_args.wsProtocol) + ? _args.wsProtocol + : [_args.wsProtocol]; + const negotiatedProtocol = requestedProtocols.find(protocol => + this.supportedProtocols.includes(protocol) + ); + if (!negotiatedProtocol) { + return Promise.reject(new Error('No compatible websocket protocol')); + } + this.protocol = negotiatedProtocol; return this; } } diff --git a/crates/bindings-typescript/src/sdk/websocket_v3_frames.ts b/crates/bindings-typescript/src/sdk/websocket_v3_frames.ts new file mode 100644 index 00000000000..36ea683ecfb --- /dev/null +++ b/crates/bindings-typescript/src/sdk/websocket_v3_frames.ts @@ -0,0 +1,124 @@ +import BinaryReader from '../lib/binary_reader.ts'; +import BinaryWriter from '../lib/binary_writer.ts'; +import { ClientMessage, ServerMessage } from './client_api/types'; + +// v3 is only a transport framing convention. The payload is one or more +// already-encoded v2 websocket messages concatenated back-to-back, so these +// helpers intentionally operate on raw bytes. +const EMPTY_V3_PAYLOAD_ERR = + 'v3 websocket payloads must contain at least one message'; + +function ensureMessages(messages: readonly Uint8Array[]): void { + if (messages.length === 0) { + throw new RangeError(EMPTY_V3_PAYLOAD_ERR); + } +} + +function ensureMessageCount( + messages: readonly Uint8Array[], + messageCount: number +): void { + ensureMessages(messages); + if (messageCount < 1 || messageCount > messages.length) { + throw new RangeError( + `v3 websocket payload requested ${messageCount} messages from ${messages.length}` + ); + } +} + +function concatenateMessagesV3( + writer: BinaryWriter, + messages: readonly Uint8Array[], + messageCount: number = messages.length +): Uint8Array { + ensureMessageCount(messages, messageCount); + writer.clear(); + for (let i = 0; i < messageCount; i++) { + writer.writeBytes(messages[i]!); + } + return writer.getBuffer(); +} + +function splitMessagesV3( + reader: BinaryReader, + data: Uint8Array, + deserialize: (reader: BinaryReader) => unknown +): Uint8Array[] { + reader.reset(data); + if (reader.remaining === 0) { + throw new RangeError(EMPTY_V3_PAYLOAD_ERR); + } + + const messages: Uint8Array[] = []; + while (reader.remaining > 0) { + const startOffset = reader.offset; + deserialize(reader); + messages.push(data.subarray(startOffset, reader.offset)); + } + + return messages; +} + +export function countClientMessagesForV3Frame( + messages: readonly Uint8Array[], + maxFrameBytes: number +): number { + ensureMessages(messages); + + const firstMessage = messages[0]!; + if (firstMessage.length > maxFrameBytes) { + return 1; + } + + let count = 1; + let frameSize = firstMessage.length; + while (count < messages.length) { + const nextMessage = messages[count]!; + const nextFrameSize = frameSize + nextMessage.length; + if (nextFrameSize > maxFrameBytes) { + break; + } + frameSize = nextFrameSize; + count += 1; + } + return count; +} + +export function encodeClientMessagesV3( + writer: BinaryWriter, + messages: readonly Uint8Array[], + messageCount: number = messages.length +): Uint8Array { + return concatenateMessagesV3(writer, messages, messageCount); +} + +export function decodeClientMessagesV3(data: Uint8Array): Uint8Array[] { + return splitMessagesV3(new BinaryReader(data), data, reader => + ClientMessage.deserialize(reader) + ); +} + +export function encodeServerMessagesV3( + writer: BinaryWriter, + messages: readonly Uint8Array[] +): Uint8Array { + return concatenateMessagesV3(writer, messages); +} + +export function forEachServerMessageV3( + reader: BinaryReader, + data: Uint8Array, + visit: (message: ServerMessage) => void +): number { + reader.reset(data); + if (reader.remaining === 0) { + throw new RangeError(EMPTY_V3_PAYLOAD_ERR); + } + + let count = 0; + while (reader.remaining > 0) { + visit(ServerMessage.deserialize(reader)); + count += 1; + } + return count; +} diff --git a/crates/bindings-typescript/tests/db_connection.test.ts b/crates/bindings-typescript/tests/db_connection.test.ts index ec17430e41a..1161c14aaa5 100644 --- a/crates/bindings-typescript/tests/db_connection.test.ts +++ b/crates/bindings-typescript/tests/db_connection.test.ts @@ -10,6 +10,8 @@ import { } from '../src'; import { ServerMessage } from '../src/sdk/client_api/types'; import WebsocketTestAdapter from '../src/sdk/websocket_test_adapter'; +import { V2_WS_PROTOCOL, V3_WS_PROTOCOL } from '../src/sdk/websocket_protocols'; +import { decodeClientMessagesV3 } from '../src/sdk/websocket_v3_frames.ts'; import { DbConnection } from '../test-app/src/module_bindings'; import User from '../test-app/src/module_bindings/user_table'; import { @@ -194,6 +196,63 @@ describe('DbConnection', () => { expect(called).toBeTruthy(); }); + test('batches same-tick reducer calls when v3 is negotiated', async () => { + const wsAdapter = new WebsocketTestAdapter(); + const client = DbConnection.builder() + .withUri('ws://127.0.0.1:1234') + .withDatabaseName('db') + .withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any) + .build(); + + await client['wsPromise']; + wsAdapter.acceptConnection(); + + void client.reducers.createPlayer({ + name: 'Player One', + location: { x: 1, y: 2 }, + }); + void client.reducers.createPlayer({ + name: 'Player Two', + location: { x: 3, y: 4 }, + }); + + await Promise.resolve(); + + expect(wsAdapter.protocol).toEqual(V3_WS_PROTOCOL); + expect(wsAdapter.messageQueue).toHaveLength(1); + expect(wsAdapter.outgoingMessages).toHaveLength(2); + + expect(decodeClientMessagesV3(wsAdapter.messageQueue[0])).toHaveLength(2); + }); + + test('falls back to v2 and does not batch reducer calls when v3 is unavailable', async () => { + const wsAdapter = new WebsocketTestAdapter(); + wsAdapter.supportedProtocols = [V2_WS_PROTOCOL]; + const client = DbConnection.builder() + .withUri('ws://127.0.0.1:1234') + .withDatabaseName('db') + .withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any) + .build(); + + await client['wsPromise']; + wsAdapter.acceptConnection(); + + void client.reducers.createPlayer({ + name: 'Player One', + location: { x: 1, y: 2 }, + }); + void client.reducers.createPlayer({ + name: 'Player Two', + location: { x: 3, y: 4 }, + }); + + await Promise.resolve(); + + expect(wsAdapter.protocol).toEqual(V2_WS_PROTOCOL); + expect(wsAdapter.messageQueue).toHaveLength(2); + expect(wsAdapter.outgoingMessages).toHaveLength(2); + }); + test('disconnects when SubscriptionError has no requestId', async () => { const onDisconnectPromise = new Deferred(); const wsAdapter = new WebsocketTestAdapter(); @@ -750,6 +809,7 @@ describe('DbConnection', () => { .withWSFn(wsAdapter.createWebSocketFn.bind(wsAdapter) as any) .build(); await client['wsPromise']; + wsAdapter.acceptConnection(); const user1 = { identity: bobIdentity, username: 'bob' }; const user2 = { identity: sallyIdentity, diff --git a/crates/bindings-typescript/tests/websocket_v3_frames.test.ts b/crates/bindings-typescript/tests/websocket_v3_frames.test.ts new file mode 100644 index 00000000000..eefe7c0c640 --- /dev/null +++ b/crates/bindings-typescript/tests/websocket_v3_frames.test.ts @@ -0,0 +1,100 @@ +import { describe, expect, test } from 'vitest'; +import BinaryReader from '../src/lib/binary_reader.ts'; +import BinaryWriter from '../src/lib/binary_writer.ts'; +import { ClientMessage } from '../src/sdk/client_api/types'; +import { + countClientMessagesForV3Frame, + decodeClientMessagesV3, + encodeClientMessagesV3, +} from '../src/sdk/websocket_v3_frames'; + +function encodeClientMessage(message: ClientMessage): Uint8Array { + const writer = new BinaryWriter(128); + ClientMessage.serialize(writer, message); + return writer.getBuffer().slice(); +} + +describe('websocket_v3_frames', () => { + test('counts as many client messages as fit within the encoded frame limit', () => { + const messages = [ + new Uint8Array(10), + new Uint8Array(20), + new Uint8Array(30), + ]; + + expect(countClientMessagesForV3Frame(messages, 10)).toBe(1); + expect(countClientMessagesForV3Frame(messages, 30)).toBe(2); + expect(countClientMessagesForV3Frame(messages, 60)).toBe(3); + }); + + test('still emits an oversized first message on its own', () => { + const messages = [new Uint8Array(300_000), new Uint8Array(10)]; + expect(countClientMessagesForV3Frame(messages, 256 * 1024)).toBe(1); + }); + + test('encodes and decodes raw concatenated v2 messages', () => { + const encodedMessages = [ + encodeClientMessage( + ClientMessage.CallReducer({ + requestId: 7, + flags: 0, + reducer: 'first', + args: new Uint8Array([1, 2]), + }) + ), + encodeClientMessage( + ClientMessage.CallProcedure({ + requestId: 8, + flags: 0, + procedure: 'second', + args: new Uint8Array([3, 4, 5]), + }) + ), + ]; + const payload = encodeClientMessagesV3( + new BinaryWriter(128), + encodedMessages + ); + const decodedMessages = decodeClientMessagesV3(payload); + + expect(decodedMessages).toHaveLength(2); + expect( + ClientMessage.deserialize(new BinaryReader(decodedMessages[0])).tag + ).toBe('CallReducer'); + expect( + ClientMessage.deserialize(new BinaryReader(decodedMessages[1])).tag + ).toBe('CallProcedure'); + }); + + test('can encode only a prefix of queued client messages', () => { + const encodedMessages = [ + encodeClientMessage( + ClientMessage.CallReducer({ + requestId: 7, + flags: 0, + reducer: 'first', + args: new Uint8Array([1, 2]), + }) + ), + encodeClientMessage( + ClientMessage.CallProcedure({ + requestId: 8, + flags: 0, + procedure: 'second', + args: new Uint8Array([3, 4, 5]), + }) + ), + ]; + const payload = encodeClientMessagesV3( + new BinaryWriter(128), + encodedMessages, + 1 + ); + const decodedMessages = decodeClientMessagesV3(payload); + + expect(decodedMessages).toHaveLength(1); + expect( + ClientMessage.deserialize(new BinaryReader(decodedMessages[0])).tag + ).toBe('CallReducer'); + }); +});