Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sdks/python/pmxt/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import json
import logging
import socket
import threading
import time
Expand All @@ -21,6 +22,7 @@
MAX_QUEUED_MESSAGES_PER_SUBSCRIPTION = 100_000
CONNECT_ATTEMPTS = 3
_NO_DATA = object()
logger = logging.getLogger(__name__)


def _connect_websocket(ws: Any, url: str, timeout: float) -> None:
Expand Down Expand Up @@ -130,7 +132,10 @@ def _ensure_connected(self) -> None:
try:
ws.close()
except Exception:
pass
logger.debug(
"Failed to close unsuccessful WebSocket connection",
exc_info=True,
)
if attempt < CONNECT_ATTEMPTS - 1:
time.sleep(0.25 * (attempt + 1))

Expand Down
55 changes: 53 additions & 2 deletions sdks/python/tests/test_ws_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
import importlib.util
import logging
import pathlib
import threading
import time
import socket
import sys
import types

import pmxt.ws_client as ws_client
from pmxt.ws_client import SidecarWsClient, _WsSubscription, _connect_websocket

def _load_ws_client_module():
"""Load pmxt.ws_client without requiring generated pmxt_internal artifacts."""
package_dir = pathlib.Path(__file__).resolve().parents[1] / "pmxt"
package = sys.modules.setdefault("pmxt", types.ModuleType("pmxt"))
package.__path__ = [str(package_dir)]

errors_spec = importlib.util.spec_from_file_location("pmxt.errors", package_dir / "errors.py")
errors_module = importlib.util.module_from_spec(errors_spec)
sys.modules["pmxt.errors"] = errors_module
errors_spec.loader.exec_module(errors_module)

spec = importlib.util.spec_from_file_location("pmxt.ws_client", package_dir / "ws_client.py")
module = importlib.util.module_from_spec(spec)
sys.modules["pmxt.ws_client"] = module
spec.loader.exec_module(module)
return module


ws_client = _load_ws_client_module()
SidecarWsClient = ws_client.SidecarWsClient
_WsSubscription = ws_client._WsSubscription
_connect_websocket = ws_client._connect_websocket


def _register_subscription(client, request_id="req-firehose"):
Expand Down Expand Up @@ -133,3 +157,30 @@ def fake_connect(_ws, _url, timeout):

assert attempts["count"] == 3
assert client._ws is not None


def test_logs_failed_cleanup_when_retrying_handshake_failure(monkeypatch, caplog):
class FakeWebSocket:
def close(self):
raise RuntimeError("close failed")

def fail_connect(*_args, **_kwargs):
raise OSError("handshake failed")

monkeypatch.setitem(sys.modules, "websocket", types.SimpleNamespace(WebSocket=FakeWebSocket))
monkeypatch.setattr(ws_client, "_connect_websocket", fail_connect)
monkeypatch.setattr(ws_client.time, "sleep", lambda _seconds: None)

client = SidecarWsClient("https://api.pmxt.dev", api_key="pmxt_test")

with caplog.at_level(logging.DEBUG, logger="pmxt.ws_client"):
try:
with client._lock:
client._ensure_connected()
except OSError as exc:
assert str(exc) == "handshake failed"
else:
raise AssertionError("expected handshake failure")

assert "Failed to close unsuccessful WebSocket connection" in caplog.text
assert "close failed" in caplog.text
Loading