Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions dash/backends/_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,19 @@ async def websocket_handler(websocket: WebSocket):

await websocket.accept()

# Capture request metadata from the WebSocket handshake once per
# connection so that callbacks running over the WebSocket transport
# can access cookies/headers (e.g. for authentication helpers such
# as dash_enterprise_auth.get_user_data).
request_context = {
"cookies": dict(websocket.cookies),
"headers": dict(websocket.headers),
"args": dict(websocket.query_params),
"path": websocket.url.path,
"remote": websocket.client.host if websocket.client else "",
"origin": websocket.headers.get("origin", ""),
}

# Create janus queue for outbound messages (main loop context)
outbound_queue: janus.Queue[str] = janus.Queue()
# Track pending get_props requests with standard queue.Queue for responses
Expand Down Expand Up @@ -788,6 +801,7 @@ async def websocket_handler(websocket: WebSocket):
payload,
ws_cb,
FastAPIResponseAdapter(),
request_context,
)

# Set up done callback to send response
Expand Down
14 changes: 14 additions & 0 deletions dash/backends/_quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,19 @@ async def websocket_handler(): # pylint: disable=too-many-branches

await ws.accept()

# Capture request metadata from the WebSocket handshake once per
# connection so that callbacks running over the WebSocket transport
# can access cookies/headers (e.g. for authentication helpers such
# as dash_enterprise_auth.get_user_data).
request_context = {
"cookies": dict(ws.cookies),
"headers": dict(ws.headers),
"args": dict(ws.args),
"path": ws.path,
"remote": ws.remote_addr,
"origin": ws.headers.get("origin", ""),
}

# Track this connection for graceful shutdown
try:
ws_obj = ws._get_current_object()
Expand Down Expand Up @@ -623,6 +636,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches
payload,
ws_cb,
QuartResponseAdapter(),
request_context,
)

# Set up done callback to send response
Expand Down
23 changes: 22 additions & 1 deletion dash/backends/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,20 @@ def create_ws_context(
payload: dict,
response_adapter: "ResponseAdapter",
websocket_callback: DashWebsocketCallback,
request_context: "dict | None" = None,
):
"""Create callback context from WebSocket message.

Args:
payload: The callback payload
response_adapter: The response adapter instance for the backend
websocket_callback: The websocket callback instance for the backend
request_context: Optional request metadata (cookies, headers, args,
path, remote, origin) captured from the WebSocket handshake. This
mirrors the context populated for regular HTTP callbacks so that
``callback_context.cookies``/``headers`` (and downstream helpers
such as ``dash_enterprise_auth.get_user_data``) work inside
WebSocket callbacks.

Returns:
AttributeDict with callback context
Expand All @@ -217,6 +224,14 @@ def create_ws_context(
g.updated_props = {}
g.dash_websocket = websocket_callback

request_context = request_context or {}
g.cookies = request_context.get("cookies", {})
g.headers = request_context.get("headers", {})
g.args = request_context.get("args", "")
g.path = request_context.get("path", "")
g.remote = request_context.get("remote", "")
g.origin = request_context.get("origin", "")

return g


Expand Down Expand Up @@ -396,6 +411,7 @@ def run_callback_in_executor(
payload: dict,
ws_callback: DashWebsocketCallback,
response_adapter: "ResponseAdapter",
request_context: "dict | None" = None,
) -> concurrent.futures.Future:
"""Submit callback to executor for thread pool execution.

Expand All @@ -408,14 +424,19 @@ def run_callback_in_executor(
payload: The callback payload from WebSocket message
ws_callback: WebSocket callback instance for set_prop/get_prop
response_adapter: Response adapter for the backend
request_context: Optional request metadata (cookies, headers, args,
path, remote, origin) captured from the WebSocket handshake, made
available on the callback context.

Returns:
Future representing the pending callback execution
"""

def execute() -> dict:
try:
cb_ctx = create_ws_context(payload, response_adapter, ws_callback)
cb_ctx = create_ws_context(
payload, response_adapter, ws_callback, request_context
)
# pylint: disable=protected-access
func = dash_app._prepare_callback(cb_ctx, payload)
args = dash_app._inputs_to_vals( # pylint: disable=protected-access
Expand Down
56 changes: 56 additions & 0 deletions tests/websocket/test_ws_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Unit tests for WebSocket callback context creation.

These tests verify that request metadata captured from the WebSocket
handshake (cookies, headers, etc.) is propagated onto the callback
context. This is required so authentication helpers that read
``callback_context.cookies``/``headers`` (such as
``dash_enterprise_auth.get_user_data``) work inside WebSocket callbacks.
"""

from dash.backends.ws import create_ws_context


def test_create_ws_context_propagates_request_context():
"""Request metadata should be copied onto the callback context."""
payload = {
"inputs": [],
"state": [],
"outputs": [],
"changedPropIds": [],
}
request_context = {
"cookies": {"kcIdToken": "token-value"},
"headers": {"Plotly-User-Data": "{}"},
"args": {"foo": "bar"},
"path": "/_dash-ws-callback",
"remote": "10.0.0.1",

Check warning on line 26 in tests/websocket/test_ws_context.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Make sure using this hardcoded IP address "10.0.0.1" is safe here.

See more on https://sonarcloud.io/project/issues?id=plotly_dash&issues=AZ67Ff_7YiCndbFqAufS&open=AZ67Ff_7YiCndbFqAufS&pullRequest=3815
"origin": "https://example.com",
}

g = create_ws_context(payload, response_adapter=None, websocket_callback=None, request_context=request_context)

Check failure on line 30 in tests/websocket/test_ws_context.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Change this argument; Function "create_ws_context" expects a different type

See more on https://sonarcloud.io/project/issues?id=plotly_dash&issues=AZ67Ff_7YiCndbFqAufT&open=AZ67Ff_7YiCndbFqAufT&pullRequest=3815

assert g.cookies == {"kcIdToken": "token-value"}
assert g.headers == {"Plotly-User-Data": "{}"}
assert g.args == {"foo": "bar"}
assert g.path == "/_dash-ws-callback"
assert g.remote == "10.0.0.1"

Check warning on line 36 in tests/websocket/test_ws_context.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Make sure using this hardcoded IP address "10.0.0.1" is safe here.

See more on https://sonarcloud.io/project/issues?id=plotly_dash&issues=AZ67Ff_7YiCndbFqAufU&open=AZ67Ff_7YiCndbFqAufU&pullRequest=3815
assert g.origin == "https://example.com"


def test_create_ws_context_defaults_without_request_context():
"""Context should expose empty defaults when no request context is given."""
payload = {
"inputs": [],
"state": [],
"outputs": [],
"changedPropIds": [],
}

g = create_ws_context(payload, response_adapter=None, websocket_callback=None)

Check failure on line 49 in tests/websocket/test_ws_context.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Change this argument; Function "create_ws_context" expects a different type

See more on https://sonarcloud.io/project/issues?id=plotly_dash&issues=AZ67Ff_7YiCndbFqAufV&open=AZ67Ff_7YiCndbFqAufV&pullRequest=3815

assert g.cookies == {}
assert g.headers == {}
assert g.args == ""
assert g.path == ""
assert g.remote == ""
assert g.origin == ""
Loading