From ff38a4ccf3d5197dff2459475f7201ab677143d5 Mon Sep 17 00:00:00 2001 From: Pigbibi <20649888+Pigbibi@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:52:24 +0800 Subject: [PATCH] fix schwab market data transient retry --- pyproject.toml | 2 +- setup.py | 2 +- src/quant_platform_kit/__init__.py | 2 +- src/quant_platform_kit/schwab/market_data.py | 112 +++++++++++++++++-- tests/test_schwab_market_data.py | 70 +++++++++++- 5 files changed, 172 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3231a02..a7d58fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "quant-platform-kit" -version = "0.7.36" +version = "0.7.37" description = "Shared broker adapters, domain models, execution ports, and notification utilities for QuantStrategyLab strategies." readme = "README.md" requires-python = ">=3.9" diff --git a/setup.py b/setup.py index 46ef4d6..d4256dd 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quant-platform-kit", - version="0.7.36", + version="0.7.37", description="Shared broker adapters, domain models, execution ports, and notification utilities for QuantStrategyLab strategies.", package_dir={"": "src"}, packages=find_packages(where="src"), diff --git a/src/quant_platform_kit/__init__.py b/src/quant_platform_kit/__init__.py index 47efc00..2fddb4b 100644 --- a/src/quant_platform_kit/__init__.py +++ b/src/quant_platform_kit/__init__.py @@ -4,7 +4,7 @@ used by older strategy repositories. """ -__version__ = "0.7.36" +__version__ = "0.7.37" from .common.models import ( ExecutionReport, diff --git a/src/quant_platform_kit/schwab/market_data.py b/src/quant_platform_kit/schwab/market_data.py index fad205d..391197d 100644 --- a/src/quant_platform_kit/schwab/market_data.py +++ b/src/quant_platform_kit/schwab/market_data.py @@ -1,10 +1,102 @@ from __future__ import annotations -from datetime import datetime -from typing import Any +import os +import time +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from typing import Any, Callable, Optional from quant_platform_kit.common.models import QuoteSnapshot +RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} +DEFAULT_HTTP_MAX_ATTEMPTS = 4 +DEFAULT_HTTP_BACKOFF_SECONDS = 1.0 +DEFAULT_HTTP_MAX_BACKOFF_SECONDS = 8.0 + + +def _env_int(name: str, default: int, *, minimum: int, maximum: int) -> int: + raw_value = os.environ.get(name) + if not raw_value: + return default + try: + value = int(raw_value) + except ValueError: + return default + return min(max(value, minimum), maximum) + + +def _env_float(name: str, default: float, *, minimum: float, maximum: float) -> float: + raw_value = os.environ.get(name) + if not raw_value: + return default + try: + value = float(raw_value) + except ValueError: + return default + return min(max(value, minimum), maximum) + + +def _header_value(headers: Any, name: str) -> Optional[str]: + if not headers: + return None + if hasattr(headers, "get"): + value = headers.get(name) + if value is None: + value = headers.get(name.lower()) + if value is None: + value = headers.get(name.upper()) + return str(value).strip() if value is not None else None + return None + + +def _retry_after_seconds(response: Any, fallback_seconds: float, max_seconds: float) -> float: + raw_value = _header_value(getattr(response, "headers", None), "Retry-After") + if not raw_value: + return min(fallback_seconds, max_seconds) + try: + return min(max(float(raw_value), 0.0), max_seconds) + except ValueError: + pass + + try: + retry_at = parsedate_to_datetime(raw_value) + except (TypeError, ValueError): + return min(fallback_seconds, max_seconds) + if retry_at.tzinfo is None: + retry_at = retry_at.replace(tzinfo=timezone.utc) + wait_seconds = (retry_at - datetime.now(timezone.utc)).total_seconds() + return min(max(wait_seconds, 0.0), max_seconds) + + +def _request_with_retries(request_fn: Callable[[], Any]) -> Any: + max_attempts = _env_int("QPK_SCHWAB_HTTP_MAX_ATTEMPTS", DEFAULT_HTTP_MAX_ATTEMPTS, minimum=1, maximum=8) + backoff_seconds = _env_float( + "QPK_SCHWAB_HTTP_BACKOFF_SECONDS", + DEFAULT_HTTP_BACKOFF_SECONDS, + minimum=0.0, + maximum=30.0, + ) + max_backoff_seconds = _env_float( + "QPK_SCHWAB_HTTP_MAX_BACKOFF_SECONDS", + DEFAULT_HTTP_MAX_BACKOFF_SECONDS, + minimum=0.0, + maximum=60.0, + ) + + response = None + for attempt in range(1, max_attempts + 1): + response = request_fn() + status_code = getattr(response, "status_code", None) + if status_code not in RETRYABLE_STATUS_CODES or attempt >= max_attempts: + return response + + fallback_seconds = backoff_seconds * (2 ** (attempt - 1)) + wait_seconds = _retry_after_seconds(response, fallback_seconds, max_backoff_seconds) + if wait_seconds > 0: + time.sleep(wait_seconds) + + return response + def decode_response_json(response: Any, context: str) -> Any: if response.status_code not in (200, 201): @@ -18,12 +110,14 @@ def decode_response_json(response: Any, context: str) -> Any: def fetch_default_daily_price_history_candles(api_client: Any, symbol: str) -> list[dict[str, Any]]: from schwab import client - response = api_client.get_price_history( - symbol, - period_type=client.Client.PriceHistory.PeriodType.YEAR, - period=client.Client.PriceHistory.Period.TWO_YEARS, - frequency_type=client.Client.PriceHistory.FrequencyType.DAILY, - frequency=client.Client.PriceHistory.Frequency.DAILY, + response = _request_with_retries( + lambda: api_client.get_price_history( + symbol, + period_type=client.Client.PriceHistory.PeriodType.YEAR, + period=client.Client.PriceHistory.Period.TWO_YEARS, + frequency_type=client.Client.PriceHistory.FrequencyType.DAILY, + frequency=client.Client.PriceHistory.Frequency.DAILY, + ) ) payload = decode_response_json(response, f"{symbol} history") candles = payload.get("candles") @@ -33,7 +127,7 @@ def fetch_default_daily_price_history_candles(api_client: Any, symbol: str) -> l def fetch_quotes(api_client: Any, symbols: list[str] | tuple[str, ...]) -> dict[str, QuoteSnapshot]: - payload = decode_response_json(api_client.get_quotes(symbols), "Quotes") + payload = decode_response_json(_request_with_retries(lambda: api_client.get_quotes(symbols)), "Quotes") as_of = datetime.utcnow() snapshots: dict[str, QuoteSnapshot] = {} for symbol in symbols: diff --git a/tests/test_schwab_market_data.py b/tests/test_schwab_market_data.py index f5550ef..7cc4022 100644 --- a/tests/test_schwab_market_data.py +++ b/tests/test_schwab_market_data.py @@ -5,13 +5,18 @@ import unittest from unittest.mock import patch -from quant_platform_kit.schwab.market_data import fetch_default_daily_price_history_candles, fetch_quotes +from quant_platform_kit.schwab.market_data import ( + decode_response_json, + fetch_default_daily_price_history_candles, + fetch_quotes, +) class FakeResponse: - def __init__(self, payload, status_code=200): + def __init__(self, payload, status_code=200, headers=None): self._payload = payload self.status_code = status_code + self.headers = headers or {} self.text = str(payload) def json(self): @@ -32,7 +37,7 @@ def get_quotes(self, symbols): class SchwabMarketDataTests(unittest.TestCase): - def test_fetch_default_daily_price_history_candles(self) -> None: + def _install_fake_schwab_module(self): schwab_module = types.ModuleType("schwab") client_module = types.ModuleType("schwab.client") client_module.Client = types.SimpleNamespace( @@ -43,13 +48,70 @@ def test_fetch_default_daily_price_history_candles(self) -> None: Frequency=types.SimpleNamespace(DAILY="DAILY"), ) ) + return patch.dict(sys.modules, {"schwab": schwab_module, "schwab.client": client_module}) - with patch.dict(sys.modules, {"schwab": schwab_module, "schwab.client": client_module}): + def test_fetch_default_daily_price_history_candles(self) -> None: + with self._install_fake_schwab_module(): candles = fetch_default_daily_price_history_candles(FakeClient(), "QQQ") self.assertEqual(len(candles), 2) self.assertEqual(candles[-1]["close"], 11.0) + def test_fetch_default_daily_price_history_retries_rate_limit(self) -> None: + class RateLimitedClient: + def __init__(self): + self.calls = 0 + + def get_price_history(self, symbol, **_kwargs): + self.calls += 1 + if self.calls == 1: + return FakeResponse({"error": "rate limited"}, status_code=429, headers={"Retry-After": "0.25"}) + return FakeResponse({"candles": [{"close": 12.0}]}) + + rate_limited_client = RateLimitedClient() + with self._install_fake_schwab_module(), patch( + "quant_platform_kit.schwab.market_data.time.sleep" + ) as sleep_mock: + candles = fetch_default_daily_price_history_candles(rate_limited_client, "SOXL") + + self.assertEqual(candles, [{"close": 12.0}]) + self.assertEqual(rate_limited_client.calls, 2) + sleep_mock.assert_called_once_with(0.25) + + def test_fetch_quotes_retries_transient_server_error(self) -> None: + class FlakyQuoteClient: + def __init__(self): + self.calls = 0 + + def get_quotes(self, symbols): + self.calls += 1 + if self.calls < 3: + return FakeResponse({"error": "unavailable"}, status_code=503) + return FakeClient().get_quotes(symbols) + + flaky_client = FlakyQuoteClient() + with patch("quant_platform_kit.schwab.market_data.time.sleep") as sleep_mock: + snapshots = fetch_quotes(flaky_client, ["TQQQ"]) + + self.assertEqual(snapshots["TQQQ"].last_price, 100.0) + self.assertEqual(flaky_client.calls, 3) + self.assertEqual([call.args[0] for call in sleep_mock.call_args_list], [1.0, 2.0]) + + def test_retry_exhaustion_keeps_original_error_context(self) -> None: + class AlwaysRateLimitedClient: + def get_price_history(self, symbol, **_kwargs): + return FakeResponse({"error": "rate limited"}, status_code=429) + + with self._install_fake_schwab_module(), patch( + "quant_platform_kit.schwab.market_data.time.sleep" + ), patch.dict("os.environ", {"QPK_SCHWAB_HTTP_MAX_ATTEMPTS": "2"}): + with self.assertRaisesRegex(RuntimeError, "SOXL history failed: 429"): + fetch_default_daily_price_history_candles(AlwaysRateLimitedClient(), "SOXL") + + def test_decode_response_json_still_reports_non_retryable_errors(self) -> None: + with self.assertRaisesRegex(RuntimeError, "Quotes failed: 400"): + decode_response_json(FakeResponse({"error": "bad request"}, status_code=400), "Quotes") + def test_fetch_quotes_returns_snapshots(self) -> None: snapshots = fetch_quotes(FakeClient(), ["TQQQ", "BOXX"])