diff --git a/README.md b/README.md index c21d92d..c86515b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # shinyloadtest -A load-generation tool for [Shiny](https://shiny.posit.co/) applications. -shinyloadtest replays recorded sessions against a deployed Shiny app, simulating -concurrent users to measure application performance under load. +A load-testing tool for [Shiny](https://shiny.posit.co/) applications. +shinyloadtest records a user session and replays it with concurrent workers +to measure application performance under load. ## Installation @@ -20,13 +20,49 @@ Or run directly with npx: npx @posit-dev/shinyloadtest --help ``` -## Usage +## Quick Start ```bash -shinyloadtest replay recording.log https://example.com/app [options] +# 1. Record a session against a running Shiny app +shinyloadtest record https://example.com/app + +# 2. Interact with the app in the browser at the printed proxy URL, +# then close the browser tab to stop recording. + +# 3. Replay the recording with multiple concurrent users +shinyloadtest replay recording.log https://example.com/app --workers 5 +``` + +## Recording + +```bash +shinyloadtest record [options] ``` -### Options +Starts a local reverse proxy that sits between your browser and the Shiny +application. All HTTP and WebSocket traffic is captured to a recording file. +Navigate to the proxy URL printed on startup, interact with the app as a +typical user would, then close the browser tab (or press Ctrl+C) to stop. + +### Record Options + +| Option | Description | +|--------|-------------| +| `--port ` | Local proxy port (default: `8600`) | +| `--host ` | Local proxy host (default: `127.0.0.1`) | +| `--output ` | Output recording file (default: `recording.log`) | +| `--open` | Open browser automatically | + +## Replay + +```bash +shinyloadtest replay [app-url] [options] +``` + +Replays a recorded session with one or more concurrent workers. If `app-url` +is omitted, the target URL from the recording file is used. + +### Replay Options | Option | Description | |--------|-------------| @@ -49,24 +85,19 @@ shinyloadtest supports authentication via environment variables: | `SHINYLOADTEST_PASS` | Password for Shiny Server Pro or Posit Connect | | `SHINYLOADTEST_CONNECT_API_KEY` | API key for Posit Connect | +These variables are used during both recording and replay. If the app requires +login and environment variables are not set, `record` will prompt interactively +(TTY required). + > **Note:** If the recording was made with a Connect API key, playback must > also use a Connect API key. Likewise, if the recording was made without an > API key, playback must not use one. -## Example - -```bash -shinyloadtest replay recording.log https://rsc.example.com/app \ - --workers 5 \ - --loaded-duration-minutes 10 \ - --output-dir load-test-results -``` - ## Companion Package shinyloadtest is designed to work with the [shinyloadtest](https://rstudio.github.io/shinyloadtest) R package. -Use the R package to record sessions and analyze load test results. +Use the R package to analyze load test results. ## Migration from shinycannon diff --git a/examples/loadtest-demo-py/app.py b/examples/loadtest-demo-py/app.py index 9099e4e..b1873c7 100644 --- a/examples/loadtest-demo-py/app.py +++ b/examples/loadtest-demo-py/app.py @@ -92,6 +92,8 @@ async def _update_products(): @reactive.event(input.run) async def result(): product = input.product() + if not product: + return None n = simulation_sizes[input.sim_size()] # Fake slow database read for historical sales data @@ -114,6 +116,8 @@ async def result(): @render_plotly async def sim_plot(): res = await result() + if res is None: + return None fig = go.Figure() fig.add_trace( @@ -147,6 +151,8 @@ async def sim_plot(): @render.text async def result_text(): res = await result() + if res is None: + return "" q05, q50, q95 = np.quantile(res["demand"], [0.05, 0.5, 0.95]) return ( f"Product: {res['product']} ({res['category']})\n" diff --git a/examples/loadtest-demo-r/app.R b/examples/loadtest-demo-r/app.R index 4080ceb..b803ae3 100644 --- a/examples/loadtest-demo-r/app.R +++ b/examples/loadtest-demo-r/app.R @@ -54,7 +54,7 @@ ui <- page_sidebar( ), input_task_button("run", "Run Forecast") ), - class = "bslib-page-dasboard", + class = "bslib-page-dashboard", shiny::useBusyIndicators(), card( min_height = 300, @@ -84,6 +84,7 @@ server <- function(input, output, session) { # Step 3 -> Step 4: run forecast on button click result <- eventReactive(input$run, { + req(input$product) product <- input$product n <- simulation_sizes[[input$sim_size]] diff --git a/src/cli.ts b/src/cli.ts index 4fb0b51..dee22f7 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -2,11 +2,12 @@ import * as fs from "node:fs" import { Command } from "commander" import { bold, cyan, dim, green, magenta, yellow } from "yoctocolors" import { VERSION } from "./version.js" -import { defaultOutputDir } from "./output.js" +import { defaultOutputDir } from "./replay/output.js" import { parseLogLevel, LogLevel } from "./logger.js" import { getCreds } from "./auth.js" import { type Creds } from "./types.js" import { readRecording } from "./recording.js" +import { type RecordOptions } from "./record/record.js" // --------------------------------------------------------------------------- // ParsedArgs @@ -85,11 +86,19 @@ export function serializeArgs(args: ParsedArgs): { return { argsString, argsJson } } +// --------------------------------------------------------------------------- +// CLI result discriminated union +// --------------------------------------------------------------------------- + +export type CliResult = + | { command: "replay"; args: ParsedArgs } + | { command: "record"; options: RecordOptions } + // --------------------------------------------------------------------------- // Argument parsing // --------------------------------------------------------------------------- -export function parseArgs(argv?: string[]): ParsedArgs { +export function parseArgs(argv?: string[]): CliResult { const program = new Command() const colorArgument = (str: string): string => { @@ -107,7 +116,65 @@ export function parseArgs(argv?: string[]): ParsedArgs { .description("Load testing tool for Shiny applications.") .version(VERSION) - let result: ParsedArgs | undefined + let result: CliResult | undefined + + const recordCmd = program + .command("record") + .configureHelp({ + styleTitle: (str) => bold(str), + styleArgumentTerm: (str) => colorArgument(str), + styleArgumentText: (str) => colorArgument(str), + styleOptionTerm: (str) => cyan(str), + }) + .description( + "Record a Shiny application session for later replay.\n\n" + + "Starts a local reverse proxy. Navigate your browser through the proxy\n" + + "to interact with the Shiny application; all WebSocket and HTTP traffic\n" + + "is captured to a recording file.\n\n" + + dim("Example:") + + "\n" + + ` ${cyan("$")} shinyloadtest record https://rsc.example.com/app`, + ) + .argument("", "URL of the Shiny application to record") + .option("--port ", "Local proxy port", "8600") + .option("--host ", "Local proxy host", "127.0.0.1") + .option("--output ", "Output recording file", "recording.log") + .option("--open", "Open browser automatically", false) + .addHelpText( + "after", + `\n${bold("Environment variables:")}\n` + + ` ${yellow("SHINYLOADTEST_USER")} Username for SSP or Connect auth\n` + + ` ${yellow("SHINYLOADTEST_PASS")} Password for SSP or Connect auth\n` + + ` ${yellow("SHINYLOADTEST_CONNECT_API_KEY")} Posit Connect API key\n` + + `\n${dim(" Legacy SHINYCANNON_* environment variables are also supported.")}`, + ) + .action( + ( + targetUrl: string, + opts: { + port: string + host: string + output: string + open: boolean + }, + ) => { + const port = Number(opts.port) + if (!Number.isInteger(port) || port < 1 || port > 65535) { + throw new Error(`Invalid port value: ${opts.port}`) + } + + result = { + command: "record", + options: { + targetUrl, + port, + host: opts.host, + output: opts.output, + open: opts.open, + }, + } + }, + ) const replayCmd = program .command("replay") @@ -231,17 +298,20 @@ export function parseArgs(argv?: string[]): ParsedArgs { } result = { - recordingPath, - appUrl, - workers, - loadedDurationMinutes, - startInterval, - headers, - outputDir: opts.outputDir, - overwriteOutput: opts.overwriteOutput, - debugLog: opts.debugLog, - logLevel: parseLogLevel(opts.logLevel), - creds: getCreds(), + command: "replay", + args: { + recordingPath, + appUrl, + workers, + loadedDurationMinutes, + startInterval, + headers, + outputDir: opts.outputDir, + overwriteOutput: opts.overwriteOutput, + debugLog: opts.debugLog, + logLevel: parseLogLevel(opts.logLevel), + creds: getCreds(), + }, } }, ) @@ -259,6 +329,10 @@ export function parseArgs(argv?: string[]): ParsedArgs { replayCmd.help() } + if (userArgs.length === 1 && userArgs[0] === "record") { + recordCmd.help() + } + program.parse(raw) if (!result) { diff --git a/src/main.ts b/src/main.ts index f0dde8e..aac8bbb 100644 --- a/src/main.ts +++ b/src/main.ts @@ -2,17 +2,25 @@ import * as path from "node:path" import { CookieJar } from "tough-cookie" import { VERSION } from "./version.js" import { parseArgs, serializeArgs } from "./cli.js" +import { record } from "./record/record.js" import { readRecording, recordingDuration } from "./recording.js" import { createLogger } from "./logger.js" -import { createOutputDir } from "./output.js" -import { runEnduranceTest } from "./worker.js" +import { createOutputDir } from "./replay/output.js" +import { runEnduranceTest } from "./replay/worker.js" import { SERVER_TYPE_NAMES, ServerType } from "./types.js" import { HttpClient } from "./http.js" import { detectServerType } from "./detect.js" -import { TerminalUI } from "./ui.js" +import { ReplayTerminalUI } from "./replay/ui.js" async function main(): Promise { - const args = parseArgs() + const result = parseArgs() + + if (result.command === "record") { + await record(result.options) + return + } + + const args = result.args const recording = readRecording(args.recordingPath) const duration = recordingDuration(recording) @@ -91,7 +99,7 @@ async function main(): Promise { const { argsString, argsJson } = serializeArgs(args) const ui = process.stderr.isTTY - ? new TerminalUI({ + ? new ReplayTerminalUI({ version: VERSION, appUrl: args.appUrl, workers: args.workers, diff --git a/src/record/events.ts b/src/record/events.ts new file mode 100644 index 0000000..f41f6ef --- /dev/null +++ b/src/record/events.ts @@ -0,0 +1,139 @@ +/** + * Event construction helpers for recording. + * Creates event objects that serialize to the recording log format. + */ + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface RecordingHttpEvent { + readonly type: "REQ_HOME" | "REQ_TOK" | "REQ_SINF" | "REQ_GET" | "REQ_POST" + readonly begin: string + readonly end: string + readonly status: number + readonly url: string + readonly datafile?: string +} + +export interface RecordingWsEvent { + readonly type: + | "WS_OPEN" + | "WS_RECV" + | "WS_RECV_INIT" + | "WS_RECV_BEGIN_UPLOAD" + | "WS_SEND" + | "WS_CLOSE" + readonly begin: string + readonly url?: string + readonly message?: string +} + +export type RecordingEvent = RecordingHttpEvent | RecordingWsEvent + +// --------------------------------------------------------------------------- +// Timestamp helper +// --------------------------------------------------------------------------- + +export function toISOTimestamp(date: Date): string { + return date.toISOString() +} + +// --------------------------------------------------------------------------- +// Factory functions +// --------------------------------------------------------------------------- + +export function makeHttpEvent( + type: RecordingHttpEvent["type"], + begin: Date, + end: Date, + status: number, + url: string, + datafile?: string, +): RecordingHttpEvent { + const event: RecordingHttpEvent = { + type, + begin: toISOTimestamp(begin), + end: toISOTimestamp(end), + status, + url, + } + if (datafile !== undefined) { + return { ...event, datafile } + } + return event +} + +export function makeWsEvent( + type: RecordingWsEvent["type"], + begin: Date, + extras?: { url?: string; message?: string }, +): RecordingWsEvent { + const event: RecordingWsEvent = { + type, + begin: toISOTimestamp(begin), + } + if (extras?.url !== undefined && extras?.message !== undefined) { + return { ...event, url: extras.url, message: extras.message } + } + if (extras?.url !== undefined) { + return { ...event, url: extras.url } + } + if (extras?.message !== undefined) { + return { ...event, message: extras.message } + } + return event +} + +// --------------------------------------------------------------------------- +// GET request classification +// --------------------------------------------------------------------------- + +/** + * Classify a GET request by its path and return the appropriate event type. + * Also extracts token values from the path for SINF requests. + * + * Classification rules (from the R reference implementation): + * - Path ends with `/` or `.rmd` (case-insensitive) → REQ_HOME + * - Path contains `__token__` → REQ_TOK + * - Path matches `/__sockjs__/...n=` → REQ_SINF (extracts ROBUST_ID) + * - Everything else → REQ_GET + */ +export function classifyGetRequest(pathWithQuery: string): { + type: RecordingHttpEvent["type"] + robustId?: string +} { + // Extract just the path (before query string) + const path = pathWithQuery.split("?")[0] ?? pathWithQuery + + // REQ_HOME: ends with / or .rmd (case-insensitive) + if (/(\/|\.rmd)$/i.test(path)) { + return { type: "REQ_HOME" } + } + + // REQ_TOK: contains __token__ + if (path.includes("__token__")) { + return { type: "REQ_TOK" } + } + + // REQ_SINF: /__sockjs__/ with n= in path segment or query param + if (path.includes("/__sockjs__/")) { + // Check path segments for n= + const pathNMatch = /\/n=([^/?&]+)/.exec(path) + if (pathNMatch?.[1]) { + return { type: "REQ_SINF", robustId: pathNMatch[1] } + } + // Check query params for n= + try { + const parsed = new URL(pathWithQuery, "http://localhost") + const n = parsed.searchParams.get("n") + if (n) { + return { type: "REQ_SINF", robustId: n } + } + } catch { + // Malformed URL, fall through + } + } + + return { type: "REQ_GET" } +} diff --git a/src/record/proxy.ts b/src/record/proxy.ts new file mode 100644 index 0000000..04ab1fe --- /dev/null +++ b/src/record/proxy.ts @@ -0,0 +1,667 @@ +/** + * HTTP reverse proxy for recording Shiny sessions. + * Intercepts browser requests, forwards them to the target app, + * records events, and returns responses to the browser. + */ + +import * as http from "node:http" +import * as https from "node:https" +import type { Socket } from "node:net" +import WebSocket, { WebSocketServer } from "ws" +import { CookieJar } from "tough-cookie" +import { RecordingWriter } from "./writer.js" +import { RecordingTokens } from "./tokens.js" +import { classifyGetRequest, makeHttpEvent, makeWsEvent } from "./events.js" +import { extractWorkerId, extractToken } from "../http.js" +import { canIgnore, parseMessage } from "../sockjs.js" +import { httpToWs } from "../url.js" + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +const HOP_BY_HOP_HEADERS = new Set([ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", +]) + +const SINF_DELAY_MS = 750 + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function stripHopByHop( + headers: http.IncomingHttpHeaders, +): Record { + const toRemove = new Set(HOP_BY_HOP_HEADERS) + const connHeader = headers["connection"] + if (connHeader) { + const raw = Array.isArray(connHeader) ? connHeader.join(",") : connHeader + for (const token of raw.split(",")) { + toRemove.add(token.trim().toLowerCase()) + } + } + + const result: Record = {} + for (const [key, value] of Object.entries(headers)) { + if (value !== undefined && !toRemove.has(key.toLowerCase())) { + result[key] = value + } + } + return result +} + +function shouldIgnoreGet(path: string): boolean { + return /.*favicon\.ico$/.test(path) +} + +function delay(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface RecordingProxyOptions { + targetUrl: string + host: string + port: number + writer: RecordingWriter + tokens: RecordingTokens + cookieJar: CookieJar + authHeaders: Record + onFirstConnection?: () => void + onShutdown: () => void +} + +interface ForwardedResponse { + statusCode: number + headers: http.IncomingHttpHeaders + body: Buffer +} + +// --------------------------------------------------------------------------- +// RecordingProxy +// --------------------------------------------------------------------------- + +export class RecordingProxy { + private server: http.Server | null = null + private wss: WebSocketServer | null = null + private readonly target: URL + private readonly host: string + private readonly port: number + private readonly writer: RecordingWriter + private readonly tokens: RecordingTokens + private readonly cookieJar: CookieJar + private readonly authHeaders: Record + private onFirstConnection: (() => void) | null + readonly onShutdown: () => void + private connected = false + private activeWsCount = 0 + private shutdownTimer: ReturnType | null = null + private static readonly SHUTDOWN_GRACE_MS = 500 + + constructor(options: RecordingProxyOptions) { + this.target = new URL(options.targetUrl) + this.host = options.host + this.port = options.port + this.writer = options.writer + this.tokens = options.tokens + this.cookieJar = options.cookieJar + this.authHeaders = options.authHeaders + this.onFirstConnection = options.onFirstConnection ?? null + this.onShutdown = options.onShutdown + } + + private notifyFirstConnection(): void { + if (!this.connected) { + this.connected = true + this.onFirstConnection?.() + this.onFirstConnection = null + } + } + + async start(): Promise { + this.server = http.createServer((req, res) => { + this.notifyFirstConnection() + this.handleRequest(req, res).catch((err) => { + console.error("Proxy error:", err) + if (!res.headersSent) { + res.writeHead(502) + res.end("Bad Gateway") + } + }) + }) + + this.wss = new WebSocketServer({ noServer: true }) + + this.server.on("upgrade", (req, socket, head) => { + this.handleUpgrade(req, socket as Socket, head) + }) + + return new Promise((resolve, reject) => { + const onError = (err: Error): void => reject(err) + this.server!.once("error", onError) + this.server!.listen(this.port, this.host, () => { + this.server!.removeListener("error", onError) + resolve() + }) + }) + } + + async stop(): Promise { + if (this.shutdownTimer) { + clearTimeout(this.shutdownTimer) + this.shutdownTimer = null + } + return new Promise((resolve) => { + const closeServer = (): void => { + if (this.server) { + this.server.close(() => resolve()) + } else { + resolve() + } + } + if (this.wss) { + this.wss.close(() => closeServer()) + } else { + closeServer() + } + }) + } + + /** Expose the underlying server for WebSocket upgrade handling. */ + get httpServer(): http.Server | null { + return this.server + } + + // ------------------------------------------------------------------------- + // WebSocket handling + // ------------------------------------------------------------------------- + + private handleUpgrade( + req: http.IncomingMessage, + socket: Socket, + head: Buffer, + ): void { + this.wss!.handleUpgrade(req, socket, head, (clientWs) => { + this.handleWebSocket(req, clientWs) + }) + } + + private handleWebSocket( + req: http.IncomingMessage, + clientWs: WebSocket, + ): void { + this.activeWsCount++ + if (this.shutdownTimer) { + clearTimeout(this.shutdownTimer) + this.shutdownTimer = null + } + + const rawUrl = req.url ?? "/" + let parsed: URL + try { + parsed = new URL(rawUrl, "http://localhost") + } catch { + this.activeWsCount-- + clientWs.close(1002, "Malformed URL") + return + } + const pathInfo = parsed.pathname + + // Discover SOCKJSID from the path + // Pattern: /...///websocket + const sockjsMatch = /\/([^/]+\/[^/]+)\/websocket$/.exec(pathInfo) + if (sockjsMatch?.[1]) { + this.tokens.discover("SOCKJSID", sockjsMatch[1]) + } + + // Record WS_OPEN with token-replaced URL + this.writer.writeEvent( + makeWsEvent("WS_OPEN", new Date(), { + url: this.tokens.replaceInString(pathInfo), + }), + ) + + // Build target WebSocket URL + const targetHttpUrl = new URL(this.target.toString()) + targetHttpUrl.pathname = this.target.pathname.replace(/\/$/, "") + pathInfo + targetHttpUrl.search = parsed.search + const wsUrl = httpToWs(targetHttpUrl.toString()) + + // Build outgoing WS headers (cookies + auth) + const wsHeaders: Record = {} + const cookieString = this.cookieJar.getCookieStringSync( + this.target.toString(), + ) + if (cookieString) { + wsHeaders["Cookie"] = cookieString + } + for (const [key, value] of Object.entries(this.authHeaders)) { + wsHeaders[key] = value + } + + // Connect to target + const serverWs = new WebSocket(wsUrl, { headers: wsHeaders }) + + // Server send buffer (for messages from client before server is open) + const serverSendBuffer: string[] = [] + + const serverSend = (msg: string): void => { + if ( + serverWs.readyState === WebSocket.OPEN && + serverSendBuffer.length === 0 + ) { + serverWs.send(msg) + } else { + serverSendBuffer.push(msg) + } + } + + const sendToClient = (msg: string): void => { + if (clientWs.readyState === WebSocket.OPEN) { + clientWs.send(msg) + } + } + + // Server -> Client relay + serverWs.on("message", (data) => { + const msg = data.toString() + + // SockJS open frame: record and relay + if (msg === "o") { + this.writer.writeEvent( + makeWsEvent("WS_RECV", new Date(), { message: msg }), + ) + sendToClient(msg) + return + } + + // Ignorable messages: relay but don't record + if (canIgnore(msg)) { + sendToClient(msg) + return + } + + let parsed: Record | null = null + try { + parsed = parseMessage(msg) + } catch { + // parseMessage failed on malformed/non-SockJS frame; fall through + // to generic WS_RECV recording below + } + + // WS_RECV_INIT: config message + if ( + parsed && + "config" in parsed && + typeof parsed["config"] === "object" && + parsed["config"] !== null + ) { + const config = parsed["config"] as Record + if (typeof config["sessionId"] === "string") { + this.tokens.discover("SESSION", config["sessionId"]) + } + this.writer.writeEvent( + makeWsEvent("WS_RECV_INIT", new Date(), { + message: this.tokens.replaceInString(msg), + }), + ) + sendToClient(msg) + return + } + + // WS_RECV_BEGIN_UPLOAD: upload response with jobId + if (parsed) { + const response = parsed["response"] + if (response && typeof response === "object") { + const value = (response as Record)["value"] + if (value && typeof value === "object") { + const v = value as Record + if (typeof v["jobId"] === "string") { + if (typeof v["uploadUrl"] === "string") { + this.tokens.discover("UPLOAD_URL", v["uploadUrl"]) + } + this.tokens.discover("UPLOAD_JOB_ID", v["jobId"]) + this.writer.writeEvent( + makeWsEvent("WS_RECV_BEGIN_UPLOAD", new Date(), { + message: this.tokens.replaceInString(msg), + }), + ) + sendToClient(msg) + return + } + } + } + } + + // Regular WS_RECV + this.writer.writeEvent( + makeWsEvent("WS_RECV", new Date(), { message: msg }), + ) + sendToClient(msg) + }) + + // Server open: flush buffer + serverWs.on("open", () => { + for (const msg of serverSendBuffer) { + serverWs.send(msg) + } + serverSendBuffer.length = 0 + }) + + // Server close: close client if still open + serverWs.on("close", () => { + if (clientWs.readyState === WebSocket.OPEN) { + clientWs.close() + } + }) + + serverWs.on("error", (err) => { + console.error("Server WebSocket error:", err) + if (clientWs.readyState === WebSocket.OPEN) { + clientWs.close() + } + }) + + // Client -> Server relay + clientWs.on("message", (data) => { + const msg = data.toString() + + if (canIgnore(msg)) { + serverSend(msg) + return + } + + this.writer.writeEvent( + makeWsEvent("WS_SEND", new Date(), { + message: this.tokens.replaceInString(msg), + }), + ) + serverSend(msg) + }) + + // Client close: close server, record WS_CLOSE, schedule shutdown + clientWs.on("close", () => { + if (serverWs.readyState <= WebSocket.OPEN) { + serverWs.close() + } + this.writer.writeEvent(makeWsEvent("WS_CLOSE", new Date())) + this.activeWsCount-- + if (this.activeWsCount <= 0) { + this.shutdownTimer = setTimeout(() => { + if (this.activeWsCount <= 0) { + this.onShutdown() + } + }, RecordingProxy.SHUTDOWN_GRACE_MS) + } + }) + } + + // ------------------------------------------------------------------------- + // Request handling + // ------------------------------------------------------------------------- + + private async handleRequest( + req: http.IncomingMessage, + res: http.ServerResponse, + ): Promise { + const method = req.method ?? "GET" + const pathWithQuery = req.url ?? "/" + + // Build the full target URL preserving the base path + const incoming = new URL(pathWithQuery, "http://localhost") + const targetUrl = new URL(this.target.toString()) + targetUrl.pathname = + this.target.pathname.replace(/\/$/, "") + incoming.pathname + targetUrl.search = incoming.search + + // Buffer request body (for POST) + const requestBody = await this.bufferBody(req) + + // Build outgoing headers + const outHeaders = this.buildOutgoingHeaders(req.headers, targetUrl) + + // Add request body content-length for POST + if (method === "POST" && requestBody.length > 0) { + outHeaders["content-length"] = String(requestBody.length) + } + + // Capture timing + const begin = new Date() + + // Forward request to target + const forwarded = await this.forwardRequest( + method, + targetUrl, + outHeaders, + requestBody, + ) + + const end = new Date() + + // Store cookies from response + await this.storeCookies(forwarded.headers, targetUrl.toString()) + + // Record event (may introduce delay for SINF) + if (method === "POST") { + await this.recordPost( + pathWithQuery, + forwarded.statusCode, + requestBody, + begin, + end, + ) + } else if (method === "GET") { + await this.recordGet( + pathWithQuery, + forwarded.statusCode, + forwarded.body, + begin, + end, + ) + } + // Other methods (HEAD, PUT, etc.) are proxied but not recorded + + // Send response back to browser + const cleanHeaders = stripHopByHop(forwarded.headers) + // Strip content-encoding: we remove accept-encoding from outgoing requests + // so the upstream should return uncompressed content; this is a safety net + delete cleanHeaders["content-encoding"] + // Update content-length to match the actual body we're sending + cleanHeaders["content-length"] = String(forwarded.body.length) + + res.writeHead(forwarded.statusCode, cleanHeaders) + res.end(forwarded.body) + } + + // ------------------------------------------------------------------------- + // Header building + // ------------------------------------------------------------------------- + + private buildOutgoingHeaders( + incomingHeaders: http.IncomingHttpHeaders, + targetUrl: URL, + ): Record { + const cleaned = stripHopByHop(incomingHeaders) + const result: Record = {} + + // Flatten to single strings (take first value for arrays) + for (const [key, value] of Object.entries(cleaned)) { + if (Array.isArray(value)) { + result[key] = value[0] ?? "" + } else { + result[key] = value + } + } + + // Rewrite host to target + const portSuffix = targetUrl.port ? `:${targetUrl.port}` : "" + result["host"] = targetUrl.hostname + portSuffix + + // Don't send accept-encoding so we get uncompressed responses + delete result["accept-encoding"] + + // Add cookies from jar + const cookieString = this.cookieJar.getCookieStringSync( + targetUrl.toString(), + ) + if (cookieString) { + result["cookie"] = cookieString + } + + // Add auth headers + for (const [key, value] of Object.entries(this.authHeaders)) { + result[key] = value + } + + return result + } + + // ------------------------------------------------------------------------- + // Request forwarding + // ------------------------------------------------------------------------- + + private forwardRequest( + method: string, + targetUrl: URL, + headers: Record, + body: Buffer, + ): Promise { + return new Promise((resolve, reject) => { + const transport = targetUrl.protocol === "https:" ? https : http + + const options: http.RequestOptions = { + method, + hostname: targetUrl.hostname, + port: targetUrl.port || (targetUrl.protocol === "https:" ? 443 : 80), + path: targetUrl.pathname + targetUrl.search, + headers, + } + + const outReq = transport.request(options, (outRes) => { + const chunks: Buffer[] = [] + outRes.on("data", (chunk: Buffer) => chunks.push(chunk)) + outRes.on("end", () => { + resolve({ + statusCode: outRes.statusCode ?? 502, + headers: outRes.headers, + body: Buffer.concat(chunks), + }) + }) + outRes.on("error", reject) + }) + + outReq.on("error", reject) + + if (body.length > 0) { + outReq.write(body) + } + outReq.end() + }) + } + + // ------------------------------------------------------------------------- + // Body buffering + // ------------------------------------------------------------------------- + + private bufferBody(stream: http.IncomingMessage): Promise { + return new Promise((resolve, reject) => { + const chunks: Buffer[] = [] + stream.on("data", (chunk: Buffer) => chunks.push(chunk)) + stream.on("end", () => resolve(Buffer.concat(chunks))) + stream.on("error", reject) + }) + } + + // ------------------------------------------------------------------------- + // Cookie management + // ------------------------------------------------------------------------- + + private async storeCookies( + headers: http.IncomingHttpHeaders, + url: string, + ): Promise { + const setCookieHeaders = headers["set-cookie"] + if (!setCookieHeaders) return + + const cookies = Array.isArray(setCookieHeaders) + ? setCookieHeaders + : [setCookieHeaders] + + for (const cookie of cookies) { + await this.cookieJar.setCookie(cookie, url) + } + } + + // ------------------------------------------------------------------------- + // Event recording + // ------------------------------------------------------------------------- + + private async recordGet( + pathWithQuery: string, + statusCode: number, + responseBody: Buffer, + begin: Date, + end: Date, + ): Promise { + const pathOnly = pathWithQuery.split("?")[0] ?? pathWithQuery + + // Don't record favicon requests + if (shouldIgnoreGet(pathOnly)) return + + const { type, robustId } = classifyGetRequest(pathWithQuery) + const bodyText = responseBody.toString("utf-8") + + // Token discovery based on event type + if (type === "REQ_HOME") { + const workerId = extractWorkerId(bodyText) + if (workerId) { + this.tokens.discover("WORKER", workerId) + } + } else if (type === "REQ_TOK") { + const token = extractToken(bodyText) + if (token) { + this.tokens.discover("TOKEN", token) + } + } else if (type === "REQ_SINF" && robustId) { + this.tokens.discover("ROBUST_ID", robustId) + // Crude workaround: delay before responding + await delay(SINF_DELAY_MS) + } + + const recordedUrl = this.tokens.replaceInString(pathWithQuery) + + this.writer.writeEvent( + makeHttpEvent(type, begin, end, statusCode, recordedUrl), + ) + } + + private async recordPost( + pathWithQuery: string, + statusCode: number, + requestBody: Buffer, + begin: Date, + end: Date, + ): Promise { + const recordedUrl = this.tokens.replaceInString(pathWithQuery) + + let datafile: string | undefined + if (requestBody.length > 0) { + datafile = this.writer.writePostData(requestBody) + } + + this.writer.writeEvent( + makeHttpEvent("REQ_POST", begin, end, statusCode, recordedUrl, datafile), + ) + } +} diff --git a/src/record/record.ts b/src/record/record.ts new file mode 100644 index 0000000..67626d6 --- /dev/null +++ b/src/record/record.ts @@ -0,0 +1,303 @@ +import { execFile } from "node:child_process" +import * as readline from "node:readline" +import { CookieJar } from "tough-cookie" +import { VERSION } from "../version.js" +import { ServerType, SERVER_TYPE_NAMES } from "../types.js" +import { HttpClient } from "../http.js" +import { detectServerType } from "../detect.js" +import { + isProtected, + loginRSC, + loginSSP, + loginUrlFor, + extractHiddenInputs, + getCreds, + connectApiKeyHeader, +} from "../auth.js" +import { RecordingWriter } from "./writer.js" +import { RecordingTokens } from "./tokens.js" +import { RecordingProxy } from "./proxy.js" +import { RecordTerminalUI } from "./ui.js" + +export interface RecordOptions { + readonly targetUrl: string + readonly port: number + readonly host: string + readonly output: string + readonly open: boolean +} + +export async function record(options: RecordOptions): Promise { + const { targetUrl, port, host, output, open } = options + + const ui = process.stderr.isTTY ? new RecordTerminalUI() : undefined + + // Validate target URL + let parsedUrl: URL + try { + parsedUrl = new URL(targetUrl) + } catch { + throw new Error(`Invalid target URL: ${targetUrl}`) + } + + if (parsedUrl.protocol !== "http:" && parsedUrl.protocol !== "https:") { + throw new Error(`Target URL must use http or https: ${targetUrl}`) + } + + // Set up HTTP client for detection and auth + const cookieJar = new CookieJar() + const httpClient = new HttpClient({ + cookieJar, + headers: {}, + userAgent: `shinyloadtest/${VERSION}`, + }) + + try { + // Detect server type + ui?.startDetecting() + const serverType = await detectServerType(targetUrl, httpClient) + const serverTypeName = SERVER_TYPE_NAMES.get(serverType) ?? serverType + ui?.detectedServerType(serverTypeName) + if (!ui) console.error(`Target type: ${serverTypeName}`) + + // Reject shinyapps.io + if (serverType === ServerType.SAI) { + ui?.cleanup() + throw new Error("Recording shinyapps.io applications is not supported.") + } + + // RSC fragment check + if (serverType === ServerType.RSC && targetUrl.includes("#")) { + ui?.cleanup() + throw new Error( + "The app URL contains a '#' fragment. For Posit Connect, use the " + + "content URL (solo mode) instead of the dashboard URL.", + ) + } + + // Authentication + const creds = getCreds() + let authHeaders: Record = {} + let rscApiKeyRequired = false + + if (creds.connectApiKey !== null) { + ui?.startAuthenticating("Connect API key") + authHeaders = connectApiKeyHeader(creds.connectApiKey) + rscApiKeyRequired = true + ui?.authenticated("Posit Connect") + if (!ui) console.error("Logged in to Posit Connect") + } else if (await isProtected(httpClient, targetUrl)) { + let username = creds.user + let password = creds.pass + + if (username === null || password === null) { + if (!process.stdin.isTTY) { + throw new Error( + "The application requires authentication but stdin is not a TTY. " + + "Set SHINYLOADTEST_USER and SHINYLOADTEST_PASS environment variables.", + ) + } + ui?.cleanup() + console.error("The application requires authentication.") + const prompted = await promptCredentials() + username = prompted.username + password = prompted.password + } + + const loginUrl = loginUrlFor(targetUrl, serverType) + + if (serverType === ServerType.RSC) { + ui?.startAuthenticating("username/password") + await loginRSC(httpClient, loginUrl, username, password) + ui?.authenticated("Posit Connect") + if (!ui) console.error("Logged in to Posit Connect") + } else if (serverType === ServerType.SSP) { + ui?.startAuthenticating("username/password") + const loginPage = await httpClient.get(loginUrl) + const hiddenInputs = extractHiddenInputs(loginPage.body) + await loginSSP(httpClient, loginUrl, username, password, hiddenInputs) + ui?.authenticated("Shiny Server Pro") + if (!ui) console.error("Logged in to Shiny Server Pro") + } + } + + // Create recording writer + const writer = new RecordingWriter({ + outputPath: output, + targetUrl, + targetType: serverType, + rscApiKeyRequired, + }) + + try { + // Create recording tokens + const tokens = new RecordingTokens() + + // Create and start proxy + const startTime = Date.now() + let shutdownResolve: (() => void) | null = null + const shutdownPromise = new Promise((resolve) => { + shutdownResolve = resolve + }) + + let recording = false + const proxy = new RecordingProxy({ + targetUrl, + host, + port, + writer, + tokens, + cookieJar, + authHeaders, + onFirstConnection: () => { + recording = true + ui?.startRecording(() => writer.eventCount) + }, + onShutdown: () => { + ui?.stopRecording("disconnected") + if (!ui) console.error("Client disconnected. Stopping recording.") + shutdownResolve?.() + }, + }) + + await proxy.start() + + const proxyUrl = `http://${host}:${port}` + + ui?.showBanner({ version: VERSION, targetUrl, proxyUrl, output }) + ui?.startWaiting(proxyUrl) + if (!ui) { + console.error(`Proxy URL: ${proxyUrl}`) + console.error(`Output: ${output}`) + console.error( + `Navigate your browser to the proxy URL to begin recording: ${proxyUrl}`, + ) + } + + // Open browser if requested + if (open) { + openBrowser(proxyUrl) + } + + // Handle Ctrl+C + const handleSignal = (): void => { + ui?.stopRecording(recording ? "interrupted" : "cancelled") + if (!ui) console.error("Interrupted. Stopping recording.") + shutdownResolve?.() + } + process.on("SIGINT", handleSignal) + process.on("SIGTERM", handleSignal) + + // Wait for shutdown + await shutdownPromise + + // Clean up + process.removeListener("SIGINT", handleSignal) + process.removeListener("SIGTERM", handleSignal) + + await proxy.stop() + + ui?.finish({ + output, + eventCount: writer.eventCount, + postFileCount: writer.postFileCount_, + duration: Date.now() - startTime, + }) + if (!ui) { + console.error(`Recording saved to: ${output}`) + if (writer.postFileCount_ > 0) { + console.error( + `Note: ${writer.postFileCount_} POST file(s) saved alongside the recording.`, + ) + } + } + } finally { + writer.close() + } + } finally { + ui?.cleanup() + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +async function promptCredentials(): Promise<{ + username: string + password: string +}> { + const rl = readline.createInterface({ + input: process.stdin, + output: process.stderr, + }) + + const question = (prompt: string): Promise => + new Promise((resolve) => { + rl.question(prompt, (answer) => resolve(answer)) + }) + + const questionHidden = (prompt: string): Promise => + new Promise((resolve) => { + const stdin = process.stdin + const wasTTY = stdin.isTTY && typeof stdin.setRawMode === "function" + if (wasTTY) stdin.setRawMode(true) + process.stderr.write(prompt) + + let input = "" + const onData = (ch: Buffer): void => { + const c = ch.toString() + if (c === "\n" || c === "\r") { + if (wasTTY) stdin.setRawMode(false) + stdin.removeListener("data", onData) + process.stderr.write("\n") + resolve(input) + } else if (c === "\u0003") { + // Ctrl+C + if (wasTTY) stdin.setRawMode(false) + stdin.removeListener("data", onData) + resolve("") + } else if (c === "\u007f" || c === "\b") { + input = input.slice(0, -1) + } else { + input += c + } + } + stdin.resume() + stdin.on("data", onData) + }) + + try { + const username = await question("Username: ") + const password = await questionHidden("Password: ") + + if (!username || !password) { + throw new Error("Login aborted (credentials not provided).") + } + + return { username, password } + } finally { + rl.close() + } +} + +function openBrowser(url: string): void { + const platform = process.platform + let cmd: string + let args: string[] + if (platform === "darwin") { + cmd = "open" + args = [url] + } else if (platform === "win32") { + cmd = "cmd" + args = ["/c", "start", "", url] + } else { + cmd = "xdg-open" + args = [url] + } + execFile(cmd, args, (err) => { + if (err) { + console.error(`Could not open browser: ${err.message}`) + } + }) +} diff --git a/src/record/tokens.ts b/src/record/tokens.ts new file mode 100644 index 0000000..c6e7a69 --- /dev/null +++ b/src/record/tokens.ts @@ -0,0 +1,52 @@ +/** + * Token discovery and replacement for recording. + * Maps actual session values to their ${PLACEHOLDER} equivalents. + */ + +export class RecordingTokens { + // Maps actual value -> placeholder name (e.g. "abc123" -> "WORKER") + private readonly tokens = new Map() + + /** + * Register a discovered token value. + * If the value is empty or already known, this is a no-op. + */ + discover(name: string, value: string): void { + if (!value) return + if (!this.tokens.has(value)) { + this.tokens.set(value, name) + } + } + + /** + * Replace all known actual values in a string with their ${PLACEHOLDER} equivalents. + * Longer values are replaced first to avoid partial matches. + */ + replaceInString(str: string): string { + if (this.tokens.size === 0) return str + + // Sort by value length descending to replace longer matches first + const entries = [...this.tokens.entries()].sort( + (a, b) => b[0].length - a[0].length, + ) + + let result = str + for (const [actual, name] of entries) { + result = result.replaceAll(actual, `\${${name}}`) + } + return result + } + + /** Number of discovered tokens. */ + get size(): number { + return this.tokens.size + } + + /** Check if a token name has been discovered. */ + has(name: string): boolean { + for (const v of this.tokens.values()) { + if (v === name) return true + } + return false + } +} diff --git a/src/record/ui.ts b/src/record/ui.ts new file mode 100644 index 0000000..46df335 --- /dev/null +++ b/src/record/ui.ts @@ -0,0 +1,175 @@ +/** + * Terminal UI for the record subcommand. Provides spinners, colors, + * and a live event counter when stderr is a TTY. + */ + +import ora, { type Ora } from "ora" +import { bold, cyan, dim, green, yellow } from "yoctocolors" + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface RecordUIConfig { + version: string + targetUrl: string + proxyUrl: string + output: string +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function formatDuration(ms: number): string { + const totalSec = Math.max(0, Math.ceil(ms / 1000)) + const min = Math.floor(totalSec / 60) + const sec = totalSec % 60 + if (min > 0) return `${min}m ${String(sec).padStart(2, "0")}s` + return `${sec}s` +} + +function timestamp(): string { + const d = new Date() + const pad = (n: number): string => String(n).padStart(2, "0") + return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())} ${pad(d.getHours())}:${pad(d.getMinutes())}:${pad(d.getSeconds())}` +} + +// --------------------------------------------------------------------------- +// RecordTerminalUI +// --------------------------------------------------------------------------- + +export class RecordTerminalUI { + private spinner: Ora + private updateTimer: ReturnType | null = null + private recordingStartTime = 0 + private getEventCount: (() => number) | null = null + + constructor() { + this.spinner = ora({ + stream: process.stderr, + color: "cyan", + discardStdin: false, + }) + } + + startDetecting(): void { + this.spinner.start("Detecting server type...") + } + + detectedServerType(serverTypeName: string): void { + this.spinner.succeed(`Target type: ${bold(serverTypeName)}`) + } + + startAuthenticating(method: string): void { + this.spinner = ora({ + stream: process.stderr, + color: "cyan", + discardStdin: false, + }) + this.spinner.start(`Authenticating ${dim(`(${method})`)}`) + } + + authenticated(serverName: string): void { + this.spinner.succeed(`Logged in to ${bold(serverName)}`) + } + + showBanner(config: RecordUIConfig): void { + const { version, targetUrl, proxyUrl, output } = config + const w = process.stderr.write.bind(process.stderr) + + w("\n") + w(` ${bold(cyan("shinyloadtest record"))} ${dim(`v${version}`)}\n`) + w("\n") + w(` ${dim("Target:")} ${bold(targetUrl)}\n`) + w(` ${dim("Proxy:")} ${bold(cyan(proxyUrl))}\n`) + w(` ${dim("Output:")} ${bold(output)}\n`) + w("\n") + } + + startWaiting(proxyUrl: string): void { + const w = process.stderr.write.bind(process.stderr) + w(`${cyan("\u2192")} Navigate to: ${bold(cyan(proxyUrl))}\n`) + w("\n") + this.spinner = ora({ + stream: process.stderr, + color: "cyan", + discardStdin: false, + }) + this.spinner.start("Waiting for browser") + } + + startRecording(getEventCount: () => number): void { + this.getEventCount = getEventCount + this.recordingStartTime = Date.now() + + this.spinner.succeed(`Browser connected ${dim(`[${timestamp()}]`)}`) + this.spinner = ora({ + stream: process.stderr, + color: "cyan", + discardStdin: false, + }) + this.spinner.start(this.recordingText()) + + this.updateTimer = setInterval(() => { + this.spinner.text = this.recordingText() + }, 1000) + } + + stopRecording(reason?: "disconnected" | "interrupted" | "cancelled"): void { + this.stopUpdates() + this.spinner.stop() + const w = process.stderr.write.bind(process.stderr) + const label = + reason === "interrupted" + ? "Recording interrupted" + : reason === "cancelled" + ? "Recording cancelled" + : "Browser disconnected" + w(`${green("\u2714")} ${label} ${dim(`[${timestamp()}]`)}\n`) + } + + finish(config: { + output: string + eventCount: number + postFileCount: number + duration: number + }): void { + this.stopUpdates() + + const w = process.stderr.write.bind(process.stderr) + + w("\n") + w(`${green("\u2714")} Recording saved to ${bold(config.output)}\n`) + w( + ` ${dim("Events:")} ${bold(String(config.eventCount))} captured in ${bold(formatDuration(config.duration))}\n`, + ) + if (config.postFileCount > 0) { + w( + ` ${dim("POST data:")} ${yellow(String(config.postFileCount))} file(s) created\n`, + ) + } + w("\n") + } + + cleanup(): void { + this.stopUpdates() + this.spinner.stop() + } + + private recordingText(): string { + const count = this.getEventCount?.() ?? 0 + const elapsed = Date.now() - this.recordingStartTime + return [ + `${bold("Recording")} ${bold(formatDuration(elapsed))} ${dim("elapsed")} ${dim("\u2502")} ${bold(String(count))} ${dim("events captured")}`, + `${cyan("\u2139")} ${dim("Close browser to stop recording")}`, + ].join("\n") + } + + private stopUpdates(): void { + if (this.updateTimer !== null) { + clearInterval(this.updateTimer) + this.updateTimer = null + } + } +} diff --git a/src/record/writer.ts b/src/record/writer.ts new file mode 100644 index 0000000..ada05cf --- /dev/null +++ b/src/record/writer.ts @@ -0,0 +1,70 @@ +import * as fs from "node:fs" +import * as path from "node:path" +import { type RecordingEvent } from "./events.js" +import { type ServerType, SERVER_TYPE_NAMES } from "../types.js" + +export interface RecordingWriterOptions { + readonly outputPath: string + readonly targetUrl: string + readonly targetType: ServerType + readonly rscApiKeyRequired: boolean +} + +export class RecordingWriter { + private readonly fd: number + private readonly outputPath: string + private postFileCount = 0 + private eventCount_ = 0 + + constructor(options: RecordingWriterOptions) { + this.outputPath = options.outputPath + + // Open file for writing (truncate if exists) + this.fd = fs.openSync(options.outputPath, "w") + + // Write header + const targetTypeName = + SERVER_TYPE_NAMES.get(options.targetType) ?? options.targetType + + this.writeLine(`# version: 1`) + this.writeLine(`# target_url: ${options.targetUrl}`) + this.writeLine(`# target_type: ${targetTypeName}`) + if (options.rscApiKeyRequired) { + this.writeLine(`# rscApiKeyRequired: true`) + } + } + + writeEvent(event: RecordingEvent): void { + this.writeLine(JSON.stringify(event)) + this.eventCount_++ + } + + /** + * Write a POST body to an adjacent file. + * Returns the basename of the created file (for the datafile field). + */ + writePostData(data: Buffer): string { + const postPath = `${this.outputPath}.post.${this.postFileCount}` + this.postFileCount++ + fs.writeFileSync(postPath, data) + return path.basename(postPath) + } + + close(): void { + fs.closeSync(this.fd) + } + + /** Number of POST data files written. */ + get postFileCount_(): number { + return this.postFileCount + } + + /** Number of events written. */ + get eventCount(): number { + return this.eventCount_ + } + + private writeLine(line: string): void { + fs.writeSync(this.fd, line + "\n") + } +} diff --git a/src/output.ts b/src/replay/output.ts similarity index 100% rename from src/output.ts rename to src/replay/output.ts diff --git a/src/session.ts b/src/replay/session.ts similarity index 97% rename from src/session.ts rename to src/replay/session.ts index 67ed36e..9f6062e 100644 --- a/src/session.ts +++ b/src/replay/session.ts @@ -6,7 +6,7 @@ import * as fs from "node:fs" import * as path from "node:path" import { CookieJar } from "tough-cookie" -import { VERSION } from "./version.js" +import { VERSION } from "../version.js" import { loginUrlFor, extractHiddenInputs, @@ -15,27 +15,27 @@ import { loginSSP, getConnectCookies, connectApiKeyHeader, -} from "./auth.js" -import { detectServerType } from "./detect.js" +} from "../auth.js" +import { detectServerType } from "../detect.js" import { HttpClient, validateStatus, extractWorkerId, extractToken, getCookieString, -} from "./http.js" -import type { Logger } from "./logger.js" +} from "../http.js" +import type { Logger } from "../logger.js" import { SessionWriter } from "./output.js" -import { normalizeMessage, parseMessage } from "./sockjs.js" -import { replaceTokens, createTokenDictionary } from "./tokens.js" -import type { Recording, RecordingEvent, Creds } from "./types.js" +import { normalizeMessage, parseMessage } from "../sockjs.js" +import { replaceTokens, createTokenDictionary } from "../tokens.js" +import type { Recording, RecordingEvent, Creds } from "../types.js" import { ALLOWED_TOKENS, ServerType, hasUserPass, hasConnectApiKey, -} from "./types.js" -import { joinPaths, httpToWs } from "./url.js" +} from "../types.js" +import { joinPaths, httpToWs } from "../url.js" import { ShinyWebSocket } from "./websocket.js" // --------------------------------------------------------------------------- diff --git a/src/ui.ts b/src/replay/ui.ts similarity index 99% rename from src/ui.ts rename to src/replay/ui.ts index f05c7b0..c9c72f0 100644 --- a/src/ui.ts +++ b/src/replay/ui.ts @@ -68,10 +68,10 @@ function statsLine(stats: StatsCounts): string { } // --------------------------------------------------------------------------- -// TerminalUI +// ReplayTerminalUI // --------------------------------------------------------------------------- -export class TerminalUI { +export class ReplayTerminalUI { private config: UIConfig private spinner: Ora private updateTimer: ReturnType | null = null diff --git a/src/websocket.ts b/src/replay/websocket.ts similarity index 98% rename from src/websocket.ts rename to src/replay/websocket.ts index 9e0334d..f7be31e 100644 --- a/src/websocket.ts +++ b/src/replay/websocket.ts @@ -2,8 +2,8 @@ // Wraps the `ws` library for use during Shiny session playback. import WebSocket from "ws" -import { canIgnore } from "./sockjs.js" -import { RECEIVE_QUEUE_SIZE } from "./types.js" +import { canIgnore } from "../sockjs.js" +import { RECEIVE_QUEUE_SIZE } from "../types.js" // --------------------------------------------------------------------------- // AsyncQueue diff --git a/src/worker.ts b/src/replay/worker.ts similarity index 96% rename from src/worker.ts rename to src/replay/worker.ts index 1b01dd0..65656b9 100644 --- a/src/worker.ts +++ b/src/replay/worker.ts @@ -4,11 +4,11 @@ * staggered start, loaded duration control, and progress reporting. */ -import type { Logger } from "./logger.js" +import type { Logger } from "../logger.js" import { Stats, runSession } from "./session.js" import type { SessionConfig } from "./session.js" -import type { Recording, Creds } from "./types.js" -import type { TerminalUI } from "./ui.js" +import type { Recording, Creds } from "../types.js" +import type { ReplayTerminalUI } from "./ui.js" // --------------------------------------------------------------------------- // Types @@ -27,7 +27,7 @@ export interface EnduranceTestConfig { logger: Logger argsString: string argsJson: string - ui?: TerminalUI + ui?: ReplayTerminalUI } // --------------------------------------------------------------------------- diff --git a/src/tests/auth-integration.test.ts b/src/tests/auth-integration.test.ts index ead4f63..0fcf0ba 100644 --- a/src/tests/auth-integration.test.ts +++ b/src/tests/auth-integration.test.ts @@ -3,8 +3,8 @@ import * as fs from "node:fs" import * as path from "node:path" import * as os from "node:os" import { MockShinyServer } from "./helpers/mock-shiny-server.js" -import { runSession, Stats } from "../session.js" -import type { SessionConfig } from "../session.js" +import { runSession, Stats } from "../replay/session.js" +import type { SessionConfig } from "../replay/session.js" import { readRecordingFromString } from "../recording.js" import type { Logger } from "../logger.js" diff --git a/src/tests/cli.test.ts b/src/tests/cli.test.ts index e381fa7..bd39c78 100644 --- a/src/tests/cli.test.ts +++ b/src/tests/cli.test.ts @@ -172,7 +172,7 @@ describe("parseArgs", () => { }) it("defaults startInterval to null when not provided", () => { - const args = parseArgs([ + const result = parseArgs([ "node", "script", "replay", @@ -180,28 +180,34 @@ describe("parseArgs", () => { "http://example.com", ]) - expect(args.startInterval).toBeNull() - expect(args.recordingPath).toBe(recordingFile) - expect(args.appUrl).toBe("http://example.com") - expect(args.workers).toBe(1) - expect(args.loadedDurationMinutes).toBe(5) - expect(args.logLevel).toBe(LogLevel.WARN) + expect(result.command).toBe("replay") + if (result.command !== "replay") return + expect(result.args.startInterval).toBeNull() + expect(result.args.recordingPath).toBe(recordingFile) + expect(result.args.appUrl).toBe("http://example.com") + expect(result.args.workers).toBe(1) + expect(result.args.loadedDurationMinutes).toBe(5) + expect(result.args.logLevel).toBe(LogLevel.WARN) }) it("uses explicit app-url when provided", () => { - const args = parseArgs([ + const result = parseArgs([ "node", "script", "replay", recordingFile, "http://override.example.com", ]) - expect(args.appUrl).toBe("http://override.example.com") + expect(result.command).toBe("replay") + if (result.command !== "replay") return + expect(result.args.appUrl).toBe("http://override.example.com") }) it("resolves app-url from recording target_url when omitted", () => { - const args = parseArgs(["node", "script", "replay", recordingFile]) - expect(args.appUrl).toBe("http://recorded-host.example.com") + const result = parseArgs(["node", "script", "replay", recordingFile]) + expect(result.command).toBe("replay") + if (result.command !== "replay") return + expect(result.args.appUrl).toBe("http://recorded-host.example.com") }) }) diff --git a/src/tests/error-handling.test.ts b/src/tests/error-handling.test.ts index 04adaa0..57c4f02 100644 --- a/src/tests/error-handling.test.ts +++ b/src/tests/error-handling.test.ts @@ -4,12 +4,12 @@ import * as net from "node:net" import * as path from "node:path" import * as os from "node:os" -import { runSession, Stats } from "../session.js" +import { runSession, Stats } from "../replay/session.js" import { readRecordingFromString } from "../recording.js" import { MockShinyServer } from "./helpers/mock-shiny-server.js" import { createLogger, LogLevel } from "../logger.js" import type { Logger } from "../logger.js" -import { createOutputDir } from "../output.js" +import { createOutputDir } from "../replay/output.js" /** Get an unused local port by briefly binding and releasing. */ function getUnusedPort(): Promise { diff --git a/src/tests/output.test.ts b/src/tests/output.test.ts index 91da71a..e3a4793 100644 --- a/src/tests/output.test.ts +++ b/src/tests/output.test.ts @@ -3,7 +3,11 @@ import * as fs from "node:fs" import * as path from "node:path" import * as os from "node:os" -import { createOutputDir, defaultOutputDir, SessionWriter } from "../output.js" +import { + createOutputDir, + defaultOutputDir, + SessionWriter, +} from "../replay/output.js" describe("defaultOutputDir", () => { it("replaces colons with underscores", () => { diff --git a/src/tests/record-e2e.test.ts b/src/tests/record-e2e.test.ts new file mode 100644 index 0000000..70a580e --- /dev/null +++ b/src/tests/record-e2e.test.ts @@ -0,0 +1,430 @@ +import { describe, it, expect, beforeEach, afterEach } from "vitest" +import * as http from "node:http" +import * as fs from "node:fs" +import * as os from "node:os" +import * as path from "node:path" +import WebSocket from "ws" +import { WebSocketServer } from "ws" +import { CookieJar } from "tough-cookie" +import { RecordingProxy } from "../record/proxy.js" +import { RecordingWriter } from "../record/writer.js" +import { RecordingTokens } from "../record/tokens.js" +import { ServerType } from "../types.js" + +// --------------------------------------------------------------------------- +// Mock target server (HTTP + WebSocket) +// --------------------------------------------------------------------------- + +interface MockTarget { + server: http.Server + wss: WebSocketServer + port: number + start(): Promise + stop(): Promise +} + +function createMockTarget(): MockTarget { + const server = http.createServer((req, res) => { + const url = req.url ?? "/" + + if (req.method === "GET" && (url === "/" || url.startsWith("/?"))) { + res.writeHead(200, { "Content-Type": "text/html" }) + res.end( + "" + + '' + + '' + + "Shiny App", + ) + return + } + + if (req.method === "GET" && url === "/__token__") { + res.writeHead(200, { "Content-Type": "text/plain" }) + res.end("tokenvalue123") + return + } + + if ( + req.method === "GET" && + url.startsWith("/__sockjs__/") && + url.includes("n=") + ) { + res.writeHead(200, { "Content-Type": "application/json" }) + res.end( + JSON.stringify({ + websocket: true, + cookie_needed: false, + origins: ["*:*"], + entropy: 87654321, + }), + ) + return + } + + if (req.method === "POST" && url === "/upload") { + const chunks: Buffer[] = [] + req.on("data", (chunk: Buffer) => chunks.push(chunk)) + req.on("end", () => { + const body = Buffer.concat(chunks) + res.writeHead(200, { "Content-Type": "text/plain" }) + res.end(body) + }) + return + } + + res.writeHead(404) + res.end("Not found") + }) + + const wss = new WebSocketServer({ server }) + + wss.on("connection", (ws) => { + ws.send("o") + + const initMsg = JSON.stringify({ + config: { sessionId: "session999" }, + custom: {}, + }) + ws.send(initMsg) + + ws.on("message", (_data) => { + if (ws.readyState === WebSocket.OPEN) { + ws.send( + JSON.stringify({ + values: { x: 1 }, + inputMessages: [], + errors: {}, + }), + ) + } + }) + }) + + let port = 0 + + return { + server, + wss, + get port() { + return port + }, + start(): Promise { + return new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as import("node:net").AddressInfo + port = addr.port + resolve() + }) + }) + }, + stop(): Promise { + return new Promise((resolve) => { + wss.close(() => { + server.close(() => resolve()) + }) + }) + }, + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function httpRequest( + port: number, + method: string, + path: string, + body?: string, +): Promise<{ statusCode: number; body: string }> { + return new Promise((resolve, reject) => { + const req = http.request( + { hostname: "127.0.0.1", port, method, path }, + (res) => { + const chunks: Buffer[] = [] + res.on("data", (chunk: Buffer) => chunks.push(chunk)) + res.on("end", () => { + resolve({ + statusCode: res.statusCode ?? 0, + body: Buffer.concat(chunks).toString(), + }) + }) + res.on("error", reject) + }, + ) + req.on("error", reject) + if (body) req.write(body) + req.end() + }) +} + +function readRecordingLines(recordingPath: string): string[] { + return fs + .readFileSync(recordingPath, "utf-8") + .split("\n") + .filter((line) => line.length > 0) +} + +function readRecordingEvents( + recordingPath: string, +): Array> { + return readRecordingLines(recordingPath) + .filter((line) => !line.startsWith("#")) + .map((line) => JSON.parse(line) as Record) +} + +function readRecordingHeaders(recordingPath: string): string[] { + return readRecordingLines(recordingPath).filter((line) => + line.startsWith("#"), + ) +} + +function waitForMessages( + ws: WebSocket, + count = 1, + timeoutMs = 5000, +): Promise { + return new Promise((resolve, reject) => { + const msgs: string[] = [] + const timer = setTimeout(() => { + ws.removeListener("message", onMessage) + reject( + new Error( + `waitForMessages: timed out after ${timeoutMs}ms (got ${msgs.length}/${count})`, + ), + ) + }, timeoutMs) + const onMessage = (data: WebSocket.RawData): void => { + msgs.push(data.toString()) + if (msgs.length >= count) { + clearTimeout(timer) + ws.removeListener("message", onMessage) + resolve(msgs) + } + } + ws.on("message", onMessage) + }) +} + +function waitForClose(ws: WebSocket): Promise { + return new Promise((resolve) => { + ws.on("close", () => resolve()) + }) +} + +function delay(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +// --------------------------------------------------------------------------- +// Test suite +// --------------------------------------------------------------------------- + +describe("Recording E2E lifecycle", () => { + let mockTarget: MockTarget + let tmpDir: string + let recordingPath: string + let writer: RecordingWriter + let tokens: RecordingTokens + let proxy: RecordingProxy + let proxyPort: number + let shutdownCalled: boolean + + beforeEach(async () => { + mockTarget = createMockTarget() + await mockTarget.start() + + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "shinycannon-e2e-test-")) + recordingPath = path.join(tmpDir, "recording.log") + + writer = new RecordingWriter({ + outputPath: recordingPath, + targetUrl: `http://127.0.0.1:${mockTarget.port}`, + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + tokens = new RecordingTokens() + shutdownCalled = false + + proxy = new RecordingProxy({ + targetUrl: `http://127.0.0.1:${mockTarget.port}`, + host: "127.0.0.1", + port: 0, + writer, + tokens, + cookieJar: new CookieJar(), + authHeaders: {}, + onShutdown: () => { + shutdownCalled = true + }, + }) + + await proxy.start() + const addr = proxy.httpServer!.address() as import("node:net").AddressInfo + proxyPort = addr.port + }) + + afterEach(async () => { + await proxy.stop() + writer.close() + await mockTarget.stop() + fs.rmSync(tmpDir, { recursive: true, force: true }) + }) + + it( + "E2E-01: full recording lifecycle produces valid recording file", + { timeout: 10000 }, + async () => { + // --- HTTP phase --- + + const homeRes = await httpRequest(proxyPort, "GET", "/") + expect(homeRes.statusCode).toBe(200) + await delay(50) + + const tokRes = await httpRequest(proxyPort, "GET", "/__token__") + expect(tokRes.body).toBe("tokenvalue123") + await delay(50) + + // SINF has a ~750ms delay inside the proxy + const sinfRes = await httpRequest( + proxyPort, + "GET", + "/__sockjs__/000/abc/n=robustXYZ", + ) + expect(sinfRes.statusCode).toBe(200) + await delay(50) + + // --- POST phase --- + + const postRes = await httpRequest( + proxyPort, + "POST", + "/upload", + "upload-data", + ) + expect(postRes.statusCode).toBe(200) + await delay(50) + + // --- WebSocket phase --- + + const ws = new WebSocket( + `ws://127.0.0.1:${proxyPort}/__sockjs__/000/sess001/websocket`, + ) + + // Collect "o" + config init (2 server-initiated messages) upfront, + // then send a client message and collect the echo (1 more = 3 total). + const serverMsgsPromise = waitForMessages(ws, 3) + + // Wait for the WS to be open before sending + await new Promise((resolve, reject) => { + if (ws.readyState === WebSocket.OPEN) { + resolve() + } else { + ws.on("open", () => resolve()) + ws.on("error", reject) + } + }) + + // Give the server time to send "o" and init before we send + await delay(100) + ws.send(JSON.stringify({ method: "init", data: {} })) + + const serverMsgs = await serverMsgsPromise + expect(serverMsgs[0]).toBe("o") + expect(serverMsgs[1]).toContain("sessionId") + + await delay(100) + + // Close the WS and wait for shutdown callback + ws.close() + await waitForClose(ws) + // Wait longer than the shutdown grace period (500ms) + await delay(700) + + expect(shutdownCalled).toBe(true) + + // --- Verify recording file --- + + // Header checks + const headers = readRecordingHeaders(recordingPath) + expect(headers).toContain("# version: 1") + expect( + headers.some((h) => h.includes(`http://127.0.0.1:${mockTarget.port}`)), + ).toBe(true) + expect(headers.some((h) => h.includes("R/Shiny"))).toBe(true) + expect(headers.every((h) => !h.includes("rscApiKeyRequired"))).toBe(true) + + // Events in order + const events = readRecordingEvents(recordingPath) + const types = events.map((e) => e["type"]) + + const homeIdx = types.indexOf("REQ_HOME") + const tokIdx = types.indexOf("REQ_TOK") + const sinfIdx = types.indexOf("REQ_SINF") + const postIdx = types.indexOf("REQ_POST") + const wsOpenIdx = types.indexOf("WS_OPEN") + const wsCloseIdx = types.lastIndexOf("WS_CLOSE") + + expect(homeIdx).toBeGreaterThanOrEqual(0) + expect(tokIdx).toBeGreaterThan(homeIdx) + expect(sinfIdx).toBeGreaterThan(tokIdx) + expect(postIdx).toBeGreaterThan(sinfIdx) + expect(wsOpenIdx).toBeGreaterThan(postIdx) + expect(wsCloseIdx).toBeGreaterThan(wsOpenIdx) + + // WS_RECV "o" frame + const wsRecvO = events.find( + (e) => e["type"] === "WS_RECV" && e["message"] === "o", + ) + expect(wsRecvO).toBeDefined() + expect(types.indexOf("WS_RECV")).toBeGreaterThan(wsOpenIdx) + + // WS_RECV_INIT + const wsRecvInit = events.find((e) => e["type"] === "WS_RECV_INIT") + expect(wsRecvInit).toBeDefined() + expect(wsRecvInit!["message"]).toContain("${SESSION}") + expect(wsRecvInit!["message"] as string).not.toContain("session999") + + // WS_SEND + const wsSend = events.find((e) => e["type"] === "WS_SEND") + expect(wsSend).toBeDefined() + + // WS_RECV echo (not the "o" frame) + const wsRecvEcho = events.find( + (e) => + e["type"] === "WS_RECV" && + typeof e["message"] === "string" && + (e["message"] as string).includes("values"), + ) + expect(wsRecvEcho).toBeDefined() + + // Token replacements in recorded URLs + const homeEvent = events[homeIdx] + expect(homeEvent!["url"]).toBe("/") + + const sinfEvent = events[sinfIdx] + expect(sinfEvent!["url"]).toContain("${ROBUST_ID}") + expect(sinfEvent!["url"] as string).not.toContain("robustXYZ") + + const wsOpenEvent = events[wsOpenIdx] + expect(wsOpenEvent!["url"]).toContain("${SOCKJSID}") + expect(wsOpenEvent!["url"] as string).not.toContain("sess001") + + // POST datafile + const postEvent = events[postIdx] + expect(typeof postEvent!["datafile"]).toBe("string") + const datafileName = postEvent!["datafile"] as string + const datafilePath = path.join(tmpDir, datafileName) + expect(fs.existsSync(datafilePath)).toBe(true) + expect(fs.readFileSync(datafilePath, "utf-8")).toBe("upload-data") + + // Token discovery + expect(tokens.has("WORKER")).toBe(true) + expect(tokens.has("TOKEN")).toBe(true) + expect(tokens.has("ROBUST_ID")).toBe(true) + expect(tokens.has("SOCKJSID")).toBe(true) + expect(tokens.has("SESSION")).toBe(true) + }, + ) +}) diff --git a/src/tests/record-events.test.ts b/src/tests/record-events.test.ts new file mode 100644 index 0000000..0c8765c --- /dev/null +++ b/src/tests/record-events.test.ts @@ -0,0 +1,122 @@ +import { describe, it, expect } from "vitest" + +import { + toISOTimestamp, + makeHttpEvent, + makeWsEvent, + classifyGetRequest, +} from "../record/events.js" + +describe("toISOTimestamp()", () => { + it("returns ISO 8601 string", () => { + const date = new Date("2024-01-15T10:30:00.000Z") + expect(toISOTimestamp(date)).toBe("2024-01-15T10:30:00.000Z") + }) +}) + +describe("makeHttpEvent()", () => { + it("creates event with correct fields", () => { + const begin = new Date("2024-01-15T10:00:00.000Z") + const end = new Date("2024-01-15T10:00:01.000Z") + const event = makeHttpEvent("REQ_HOME", begin, end, 200, "/") + expect(event.type).toBe("REQ_HOME") + expect(event.begin).toBe("2024-01-15T10:00:00.000Z") + expect(event.end).toBe("2024-01-15T10:00:01.000Z") + expect(event.status).toBe(200) + expect(event.url).toBe("/") + expect(event.datafile).toBeUndefined() + }) + + it("includes datafile when provided", () => { + const begin = new Date("2024-01-15T10:00:00.000Z") + const end = new Date("2024-01-15T10:00:01.000Z") + const event = makeHttpEvent( + "REQ_POST", + begin, + end, + 200, + "/upload", + "recording.log.post.0", + ) + expect(event.type).toBe("REQ_POST") + expect(event.datafile).toBe("recording.log.post.0") + }) +}) + +describe("makeWsEvent()", () => { + it("creates event with just type and begin for WS_CLOSE", () => { + const begin = new Date("2024-01-15T10:00:00.000Z") + const event = makeWsEvent("WS_CLOSE", begin) + expect(event.type).toBe("WS_CLOSE") + expect(event.begin).toBe("2024-01-15T10:00:00.000Z") + expect(event.url).toBeUndefined() + expect(event.message).toBeUndefined() + }) + + it("includes url for WS_OPEN", () => { + const begin = new Date("2024-01-15T10:00:00.000Z") + const event = makeWsEvent("WS_OPEN", begin, { url: "/app/ws" }) + expect(event.type).toBe("WS_OPEN") + expect(event.url).toBe("/app/ws") + expect(event.message).toBeUndefined() + }) + + it("includes message for WS_RECV", () => { + const begin = new Date("2024-01-15T10:00:00.000Z") + const event = makeWsEvent("WS_RECV", begin, { message: '{"data":1}' }) + expect(event.type).toBe("WS_RECV") + expect(event.message).toBe('{"data":1}') + expect(event.url).toBeUndefined() + }) +}) + +describe("classifyGetRequest()", () => { + it('classifies "/" as REQ_HOME', () => { + expect(classifyGetRequest("/")).toEqual({ type: "REQ_HOME" }) + }) + + it('classifies "/app/" as REQ_HOME', () => { + expect(classifyGetRequest("/app/")).toEqual({ type: "REQ_HOME" }) + }) + + it('classifies "/app/something.Rmd" as REQ_HOME (case-insensitive)', () => { + expect(classifyGetRequest("/app/something.Rmd")).toEqual({ + type: "REQ_HOME", + }) + }) + + it('classifies "/__token__" as REQ_TOK', () => { + expect(classifyGetRequest("/__token__")).toEqual({ type: "REQ_TOK" }) + }) + + it('classifies "/__sockjs__/000/abc123/n=xyz789" as REQ_SINF with robustId', () => { + expect(classifyGetRequest("/__sockjs__/000/abc123/n=xyz789")).toEqual({ + type: "REQ_SINF", + robustId: "xyz789", + }) + }) + + it('classifies "/shared/shiny.js" as REQ_GET', () => { + expect(classifyGetRequest("/shared/shiny.js")).toEqual({ type: "REQ_GET" }) + }) + + it("classifies path with query string correctly", () => { + expect(classifyGetRequest("/app/?_ga=123")).toEqual({ type: "REQ_HOME" }) + expect(classifyGetRequest("/shared/shiny.js?v=1")).toEqual({ + type: "REQ_GET", + }) + expect( + classifyGetRequest("/__sockjs__/000/abc123/n=xyz789?foo=bar"), + ).toEqual({ type: "REQ_SINF", robustId: "xyz789" }) + }) + + it("classifies REQ_SINF with n= in query params", () => { + expect(classifyGetRequest("/__sockjs__/000/abc123?n=xyz789")).toEqual({ + type: "REQ_SINF", + robustId: "xyz789", + }) + expect( + classifyGetRequest("/__sockjs__/000/abc123?foo=bar&n=xyz789"), + ).toEqual({ type: "REQ_SINF", robustId: "xyz789" }) + }) +}) diff --git a/src/tests/record-proxy.test.ts b/src/tests/record-proxy.test.ts new file mode 100644 index 0000000..73027b3 --- /dev/null +++ b/src/tests/record-proxy.test.ts @@ -0,0 +1,324 @@ +import { describe, it, expect, beforeEach, afterEach } from "vitest" +import * as http from "node:http" +import * as fs from "node:fs" +import * as os from "node:os" +import * as path from "node:path" +import { CookieJar } from "tough-cookie" +import { RecordingProxy } from "../record/proxy.js" +import { RecordingWriter } from "../record/writer.js" +import { RecordingTokens } from "../record/tokens.js" +import { ServerType } from "../types.js" + +// --------------------------------------------------------------------------- +// Mock target server +// --------------------------------------------------------------------------- + +function createMockTarget(): http.Server { + return http.createServer((req, res) => { + const url = req.url ?? "/" + + if (req.method === "GET" && (url === "/" || url.startsWith("/?"))) { + res.writeHead(200, { "Content-Type": "text/html", "x-custom": "hello" }) + res.end( + "" + + '' + + '' + + "Shiny App", + ) + return + } + + if (req.method === "GET" && url === "/__token__") { + res.writeHead(200, { "Content-Type": "text/plain" }) + res.end("test-token-abc") + return + } + + if ( + req.method === "GET" && + url.startsWith("/__sockjs__/") && + url.includes("n=") + ) { + res.writeHead(200, { "Content-Type": "application/json" }) + res.end( + JSON.stringify({ + websocket: true, + cookie_needed: false, + origins: ["*:*"], + entropy: 12345678, + }), + ) + return + } + + if (req.method === "POST" && url === "/upload") { + const chunks: Buffer[] = [] + req.on("data", (chunk: Buffer) => chunks.push(chunk)) + req.on("end", () => { + const body = Buffer.concat(chunks) + res.writeHead(200, { "Content-Type": "text/plain" }) + res.end(body) + }) + return + } + + if (req.method === "GET" && url === "/favicon.ico") { + res.writeHead(200, { "Content-Type": "image/x-icon" }) + res.end("") + return + } + + if (req.method === "GET" && url.includes("/shared/")) { + res.writeHead(200, { "Content-Type": "application/javascript" }) + res.end("// js") + return + } + + res.writeHead(404) + res.end("Not found") + }) +} + +// --------------------------------------------------------------------------- +// HTTP request helper +// --------------------------------------------------------------------------- + +interface ProxyResponse { + statusCode: number + headers: http.IncomingHttpHeaders + body: string +} + +function makeRequest( + proxyPort: number, + method: string, + urlPath: string, + body?: string, +): Promise { + return new Promise((resolve, reject) => { + const options: http.RequestOptions = { + hostname: "127.0.0.1", + port: proxyPort, + path: urlPath, + method, + headers: body + ? { + "content-type": "text/plain", + "content-length": String(Buffer.byteLength(body)), + } + : {}, + } + + const req = http.request(options, (res) => { + const chunks: Buffer[] = [] + res.on("data", (chunk: Buffer) => chunks.push(chunk)) + res.on("end", () => { + resolve({ + statusCode: res.statusCode ?? 0, + headers: res.headers, + body: Buffer.concat(chunks).toString("utf-8"), + }) + }) + res.on("error", reject) + }) + + req.on("error", reject) + + if (body) { + req.write(body) + } + req.end() + }) +} + +function get(proxyPort: number, urlPath: string): Promise { + return makeRequest(proxyPort, "GET", urlPath) +} + +function post( + proxyPort: number, + urlPath: string, + body: string, +): Promise { + return makeRequest(proxyPort, "POST", urlPath, body) +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +function startServer(server: http.Server): Promise { + return new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as import("node:net").AddressInfo + resolve(addr.port) + }) + }) +} + +function stopServer(server: http.Server): Promise { + return new Promise((resolve) => { + server.close(() => resolve()) + }) +} + +function readRecordingEvents( + recordingPath: string, +): Array> { + return fs + .readFileSync(recordingPath, "utf-8") + .split("\n") + .filter((line) => line.length > 0 && !line.startsWith("#")) + .map((line) => JSON.parse(line) as Record) +} + +// --------------------------------------------------------------------------- +// Test suite +// --------------------------------------------------------------------------- + +describe("RecordingProxy", () => { + let mockTarget: http.Server + let mockTargetPort: number + let tmpDir: string + let recordingPath: string + let writer: RecordingWriter + let tokens: RecordingTokens + let proxy: RecordingProxy + let proxyPort: number + + beforeEach(async () => { + mockTarget = createMockTarget() + mockTargetPort = await startServer(mockTarget) + + tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "shinycannon-proxy-test-")) + recordingPath = path.join(tmpDir, "recording.log") + + writer = new RecordingWriter({ + outputPath: recordingPath, + targetUrl: `http://127.0.0.1:${mockTargetPort}`, + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + tokens = new RecordingTokens() + + // Find a free port for the proxy by using port 0 and extracting it after start + proxy = new RecordingProxy({ + targetUrl: `http://127.0.0.1:${mockTargetPort}`, + host: "127.0.0.1", + port: 0, + writer, + tokens, + cookieJar: new CookieJar(), + authHeaders: {}, + onShutdown: () => {}, + }) + + await proxy.start() + const server = proxy.httpServer! + const addr = server.address() as import("node:net").AddressInfo + proxyPort = addr.port + }) + + afterEach(async () => { + await proxy.stop() + writer.close() + await stopServer(mockTarget) + fs.rmSync(tmpDir, { recursive: true, force: true }) + }) + + it("PRXY-01: GET / is proxied and recorded as REQ_HOME", async () => { + const res = await get(proxyPort, "/") + + expect(res.statusCode).toBe(200) + expect(res.body).toContain('') + + const events = readRecordingEvents(recordingPath) + const homeEvent = events.find((e) => e["type"] === "REQ_HOME") + expect(homeEvent).toBeDefined() + expect(homeEvent!["status"]).toBe(200) + + expect(tokens.has("WORKER")).toBe(true) + }) + + it("PRXY-02: GET __token__ is recorded as REQ_TOK", async () => { + const res = await get(proxyPort, "/__token__") + + expect(res.body).toBe("test-token-abc") + + const events = readRecordingEvents(recordingPath) + const tokEvent = events.find((e) => e["type"] === "REQ_TOK") + expect(tokEvent).toBeDefined() + expect(tokEvent!["status"]).toBe(200) + + expect(tokens.has("TOKEN")).toBe(true) + }) + + it( + "PRXY-03: GET __sockjs__ info is recorded as REQ_SINF with ROBUST_ID discovery", + { timeout: 5000 }, + async () => { + const res = await get(proxyPort, "/__sockjs__/000/abc/n=robustid123") + + expect(res.statusCode).toBe(200) + + const events = readRecordingEvents(recordingPath) + const sinfEvent = events.find((e) => e["type"] === "REQ_SINF") + expect(sinfEvent).toBeDefined() + expect(sinfEvent!["status"]).toBe(200) + + expect(tokens.has("ROBUST_ID")).toBe(true) + }, + ) + + it("PRXY-04: POST is recorded as REQ_POST with datafile", async () => { + const res = await post(proxyPort, "/upload", "test-post-body") + + expect(res.statusCode).toBe(200) + expect(res.body).toBe("test-post-body") + + const events = readRecordingEvents(recordingPath) + const postEvent = events.find((e) => e["type"] === "REQ_POST") + expect(postEvent).toBeDefined() + expect(typeof postEvent!["datafile"]).toBe("string") + + const datafileName = postEvent!["datafile"] as string + const datafilePath = path.join(tmpDir, datafileName) + expect(fs.existsSync(datafilePath)).toBe(true) + expect(fs.readFileSync(datafilePath, "utf-8")).toBe("test-post-body") + }) + + it("PRXY-05: favicon.ico is proxied but NOT recorded", async () => { + const res = await get(proxyPort, "/favicon.ico") + + expect(res.statusCode).toBe(200) + + const events = readRecordingEvents(recordingPath) + expect(events.length).toBe(0) + }) + + it("PRXY-06: token replacement in recorded URLs", async () => { + // First request discovers WORKER token + await get(proxyPort, "/") + // Second request uses the worker path + await get(proxyPort, "/_w_testworker/shared/something.js") + + const events = readRecordingEvents(recordingPath) + + const getEvent = events.find( + (e) => e["type"] === "REQ_GET" && (e["url"] as string).includes("shared"), + ) + expect(getEvent).toBeDefined() + expect(getEvent!["url"]).toBe("/_w_${WORKER}/shared/something.js") + }) + + it("PRXY-07: response headers are passed through (minus hop-by-hop)", async () => { + const res = await get(proxyPort, "/") + + // Custom header passes through + expect(res.headers["x-custom"]).toBe("hello") + + // Hop-by-hop headers are stripped + expect(res.headers["transfer-encoding"]).toBeUndefined() + }) +}) diff --git a/src/tests/record-tokens.test.ts b/src/tests/record-tokens.test.ts new file mode 100644 index 0000000..d0aa880 --- /dev/null +++ b/src/tests/record-tokens.test.ts @@ -0,0 +1,64 @@ +import { describe, it, expect } from "vitest" + +import { RecordingTokens } from "../record/tokens.js" + +describe("RecordingTokens", () => { + describe("discover()", () => { + it("registers a token and increases size", () => { + const tokens = new RecordingTokens() + expect(tokens.size).toBe(0) + tokens.discover("WORKER", "abc123") + expect(tokens.size).toBe(1) + }) + + it("is a no-op for empty values", () => { + const tokens = new RecordingTokens() + tokens.discover("WORKER", "") + expect(tokens.size).toBe(0) + }) + }) + + describe("replaceInString()", () => { + it("replaces known values with ${PLACEHOLDER}", () => { + const tokens = new RecordingTokens() + tokens.discover("WORKER", "abc123") + expect(tokens.replaceInString("hello abc123 world")).toBe( + "hello ${WORKER} world", + ) + }) + + it("returns the original string when no tokens are registered", () => { + const tokens = new RecordingTokens() + expect(tokens.replaceInString("hello world")).toBe("hello world") + }) + + it("replaces longer values first to avoid partial matches", () => { + const tokens = new RecordingTokens() + tokens.discover("WORKER", "abc") + tokens.discover("SESSION", "abcdef") + expect(tokens.replaceInString("abcdef")).toBe("${SESSION}") + }) + + it("replaces all occurrences in the string", () => { + const tokens = new RecordingTokens() + tokens.discover("TOKEN", "xyz") + expect(tokens.replaceInString("xyz and xyz")).toBe( + "${TOKEN} and ${TOKEN}", + ) + }) + }) + + describe("has()", () => { + it("returns true for discovered token names", () => { + const tokens = new RecordingTokens() + tokens.discover("SESSION", "sess-val") + expect(tokens.has("SESSION")).toBe(true) + }) + + it("returns false for undiscovered token names", () => { + const tokens = new RecordingTokens() + tokens.discover("SESSION", "sess-val") + expect(tokens.has("WORKER")).toBe(false) + }) + }) +}) diff --git a/src/tests/record-writer.test.ts b/src/tests/record-writer.test.ts new file mode 100644 index 0000000..6208b5a --- /dev/null +++ b/src/tests/record-writer.test.ts @@ -0,0 +1,153 @@ +import * as fs from "node:fs" +import * as os from "node:os" +import * as path from "node:path" +import { describe, it, expect, afterEach } from "vitest" + +import { RecordingWriter } from "../record/writer.js" +import { ServerType } from "../types.js" + +let tempFiles: string[] = [] + +function tempPath(name: string): string { + const p = path.join(os.tmpdir(), `record-writer-test-${Date.now()}-${name}`) + tempFiles.push(p) + return p +} + +afterEach(() => { + for (const f of tempFiles) { + // Remove base file and any adjacent post data files + for (let i = 0; i < 10; i++) { + const postFile = `${f}.post.${i}` + if (fs.existsSync(postFile)) fs.unlinkSync(postFile) + } + if (fs.existsSync(f)) fs.unlinkSync(f) + } + tempFiles = [] +}) + +describe("RecordingWriter constructor", () => { + it("writes header with version, target_url, target_type", () => { + const outputPath = tempPath("header.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://example.com/app", + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + writer.close() + + const contents = fs.readFileSync(outputPath, "utf8") + expect(contents).toContain("# version: 1") + expect(contents).toContain("# target_url: https://example.com/app") + expect(contents).toContain("# target_type: R/Shiny") + }) + + it("includes rscApiKeyRequired line when true", () => { + const outputPath = tempPath("rsc-true.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://connect.example.com/app", + targetType: ServerType.RSC, + rscApiKeyRequired: true, + }) + writer.close() + + const contents = fs.readFileSync(outputPath, "utf8") + expect(contents).toContain("# rscApiKeyRequired: true") + }) + + it("omits rscApiKeyRequired line when false", () => { + const outputPath = tempPath("rsc-false.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://example.com/app", + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + writer.close() + + const contents = fs.readFileSync(outputPath, "utf8") + expect(contents).not.toContain("rscApiKeyRequired") + }) +}) + +describe("RecordingWriter writeEvent()", () => { + it("appends JSON line", () => { + const outputPath = tempPath("events.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://example.com/app", + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + const event = { + type: "REQ_HOME" as const, + begin: "2024-01-15T10:00:00.000Z", + end: "2024-01-15T10:00:01.000Z", + status: 200, + url: "/", + } + writer.writeEvent(event) + writer.close() + + const lines = fs.readFileSync(outputPath, "utf8").split("\n") + const jsonLine = lines.find((l) => l.startsWith("{")) + expect(jsonLine).toBeDefined() + expect(JSON.parse(jsonLine!)).toEqual(event) + }) +}) + +describe("RecordingWriter writePostData()", () => { + it("creates adjacent file with correct name", () => { + const outputPath = tempPath("post.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://example.com/app", + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + const data = Buffer.from("hello=world") + const basename = writer.writePostData(data) + writer.close() + + expect(basename).toBe(path.basename(`${outputPath}.post.0`)) + const postPath = `${outputPath}.post.0` + expect(fs.existsSync(postPath)).toBe(true) + expect(fs.readFileSync(postPath)).toEqual(data) + }) + + it("increments file counter", () => { + const outputPath = tempPath("postcounter.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://example.com/app", + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + const first = writer.writePostData(Buffer.from("first")) + const second = writer.writePostData(Buffer.from("second")) + writer.close() + + expect(first).toBe(path.basename(`${outputPath}.post.0`)) + expect(second).toBe(path.basename(`${outputPath}.post.1`)) + expect(fs.existsSync(`${outputPath}.post.0`)).toBe(true) + expect(fs.existsSync(`${outputPath}.post.1`)).toBe(true) + }) +}) + +describe("RecordingWriter close()", () => { + it("does not throw", () => { + const outputPath = tempPath("close.log") + const writer = new RecordingWriter({ + outputPath, + targetUrl: "https://example.com/app", + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + expect(() => writer.close()).not.toThrow() + }) +}) diff --git a/src/tests/record-ws-proxy.test.ts b/src/tests/record-ws-proxy.test.ts new file mode 100644 index 0000000..3986220 --- /dev/null +++ b/src/tests/record-ws-proxy.test.ts @@ -0,0 +1,407 @@ +import { describe, it, expect, beforeEach, afterEach } from "vitest" +import * as http from "node:http" +import * as fs from "node:fs" +import * as os from "node:os" +import * as path from "node:path" +import WebSocket, { WebSocketServer } from "ws" +import { CookieJar } from "tough-cookie" +import { RecordingProxy } from "../record/proxy.js" +import { RecordingWriter } from "../record/writer.js" +import { RecordingTokens } from "../record/tokens.js" +import { ServerType } from "../types.js" + +// --------------------------------------------------------------------------- +// Mock target server (HTTP + WebSocket) +// --------------------------------------------------------------------------- + +interface MockTargetOptions { + /** Extra messages to send after the config init message */ + extraMessages?: string[] +} + +interface MockTarget { + server: http.Server + wss: WebSocketServer + port: number + start(): Promise + stop(): Promise +} + +function createMockTarget(options: MockTargetOptions = {}): MockTarget { + const server = http.createServer((req, res) => { + const url = req.url ?? "/" + if (url === "/" || url.startsWith("/?")) { + res.writeHead(200, { "Content-Type": "text/html" }) + res.end( + "" + + '' + + "Shiny App", + ) + return + } + res.writeHead(404) + res.end("Not found") + }) + + const wss = new WebSocketServer({ server }) + + wss.on("connection", (ws) => { + // Send SockJS open frame + ws.send("o") + + // Send config init message + const initMsg = JSON.stringify({ + config: { sessionId: "sess-12345" }, + custom: {}, + }) + ws.send(initMsg) + + // Send any extra messages after init + for (const msg of options.extraMessages ?? []) { + ws.send(msg) + } + + // Echo client messages back as a response + ws.on("message", (_data) => { + if (ws.readyState === WebSocket.OPEN) { + ws.send( + JSON.stringify({ + values: { x: 1 }, + inputMessages: [], + errors: {}, + }), + ) + } + }) + }) + + let port = 0 + + return { + server, + wss, + get port() { + return port + }, + start(): Promise { + return new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as import("node:net").AddressInfo + port = addr.port + resolve() + }) + }) + }, + stop(): Promise { + return new Promise((resolve) => { + wss.close(() => { + server.close(() => resolve()) + }) + }) + }, + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function readRecordingEvents( + recordingPath: string, +): Array> { + return fs + .readFileSync(recordingPath, "utf-8") + .split("\n") + .filter((line) => line.length > 0 && !line.startsWith("#")) + .map((line) => JSON.parse(line) as Record) +} + +function waitForMessages( + ws: WebSocket, + count = 1, + timeoutMs = 5000, +): Promise { + return new Promise((resolve, reject) => { + const msgs: string[] = [] + const timer = setTimeout(() => { + ws.removeListener("message", onMessage) + reject( + new Error( + `waitForMessages: timed out after ${timeoutMs}ms (got ${msgs.length}/${count})`, + ), + ) + }, timeoutMs) + const onMessage = (data: WebSocket.RawData): void => { + msgs.push(data.toString()) + if (msgs.length >= count) { + clearTimeout(timer) + ws.removeListener("message", onMessage) + resolve(msgs) + } + } + ws.on("message", onMessage) + }) +} + +function waitForClose(ws: WebSocket): Promise { + return new Promise((resolve) => { + ws.on("close", () => resolve()) + }) +} + +function connectWsToProxy(proxyPort: number, wsPath: string): WebSocket { + return new WebSocket(`ws://127.0.0.1:${proxyPort}${wsPath}`) +} + +function delay(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +// --------------------------------------------------------------------------- +// Test suite +// --------------------------------------------------------------------------- + +describe("RecordingProxy WebSocket", () => { + let mockTarget: MockTarget + let tmpDir: string + let recordingPath: string + let writer: RecordingWriter + let tokens: RecordingTokens + let proxy: RecordingProxy + let proxyPort: number + let shutdownCalled: boolean + + beforeEach(async () => { + mockTarget = createMockTarget() + await mockTarget.start() + + tmpDir = fs.mkdtempSync( + path.join(os.tmpdir(), "shinycannon-ws-proxy-test-"), + ) + recordingPath = path.join(tmpDir, "recording.log") + + writer = new RecordingWriter({ + outputPath: recordingPath, + targetUrl: `http://127.0.0.1:${mockTarget.port}`, + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + tokens = new RecordingTokens() + shutdownCalled = false + + proxy = new RecordingProxy({ + targetUrl: `http://127.0.0.1:${mockTarget.port}`, + host: "127.0.0.1", + port: 0, + writer, + tokens, + cookieJar: new CookieJar(), + authHeaders: {}, + onShutdown: () => { + shutdownCalled = true + }, + }) + + await proxy.start() + const addr = proxy.httpServer!.address() as import("node:net").AddressInfo + proxyPort = addr.port + }) + + afterEach(async () => { + await proxy.stop() + writer.close() + await mockTarget.stop() + fs.rmSync(tmpDir, { recursive: true, force: true }) + }) + + it( + "WS-01: WS_OPEN is recorded with token-replaced URL", + { timeout: 5000 }, + async () => { + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + // Wait for "o" frame (first message) + await waitForMessages(ws, 1) + ws.close() + await waitForClose(ws) + await delay(100) + + const events = readRecordingEvents(recordingPath) + const openEvent = events.find((e) => e["type"] === "WS_OPEN") + expect(openEvent).toBeDefined() + expect(openEvent!["url"]).toContain("${SOCKJSID}") + expect(openEvent!["url"]).not.toContain("abc123") + expect(tokens.has("SOCKJSID")).toBe(true) + }, + ) + + it( + "WS-02: SockJS open frame is recorded as WS_RECV", + { timeout: 5000 }, + async () => { + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + await waitForMessages(ws, 1) + ws.close() + await waitForClose(ws) + await delay(100) + + const events = readRecordingEvents(recordingPath) + const recvEvent = events.find( + (e) => e["type"] === "WS_RECV" && e["message"] === "o", + ) + expect(recvEvent).toBeDefined() + }, + ) + + it( + "WS-03: Config init message is recorded as WS_RECV_INIT with SESSION token", + { timeout: 5000 }, + async () => { + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + // Wait for "o" and config message (2 messages) + await waitForMessages(ws, 2) + ws.close() + await waitForClose(ws) + await delay(100) + + const events = readRecordingEvents(recordingPath) + const initEvent = events.find((e) => e["type"] === "WS_RECV_INIT") + expect(initEvent).toBeDefined() + expect(initEvent!["message"]).toContain("${SESSION}") + expect(initEvent!["message"] as string).not.toContain("sess-12345") + expect(tokens.has("SESSION")).toBe(true) + }, + ) + + it( + "WS-04: Client messages are recorded as WS_SEND", + { timeout: 5000 }, + async () => { + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + // Wait for init sequence ("o" + config) + await waitForMessages(ws, 2) + + const clientMsg = JSON.stringify({ method: "init", data: {} }) + ws.send(clientMsg) + // Wait for echo response + await waitForMessages(ws, 1) + await delay(100) + + ws.close() + await waitForClose(ws) + await delay(100) + + const events = readRecordingEvents(recordingPath) + const sendEvent = events.find((e) => e["type"] === "WS_SEND") + expect(sendEvent).toBeDefined() + expect(sendEvent!["message"]).toBe(clientMsg) + }, + ) + + it( + "WS-05: Server response is recorded as WS_RECV", + { timeout: 5000 }, + async () => { + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + await waitForMessages(ws, 2) + + ws.send(JSON.stringify({ method: "init", data: {} })) + await waitForMessages(ws, 1) + await delay(100) + + ws.close() + await waitForClose(ws) + await delay(100) + + const events = readRecordingEvents(recordingPath) + const recvEvents = events.filter((e) => e["type"] === "WS_RECV") + // Should have at least: "o" frame + echo response + const echoEvent = recvEvents.find( + (e) => + typeof e["message"] === "string" && + (e["message"] as string).includes("values"), + ) + expect(echoEvent).toBeDefined() + }, + ) + + it( + "WS-06: Heartbeat 'h' is relayed to client but NOT recorded", + { timeout: 5000 }, + async () => { + // Create a fresh mock that sends "h" after init + await proxy.stop() + writer.close() + await mockTarget.stop() + fs.rmSync(tmpDir, { recursive: true, force: true }) + + mockTarget = createMockTarget({ extraMessages: ["h"] }) + await mockTarget.start() + + tmpDir = fs.mkdtempSync( + path.join(os.tmpdir(), "shinycannon-ws-proxy-test-h-"), + ) + recordingPath = path.join(tmpDir, "recording.log") + + writer = new RecordingWriter({ + outputPath: recordingPath, + targetUrl: `http://127.0.0.1:${mockTarget.port}`, + targetType: ServerType.SHN, + rscApiKeyRequired: false, + }) + + tokens = new RecordingTokens() + + proxy = new RecordingProxy({ + targetUrl: `http://127.0.0.1:${mockTarget.port}`, + host: "127.0.0.1", + port: 0, + writer, + tokens, + cookieJar: new CookieJar(), + authHeaders: {}, + onShutdown: () => {}, + }) + + await proxy.start() + const addr = proxy.httpServer!.address() as import("node:net").AddressInfo + proxyPort = addr.port + + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + // Wait for "o", config, and "h" (3 messages) + const msgs = await waitForMessages(ws, 3) + await delay(100) + + // "h" should have been relayed to the client + expect(msgs).toContain("h") + + ws.close() + await waitForClose(ws) + await delay(100) + + const events = readRecordingEvents(recordingPath) + // No event should have message "h" + const heartbeatEvent = events.find((e) => e["message"] === "h") + expect(heartbeatEvent).toBeUndefined() + }, + ) + + it( + "WS-07: Client disconnect records WS_CLOSE and triggers onShutdown", + { timeout: 5000 }, + async () => { + const ws = connectWsToProxy(proxyPort, "/__sockjs__/000/abc123/websocket") + await waitForMessages(ws, 2) + + ws.close() + await waitForClose(ws) + // Wait longer than the shutdown grace period (500ms) + await delay(700) + + const events = readRecordingEvents(recordingPath) + const closeEvent = events.find((e) => e["type"] === "WS_CLOSE") + expect(closeEvent).toBeDefined() + expect(shutdownCalled).toBe(true) + }, + ) +}) diff --git a/src/tests/session-integration.test.ts b/src/tests/session-integration.test.ts index 931420b..a915528 100644 --- a/src/tests/session-integration.test.ts +++ b/src/tests/session-integration.test.ts @@ -3,11 +3,16 @@ import * as fs from "node:fs" import * as path from "node:path" import * as os from "node:os" import { MockShinyServer } from "./helpers/mock-shiny-server.js" -import { runSession, Stats, extractCommId, replaceCommIds } from "../session.js" -import type { SessionConfig } from "../session.js" +import { + runSession, + Stats, + extractCommId, + replaceCommIds, +} from "../replay/session.js" +import type { SessionConfig } from "../replay/session.js" import { readRecordingFromString } from "../recording.js" import { createLogger, LogLevel } from "../logger.js" -import { createOutputDir } from "../output.js" +import { createOutputDir } from "../replay/output.js" let mock: MockShinyServer let tmpDir: string diff --git a/src/tests/stats.test.ts b/src/tests/stats.test.ts index f211d6a..43be5a6 100644 --- a/src/tests/stats.test.ts +++ b/src/tests/stats.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect } from "vitest" -import { Stats } from "../session.js" -import { formatNumber, formatRate } from "../ui.js" +import { Stats } from "../replay/session.js" +import { formatNumber, formatRate } from "../replay/ui.js" describe("Stats", () => { it("tracks event count via recordEvent()", () => { diff --git a/src/tests/websocket.test.ts b/src/tests/websocket.test.ts index 7f747ff..9d72702 100644 --- a/src/tests/websocket.test.ts +++ b/src/tests/websocket.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect } from "vitest" import { WebSocketServer } from "ws" -import { AsyncQueue, ShinyWebSocket } from "../websocket.js" -import type { WSMessage } from "../websocket.js" +import { AsyncQueue, ShinyWebSocket } from "../replay/websocket.js" +import type { WSMessage } from "../replay/websocket.js" describe("AsyncQueue", () => { it("returns items in FIFO order", async () => { diff --git a/src/tests/worker-integration.test.ts b/src/tests/worker-integration.test.ts index cfdcb71..4f9f8f8 100644 --- a/src/tests/worker-integration.test.ts +++ b/src/tests/worker-integration.test.ts @@ -3,11 +3,11 @@ import * as fs from "node:fs" import * as path from "node:path" import * as os from "node:os" import { MockShinyServer } from "./helpers/mock-shiny-server.js" -import { runEnduranceTest } from "../worker.js" -import type { EnduranceTestConfig } from "../worker.js" +import { runEnduranceTest } from "../replay/worker.js" +import type { EnduranceTestConfig } from "../replay/worker.js" import { readRecordingFromString } from "../recording.js" import type { Logger } from "../logger.js" -import { createOutputDir } from "../output.js" +import { createOutputDir } from "../replay/output.js" // --------------------------------------------------------------------------- // Capturing logger