From 44d170a709e6a6fdb5aec9588b5c3b783ae29946 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Fri, 26 Jun 2026 10:28:33 +0100 Subject: [PATCH 1/3] feat(OBS-3341): add backoff on auth for 429/5xx responses Retry OAuth token fetch on transient HTTP failures with exponential backoff, jitter, and Retry-After support while honouring DIODE_MAX_AUTH_RETRIES. Co-authored-by: Cursor --- README.md | 1 + netboxlabs/diode/sdk/client.py | 300 +++++++++++++++-------------- tests/test_client.py | 341 +++++++++++++++++++-------------- 3 files changed, 349 insertions(+), 293 deletions(-) diff --git a/README.md b/README.md index 04fce66..459038a 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ pip install netboxlabs-diode-sdk * `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting * `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication * `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication +* `DIODE_MAX_AUTH_RETRIES` - Maximum attempts for OAuth2 token fetch and gRPC re-authentication on `Unauthenticated` (default: `3`). Token fetch retries with exponential backoff on `429`, `500`, `502`, and `503`, honouring `Retry-After` when present on `429`/`503`. * `DIODE_CERT_FILE` - Path to custom certificate file for TLS connections * `DIODE_SKIP_TLS_VERIFY` - Skip TLS verification (default: `false`) * `DIODE_DRY_RUN_OUTPUT_DIR` - Directory where `DiodeDryRunClient` will write JSON files diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 1334366..05a16ac 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -7,11 +7,14 @@ import logging import os import platform +import random import sys import tempfile import time import uuid -from collections.abc import Iterable +from collections.abc import Callable, Iterable +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from pathlib import Path from typing import Any from urllib.parse import urlparse @@ -50,6 +53,8 @@ _INGEST_SCOPE = "diode:ingest" _LOGGER = logging.getLogger(__name__) _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES" +_AUTH_INITIAL_RETRY_DELAY = 1.0 +_AUTH_MAX_RETRY_DELAY = 30.0 # server policy (MinTime 10s so client pings must be >= 10s, e.g. 30s interval). _GRPC_KEEPALIVE_TIME_MS = 30_000 _GRPC_KEEPALIVE_TIMEOUT_MS = 10_000 @@ -99,9 +104,7 @@ def parse_target(target: str) -> tuple[str, str, bool]: parsed_target = urlparse(target) if parsed_target.scheme not in ["grpc", "grpcs", "http", "https"]: - raise ValueError( - "target should start with grpc://, grpcs://, http:// or https://" - ) + raise ValueError("target should start with grpc://, grpcs://, http:// or https://") # Determine if TLS verification should be enabled tls_verify = _should_verify_tls(parsed_target.scheme) @@ -129,15 +132,11 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s if value is None: value = os.getenv(env_var_name) if value is None: - raise DiodeConfigError( - f"parameter or {env_var_name} environment variable required" - ) + raise DiodeConfigError(f"parameter or {env_var_name} environment variable required") return value -def _get_optional_config_value( - env_var_name: str, value: str | None = None -) -> str | None: +def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None: """Get optional config value either from provided value or environment variable.""" if value is None: value = os.getenv(env_var_name) @@ -252,17 +251,11 @@ def _should_bypass_proxy(target_host: str) -> bool: # Maximum reasonable length for hostname/domain (RFC 1035: 253 chars, we allow 256) MAX_NO_PROXY_ENTRY_LENGTH = 256 - no_proxy_list = [ - entry.strip().lower() - for entry in no_proxy.split(",") - if len(entry.strip()) <= MAX_NO_PROXY_ENTRY_LENGTH - ] + no_proxy_list = [entry.strip().lower() for entry in no_proxy.split(",") if len(entry.strip()) <= MAX_NO_PROXY_ENTRY_LENGTH] filtered_count = len([e for e in no_proxy.split(",") if len(e.strip()) > MAX_NO_PROXY_ENTRY_LENGTH]) if filtered_count > 0: - _LOGGER.warning( - f"Ignored {filtered_count} NO_PROXY entries exceeding {MAX_NO_PROXY_ENTRY_LENGTH} characters" - ) + _LOGGER.warning(f"Ignored {filtered_count} NO_PROXY entries exceeding {MAX_NO_PROXY_ENTRY_LENGTH} characters") for entry in no_proxy_list: if entry and _matches_no_proxy_entry(host, entry): @@ -298,8 +291,7 @@ def _get_grpc_proxy_url(target_host: str, use_tls: bool) -> str | None: if proxy_url: if not _validate_proxy_url(proxy_url): _LOGGER.warning( - f"Invalid proxy URL format: {proxy_url}. " - f"Proxy URL must be http:// or https:// with valid host. Ignoring proxy." + f"Invalid proxy URL format: {proxy_url}. Proxy URL must be http:// or https:// with valid host. Ignoring proxy." ) return None _LOGGER.debug(f"Using proxy {proxy_url} for gRPC target {target_host}") @@ -334,21 +326,12 @@ def __init__( log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) - self._max_auth_retries = int( - _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries)) - or max_auth_retries - ) - self._cert_file = _get_optional_config_value( - _DIODE_CERT_FILE_ENVVAR_NAME, cert_file - ) + self._max_auth_retries = int(_get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries)) or max_auth_retries) + self._cert_file = _get_optional_config_value(_DIODE_CERT_FILE_ENVVAR_NAME, cert_file) self._target, self._path, self._tls_verify = parse_target(target) # Load certificates once if needed - self._certificates = ( - _load_certs(self._cert_file) - if (self._tls_verify or self._cert_file) - else None - ) + self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None self._app_name = app_name self._app_version = app_version self._platform = platform.platform() @@ -356,9 +339,7 @@ def __init__( # Read client credentials from environment variables self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) - self._client_secret = _get_required_config_value( - _CLIENT_SECRET_ENVVAR_NAME, client_secret - ) + self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) self._metadata = ( ("platform", self._platform), @@ -367,9 +348,7 @@ def __init__( self._authenticate(_INGEST_SCOPE) - channel_opts = _diode_ingest_grpc_channel_options( - f"{self._name}/{self._version} {self._app_name}/{self._app_version}" - ) + channel_opts = _diode_ingest_grpc_channel_options(f"{self._name}/{self._version} {self._app_name}/{self._app_version}") proxy_url = _get_grpc_proxy_url(self._target, self._tls_verify) if proxy_url: @@ -381,9 +360,7 @@ def __init__( # Channel creation logic if self._tls_verify: credentials = ( - grpc.ssl_channel_credentials(root_certificates=self._certificates) - if self._certificates - else grpc.ssl_channel_credentials() + grpc.ssl_channel_credentials(root_certificates=self._certificates) if self._certificates else grpc.ssl_channel_credentials() ) _LOGGER.debug( @@ -409,9 +386,7 @@ def __init__( _LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}") rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path) - intercept_channel = grpc.intercept_channel( - self._channel, rpc_method_interceptor - ) + intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor) channel = intercept_channel self._stub = ingester_pb2_grpc.IngesterServiceStub(channel) @@ -420,9 +395,7 @@ def __init__( if self._sentry_dsn is not None: _LOGGER.debug("Setting up Sentry") - self._setup_sentry( - self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate - ) + self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate) @property def name(self) -> str: @@ -502,17 +475,13 @@ def ingest( except grpc.RpcError as err: if err.code() == grpc.StatusCode.UNAUTHENTICATED: if attempt < self._max_auth_retries - 1: - _LOGGER.info( - f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}" - ) + _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") self._authenticate(_INGEST_SCOPE) continue raise DiodeClientError(err) from err raise RuntimeError("Max retries exceeded") - def _setup_sentry( - self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float - ): + def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): sentry_sdk.init( dsn=dsn, release=self.version, @@ -541,11 +510,10 @@ def _authenticate(self, scope: str): self._app_version, self._certificates, self._cert_file, + max_retries=self._max_auth_retries, ) access_token = authentication_client.authenticate() - self._metadata = list( - filter(lambda x: x[0] != "authorization", self._metadata) - ) + [("authorization", f"Bearer {access_token}")] + self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + [("authorization", f"Bearer {access_token}")] class DiodeDryRunClient(DiodeClientInterface): @@ -613,9 +581,7 @@ def ingest( timestamp = time.perf_counter_ns() path = Path(self._output_dir) path.mkdir(parents=True, exist_ok=True) - filename = "".join( - c if c.isalnum() or c in ("_", "-") else "_" for c in self._app_name - ) + filename = "".join(c if c.isalnum() or c in ("_", "-") else "_" for c in self._app_name) file_path = path / f"{filename}_{timestamp}.json" with file_path.open("w") as fh: fh.write(output) @@ -651,18 +617,10 @@ def __init__( self._timeout = timeout self._target, self._path, self._tls_verify = parse_target(target) - self._cert_file = _get_optional_config_value( - _DIODE_CERT_FILE_ENVVAR_NAME, cert_file - ) - self._certificates = ( - _load_certs(self._cert_file) - if (self._tls_verify or self._cert_file) - else None - ) + self._cert_file = _get_optional_config_value(_DIODE_CERT_FILE_ENVVAR_NAME, cert_file) + self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None - channel_opts = _otlp_grpc_channel_options( - f"{self._name}/{self._version} {self._app_name}/{self._app_version}" - ) + channel_opts = _otlp_grpc_channel_options(f"{self._name}/{self._version} {self._app_name}/{self._app_version}") proxy_url = _get_grpc_proxy_url(self._target, self._tls_verify) if proxy_url: @@ -678,9 +636,7 @@ def __init__( # Channel creation logic if self._tls_verify: credentials = ( - grpc.ssl_channel_credentials(root_certificates=self._certificates) - if self._certificates - else grpc.ssl_channel_credentials() + grpc.ssl_channel_credentials(root_certificates=self._certificates) if self._certificates else grpc.ssl_channel_credentials() ) _LOGGER.debug( @@ -772,10 +728,7 @@ def ingest( ) -> ingester_pb2.IngestResponse: """Export entities as OTLP logs with optional request-level metadata.""" stream = stream or _DEFAULT_STREAM - log_records = [ - self._entity_to_log_record(entity) - for entity in self._normalize_entities(entities) - ] + log_records = [self._entity_to_log_record(entity) for entity in self._normalize_entities(entities)] if not log_records: return ingester_pb2.IngestResponse() @@ -793,9 +746,7 @@ def ingest( return ingester_pb2.IngestResponse() - def _normalize_entities( - self, entities: Iterable[Entity | ingester_pb2.Entity | None] - ) -> list[ingester_pb2.Entity]: + def _normalize_entities(self, entities: Iterable[Entity | ingester_pb2.Entity | None]) -> list[ingester_pb2.Entity]: normalized: list[ingester_pb2.Entity] = [] for entity in entities: if entity is None: @@ -813,9 +764,7 @@ def _build_export_request( ) -> logs_service_pb2.ExportLogsServiceRequest: resource_logs = logs_pb2.ResourceLogs() resource_logs.resource.attributes.extend(self._resource_attributes()) - resource_logs.resource.attributes.append( - self._string_kv("diode.stream", stream) - ) + resource_logs.resource.attributes.append(self._string_kv("diode.stream", stream)) # Add request-level metadata as resource attributes with diode.metadata.* prefix if metadata: @@ -867,9 +816,7 @@ def _entity_to_log_record( @staticmethod def _string_kv(key: str, value: str) -> common_pb2.KeyValue: - return common_pb2.KeyValue( - key=key, value=common_pb2.AnyValue(string_value=value) - ) + return common_pb2.KeyValue(key=key, value=common_pb2.AnyValue(string_value=value)) @staticmethod def _value_to_any_value(value: Any) -> common_pb2.AnyValue | None: # noqa: C901 @@ -892,18 +839,14 @@ def _value_to_any_value(value: Any) -> common_pb2.AnyValue | None: # noqa: C901 any_value = DiodeOTLPClient._value_to_any_value(item) if any_value: array_values.append(any_value) - return common_pb2.AnyValue( - array_value=common_pb2.ArrayValue(values=array_values) - ) + return common_pb2.AnyValue(array_value=common_pb2.ArrayValue(values=array_values)) if isinstance(value, dict): # Recursively convert dict to KeyValueList kvlist = common_pb2.KeyValueList() for k, v in value.items(): any_value = DiodeOTLPClient._value_to_any_value(v) if any_value: - kvlist.values.append( - common_pb2.KeyValue(key=k, value=any_value) - ) + kvlist.values.append(common_pb2.KeyValue(key=k, value=any_value)) return common_pb2.AnyValue(kvlist_value=kvlist) # Skip unsupported types return None @@ -932,6 +875,10 @@ def __init__( app_version: str, certificates: bytes | None = None, cert_file: str | None = None, + max_retries: int = 3, + initial_retry_delay: float | None = None, + max_retry_delay: float | None = None, + sleep: Callable[[float], None] | None = None, ): self._target = target self._tls_verify = tls_verify @@ -945,6 +892,10 @@ def __init__( self._app_version = app_version self._certificates = certificates self._cert_file = cert_file + self._max_retries = max_retries + self._initial_retry_delay = _AUTH_INITIAL_RETRY_DELAY if initial_retry_delay is None else initial_retry_delay + self._max_retry_delay = _AUTH_MAX_RETRY_DELAY if max_retry_delay is None else max_retry_delay + self._sleep = sleep or time.sleep def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" @@ -952,65 +903,78 @@ def authenticate(self) -> str: temp_cert_file = None try: - # Configure SSL verification - if self._tls_verify and self._certificates: - # Use cert_file path directly if available, otherwise write to temp file - if self._cert_file: - session.verify = self._cert_file - else: - # Write certificates to temp file for requests - with tempfile.NamedTemporaryFile( - mode="wb", delete=False, suffix=".pem" - ) as f: - f.write(self._certificates) - temp_cert_file = f.name - session.verify = temp_cert_file - elif not self._tls_verify: - session.verify = False - - # Prepare auth request - url = self._get_full_auth_url() - data = { - "grant_type": "client_credentials", - "client_id": self._client_id, - "client_secret": self._client_secret, - "scope": self._scope, - } - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}", - } - - response = session.post(url, data=data, headers=headers) - - if response.status_code != 200: - raise DiodeConfigError( - f"Failed to obtain access token: {response.reason}" - ) - - token_info = response.json() - access_token = token_info.get("access_token") - - if not access_token: - raise DiodeConfigError( - f"Failed to obtain access token for client {self._client_id}" - ) - - _LOGGER.debug(f"Access token obtained for client {self._client_id}") - return access_token - + temp_cert_file = self._configure_auth_session(session) + return self._request_access_token(session) except requests.RequestException as e: raise DiodeConfigError(f"Failed to obtain access token: {e}") finally: - # Clean up temp certificate file if temp_cert_file and os.path.exists(temp_cert_file): try: os.unlink(temp_cert_file) _LOGGER.debug(f"Cleaned up temp certificate file: {temp_cert_file}") except OSError as e: - _LOGGER.warning( - f"Failed to clean up temp certificate file {temp_cert_file}: {e}" - ) + _LOGGER.warning(f"Failed to clean up temp certificate file {temp_cert_file}: {e}") + + def _configure_auth_session(self, session: requests.Session) -> str | None: + temp_cert_file = None + if self._tls_verify and self._certificates: + if self._cert_file: + session.verify = self._cert_file + else: + with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pem") as f: + f.write(self._certificates) + temp_cert_file = f.name + session.verify = temp_cert_file + elif not self._tls_verify: + session.verify = False + return temp_cert_file + + def _request_access_token(self, session: requests.Session) -> str: + url = self._get_full_auth_url() + data = { + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": self._scope, + } + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}", + } + + last_error = "Failed to obtain access token" + for attempt in range(1, self._max_retries + 1): + response = session.post(url, data=data, headers=headers) + + if response.status_code == 200: + access_token = response.json().get("access_token") + if not access_token: + raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}") + _LOGGER.debug(f"Access token obtained for client {self._client_id}") + return access_token + + last_error = f"Failed to obtain access token: {response.reason}" + if not _is_retriable_auth_http_status(response.status_code) or attempt >= self._max_retries: + raise DiodeConfigError(last_error) + + delay = _auth_retry_delay( + attempt, + response.status_code, + response.headers.get("Retry-After"), + self._initial_retry_delay, + self._max_retry_delay, + ) + _LOGGER.debug( + "Auth token request failed, retrying", + extra={ + "status_code": response.status_code, + "attempt": attempt, + "retry_in": delay, + }, + ) + self._sleep(delay) + + raise DiodeConfigError(last_error) def _get_auth_url(self) -> str: """Construct the authentication URL, handling trailing slashes in the path.""" @@ -1036,6 +1000,48 @@ def _get_full_auth_url(self) -> str: return f"{scheme}://{self._target}{path}/auth/token" +def _is_retriable_auth_http_status(status_code: int) -> bool: + return status_code in {429, 500, 502, 503} + + +def _parse_retry_after(value: str | None) -> float | None: + if not value: + return None + try: + seconds = int(value) + except ValueError: + seconds = None + else: + if seconds < 0: + return None + return float(seconds) + + try: + retry_at = parsedate_to_datetime(value) + if retry_at.tzinfo is None: + retry_at = retry_at.replace(tzinfo=timezone.utc) + delay = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() + return max(delay, 0.0) + except (TypeError, ValueError, OverflowError): + return None + + +def _auth_retry_delay( + attempt: int, + status_code: int, + retry_after: str | None, + initial_delay: float, + max_delay: float, +) -> float: + delay: float | None = None + if status_code in (429, 503): + delay = _parse_retry_after(retry_after) + if delay is None: + delay = initial_delay * (2 ** (attempt - 1)) + delay = min(delay, max_delay) + return delay + random.uniform(0, delay / 4) + + class _ClientCallDetails( collections.namedtuple( "_ClientCallDetails", @@ -1058,9 +1064,7 @@ class _ClientCallDetails( """ -class DiodeMethodClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor -): +class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor): """ Diode Method Client Interceptor class. @@ -1099,8 +1103,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request): """Intercept unary unary.""" return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary( - self, continuation, client_call_details, request_iterator - ): + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): """Intercept stream unary.""" return self._intercept_call(continuation, client_call_details, request_iterator) diff --git a/tests/test_client.py b/tests/test_client.py index 622711c..b1432ee 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,12 +17,15 @@ DiodeDryRunClient, DiodeMethodClientInterceptor, DiodeOTLPClient, + _auth_retry_delay, _ClientCallDetails, _diode_ingest_grpc_channel_options, _DiodeAuthentication, _get_sentry_dsn, + _is_retriable_auth_http_status, _load_certs, _otlp_grpc_channel_options, + _parse_retry_after, load_dryrun_entities, parse_target, ) @@ -72,9 +75,7 @@ def test_config_errors(client_id, client_secret, env_var_name): client_id=client_id, client_secret=client_secret, ) - assert ( - str(err.value) == f"parameter or {env_var_name} environment variable required" - ) + assert str(err.value) == f"parameter or {env_var_name} environment variable required" def test_client_error(mock_diode_authentication): @@ -100,10 +101,7 @@ def test_diode_client_error_repr_returns_correct_string(): error = DiodeClientError(grpc_error) error._status_code = grpc.StatusCode.UNAVAILABLE error._details = "Some details about the error" - assert ( - repr(error) - == "" - ) + assert repr(error) == "" def test_load_certs_returns_bytes(): @@ -271,9 +269,7 @@ def test_insecure_channel_options_with_primary_user_agent(mock_diode_authenticat mock_insecure_channel.assert_called_once() _, kwargs = mock_insecure_channel.call_args assert kwargs["options"] == tuple( - _diode_ingest_grpc_channel_options( - f"{client.name}/{client.version} {client.app_name}/{client.app_version}" - ) + _diode_ingest_grpc_channel_options(f"{client.name}/{client.version} {client.app_name}/{client.app_version}") ) @@ -291,9 +287,7 @@ def test_secure_channel_options_with_primary_user_agent(mock_diode_authenticatio mock_secure_channel.assert_called_once() _, kwargs = mock_secure_channel.call_args assert kwargs["options"] == tuple( - _diode_ingest_grpc_channel_options( - f"{client.name}/{client.version} {client.app_name}/{client.app_version}" - ) + _diode_ingest_grpc_channel_options(f"{client.name}/{client.version} {client.app_name}/{client.app_version}") ) @@ -370,9 +364,7 @@ def test_client_setup_sentry_called_when_sentry_dsn_exists(mock_diode_authentica client_secret="123456", sentry_dsn="https://user@password.mock.dsn/123456", ) - mock_setup_sentry.assert_called_once_with( - "https://user@password.mock.dsn/123456", 1.0, 1.0 - ) + mock_setup_sentry.assert_called_once_with("https://user@password.mock.dsn/123456", 1.0, 1.0) def test_client_setup_sentry_not_called_when_sentry_dsn_not_exists( @@ -498,10 +490,7 @@ def continuation(x, _): None, ) request = None - assert ( - interceptor.intercept_unary_unary(continuation, client_call_details, request) - == "/my/path/diode.v1.IngesterService/Ingest" - ) + assert interceptor.intercept_unary_unary(continuation, client_call_details, request) == "/my/path/diode.v1.IngesterService/Ingest" def test_interceptor_intercepts_stream_unary_calls(): @@ -521,9 +510,7 @@ def continuation(x, _): ) request_iterator = None assert ( - interceptor.intercept_stream_unary( - continuation, client_call_details, request_iterator - ) + interceptor.intercept_stream_unary(continuation, client_call_details, request_iterator) == "/my/path/diode.v1.IngesterService/Ingest" ) @@ -705,6 +692,7 @@ def test_diode_authentication_request_exception(mock_diode_authentication): mock_session = mock_session_class.return_value # Import requests.RequestException for the side effect import requests + mock_session.post.side_effect = requests.RequestException("Connection error") with pytest.raises(DiodeConfigError) as excinfo: @@ -712,6 +700,152 @@ def test_diode_authentication_request_exception(mock_diode_authentication): assert "Failed to obtain access token: Connection error" in str(excinfo.value) +@pytest.mark.parametrize( + ("status_code", "expected"), + [ + (429, True), + (500, True), + (502, True), + (503, True), + (401, False), + (403, False), + (400, False), + (200, False), + ], +) +def test_is_retriable_auth_http_status(status_code, expected): + """Check retriable auth HTTP status classification.""" + assert _is_retriable_auth_http_status(status_code) is expected + + +def test_parse_retry_after_seconds(): + """Parse Retry-After header values in seconds.""" + assert _parse_retry_after("5") == 5.0 + + +def test_parse_retry_after_invalid(): + """Return None for invalid Retry-After values.""" + assert _parse_retry_after("not-a-date") is None + assert _parse_retry_after("") is None + + +def test_auth_retry_delay_exponential(): + """Apply exponential backoff when Retry-After is absent.""" + delay1 = _auth_retry_delay(1, 500, None, 1.0, 30.0) + delay2 = _auth_retry_delay(2, 500, None, 1.0, 30.0) + delay3 = _auth_retry_delay(3, 502, None, 1.0, 30.0) + assert 1.0 <= delay1 <= 1.25 + assert 2.0 <= delay2 <= 2.5 + assert 4.0 <= delay3 <= 5.0 + + +def test_auth_retry_delay_honours_retry_after(): + """Honour Retry-After for 429 and 503 responses.""" + delay429 = _auth_retry_delay(1, 429, "7", 1.0, 30.0) + delay503 = _auth_retry_delay(1, 503, "12", 1.0, 30.0) + assert 7.0 <= delay429 <= 8.75 + assert 12.0 <= delay503 <= 15.0 + + +def test_auth_retry_delay_caps_retry_after(): + """Cap Retry-After delays at the configured maximum.""" + delay = _auth_retry_delay(1, 429, "120", 1.0, 30.0) + assert 30.0 <= delay <= 37.5 + + +def test_diode_authentication_retries_retriable_status(mock_diode_authentication): + """Retry auth token fetch on transient HTTP failures.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", + max_retries=3, + initial_retry_delay=0, + max_retry_delay=0, + sleep=lambda _delay: None, + ) + responses = [ + mock.Mock(status_code=503, reason="Service Unavailable", headers={"Retry-After": "0"}), + mock.Mock(status_code=200, json=mock.Mock(return_value={"access_token": "mocked_token"})), + ] + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_session.post.side_effect = responses + + token = auth.authenticate() + assert token == "mocked_token" + assert mock_session.post.call_count == 2 + + +def test_diode_authentication_fails_fast_on_401(mock_diode_authentication): + """Do not retry auth token fetch on 401 responses.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", + max_retries=3, + initial_retry_delay=0, + max_retry_delay=0, + sleep=lambda _delay: None, + ) + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_response = mock.Mock() + mock_response.status_code = 401 + mock_response.reason = "Unauthorized" + mock_response.headers = {} + mock_session.post.return_value = mock_response + + with pytest.raises(DiodeConfigError): + auth.authenticate() + assert mock_session.post.call_count == 1 + + +def test_diode_authentication_exhausts_retries_on_429(mock_diode_authentication): + """Stop retrying auth token fetch after max attempts on 429.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + scope="diode:ingest", + sdk_name="diode-sdk-python", + sdk_version="0.1.0", + app_name="test-app", + app_version="1.0.0", + max_retries=2, + initial_retry_delay=0, + max_retry_delay=0, + sleep=lambda _delay: None, + ) + with mock.patch("requests.Session") as mock_session_class: + mock_session = mock_session_class.return_value + mock_response = mock.Mock() + mock_response.status_code = 429 + mock_response.reason = "Too Many Requests" + mock_response.headers = {"Retry-After": "0"} + mock_session.post.return_value = mock_response + + with pytest.raises(DiodeConfigError): + auth.authenticate() + assert mock_session.post.call_count == 2 + + def test_ingest_dry_run_stdout(capsys): """Verify ingest prints JSON when dry run is enabled.""" client = DiodeDryRunClient() @@ -767,21 +901,15 @@ def test_load_dryrun_entities_from_fixture(message_path, tmp_path): assert isinstance(entities[0], ingester_pb2.Entity) assert entities[0].asn.asn == 555 assert entities[33].ip_address.address == "192.168.100.1/24" - assert ( - entities[33].ip_address.assigned_object_interface.name == "GigabitEthernet1/0/1" - ) + assert entities[33].ip_address.assigned_object_interface.name == "GigabitEthernet1/0/1" assert entities[-1].wireless_link.ssid == "P2P-Link-1" def test_otlp_client_exports_entities(): """Ensure DiodeOTLPClient serializes entities and exports them as logs.""" with ( - patch( - "netboxlabs.diode.sdk.client.grpc.insecure_channel" - ) as mock_insecure_channel, - patch( - "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" - ) as mock_stub_cls, + patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, + patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -792,9 +920,7 @@ def test_otlp_client_exports_entities(): app_version="1.2.3", ) - response = client.ingest( - entities=[Entity(site="Site1"), Entity(device="Device1")] - ) + response = client.ingest(entities=[Entity(site="Site1"), Entity(device="Device1")]) stub_instance.Export.assert_called_once() export_args, export_kwargs = stub_instance.Export.call_args @@ -826,18 +952,12 @@ def details(self): return self._details with ( - patch( - "netboxlabs.diode.sdk.client.grpc.insecure_channel" - ) as mock_insecure_channel, - patch( - "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" - ) as mock_stub_cls, + patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, + patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value - stub_instance.Export.side_effect = DummyRpcError( - grpc.StatusCode.UNAVAILABLE, "endpoint offline" - ) + stub_instance.Export.side_effect = DummyRpcError(grpc.StatusCode.UNAVAILABLE, "endpoint offline") client = DiodeOTLPClient( target="grpc://collector:4317", @@ -855,13 +975,9 @@ def details(self): def test_otlp_client_grpcs_uses_secure_channel(): """Ensure DiodeOTLPClient configures SSL credentials for secure targets.""" with ( - patch( - "netboxlabs.diode.sdk.client.grpc.ssl_channel_credentials" - ) as mock_ssl_credentials, + patch("netboxlabs.diode.sdk.client.grpc.ssl_channel_credentials") as mock_ssl_credentials, patch("netboxlabs.diode.sdk.client.grpc.secure_channel") as mock_secure_channel, - patch( - "netboxlabs.diode.sdk.client.grpc.intercept_channel" - ) as mock_intercept_channel, + patch("netboxlabs.diode.sdk.client.grpc.intercept_channel") as mock_intercept_channel, patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub"), ): base_channel = mock.Mock() @@ -898,22 +1014,15 @@ def test_otlp_insecure_channel_options_exclude_diode_keepalive(): mock_insecure.assert_called_once() _, kwargs = mock_insecure.call_args - ua = ( - f"{client.name}/{client.version} " - f"{client.app_name}/{client.app_version}" - ) + ua = f"{client.name}/{client.version} {client.app_name}/{client.app_version}" assert kwargs["options"] == tuple(_otlp_grpc_channel_options(ua)) - assert all( - opt[0] != "grpc.keepalive_time_ms" for opt in kwargs["options"] - ) + assert all(opt[0] != "grpc.keepalive_time_ms" for opt in kwargs["options"]) def test_diode_authentication_with_custom_certificates(): """Test _DiodeAuthentication with custom certificates - covers SSL context creation.""" # Create test certificate content - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" auth = _DiodeAuthentication( target="example.com:443", @@ -984,9 +1093,7 @@ def test_diode_authentication_with_custom_certificates(): def test_load_certs_with_custom_cert_file(tmp_path): """Test _load_certs loads custom certificate file.""" # Create a dummy certificate file - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1004,9 +1111,7 @@ def test_load_certs_with_none_uses_default(): def test_client_with_cert_file_parameter(mock_diode_authentication, tmp_path): """Test DiodeClient with cert_file parameter loads custom cert but respects TLS scheme.""" # Create a dummy certificate file - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1035,9 +1140,7 @@ def test_client_with_cert_file_env_var(mock_diode_authentication, tmp_path): from netboxlabs.diode.sdk.client import _DIODE_CERT_FILE_ENVVAR_NAME # Create a dummy certificate file - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1073,19 +1176,13 @@ def test_client_with_cert_file_env_var(mock_diode_authentication, tmp_path): del os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] -def test_client_cert_file_parameter_overrides_env_var( - mock_diode_authentication, tmp_path -): +def test_client_cert_file_parameter_overrides_env_var(mock_diode_authentication, tmp_path): """Test cert_file parameter takes precedence over environment variable.""" from netboxlabs.diode.sdk.client import _DIODE_CERT_FILE_ENVVAR_NAME # Create two dummy certificate files - env_cert_content = ( - b"-----BEGIN CERTIFICATE-----\nENV CERT\n-----END CERTIFICATE-----\n" - ) - param_cert_content = ( - b"-----BEGIN CERTIFICATE-----\nPARAM CERT\n-----END CERTIFICATE-----\n" - ) + env_cert_content = b"-----BEGIN CERTIFICATE-----\nENV CERT\n-----END CERTIFICATE-----\n" + param_cert_content = b"-----BEGIN CERTIFICATE-----\nPARAM CERT\n-----END CERTIFICATE-----\n" env_cert_file = tmp_path / "env.pem" param_cert_file = tmp_path / "param.pem" @@ -1127,9 +1224,7 @@ def test_client_cert_file_parameter_overrides_env_var( def test_client_secure_channel_uses_custom_cert(mock_diode_authentication, tmp_path): """Test secure channel creation uses custom certificate when provided.""" # Create a dummy certificate file - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1180,9 +1275,7 @@ def test_client_without_cert_file_uses_default_certs(mock_diode_authentication): mock_load_certs.assert_called_with(None) # Verify ssl_channel_credentials was called with default cert content - mock_ssl_creds.assert_called_once_with( - root_certificates=b"default cert content" - ) + mock_ssl_creds.assert_called_once_with(root_certificates=b"default cert content") # Verify secure_channel was called mock_secure_channel.assert_called_once() @@ -1218,16 +1311,12 @@ def test_should_verify_tls_with_skip_env_var(): # Test truthy values that should skip TLS verification for skip_value in ["true", "True", "TRUE", "1", "yes", "on"]: os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = skip_value - assert ( - _should_verify_tls("grpcs") is False - ) # Should skip even for secure schemes + assert _should_verify_tls("grpcs") is False # Should skip even for secure schemes # Test falsy values that should NOT skip TLS verification for verify_value in ["false", "0", "no", "off", "", "random"]: os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = verify_value - assert ( - _should_verify_tls("grpcs") is True - ) # Should verify for secure schemes + assert _should_verify_tls("grpcs") is True # Should verify for secure schemes finally: # Clean up environment variable @@ -1272,16 +1361,12 @@ def test_client_with_skip_tls_verify_env_var(mock_diode_authentication): del os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] -def test_client_cert_file_with_skip_tls_verify_env_var( - mock_diode_authentication, tmp_path -): +def test_client_cert_file_with_skip_tls_verify_env_var(mock_diode_authentication, tmp_path): """Test cert_file parameter with DIODE_SKIP_TLS_VERIFY environment variable.""" from netboxlabs.diode.sdk.client import _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME # Create a dummy certificate file - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1322,17 +1407,13 @@ def test_client_cert_file_with_skip_tls_verify_env_var( def test_certificate_loading_efficiency(tmp_path): """Test that certificates are loaded only once during client initialization.""" # Create a dummy certificate file - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) with ( mock.patch("netboxlabs.diode.sdk.client._load_certs") as mock_load_certs, - mock.patch( - "netboxlabs.diode.sdk.client._DiodeAuthentication" - ) as mock_auth_class, + mock.patch("netboxlabs.diode.sdk.client._DiodeAuthentication") as mock_auth_class, ): mock_load_certs.return_value = cert_content mock_auth_instance = mock_auth_class.return_value @@ -1493,12 +1574,8 @@ def test_dryrun_client_includes_metadata_in_output(tmp_path): def test_otlp_client_maps_metadata_to_resource_attributes(): """Test DiodeOTLPClient maps request metadata to OTLP resource attributes.""" with ( - patch( - "netboxlabs.diode.sdk.client.grpc.insecure_channel" - ) as mock_insecure_channel, - patch( - "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" - ) as mock_stub_cls, + patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, + patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1541,12 +1618,8 @@ def test_otlp_client_maps_metadata_to_resource_attributes(): def test_otlp_client_handles_nested_metadata(): """Test DiodeOTLPClient handles nested metadata structures.""" with ( - patch( - "netboxlabs.diode.sdk.client.grpc.insecure_channel" - ) as mock_insecure_channel, - patch( - "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" - ) as mock_stub_cls, + patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, + patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1607,12 +1680,8 @@ def test_otlp_client_handles_nested_metadata(): def test_otlp_client_metadata_type_conversion(): """Test DiodeOTLPClient correctly converts different Python types.""" with ( - patch( - "netboxlabs.diode.sdk.client.grpc.insecure_channel" - ) as mock_insecure_channel, - patch( - "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" - ) as mock_stub_cls, + patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, + patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1656,12 +1725,8 @@ def test_otlp_client_metadata_type_conversion(): def test_otlp_client_without_metadata(): """Test DiodeOTLPClient works without metadata (backward compatibility).""" with ( - patch( - "netboxlabs.diode.sdk.client.grpc.insecure_channel" - ) as mock_insecure_channel, - patch( - "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" - ) as mock_stub_cls, + patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, + patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1833,9 +1898,7 @@ def test_diode_client_configures_proxy_option(mock_diode_authentication): options = kwargs["options"] # Check that grpc.http_proxy option is present - proxy_option = next( - (opt for opt in options if opt[0] == "grpc.http_proxy"), None - ) + proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) assert proxy_option is not None assert proxy_option[1] == "http://proxy.example.com:8080" finally: @@ -1864,9 +1927,7 @@ def test_diode_client_uses_insecure_channel_with_proxy_when_skip_tls( options = kwargs["options"] # Verify proxy option is set - proxy_option = next( - (opt for opt in options if opt[0] == "grpc.http_proxy"), None - ) + proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) assert proxy_option is not None assert proxy_option[1] == "http://proxy.example.com:8080" finally: @@ -1893,9 +1954,7 @@ def test_diode_client_respects_no_proxy_for_target(mock_diode_authentication): options = kwargs["options"] # Check that grpc.http_proxy option is NOT present - proxy_option = next( - (opt for opt in options if opt[0] == "grpc.http_proxy"), None - ) + proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) assert proxy_option is None finally: del os.environ["HTTP_PROXY"] @@ -1904,9 +1963,7 @@ def test_diode_client_respects_no_proxy_for_target(mock_diode_authentication): def test_diode_client_with_proxy_and_custom_cert(mock_diode_authentication, tmp_path): """Test DiodeClient with proxy and custom certificate (for MITM proxies).""" - cert_content = ( - b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" - ) + cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1938,9 +1995,7 @@ def test_diode_client_with_proxy_and_custom_cert(mock_diode_authentication, tmp_ # Verify proxy option is set _, kwargs = mock_secure_channel.call_args options = kwargs["options"] - proxy_option = next( - (opt for opt in options if opt[0] == "grpc.http_proxy"), None - ) + proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) assert proxy_option is not None assert proxy_option[1] == "http://proxy.example.com:8080" finally: @@ -2142,9 +2197,7 @@ def test_diode_client_with_invalid_proxy_url_falls_back_to_no_proxy( # Verify no proxy option is set (invalid proxy was rejected) _, kwargs = mock_insecure_channel.call_args options = kwargs["options"] - proxy_option = next( - (opt for opt in options if opt[0] == "grpc.http_proxy"), None - ) + proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) assert proxy_option is None # Should log warning about invalid proxy From 45c7360e18e40a74fcaa83d0ff0cf9a51c4ae4da Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Fri, 26 Jun 2026 10:47:44 +0100 Subject: [PATCH 2/3] fix(OBS-3341): clamp auth retry delay after jitter Apply jitter before the final min() so sleeps never exceed the configured max delay cap. Co-authored-by: Cursor --- netboxlabs/diode/sdk/client.py | 3 ++- tests/test_client.py | 9 ++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 05a16ac..d1e57a1 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -1039,7 +1039,8 @@ def _auth_retry_delay( if delay is None: delay = initial_delay * (2 ** (attempt - 1)) delay = min(delay, max_delay) - return delay + random.uniform(0, delay / 4) + delay += random.uniform(0, delay / 4) + return min(delay, max_delay) class _ClientCallDetails( diff --git a/tests/test_client.py b/tests/test_client.py index b1432ee..778f36f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -750,7 +750,14 @@ def test_auth_retry_delay_honours_retry_after(): def test_auth_retry_delay_caps_retry_after(): """Cap Retry-After delays at the configured maximum.""" delay = _auth_retry_delay(1, 429, "120", 1.0, 30.0) - assert 30.0 <= delay <= 37.5 + assert delay == 30.0 + + +def test_auth_retry_delay_never_exceeds_max(): + """Jitter must not push the final delay above the configured maximum.""" + for attempt in range(1, 8): + delay = _auth_retry_delay(attempt, 500, None, 1.0, 30.0) + assert delay <= 30.0 def test_diode_authentication_retries_retriable_status(mock_diode_authentication): From 5e650c6f339f3442d59c174a198870751d565556 Mon Sep 17 00:00:00 2001 From: James Jeffries Date: Fri, 26 Jun 2026 16:19:56 +0100 Subject: [PATCH 3/3] chore(OBS-3341): drop incidental formatting from auth backoff PR Restore develop formatting in client.py and test_client.py while keeping the OAuth retry/backoff behaviour, helper functions, README note, and tests. Co-authored-by: Cursor --- netboxlabs/diode/sdk/client.py | 150 ++++++++++++++++----- tests/test_client.py | 230 +++++++++++++++++++++++---------- 2 files changed, 280 insertions(+), 100 deletions(-) diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index d1e57a1..590d2dd 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -104,7 +104,9 @@ def parse_target(target: str) -> tuple[str, str, bool]: parsed_target = urlparse(target) if parsed_target.scheme not in ["grpc", "grpcs", "http", "https"]: - raise ValueError("target should start with grpc://, grpcs://, http:// or https://") + raise ValueError( + "target should start with grpc://, grpcs://, http:// or https://" + ) # Determine if TLS verification should be enabled tls_verify = _should_verify_tls(parsed_target.scheme) @@ -132,11 +134,15 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s if value is None: value = os.getenv(env_var_name) if value is None: - raise DiodeConfigError(f"parameter or {env_var_name} environment variable required") + raise DiodeConfigError( + f"parameter or {env_var_name} environment variable required" + ) return value -def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None: +def _get_optional_config_value( + env_var_name: str, value: str | None = None +) -> str | None: """Get optional config value either from provided value or environment variable.""" if value is None: value = os.getenv(env_var_name) @@ -251,11 +257,17 @@ def _should_bypass_proxy(target_host: str) -> bool: # Maximum reasonable length for hostname/domain (RFC 1035: 253 chars, we allow 256) MAX_NO_PROXY_ENTRY_LENGTH = 256 - no_proxy_list = [entry.strip().lower() for entry in no_proxy.split(",") if len(entry.strip()) <= MAX_NO_PROXY_ENTRY_LENGTH] + no_proxy_list = [ + entry.strip().lower() + for entry in no_proxy.split(",") + if len(entry.strip()) <= MAX_NO_PROXY_ENTRY_LENGTH + ] filtered_count = len([e for e in no_proxy.split(",") if len(e.strip()) > MAX_NO_PROXY_ENTRY_LENGTH]) if filtered_count > 0: - _LOGGER.warning(f"Ignored {filtered_count} NO_PROXY entries exceeding {MAX_NO_PROXY_ENTRY_LENGTH} characters") + _LOGGER.warning( + f"Ignored {filtered_count} NO_PROXY entries exceeding {MAX_NO_PROXY_ENTRY_LENGTH} characters" + ) for entry in no_proxy_list: if entry and _matches_no_proxy_entry(host, entry): @@ -291,7 +303,8 @@ def _get_grpc_proxy_url(target_host: str, use_tls: bool) -> str | None: if proxy_url: if not _validate_proxy_url(proxy_url): _LOGGER.warning( - f"Invalid proxy URL format: {proxy_url}. Proxy URL must be http:// or https:// with valid host. Ignoring proxy." + f"Invalid proxy URL format: {proxy_url}. " + f"Proxy URL must be http:// or https:// with valid host. Ignoring proxy." ) return None _LOGGER.debug(f"Using proxy {proxy_url} for gRPC target {target_host}") @@ -326,12 +339,21 @@ def __init__( log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) - self._max_auth_retries = int(_get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries)) or max_auth_retries) - self._cert_file = _get_optional_config_value(_DIODE_CERT_FILE_ENVVAR_NAME, cert_file) + self._max_auth_retries = int( + _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, str(max_auth_retries)) + or max_auth_retries + ) + self._cert_file = _get_optional_config_value( + _DIODE_CERT_FILE_ENVVAR_NAME, cert_file + ) self._target, self._path, self._tls_verify = parse_target(target) # Load certificates once if needed - self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None + self._certificates = ( + _load_certs(self._cert_file) + if (self._tls_verify or self._cert_file) + else None + ) self._app_name = app_name self._app_version = app_version self._platform = platform.platform() @@ -339,7 +361,9 @@ def __init__( # Read client credentials from environment variables self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) - self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) + self._client_secret = _get_required_config_value( + _CLIENT_SECRET_ENVVAR_NAME, client_secret + ) self._metadata = ( ("platform", self._platform), @@ -348,7 +372,9 @@ def __init__( self._authenticate(_INGEST_SCOPE) - channel_opts = _diode_ingest_grpc_channel_options(f"{self._name}/{self._version} {self._app_name}/{self._app_version}") + channel_opts = _diode_ingest_grpc_channel_options( + f"{self._name}/{self._version} {self._app_name}/{self._app_version}" + ) proxy_url = _get_grpc_proxy_url(self._target, self._tls_verify) if proxy_url: @@ -360,7 +386,9 @@ def __init__( # Channel creation logic if self._tls_verify: credentials = ( - grpc.ssl_channel_credentials(root_certificates=self._certificates) if self._certificates else grpc.ssl_channel_credentials() + grpc.ssl_channel_credentials(root_certificates=self._certificates) + if self._certificates + else grpc.ssl_channel_credentials() ) _LOGGER.debug( @@ -386,7 +414,9 @@ def __init__( _LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}") rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path) - intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor) + intercept_channel = grpc.intercept_channel( + self._channel, rpc_method_interceptor + ) channel = intercept_channel self._stub = ingester_pb2_grpc.IngesterServiceStub(channel) @@ -395,7 +425,9 @@ def __init__( if self._sentry_dsn is not None: _LOGGER.debug("Setting up Sentry") - self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate) + self._setup_sentry( + self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate + ) @property def name(self) -> str: @@ -475,13 +507,17 @@ def ingest( except grpc.RpcError as err: if err.code() == grpc.StatusCode.UNAUTHENTICATED: if attempt < self._max_auth_retries - 1: - _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") + _LOGGER.info( + f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}" + ) self._authenticate(_INGEST_SCOPE) continue raise DiodeClientError(err) from err raise RuntimeError("Max retries exceeded") - def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): + def _setup_sentry( + self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float + ): sentry_sdk.init( dsn=dsn, release=self.version, @@ -513,7 +549,9 @@ def _authenticate(self, scope: str): max_retries=self._max_auth_retries, ) access_token = authentication_client.authenticate() - self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + [("authorization", f"Bearer {access_token}")] + self._metadata = list( + filter(lambda x: x[0] != "authorization", self._metadata) + ) + [("authorization", f"Bearer {access_token}")] class DiodeDryRunClient(DiodeClientInterface): @@ -581,7 +619,9 @@ def ingest( timestamp = time.perf_counter_ns() path = Path(self._output_dir) path.mkdir(parents=True, exist_ok=True) - filename = "".join(c if c.isalnum() or c in ("_", "-") else "_" for c in self._app_name) + filename = "".join( + c if c.isalnum() or c in ("_", "-") else "_" for c in self._app_name + ) file_path = path / f"{filename}_{timestamp}.json" with file_path.open("w") as fh: fh.write(output) @@ -617,10 +657,18 @@ def __init__( self._timeout = timeout self._target, self._path, self._tls_verify = parse_target(target) - self._cert_file = _get_optional_config_value(_DIODE_CERT_FILE_ENVVAR_NAME, cert_file) - self._certificates = _load_certs(self._cert_file) if (self._tls_verify or self._cert_file) else None + self._cert_file = _get_optional_config_value( + _DIODE_CERT_FILE_ENVVAR_NAME, cert_file + ) + self._certificates = ( + _load_certs(self._cert_file) + if (self._tls_verify or self._cert_file) + else None + ) - channel_opts = _otlp_grpc_channel_options(f"{self._name}/{self._version} {self._app_name}/{self._app_version}") + channel_opts = _otlp_grpc_channel_options( + f"{self._name}/{self._version} {self._app_name}/{self._app_version}" + ) proxy_url = _get_grpc_proxy_url(self._target, self._tls_verify) if proxy_url: @@ -636,7 +684,9 @@ def __init__( # Channel creation logic if self._tls_verify: credentials = ( - grpc.ssl_channel_credentials(root_certificates=self._certificates) if self._certificates else grpc.ssl_channel_credentials() + grpc.ssl_channel_credentials(root_certificates=self._certificates) + if self._certificates + else grpc.ssl_channel_credentials() ) _LOGGER.debug( @@ -728,7 +778,10 @@ def ingest( ) -> ingester_pb2.IngestResponse: """Export entities as OTLP logs with optional request-level metadata.""" stream = stream or _DEFAULT_STREAM - log_records = [self._entity_to_log_record(entity) for entity in self._normalize_entities(entities)] + log_records = [ + self._entity_to_log_record(entity) + for entity in self._normalize_entities(entities) + ] if not log_records: return ingester_pb2.IngestResponse() @@ -746,7 +799,9 @@ def ingest( return ingester_pb2.IngestResponse() - def _normalize_entities(self, entities: Iterable[Entity | ingester_pb2.Entity | None]) -> list[ingester_pb2.Entity]: + def _normalize_entities( + self, entities: Iterable[Entity | ingester_pb2.Entity | None] + ) -> list[ingester_pb2.Entity]: normalized: list[ingester_pb2.Entity] = [] for entity in entities: if entity is None: @@ -764,7 +819,9 @@ def _build_export_request( ) -> logs_service_pb2.ExportLogsServiceRequest: resource_logs = logs_pb2.ResourceLogs() resource_logs.resource.attributes.extend(self._resource_attributes()) - resource_logs.resource.attributes.append(self._string_kv("diode.stream", stream)) + resource_logs.resource.attributes.append( + self._string_kv("diode.stream", stream) + ) # Add request-level metadata as resource attributes with diode.metadata.* prefix if metadata: @@ -816,7 +873,9 @@ def _entity_to_log_record( @staticmethod def _string_kv(key: str, value: str) -> common_pb2.KeyValue: - return common_pb2.KeyValue(key=key, value=common_pb2.AnyValue(string_value=value)) + return common_pb2.KeyValue( + key=key, value=common_pb2.AnyValue(string_value=value) + ) @staticmethod def _value_to_any_value(value: Any) -> common_pb2.AnyValue | None: # noqa: C901 @@ -839,14 +898,18 @@ def _value_to_any_value(value: Any) -> common_pb2.AnyValue | None: # noqa: C901 any_value = DiodeOTLPClient._value_to_any_value(item) if any_value: array_values.append(any_value) - return common_pb2.AnyValue(array_value=common_pb2.ArrayValue(values=array_values)) + return common_pb2.AnyValue( + array_value=common_pb2.ArrayValue(values=array_values) + ) if isinstance(value, dict): # Recursively convert dict to KeyValueList kvlist = common_pb2.KeyValueList() for k, v in value.items(): any_value = DiodeOTLPClient._value_to_any_value(v) if any_value: - kvlist.values.append(common_pb2.KeyValue(key=k, value=any_value)) + kvlist.values.append( + common_pb2.KeyValue(key=k, value=any_value) + ) return common_pb2.AnyValue(kvlist_value=kvlist) # Skip unsupported types return None @@ -893,8 +956,12 @@ def __init__( self._certificates = certificates self._cert_file = cert_file self._max_retries = max_retries - self._initial_retry_delay = _AUTH_INITIAL_RETRY_DELAY if initial_retry_delay is None else initial_retry_delay - self._max_retry_delay = _AUTH_MAX_RETRY_DELAY if max_retry_delay is None else max_retry_delay + self._initial_retry_delay = ( + _AUTH_INITIAL_RETRY_DELAY if initial_retry_delay is None else initial_retry_delay + ) + self._max_retry_delay = ( + _AUTH_MAX_RETRY_DELAY if max_retry_delay is None else max_retry_delay + ) self._sleep = sleep or time.sleep def authenticate(self) -> str: @@ -908,20 +975,27 @@ def authenticate(self) -> str: except requests.RequestException as e: raise DiodeConfigError(f"Failed to obtain access token: {e}") finally: + # Clean up temp certificate file if temp_cert_file and os.path.exists(temp_cert_file): try: os.unlink(temp_cert_file) _LOGGER.debug(f"Cleaned up temp certificate file: {temp_cert_file}") except OSError as e: - _LOGGER.warning(f"Failed to clean up temp certificate file {temp_cert_file}: {e}") + _LOGGER.warning( + f"Failed to clean up temp certificate file {temp_cert_file}: {e}" + ) def _configure_auth_session(self, session: requests.Session) -> str | None: temp_cert_file = None if self._tls_verify and self._certificates: + # Use cert_file path directly if available, otherwise write to temp file if self._cert_file: session.verify = self._cert_file else: - with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".pem") as f: + # Write certificates to temp file for requests + with tempfile.NamedTemporaryFile( + mode="wb", delete=False, suffix=".pem" + ) as f: f.write(self._certificates) temp_cert_file = f.name session.verify = temp_cert_file @@ -949,7 +1023,9 @@ def _request_access_token(self, session: requests.Session) -> str: if response.status_code == 200: access_token = response.json().get("access_token") if not access_token: - raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}") + raise DiodeConfigError( + f"Failed to obtain access token for client {self._client_id}" + ) _LOGGER.debug(f"Access token obtained for client {self._client_id}") return access_token @@ -1065,7 +1141,9 @@ class _ClientCallDetails( """ -class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor): +class DiodeMethodClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor +): """ Diode Method Client Interceptor class. @@ -1104,6 +1182,8 @@ def intercept_unary_unary(self, continuation, client_call_details, request): """Intercept unary unary.""" return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): """Intercept stream unary.""" return self._intercept_call(continuation, client_call_details, request_iterator) diff --git a/tests/test_client.py b/tests/test_client.py index 778f36f..54dfbcc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,12 +12,12 @@ import pytest from netboxlabs.diode.sdk.client import ( + _auth_retry_delay, _DIODE_SENTRY_DSN_ENVVAR_NAME, DiodeClient, DiodeDryRunClient, DiodeMethodClientInterceptor, DiodeOTLPClient, - _auth_retry_delay, _ClientCallDetails, _diode_ingest_grpc_channel_options, _DiodeAuthentication, @@ -75,7 +75,9 @@ def test_config_errors(client_id, client_secret, env_var_name): client_id=client_id, client_secret=client_secret, ) - assert str(err.value) == f"parameter or {env_var_name} environment variable required" + assert ( + str(err.value) == f"parameter or {env_var_name} environment variable required" + ) def test_client_error(mock_diode_authentication): @@ -101,7 +103,10 @@ def test_diode_client_error_repr_returns_correct_string(): error = DiodeClientError(grpc_error) error._status_code = grpc.StatusCode.UNAVAILABLE error._details = "Some details about the error" - assert repr(error) == "" + assert ( + repr(error) + == "" + ) def test_load_certs_returns_bytes(): @@ -269,7 +274,9 @@ def test_insecure_channel_options_with_primary_user_agent(mock_diode_authenticat mock_insecure_channel.assert_called_once() _, kwargs = mock_insecure_channel.call_args assert kwargs["options"] == tuple( - _diode_ingest_grpc_channel_options(f"{client.name}/{client.version} {client.app_name}/{client.app_version}") + _diode_ingest_grpc_channel_options( + f"{client.name}/{client.version} {client.app_name}/{client.app_version}" + ) ) @@ -287,7 +294,9 @@ def test_secure_channel_options_with_primary_user_agent(mock_diode_authenticatio mock_secure_channel.assert_called_once() _, kwargs = mock_secure_channel.call_args assert kwargs["options"] == tuple( - _diode_ingest_grpc_channel_options(f"{client.name}/{client.version} {client.app_name}/{client.app_version}") + _diode_ingest_grpc_channel_options( + f"{client.name}/{client.version} {client.app_name}/{client.app_version}" + ) ) @@ -364,7 +373,9 @@ def test_client_setup_sentry_called_when_sentry_dsn_exists(mock_diode_authentica client_secret="123456", sentry_dsn="https://user@password.mock.dsn/123456", ) - mock_setup_sentry.assert_called_once_with("https://user@password.mock.dsn/123456", 1.0, 1.0) + mock_setup_sentry.assert_called_once_with( + "https://user@password.mock.dsn/123456", 1.0, 1.0 + ) def test_client_setup_sentry_not_called_when_sentry_dsn_not_exists( @@ -490,7 +501,10 @@ def continuation(x, _): None, ) request = None - assert interceptor.intercept_unary_unary(continuation, client_call_details, request) == "/my/path/diode.v1.IngesterService/Ingest" + assert ( + interceptor.intercept_unary_unary(continuation, client_call_details, request) + == "/my/path/diode.v1.IngesterService/Ingest" + ) def test_interceptor_intercepts_stream_unary_calls(): @@ -510,7 +524,9 @@ def continuation(x, _): ) request_iterator = None assert ( - interceptor.intercept_stream_unary(continuation, client_call_details, request_iterator) + interceptor.intercept_stream_unary( + continuation, client_call_details, request_iterator + ) == "/my/path/diode.v1.IngesterService/Ingest" ) @@ -692,7 +708,6 @@ def test_diode_authentication_request_exception(mock_diode_authentication): mock_session = mock_session_class.return_value # Import requests.RequestException for the side effect import requests - mock_session.post.side_effect = requests.RequestException("Connection error") with pytest.raises(DiodeConfigError) as excinfo: @@ -779,8 +794,15 @@ def test_diode_authentication_retries_retriable_status(mock_diode_authentication sleep=lambda _delay: None, ) responses = [ - mock.Mock(status_code=503, reason="Service Unavailable", headers={"Retry-After": "0"}), - mock.Mock(status_code=200, json=mock.Mock(return_value={"access_token": "mocked_token"})), + mock.Mock( + status_code=503, + reason="Service Unavailable", + headers={"Retry-After": "0"}, + ), + mock.Mock( + status_code=200, + json=mock.Mock(return_value={"access_token": "mocked_token"}), + ), ] with mock.patch("requests.Session") as mock_session_class: mock_session = mock_session_class.return_value @@ -811,19 +833,16 @@ def test_diode_authentication_fails_fast_on_401(mock_diode_authentication): ) with mock.patch("requests.Session") as mock_session_class: mock_session = mock_session_class.return_value - mock_response = mock.Mock() - mock_response.status_code = 401 - mock_response.reason = "Unauthorized" - mock_response.headers = {} - mock_session.post.return_value = mock_response + mock_session.post.return_value = mock.Mock(status_code=401, reason="Unauthorized") - with pytest.raises(DiodeConfigError): + with pytest.raises(DiodeConfigError) as excinfo: auth.authenticate() + assert "Failed to obtain access token: Unauthorized" in str(excinfo.value) assert mock_session.post.call_count == 1 -def test_diode_authentication_exhausts_retries_on_429(mock_diode_authentication): - """Stop retrying auth token fetch after max attempts on 429.""" +def test_diode_authentication_exhausts_retries(mock_diode_authentication): + """Raise after exhausting auth retry attempts.""" auth = _DiodeAuthentication( target="localhost:8081", path="/diode", @@ -842,14 +861,14 @@ def test_diode_authentication_exhausts_retries_on_429(mock_diode_authentication) ) with mock.patch("requests.Session") as mock_session_class: mock_session = mock_session_class.return_value - mock_response = mock.Mock() - mock_response.status_code = 429 - mock_response.reason = "Too Many Requests" - mock_response.headers = {"Retry-After": "0"} - mock_session.post.return_value = mock_response + mock_session.post.return_value = mock.Mock( + status_code=503, + reason="Service Unavailable", + ) - with pytest.raises(DiodeConfigError): + with pytest.raises(DiodeConfigError) as excinfo: auth.authenticate() + assert "Failed to obtain access token: Service Unavailable" in str(excinfo.value) assert mock_session.post.call_count == 2 @@ -908,15 +927,21 @@ def test_load_dryrun_entities_from_fixture(message_path, tmp_path): assert isinstance(entities[0], ingester_pb2.Entity) assert entities[0].asn.asn == 555 assert entities[33].ip_address.address == "192.168.100.1/24" - assert entities[33].ip_address.assigned_object_interface.name == "GigabitEthernet1/0/1" + assert ( + entities[33].ip_address.assigned_object_interface.name == "GigabitEthernet1/0/1" + ) assert entities[-1].wireless_link.ssid == "P2P-Link-1" def test_otlp_client_exports_entities(): """Ensure DiodeOTLPClient serializes entities and exports them as logs.""" with ( - patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, - patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, + patch( + "netboxlabs.diode.sdk.client.grpc.insecure_channel" + ) as mock_insecure_channel, + patch( + "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" + ) as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -927,7 +952,9 @@ def test_otlp_client_exports_entities(): app_version="1.2.3", ) - response = client.ingest(entities=[Entity(site="Site1"), Entity(device="Device1")]) + response = client.ingest( + entities=[Entity(site="Site1"), Entity(device="Device1")] + ) stub_instance.Export.assert_called_once() export_args, export_kwargs = stub_instance.Export.call_args @@ -959,12 +986,18 @@ def details(self): return self._details with ( - patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, - patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, + patch( + "netboxlabs.diode.sdk.client.grpc.insecure_channel" + ) as mock_insecure_channel, + patch( + "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" + ) as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value - stub_instance.Export.side_effect = DummyRpcError(grpc.StatusCode.UNAVAILABLE, "endpoint offline") + stub_instance.Export.side_effect = DummyRpcError( + grpc.StatusCode.UNAVAILABLE, "endpoint offline" + ) client = DiodeOTLPClient( target="grpc://collector:4317", @@ -982,9 +1015,13 @@ def details(self): def test_otlp_client_grpcs_uses_secure_channel(): """Ensure DiodeOTLPClient configures SSL credentials for secure targets.""" with ( - patch("netboxlabs.diode.sdk.client.grpc.ssl_channel_credentials") as mock_ssl_credentials, + patch( + "netboxlabs.diode.sdk.client.grpc.ssl_channel_credentials" + ) as mock_ssl_credentials, patch("netboxlabs.diode.sdk.client.grpc.secure_channel") as mock_secure_channel, - patch("netboxlabs.diode.sdk.client.grpc.intercept_channel") as mock_intercept_channel, + patch( + "netboxlabs.diode.sdk.client.grpc.intercept_channel" + ) as mock_intercept_channel, patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub"), ): base_channel = mock.Mock() @@ -1021,15 +1058,22 @@ def test_otlp_insecure_channel_options_exclude_diode_keepalive(): mock_insecure.assert_called_once() _, kwargs = mock_insecure.call_args - ua = f"{client.name}/{client.version} {client.app_name}/{client.app_version}" + ua = ( + f"{client.name}/{client.version} " + f"{client.app_name}/{client.app_version}" + ) assert kwargs["options"] == tuple(_otlp_grpc_channel_options(ua)) - assert all(opt[0] != "grpc.keepalive_time_ms" for opt in kwargs["options"]) + assert all( + opt[0] != "grpc.keepalive_time_ms" for opt in kwargs["options"] + ) def test_diode_authentication_with_custom_certificates(): """Test _DiodeAuthentication with custom certificates - covers SSL context creation.""" # Create test certificate content - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) auth = _DiodeAuthentication( target="example.com:443", @@ -1100,7 +1144,9 @@ def test_diode_authentication_with_custom_certificates(): def test_load_certs_with_custom_cert_file(tmp_path): """Test _load_certs loads custom certificate file.""" # Create a dummy certificate file - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1118,7 +1164,9 @@ def test_load_certs_with_none_uses_default(): def test_client_with_cert_file_parameter(mock_diode_authentication, tmp_path): """Test DiodeClient with cert_file parameter loads custom cert but respects TLS scheme.""" # Create a dummy certificate file - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1147,7 +1195,9 @@ def test_client_with_cert_file_env_var(mock_diode_authentication, tmp_path): from netboxlabs.diode.sdk.client import _DIODE_CERT_FILE_ENVVAR_NAME # Create a dummy certificate file - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1183,13 +1233,19 @@ def test_client_with_cert_file_env_var(mock_diode_authentication, tmp_path): del os.environ[_DIODE_CERT_FILE_ENVVAR_NAME] -def test_client_cert_file_parameter_overrides_env_var(mock_diode_authentication, tmp_path): +def test_client_cert_file_parameter_overrides_env_var( + mock_diode_authentication, tmp_path +): """Test cert_file parameter takes precedence over environment variable.""" from netboxlabs.diode.sdk.client import _DIODE_CERT_FILE_ENVVAR_NAME # Create two dummy certificate files - env_cert_content = b"-----BEGIN CERTIFICATE-----\nENV CERT\n-----END CERTIFICATE-----\n" - param_cert_content = b"-----BEGIN CERTIFICATE-----\nPARAM CERT\n-----END CERTIFICATE-----\n" + env_cert_content = ( + b"-----BEGIN CERTIFICATE-----\nENV CERT\n-----END CERTIFICATE-----\n" + ) + param_cert_content = ( + b"-----BEGIN CERTIFICATE-----\nPARAM CERT\n-----END CERTIFICATE-----\n" + ) env_cert_file = tmp_path / "env.pem" param_cert_file = tmp_path / "param.pem" @@ -1231,7 +1287,9 @@ def test_client_cert_file_parameter_overrides_env_var(mock_diode_authentication, def test_client_secure_channel_uses_custom_cert(mock_diode_authentication, tmp_path): """Test secure channel creation uses custom certificate when provided.""" # Create a dummy certificate file - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1282,7 +1340,9 @@ def test_client_without_cert_file_uses_default_certs(mock_diode_authentication): mock_load_certs.assert_called_with(None) # Verify ssl_channel_credentials was called with default cert content - mock_ssl_creds.assert_called_once_with(root_certificates=b"default cert content") + mock_ssl_creds.assert_called_once_with( + root_certificates=b"default cert content" + ) # Verify secure_channel was called mock_secure_channel.assert_called_once() @@ -1318,12 +1378,16 @@ def test_should_verify_tls_with_skip_env_var(): # Test truthy values that should skip TLS verification for skip_value in ["true", "True", "TRUE", "1", "yes", "on"]: os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = skip_value - assert _should_verify_tls("grpcs") is False # Should skip even for secure schemes + assert ( + _should_verify_tls("grpcs") is False + ) # Should skip even for secure schemes # Test falsy values that should NOT skip TLS verification for verify_value in ["false", "0", "no", "off", "", "random"]: os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] = verify_value - assert _should_verify_tls("grpcs") is True # Should verify for secure schemes + assert ( + _should_verify_tls("grpcs") is True + ) # Should verify for secure schemes finally: # Clean up environment variable @@ -1368,12 +1432,16 @@ def test_client_with_skip_tls_verify_env_var(mock_diode_authentication): del os.environ[_DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME] -def test_client_cert_file_with_skip_tls_verify_env_var(mock_diode_authentication, tmp_path): +def test_client_cert_file_with_skip_tls_verify_env_var( + mock_diode_authentication, tmp_path +): """Test cert_file parameter with DIODE_SKIP_TLS_VERIFY environment variable.""" from netboxlabs.diode.sdk.client import _DIODE_SKIP_TLS_VERIFY_ENVVAR_NAME # Create a dummy certificate file - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -1414,13 +1482,17 @@ def test_client_cert_file_with_skip_tls_verify_env_var(mock_diode_authentication def test_certificate_loading_efficiency(tmp_path): """Test that certificates are loaded only once during client initialization.""" # Create a dummy certificate file - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) with ( mock.patch("netboxlabs.diode.sdk.client._load_certs") as mock_load_certs, - mock.patch("netboxlabs.diode.sdk.client._DiodeAuthentication") as mock_auth_class, + mock.patch( + "netboxlabs.diode.sdk.client._DiodeAuthentication" + ) as mock_auth_class, ): mock_load_certs.return_value = cert_content mock_auth_instance = mock_auth_class.return_value @@ -1581,8 +1653,12 @@ def test_dryrun_client_includes_metadata_in_output(tmp_path): def test_otlp_client_maps_metadata_to_resource_attributes(): """Test DiodeOTLPClient maps request metadata to OTLP resource attributes.""" with ( - patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, - patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, + patch( + "netboxlabs.diode.sdk.client.grpc.insecure_channel" + ) as mock_insecure_channel, + patch( + "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" + ) as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1625,8 +1701,12 @@ def test_otlp_client_maps_metadata_to_resource_attributes(): def test_otlp_client_handles_nested_metadata(): """Test DiodeOTLPClient handles nested metadata structures.""" with ( - patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, - patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, + patch( + "netboxlabs.diode.sdk.client.grpc.insecure_channel" + ) as mock_insecure_channel, + patch( + "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" + ) as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1687,8 +1767,12 @@ def test_otlp_client_handles_nested_metadata(): def test_otlp_client_metadata_type_conversion(): """Test DiodeOTLPClient correctly converts different Python types.""" with ( - patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, - patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, + patch( + "netboxlabs.diode.sdk.client.grpc.insecure_channel" + ) as mock_insecure_channel, + patch( + "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" + ) as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1732,8 +1816,12 @@ def test_otlp_client_metadata_type_conversion(): def test_otlp_client_without_metadata(): """Test DiodeOTLPClient works without metadata (backward compatibility).""" with ( - patch("netboxlabs.diode.sdk.client.grpc.insecure_channel") as mock_insecure_channel, - patch("netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub") as mock_stub_cls, + patch( + "netboxlabs.diode.sdk.client.grpc.insecure_channel" + ) as mock_insecure_channel, + patch( + "netboxlabs.diode.sdk.client.logs_service_pb2_grpc.LogsServiceStub" + ) as mock_stub_cls, ): mock_insecure_channel.return_value = mock.Mock() stub_instance = mock_stub_cls.return_value @@ -1905,7 +1993,9 @@ def test_diode_client_configures_proxy_option(mock_diode_authentication): options = kwargs["options"] # Check that grpc.http_proxy option is present - proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) assert proxy_option is not None assert proxy_option[1] == "http://proxy.example.com:8080" finally: @@ -1934,7 +2024,9 @@ def test_diode_client_uses_insecure_channel_with_proxy_when_skip_tls( options = kwargs["options"] # Verify proxy option is set - proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) assert proxy_option is not None assert proxy_option[1] == "http://proxy.example.com:8080" finally: @@ -1961,7 +2053,9 @@ def test_diode_client_respects_no_proxy_for_target(mock_diode_authentication): options = kwargs["options"] # Check that grpc.http_proxy option is NOT present - proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) assert proxy_option is None finally: del os.environ["HTTP_PROXY"] @@ -1970,7 +2064,9 @@ def test_diode_client_respects_no_proxy_for_target(mock_diode_authentication): def test_diode_client_with_proxy_and_custom_cert(mock_diode_authentication, tmp_path): """Test DiodeClient with proxy and custom certificate (for MITM proxies).""" - cert_content = b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + cert_content = ( + b"-----BEGIN CERTIFICATE-----\nTEST CERT\n-----END CERTIFICATE-----\n" + ) cert_file = tmp_path / "custom.pem" cert_file.write_bytes(cert_content) @@ -2002,7 +2098,9 @@ def test_diode_client_with_proxy_and_custom_cert(mock_diode_authentication, tmp_ # Verify proxy option is set _, kwargs = mock_secure_channel.call_args options = kwargs["options"] - proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) assert proxy_option is not None assert proxy_option[1] == "http://proxy.example.com:8080" finally: @@ -2204,7 +2302,9 @@ def test_diode_client_with_invalid_proxy_url_falls_back_to_no_proxy( # Verify no proxy option is set (invalid proxy was rejected) _, kwargs = mock_insecure_channel.call_args options = kwargs["options"] - proxy_option = next((opt for opt in options if opt[0] == "grpc.http_proxy"), None) + proxy_option = next( + (opt for opt in options if opt[0] == "grpc.http_proxy"), None + ) assert proxy_option is None # Should log warning about invalid proxy