Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pip install netboxlabs-diode-sdk
* `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting
* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication
* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication
* `DIODE_MAX_AUTH_RETRIES` - Maximum attempts for OAuth2 token fetch and gRPC re-authentication on `Unauthenticated` (default: `3`). Token fetch retries with exponential backoff on `429`, `500`, `502`, and `503`, honouring `Retry-After` when present on `429`/`503`.
* `DIODE_CERT_FILE` - Path to custom certificate file for TLS connections
* `DIODE_SKIP_TLS_VERIFY` - Skip TLS verification (default: `false`)
* `DIODE_DRY_RUN_OUTPUT_DIR` - Directory where `DiodeDryRunClient` will write JSON files
Expand Down
179 changes: 131 additions & 48 deletions netboxlabs/diode/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import logging
import os
import platform
import random
import sys
import tempfile
import time
import uuid
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from pathlib import Path
from typing import Any
from urllib.parse import urlparse
Expand Down Expand Up @@ -50,6 +53,8 @@
_INGEST_SCOPE = "diode:ingest"
_LOGGER = logging.getLogger(__name__)
_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
_AUTH_INITIAL_RETRY_DELAY = 1.0
_AUTH_MAX_RETRY_DELAY = 30.0
# server policy (MinTime 10s so client pings must be >= 10s, e.g. 30s interval).
_GRPC_KEEPALIVE_TIME_MS = 30_000
_GRPC_KEEPALIVE_TIMEOUT_MS = 10_000
Expand Down Expand Up @@ -541,6 +546,7 @@ def _authenticate(self, scope: str):
self._app_version,
self._certificates,
self._cert_file,
max_retries=self._max_auth_retries,
)
access_token = authentication_client.authenticate()
self._metadata = list(
Expand Down Expand Up @@ -932,6 +938,10 @@ def __init__(
app_version: str,
certificates: bytes | None = None,
cert_file: str | None = None,
max_retries: int = 3,
initial_retry_delay: float | None = None,
max_retry_delay: float | None = None,
sleep: Callable[[float], None] | None = None,
):
self._target = target
self._tls_verify = tls_verify
Expand All @@ -945,60 +955,23 @@ def __init__(
self._app_version = app_version
self._certificates = certificates
self._cert_file = cert_file
self._max_retries = max_retries
self._initial_retry_delay = (
_AUTH_INITIAL_RETRY_DELAY if initial_retry_delay is None else initial_retry_delay
)
self._max_retry_delay = (
_AUTH_MAX_RETRY_DELAY if max_retry_delay is None else max_retry_delay
)
self._sleep = sleep or time.sleep

def authenticate(self) -> str:
"""Request an OAuth2 token using client credentials and return it."""
session = requests.Session()
temp_cert_file = None

try:
# Configure SSL verification
if self._tls_verify and self._certificates:
# Use cert_file path directly if available, otherwise write to temp file
if self._cert_file:
session.verify = self._cert_file
else:
# Write certificates to temp file for requests
with tempfile.NamedTemporaryFile(
mode="wb", delete=False, suffix=".pem"
) as f:
f.write(self._certificates)
temp_cert_file = f.name
session.verify = temp_cert_file
elif not self._tls_verify:
session.verify = False

# Prepare auth request
url = self._get_full_auth_url()
data = {
"grant_type": "client_credentials",
"client_id": self._client_id,
"client_secret": self._client_secret,
"scope": self._scope,
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}",
}

response = session.post(url, data=data, headers=headers)

if response.status_code != 200:
raise DiodeConfigError(
f"Failed to obtain access token: {response.reason}"
)

token_info = response.json()
access_token = token_info.get("access_token")

if not access_token:
raise DiodeConfigError(
f"Failed to obtain access token for client {self._client_id}"
)

_LOGGER.debug(f"Access token obtained for client {self._client_id}")
return access_token

temp_cert_file = self._configure_auth_session(session)
return self._request_access_token(session)
except requests.RequestException as e:
raise DiodeConfigError(f"Failed to obtain access token: {e}")
finally:
Expand All @@ -1012,6 +985,73 @@ def authenticate(self) -> str:
f"Failed to clean up temp certificate file {temp_cert_file}: {e}"
)

def _configure_auth_session(self, session: requests.Session) -> str | None:
temp_cert_file = None
if self._tls_verify and self._certificates:
# Use cert_file path directly if available, otherwise write to temp file
if self._cert_file:
session.verify = self._cert_file
else:
# Write certificates to temp file for requests
with tempfile.NamedTemporaryFile(
mode="wb", delete=False, suffix=".pem"
) as f:
f.write(self._certificates)
temp_cert_file = f.name
session.verify = temp_cert_file
elif not self._tls_verify:
session.verify = False
return temp_cert_file

def _request_access_token(self, session: requests.Session) -> str:
url = self._get_full_auth_url()
data = {
"grant_type": "client_credentials",
"client_id": self._client_id,
"client_secret": self._client_secret,
"scope": self._scope,
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": f"{self._sdk_name}/{self._sdk_version} {self._app_name}/{self._app_version}",
}

last_error = "Failed to obtain access token"
for attempt in range(1, self._max_retries + 1):
response = session.post(url, data=data, headers=headers)

if response.status_code == 200:
access_token = response.json().get("access_token")
if not access_token:
raise DiodeConfigError(
f"Failed to obtain access token for client {self._client_id}"
)
_LOGGER.debug(f"Access token obtained for client {self._client_id}")
return access_token

last_error = f"Failed to obtain access token: {response.reason}"
if not _is_retriable_auth_http_status(response.status_code) or attempt >= self._max_retries:
raise DiodeConfigError(last_error)

delay = _auth_retry_delay(
attempt,
response.status_code,
response.headers.get("Retry-After"),
self._initial_retry_delay,
self._max_retry_delay,
)
_LOGGER.debug(
"Auth token request failed, retrying",
extra={
"status_code": response.status_code,
"attempt": attempt,
"retry_in": delay,
},
)
self._sleep(delay)

raise DiodeConfigError(last_error)

def _get_auth_url(self) -> str:
"""Construct the authentication URL, handling trailing slashes in the path."""
# Ensure the path does not have trailing slashes
Expand All @@ -1036,6 +1076,49 @@ def _get_full_auth_url(self) -> str:
return f"{scheme}://{self._target}{path}/auth/token"


def _is_retriable_auth_http_status(status_code: int) -> bool:
return status_code in {429, 500, 502, 503}


def _parse_retry_after(value: str | None) -> float | None:
if not value:
return None
try:
seconds = int(value)
except ValueError:
seconds = None
else:
if seconds < 0:
return None
return float(seconds)

try:
retry_at = parsedate_to_datetime(value)
if retry_at.tzinfo is None:
retry_at = retry_at.replace(tzinfo=timezone.utc)
delay = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
return max(delay, 0.0)
except (TypeError, ValueError, OverflowError):
return None


def _auth_retry_delay(
attempt: int,
status_code: int,
retry_after: str | None,
initial_delay: float,
max_delay: float,
) -> float:
delay: float | None = None
if status_code in (429, 503):
delay = _parse_retry_after(retry_after)
if delay is None:
delay = initial_delay * (2 ** (attempt - 1))
delay = min(delay, max_delay)
delay += random.uniform(0, delay / 4)
return min(delay, max_delay)


class _ClientCallDetails(
collections.namedtuple(
"_ClientCallDetails",
Expand Down
Loading
Loading