diff --git a/README.md b/README.md index 04fce66..459038a 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ pip install netboxlabs-diode-sdk * `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting * `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication * `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication +* `DIODE_MAX_AUTH_RETRIES` - Maximum attempts for OAuth2 token fetch and gRPC re-authentication on `Unauthenticated` (default: `3`). Token fetch retries with exponential backoff on `429`, `500`, `502`, and `503`, honouring `Retry-After` when present on `429`/`503`. * `DIODE_CERT_FILE` - Path to custom certificate file for TLS connections * `DIODE_SKIP_TLS_VERIFY` - Skip TLS verification (default: `false`) * `DIODE_DRY_RUN_OUTPUT_DIR` - Directory where `DiodeDryRunClient` will write JSON files diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 1334366..590d2dd 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -7,11 +7,14 @@ import logging import os import platform +import random import sys import tempfile import time import uuid -from collections.abc import Iterable +from collections.abc import Callable, Iterable +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from pathlib import Path from typing import Any from urllib.parse import urlparse @@ -50,6 +53,8 @@ _INGEST_SCOPE = "diode:ingest" _LOGGER = logging.getLogger(__name__) _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES" +_AUTH_INITIAL_RETRY_DELAY = 1.0 +_AUTH_MAX_RETRY_DELAY = 30.0 # server policy (MinTime 10s so client pings must be >= 10s, e.g. 30s interval). _GRPC_KEEPALIVE_TIME_MS = 30_000 _GRPC_KEEPALIVE_TIMEOUT_MS = 10_000 @@ -541,6 +546,7 @@ def _authenticate(self, scope: str): self._app_version, self._certificates, self._cert_file, + max_retries=self._max_auth_retries, ) access_token = authentication_client.authenticate() self._metadata = list( @@ -932,6 +938,10 @@ def __init__( app_version: str, certificates: bytes | None = None, cert_file: str | None = None, + max_retries: int = 3, + initial_retry_delay: float | None = None, + max_retry_delay: float | None = None, + sleep: Callable[[float], None] | None = None, ): self._target = target self._tls_verify = tls_verify @@ -945,6 +955,14 @@ def __init__( self._app_version = app_version self._certificates = certificates self._cert_file = cert_file + self._max_retries = max_retries + self._initial_retry_delay = ( + _AUTH_INITIAL_RETRY_DELAY if initial_retry_delay is None else initial_retry_delay + ) + self._max_retry_delay = ( + _AUTH_MAX_RETRY_DELAY if max_retry_delay is None else max_retry_delay + ) + self._sleep = sleep or time.sleep def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" @@ -952,53 +970,8 @@ def authenticate(self) -> str: temp_cert_file = None try: - # Configure SSL verification - if self._tls_verify and self._certificates: - # Use cert_file path directly if available, otherwise write to temp file - if self._cert_file: - session.verify = self._cert_file - else: - # Write certificates to temp file for requests - with tempfile.NamedTemporaryFile( - mode="wb", delete=False, suffix=".pem" - ) as f: - f.write(self._certificates) - temp_cert_file = f.name - session.verify = temp_cert_file - elif not self._tls_verify: - session.verify = False - - # Prepare auth request - url = self._get_full_auth_url() - data = { - "grant_type": "client_credentials", - "client_id": self._client_id, - "client_secret": self._client_secret, - "scope": self._scope, - } - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}", - } - - response = session.post(url, data=data, headers=headers) - - if response.status_code != 200: - raise DiodeConfigError( - f"Failed to obtain access token: {response.reason}" - ) - - token_info = response.json() - access_token = token_info.get("access_token") - - if not access_token: - raise DiodeConfigError( - f"Failed to obtain access token for client {self._client_id}" - ) - - _LOGGER.debug(f"Access token obtained for client {self._client_id}") - return access_token - + temp_cert_file = self._configure_auth_session(session) + return self._request_access_token(session) except requests.RequestException as e: raise DiodeConfigError(f"Failed to obtain access token: {e}") finally: @@ -1012,6 +985,73 @@ def authenticate(self) -> str: f"Failed to clean up temp certificate file {temp_cert_file}: {e}" ) + def _configure_auth_session(self, session: requests.Session) -> str | None: + temp_cert_file = None + if self._tls_verify and self._certificates: + # Use cert_file path directly if available, otherwise write to temp file + if self._cert_file: + session.verify = self._cert_file + else: + # Write certificates to temp file for requests + with tempfile.NamedTemporaryFile( + mode="wb", delete=False, suffix=".pem" + ) as f: + f.write(self._certificates) + temp_cert_file = f.name + session.verify = temp_cert_file + elif not self._tls_verify: + session.verify = False + return temp_cert_file + + def _request_access_token(self, session: requests.Session) -> str: + url = self._get_full_auth_url() + data = { + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": self._scope, + } + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}", + } + + last_error = "Failed to obtain access token" + for attempt in range(1, self._max_retries + 1): + response = session.post(url, data=data, headers=headers) + + if response.status_code == 200: + access_token = response.json().get("access_token") + if not access_token: + raise DiodeConfigError( + f"Failed to obtain access token for client {self._client_id}" + ) + _LOGGER.debug(f"Access token obtained for client {self._client_id}") + return access_token + + last_error = f"Failed to obtain access token: {response.reason}" + if not _is_retriable_auth_http_status(response.status_code) or attempt >= self._max_retries: + raise DiodeConfigError(last_error) + + delay = _auth_retry_delay( + attempt, + response.status_code, + response.headers.get("Retry-After"), + self._initial_retry_delay, + self._max_retry_delay, + ) + _LOGGER.debug( + "Auth token request failed, retrying", + extra={ + "status_code": response.status_code, + "attempt": attempt, + "retry_in": delay, + }, + ) + self._sleep(delay) + + raise DiodeConfigError(last_error) + def _get_auth_url(self) -> str: """Construct the authentication URL, handling trailing slashes in the path.""" # Ensure the path does not have trailing slashes @@ -1036,6 +1076,49 @@ def _get_full_auth_url(self) -> str: return f"{scheme}://{self._target}{path}/auth/token" +def _is_retriable_auth_http_status(status_code: int) -> bool: + return status_code in {429, 500, 502, 503} + + +def _parse_retry_after(value: str | None) -> float | None: + if not value: + return None + try: + seconds = int(value) + except ValueError: + seconds = None + else: + if seconds < 0: + return None + return float(seconds) + + try: + retry_at = parsedate_to_datetime(value) + if retry_at.tzinfo is None: + retry_at = retry_at.replace(tzinfo=timezone.utc) + delay = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() + return max(delay, 0.0) + except (TypeError, ValueError, OverflowError): + return None + + +def _auth_retry_delay( + attempt: int, + status_code: int, + retry_after: str | None, + initial_delay: float, + max_delay: float, +) -> float: + delay: float | None = None + if status_code in (429, 503): + delay = _parse_retry_after(retry_after) + if delay is None: + delay = initial_delay * (2 ** (attempt - 1)) + delay = min(delay, max_delay) + delay += random.uniform(0, delay / 4) + return min(delay, max_delay) + + class _ClientCallDetails( collections.namedtuple( "_ClientCallDetails", diff --git a/tests/test_client.py b/tests/test_client.py index 622711c..54dfbcc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,6 +12,7 @@ import pytest from netboxlabs.diode.sdk.client import ( + _auth_retry_delay, _DIODE_SENTRY_DSN_ENVVAR_NAME, DiodeClient, DiodeDryRunClient, @@ -21,8 +22,10 @@ _diode_ingest_grpc_channel_options, _DiodeAuthentication, _get_sentry_dsn, + _is_retriable_auth_http_status, _load_certs, _otlp_grpc_channel_options, + _parse_retry_after, load_dryrun_entities, parse_target, ) @@ -712,6 +715,163 @@ def test_diode_authentication_request_exception(mock_diode_authentication): assert "Failed to obtain access token: Connection error" in str(excinfo.value) +@pytest.mark.parametrize( + ("status_code", "expected"), + [ + (429, True), + (500, True), + (502, True), + (503, True), + (401, False), + (403, False), + (400, False), + (200, False), + ], +) +def test_is_retriable_auth_http_status(status_code, expected): + """Check retriable auth HTTP status classification.""" + assert _is_retriable_auth_http_status(status_code) is expected + + +def test_parse_retry_after_seconds(): + """Parse Retry-After header values in seconds.""" + assert _parse_retry_after("5") == 5.0 + + +def test_parse_retry_after_invalid(): + """Return None for invalid Retry-After values.""" + assert _parse_retry_after("not-a-date") is None + assert _parse_retry_after("") is None + + +def test_auth_retry_delay_exponential(): + """Apply exponential backoff when Retry-After is absent.""" + delay1 = _auth_retry_delay(1, 500, None, 1.0, 30.0) + delay2 = _auth_retry_delay(2, 500, None, 1.0, 30.0) + delay3 = _auth_retry_delay(3, 502, None, 1.0, 30.0) + assert 1.0 <= delay1 <= 1.25 + assert 2.0 <= delay2 <= 2.5 + assert 4.0 <= delay3 <= 5.0 + + +def test_auth_retry_delay_honours_retry_after(): + """Honour Retry-After for 429 and 503 responses.""" + delay429 = _auth_retry_delay(1, 429, "7", 1.0, 30.0) + delay503 = _auth_retry_delay(1, 503, "12", 1.0, 30.0) + assert 7.0 <= delay429 <= 8.75 + assert 12.0 <= delay503 <= 15.0 + + +def test_auth_retry_delay_caps_retry_after(): + """Cap Retry-After delays at the configured maximum.""" + delay = _auth_retry_delay(1, 429, "120", 1.0, 30.0) + assert delay == 30.0 + + +def test_auth_retry_delay_never_exceeds_max(): + """Jitter must not push the final delay above the configured maximum.""" + for attempt in range(1, 8): + delay = _auth_retry_delay(attempt, 500, None, 1.0, 30.0) + assert delay <= 30.0 + + +def test_diode_authentication_retries_retriable_status(mock_diode_authentication): + """Retry auth token fetch on transient HTTP failures.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", + max_retries=3, + initial_retry_delay=0, + max_retry_delay=0, + sleep=lambda _delay: None, + ) + responses = [ + mock.Mock( + status_code=503, + reason="Service Unavailable", + headers={"Retry-After": "0"}, + ), + mock.Mock( + status_code=200, + json=mock.Mock(return_value={"access_token": "mocked_token"}), + ), + ] + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_session.post.side_effect = responses + + token = auth.authenticate() + assert token == "mocked_token" + assert mock_session.post.call_count == 2 + + +def test_diode_authentication_fails_fast_on_401(mock_diode_authentication): + """Do not retry auth token fetch on 401 responses.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", + max_retries=3, + initial_retry_delay=0, + max_retry_delay=0, + sleep=lambda _delay: None, + ) + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_session.post.return_value = mock.Mock(status_code=401, reason="Unauthorized") + + with pytest.raises(DiodeConfigError) as excinfo: + auth.authenticate() + assert "Failed to obtain access token: Unauthorized" in str(excinfo.value) + assert mock_session.post.call_count == 1 + + +def test_diode_authentication_exhausts_retries(mock_diode_authentication): + """Raise after exhausting auth retry attempts.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", + max_retries=2, + initial_retry_delay=0, + max_retry_delay=0, + sleep=lambda _delay: None, + ) + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_session.post.return_value = mock.Mock( + status_code=503, + reason="Service Unavailable", + ) + + with pytest.raises(DiodeConfigError) as excinfo: + auth.authenticate() + assert "Failed to obtain access token: Service Unavailable" in str(excinfo.value) + assert mock_session.post.call_count == 2 + + def test_ingest_dry_run_stdout(capsys): """Verify ingest prints JSON when dry run is enabled.""" client = DiodeDryRunClient()