From 965e07367b0a14ee8bc40c6fc408334eafadc77f Mon Sep 17 00:00:00 2001 From: Ian Murray Date: Fri, 12 Jun 2026 08:44:23 +0000 Subject: [PATCH] Fix missing request context (cookies/headers) in WebSocket callbacks Propagate cookies, headers, args, path, remote, and origin from the WebSocket handshake onto the callback context in create_ws_context, mirroring the HTTP path. This fixes auth helpers (e.g. dash_enterprise_auth.get_user_data) failing over the WebSocket transport because callback_context.cookies/headers were empty. --- dash/backends/_fastapi.py | 14 ++++++++ dash/backends/_quart.py | 14 ++++++++ dash/backends/ws.py | 23 +++++++++++- tests/websocket/test_ws_context.py | 56 ++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 tests/websocket/test_ws_context.py diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 0516b5edcb..5167922dab 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -725,6 +725,19 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() + # Capture request metadata from the WebSocket handshake once per + # connection so that callbacks running over the WebSocket transport + # can access cookies/headers (e.g. for authentication helpers such + # as dash_enterprise_auth.get_user_data). + request_context = { + "cookies": dict(websocket.cookies), + "headers": dict(websocket.headers), + "args": dict(websocket.query_params), + "path": websocket.url.path, + "remote": websocket.client.host if websocket.client else "", + "origin": websocket.headers.get("origin", ""), + } + # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests with standard queue.Queue for responses @@ -788,6 +801,7 @@ async def websocket_handler(websocket: WebSocket): payload, ws_cb, FastAPIResponseAdapter(), + request_context, ) # Set up done callback to send response diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index c7634ce93a..f827d5115a 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -545,6 +545,19 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() + # Capture request metadata from the WebSocket handshake once per + # connection so that callbacks running over the WebSocket transport + # can access cookies/headers (e.g. for authentication helpers such + # as dash_enterprise_auth.get_user_data). + request_context = { + "cookies": dict(ws.cookies), + "headers": dict(ws.headers), + "args": dict(ws.args), + "path": ws.path, + "remote": ws.remote_addr, + "origin": ws.headers.get("origin", ""), + } + # Track this connection for graceful shutdown try: ws_obj = ws._get_current_object() @@ -623,6 +636,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches payload, ws_cb, QuartResponseAdapter(), + request_context, ) # Set up done callback to send response diff --git a/dash/backends/ws.py b/dash/backends/ws.py index a4b302f215..c0fe3da4d7 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -189,6 +189,7 @@ def create_ws_context( payload: dict, response_adapter: "ResponseAdapter", websocket_callback: DashWebsocketCallback, + request_context: "dict | None" = None, ): """Create callback context from WebSocket message. @@ -196,6 +197,12 @@ def create_ws_context( payload: The callback payload response_adapter: The response adapter instance for the backend websocket_callback: The websocket callback instance for the backend + request_context: Optional request metadata (cookies, headers, args, + path, remote, origin) captured from the WebSocket handshake. This + mirrors the context populated for regular HTTP callbacks so that + ``callback_context.cookies``/``headers`` (and downstream helpers + such as ``dash_enterprise_auth.get_user_data``) work inside + WebSocket callbacks. Returns: AttributeDict with callback context @@ -217,6 +224,14 @@ def create_ws_context( g.updated_props = {} g.dash_websocket = websocket_callback + request_context = request_context or {} + g.cookies = request_context.get("cookies", {}) + g.headers = request_context.get("headers", {}) + g.args = request_context.get("args", "") + g.path = request_context.get("path", "") + g.remote = request_context.get("remote", "") + g.origin = request_context.get("origin", "") + return g @@ -396,6 +411,7 @@ def run_callback_in_executor( payload: dict, ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", + request_context: "dict | None" = None, ) -> concurrent.futures.Future: """Submit callback to executor for thread pool execution. @@ -408,6 +424,9 @@ def run_callback_in_executor( payload: The callback payload from WebSocket message ws_callback: WebSocket callback instance for set_prop/get_prop response_adapter: Response adapter for the backend + request_context: Optional request metadata (cookies, headers, args, + path, remote, origin) captured from the WebSocket handshake, made + available on the callback context. Returns: Future representing the pending callback execution @@ -415,7 +434,9 @@ def run_callback_in_executor( def execute() -> dict: try: - cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + cb_ctx = create_ws_context( + payload, response_adapter, ws_callback, request_context + ) # pylint: disable=protected-access func = dash_app._prepare_callback(cb_ctx, payload) args = dash_app._inputs_to_vals( # pylint: disable=protected-access diff --git a/tests/websocket/test_ws_context.py b/tests/websocket/test_ws_context.py new file mode 100644 index 0000000000..1b77d30fe8 --- /dev/null +++ b/tests/websocket/test_ws_context.py @@ -0,0 +1,56 @@ +"""Unit tests for WebSocket callback context creation. + +These tests verify that request metadata captured from the WebSocket +handshake (cookies, headers, etc.) is propagated onto the callback +context. This is required so authentication helpers that read +``callback_context.cookies``/``headers`` (such as +``dash_enterprise_auth.get_user_data``) work inside WebSocket callbacks. +""" + +from dash.backends.ws import create_ws_context + + +def test_create_ws_context_propagates_request_context(): + """Request metadata should be copied onto the callback context.""" + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + request_context = { + "cookies": {"kcIdToken": "token-value"}, + "headers": {"Plotly-User-Data": "{}"}, + "args": {"foo": "bar"}, + "path": "/_dash-ws-callback", + "remote": "10.0.0.1", + "origin": "https://example.com", + } + + g = create_ws_context(payload, response_adapter=None, websocket_callback=None, request_context=request_context) + + assert g.cookies == {"kcIdToken": "token-value"} + assert g.headers == {"Plotly-User-Data": "{}"} + assert g.args == {"foo": "bar"} + assert g.path == "/_dash-ws-callback" + assert g.remote == "10.0.0.1" + assert g.origin == "https://example.com" + + +def test_create_ws_context_defaults_without_request_context(): + """Context should expose empty defaults when no request context is given.""" + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + + g = create_ws_context(payload, response_adapter=None, websocket_callback=None) + + assert g.cookies == {} + assert g.headers == {} + assert g.args == "" + assert g.path == "" + assert g.remote == "" + assert g.origin == ""