diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 0425f1b..f829229 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -110,24 +110,29 @@ Your Tusk Drift API key, required when using Tusk Cloud for storing and managing This will securely store your auth key for future replay sessions. -## TUSK_SAMPLING_RATE +## TUSK_RECORDING_SAMPLING_RATE -Controls what percentage of requests are recorded during trace collection. +Controls the base recording rate used during trace collection. - **Type:** Number between 0.0 and 1.0 -- **Default:** 1.0 (100% of requests) -- **Precedence:** This environment variable is overridden by the `sampling_rate` parameter in `TuskDrift.initialize()`, but takes precedence over the `sampling_rate` setting in `.tusk/config.yaml` +- **If unset:** Falls back to `.tusk/config.yaml` and then the default base rate of `1.0` +- **Precedence:** This environment variable is overridden by the `sampling_rate` parameter in `TuskDrift.initialize()`, but takes precedence over `recording.sampling.base_rate` and the legacy `recording.sampling_rate` setting in `.tusk/config.yaml` +- **Scope:** This only overrides the base rate. It does not change `recording.sampling.mode` or `recording.sampling.min_rate` **Examples:** ```bash # Record all requests (100%) -TUSK_SAMPLING_RATE=1.0 python app.py +TUSK_RECORDING_SAMPLING_RATE=1.0 python app.py # Record 10% of requests -TUSK_SAMPLING_RATE=0.1 python app.py +TUSK_RECORDING_SAMPLING_RATE=0.1 python app.py ``` +If `recording.sampling.mode: adaptive` is enabled in `.tusk/config.yaml`, this environment variable still only changes the base rate; adaptive load shedding remains active. + +`TUSK_RECORDING_SAMPLING_RATE` is the canonical variable, but `TUSK_SAMPLING_RATE` is still accepted as a backward-compatible alias. + For more details on sampling rate configuration methods and precedence, see the [Initialization Guide](./initialization.md#configure-sampling-rate). ## Rust Core Flags diff --git a/docs/initialization.md b/docs/initialization.md index f8c707e..c425f57 100644 --- a/docs/initialization.md +++ b/docs/initialization.md @@ -73,8 +73,8 @@ Create an initialization file or add the SDK initialization to your application sampling_rate float - 1.0 - Override sampling rate (0.0 - 1.0) for recording. Takes precedence over TUSK_SAMPLING_RATE env var and config file. + None + Override the base sampling rate (0.0 - 1.0) for recording. Takes precedence over TUSK_RECORDING_SAMPLING_RATE and config file base-rate settings. Does not change recording.sampling.mode. @@ -172,50 +172,91 @@ if __name__ == "__main__": ## Configure Sampling Rate -The sampling rate determines what percentage of requests are recorded during replay tests. Tusk Drift supports three ways to configure the sampling rate, with the following precedence (highest to lowest): +Sampling controls what percentage of inbound requests are recorded in `RECORD` mode. + +Tusk Drift supports two sampling modes in `.tusk/config.yaml`: + +- `fixed`: record requests at a constant base rate. +- `adaptive`: start from a base rate and automatically shed load when queue pressure, export failures, or memory pressure indicate the SDK should back off. In severe conditions the SDK can temporarily pause recording entirely. + +Sampling configuration is resolved in two layers: -1. **Init Parameter** -2. **Environment Variable** (`TUSK_SAMPLING_RATE`) -3. **Configuration File** (`.tusk/config.yaml`) +1. **Base rate precedence** (highest to lowest): + - `TuskDrift.initialize(sampling_rate=...)` + - `TUSK_RECORDING_SAMPLING_RATE` + - legacy alias `TUSK_SAMPLING_RATE` + - `.tusk/config.yaml` `recording.sampling.base_rate` + - `.tusk/config.yaml` legacy `recording.sampling_rate` + - default base rate `1.0` +2. **Mode and minimum rate**: + - `recording.sampling.mode` comes from `.tusk/config.yaml` and defaults to `fixed` + - `recording.sampling.min_rate` is only used in `adaptive` mode and defaults to `0.001` when omitted -If not specified, the default sampling rate is `1.0` (100%). +> [!NOTE] +> Requests before `sdk.mark_app_as_ready()` are always recorded. Sampling applies to normal inbound traffic after startup. -### Method 1: Init Parameter (Programmatic Override) +### Method 1: Init Parameter (Programmatic Base-Rate Override) -Set the sampling rate directly in your initialization code: +Set the base sampling rate directly in your initialization code: ```python sdk = TuskDrift.initialize( api_key=os.environ.get("TUSK_API_KEY"), - sampling_rate=0.1, # 10% of requests + sampling_rate=0.1, # Base rate: 10% of requests ) ``` ### Method 2: Environment Variable -Set the `TUSK_SAMPLING_RATE` environment variable: +Set the `TUSK_RECORDING_SAMPLING_RATE` environment variable to override the base sampling rate: ```bash # Development - record everything -TUSK_SAMPLING_RATE=1.0 python app.py +TUSK_RECORDING_SAMPLING_RATE=1.0 python app.py # Production - sample 10% of requests -TUSK_SAMPLING_RATE=0.1 python app.py +TUSK_RECORDING_SAMPLING_RATE=0.1 python app.py ``` +`TUSK_SAMPLING_RATE` is still supported as a backward-compatible alias, but new setups should prefer `TUSK_RECORDING_SAMPLING_RATE`. + ### Method 3: Configuration File -Update the configuration file `.tusk/config.yaml` to include a `recording` section: +Use the nested `recording.sampling` config to choose `fixed` vs `adaptive` mode and set the base/minimum rates. + +**Fixed sampling example:** ```yaml # ... existing configuration ... recording: - sampling_rate: 0.1 + sampling: + mode: fixed + base_rate: 0.1 export_spans: true enable_env_var_recording: true ``` +**Adaptive sampling example:** + +```yaml +# ... existing configuration ... + +recording: + sampling: + mode: adaptive + base_rate: 0.25 + min_rate: 0.01 + export_spans: true +``` + +**Legacy config still supported:** + +```yaml +recording: + sampling_rate: 0.1 +``` + ### Recording Configuration Options @@ -229,10 +270,28 @@ recording: - + + + + + + + - + + + + + + + + + + + + + diff --git a/docs/quickstart.md b/docs/quickstart.md index e25aff5..6549f09 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -2,9 +2,18 @@ Let's walk through recording and replaying your first trace: -## Step 1: Set sampling rate to 1.0 +## Step 1: Set fixed sampling to 1.0 -Set the `sampling_rate` in `.tusk/config.yaml` to 1.0 to ensure that all requests are recorded. +Update `.tusk/config.yaml` so your first recording captures every request: + +```yaml +recording: + sampling: + mode: fixed + base_rate: 1.0 +``` + +Legacy `recording.sampling_rate: 1.0` still works, but `recording.sampling` is the preferred config shape. ## Step 2: Start server in record mode diff --git a/drift/core/adaptive_sampling.py b/drift/core/adaptive_sampling.py new file mode 100644 index 0000000..ec6710b --- /dev/null +++ b/drift/core/adaptive_sampling.py @@ -0,0 +1,254 @@ +"""Adaptive sampling controller for inbound root-request admission.""" + +from __future__ import annotations + +import logging +import math +import random +import threading +import time +from dataclasses import dataclass +from typing import Literal + +logger = logging.getLogger(__name__) + +SamplingMode = Literal["fixed", "adaptive"] +AdaptiveSamplingState = Literal["fixed", "healthy", "warm", "hot", "critical_pause"] +RootSamplingDecisionReason = Literal[ + "pre_app_start", + "sampled", + "not_sampled", + "load_shed", + "critical_pause", +] + + +@dataclass +class ResolvedSamplingConfig: + mode: SamplingMode + base_rate: float + min_rate: float + + +@dataclass +class AdaptiveSamplingHealthSnapshot: + queue_fill_ratio: float | None = None + dropped_span_count: int = 0 + export_failure_count: int = 0 + export_circuit_open: bool = False + memory_pressure_ratio: float | None = None + + +@dataclass +class RootSamplingDecision: + should_record: bool + reason: RootSamplingDecisionReason + mode: SamplingMode + state: AdaptiveSamplingState + base_rate: float + min_rate: float + effective_rate: float + admission_multiplier: float + + +def _clamp(value: float, min_value: float, max_value: float) -> float: + return min(max_value, max(min_value, value)) + + +def _clamp01(value: float) -> float: + return _clamp(value, 0.0, 1.0) + + +def _normalize_between(value: float | None, zero_point: float, one_point: float) -> float: + if value is None or one_point <= zero_point: + return 0.0 + return _clamp01((value - zero_point) / (one_point - zero_point)) + + +class AdaptiveSamplingController: + def __init__( + self, + config: ResolvedSamplingConfig, + *, + random_fn=random.random, + now_fn=time.monotonic, + ) -> None: + self._config = config + self._random_fn = random_fn + self._now_fn = now_fn + self._lock = threading.RLock() + + self._admission_multiplier = 1.0 + self._state: AdaptiveSamplingState = "fixed" if config.mode == "fixed" else "healthy" + self._paused_until_s = 0.0 + self._last_updated_at_s: float | None = None + self._last_decrease_at_s = 0.0 + + self._prev_dropped_span_count = 0 + self._prev_export_failure_count = 0 + + self._queue_fill_ewma: float | None = None + self._recent_drop_signal = 0.0 + self._recent_failure_signal = 0.0 + + def update(self, snapshot: AdaptiveSamplingHealthSnapshot) -> None: + with self._lock: + if self._config.mode != "adaptive": + self._state = "fixed" + self._admission_multiplier = 1.0 + return + + now_s = self._now_fn() + elapsed_s = 2.0 if self._last_updated_at_s is None else max(0.001, now_s - self._last_updated_at_s) + self._last_updated_at_s = now_s + + decay = math.exp(-(elapsed_s * 1000.0) / 30000.0) + self._recent_drop_signal *= decay + self._recent_failure_signal *= decay + + dropped_delta = max(0, snapshot.dropped_span_count - self._prev_dropped_span_count) + export_failure_delta = max(0, snapshot.export_failure_count - self._prev_export_failure_count) + + self._prev_dropped_span_count = snapshot.dropped_span_count + self._prev_export_failure_count = snapshot.export_failure_count + + self._recent_drop_signal += dropped_delta + self._recent_failure_signal += export_failure_delta + + if snapshot.queue_fill_ratio is not None: + queue_fill_ratio = _clamp01(snapshot.queue_fill_ratio) + self._queue_fill_ewma = ( + queue_fill_ratio + if self._queue_fill_ewma is None + else (0.25 * queue_fill_ratio) + (0.75 * self._queue_fill_ewma) + ) + + queue_pressure = _normalize_between(self._queue_fill_ewma, 0.20, 0.85) + memory_pressure = _normalize_between(snapshot.memory_pressure_ratio, 0.80, 0.92) + export_failure_pressure = _clamp01(self._recent_failure_signal / 5.0) + pressure = max(queue_pressure, memory_pressure, export_failure_pressure) + + hard_brake = ( + dropped_delta > 0 or snapshot.export_circuit_open or (snapshot.memory_pressure_ratio or 0.0) >= 0.92 + ) + + previous_state = self._state + previous_multiplier = self._admission_multiplier + + if hard_brake: + self._paused_until_s = now_s + 15.0 + self._admission_multiplier = 0.0 + self._state = "critical_pause" + self._last_decrease_at_s = now_s + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + return + + if now_s < self._paused_until_s: + self._state = "critical_pause" + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + return + + min_multiplier = self._get_min_multiplier() + if pressure >= 0.70: + self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.4) + self._state = "hot" + self._last_decrease_at_s = now_s + elif pressure >= 0.45: + self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.7) + self._state = "warm" + self._last_decrease_at_s = now_s + else: + if pressure <= 0.20 and (now_s - self._last_decrease_at_s) >= 10.0: + self._admission_multiplier = min(1.0, self._admission_multiplier + 0.05) + self._state = "healthy" + + self._log_transition(previous_state, previous_multiplier, pressure, snapshot) + + def get_decision(self, *, is_pre_app_start: bool) -> RootSamplingDecision: + with self._lock: + if is_pre_app_start: + return RootSamplingDecision( + should_record=True, + reason="pre_app_start", + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=1.0, + admission_multiplier=1.0, + ) + + effective_rate = ( + self.get_effective_sampling_rate() + if self._config.mode == "adaptive" + else _clamp01(self._config.base_rate) + ) + + if effective_rate <= 0.0: + return RootSamplingDecision( + should_record=False, + reason="critical_pause" if self._state == "critical_pause" else "not_sampled", + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=effective_rate, + admission_multiplier=self._admission_multiplier, + ) + + should_record = self._random_fn() < effective_rate + return RootSamplingDecision( + should_record=should_record, + reason=( + "sampled" + if should_record + else "load_shed" + if self._config.mode == "adaptive" and effective_rate < self._config.base_rate + else "not_sampled" + ), + mode=self._config.mode, + state=self._state, + base_rate=self._config.base_rate, + min_rate=self._config.min_rate, + effective_rate=effective_rate, + admission_multiplier=self._admission_multiplier if self._config.mode == "adaptive" else 1.0, + ) + + def get_effective_sampling_rate(self) -> float: + with self._lock: + if self._config.mode != "adaptive": + return _clamp01(self._config.base_rate) + if self._state == "critical_pause" and self._now_fn() < self._paused_until_s: + return 0.0 + effective_rate = self._config.base_rate * self._admission_multiplier + return _clamp( + effective_rate, + min(self._config.base_rate, self._config.min_rate), + self._config.base_rate, + ) + + def _get_min_multiplier(self) -> float: + if self._config.base_rate <= 0.0 or self._config.min_rate <= 0.0: + return 0.0 + return _clamp01(self._config.min_rate / self._config.base_rate) + + def _log_transition( + self, + previous_state: AdaptiveSamplingState, + previous_multiplier: float, + pressure: float, + snapshot: AdaptiveSamplingHealthSnapshot, + ) -> None: + if previous_state == self._state and abs(previous_multiplier - self._admission_multiplier) < 0.05: + return + + logger.info( + "Adaptive sampling updated (state=%s, multiplier=%.2f, effective_rate=%.4f, pressure=%.2f, queue_fill=%s, memory_pressure_ratio=%s, export_circuit_open=%s).", + self._state, + self._admission_multiplier, + self.get_effective_sampling_rate(), + pressure, + f"{self._queue_fill_ewma:.2f}" if self._queue_fill_ewma is not None else "n/a", + snapshot.memory_pressure_ratio if snapshot.memory_pressure_ratio is not None else "n/a", + snapshot.export_circuit_open, + ) diff --git a/drift/core/batch_processor.py b/drift/core/batch_processor.py index 13c89e9..9338657 100644 --- a/drift/core/batch_processor.py +++ b/drift/core/batch_processor.py @@ -244,3 +244,8 @@ def queue_size(self) -> int: def dropped_span_count(self) -> int: """Get the number of dropped spans.""" return self._dropped_spans + + @property + def max_queue_size(self) -> int: + """Get the configured maximum queue size.""" + return self._config.max_queue_size diff --git a/drift/core/config.py b/drift/core/config.py index 481c07d..16df032 100644 --- a/drift/core/config.py +++ b/drift/core/config.py @@ -66,11 +66,21 @@ class ComparisonConfig: ignore_fields: list[str] = field(default_factory=list) +@dataclass +class SamplingConfig: + """Configuration for fixed vs adaptive sampling.""" + + mode: str | None = None + base_rate: float | None = None + min_rate: float | None = None + + @dataclass class RecordingConfig: """Configuration for recording behavior.""" sampling_rate: float | None = None + sampling: SamplingConfig | None = None export_spans: bool | None = None enable_env_var_recording: bool | None = None enable_analytics: bool | None = None @@ -144,8 +154,42 @@ def _parse_recording_config(data: dict[str, Any]) -> RecordingConfig: ) sampling_rate = None + sampling = None + raw_sampling = data.get("sampling") + if isinstance(raw_sampling, dict): + base_rate = raw_sampling.get("base_rate") + if base_rate is not None and not isinstance(base_rate, (int, float)): + logger.warning( + f"Invalid 'sampling.base_rate' in config: expected number, got {type(base_rate).__name__}. " + "This value will be ignored." + ) + base_rate = None + + min_rate = raw_sampling.get("min_rate") + if min_rate is not None and not isinstance(min_rate, (int, float)): + logger.warning( + f"Invalid 'sampling.min_rate' in config: expected number, got {type(min_rate).__name__}. " + "This value will be ignored." + ) + min_rate = None + + mode = raw_sampling.get("mode") + if mode is not None and not isinstance(mode, str): + logger.warning( + f"Invalid 'sampling.mode' in config: expected string, got {type(mode).__name__}. " + "This value will be ignored." + ) + mode = None + + sampling = SamplingConfig( + mode=mode, + base_rate=float(base_rate) if base_rate is not None else None, + min_rate=float(min_rate) if min_rate is not None else None, + ) + return RecordingConfig( sampling_rate=sampling_rate, + sampling=sampling, export_spans=data.get("export_spans"), enable_env_var_recording=data.get("enable_env_var_recording"), enable_analytics=data.get("enable_analytics"), diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 395f8b9..b93ac2c 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -8,6 +8,7 @@ import platform import random import stat +import threading import time from pathlib import Path from typing import TYPE_CHECKING, Any @@ -19,6 +20,13 @@ from ..instrumentation.registry import install_hooks from ..version import SDK_VERSION +from .adaptive_sampling import ( + AdaptiveSamplingController, + AdaptiveSamplingHealthSnapshot, + ResolvedSamplingConfig, + RootSamplingDecision, + SamplingMode, +) from .communication.communicator import CommunicatorConfig, ProtobufCommunicator from .communication.types import MockRequestInput, MockResponseOutput from .config import TuskConfig, TuskFileConfig, load_tusk_config @@ -48,6 +56,12 @@ def __init__(self) -> None: self.app_ready = False self._sdk_instance_id = self._generate_sdk_instance_id() self._sampling_rate: float = 1.0 + self._sampling_mode: str = "fixed" + self._min_sampling_rate: float = 0.0 + self._adaptive_sampling_controller: AdaptiveSamplingController | None = None + self._adaptive_sampling_thread: threading.Thread | None = None + self._adaptive_sampling_stop_event = threading.Event() + self._effective_memory_limit_bytes: int | None = None self._transform_configs: dict[str, Any] | None = None self._init_params: dict[str, Any] = {} @@ -121,14 +135,16 @@ def _log_startup_summary(self, env: str, use_remote_export: bool) -> None: ) logger.info( - "SDK initialized successfully (version=%s, mode=%s, env=%s, service=%s, serviceId=%s, exportSpans=%s, samplingRate=%s, logLevel=%s, runtime=python %s, platform=%s/%s).", + "SDK initialized successfully (version=%s, mode=%s, env=%s, service=%s, serviceId=%s, exportSpans=%s, samplingMode=%s, samplingBaseRate=%s, samplingMinRate=%s, logLevel=%s, runtime=python %s, platform=%s/%s).", SDK_VERSION, self.mode, env, service_name, service_id, use_remote_export, + self._sampling_mode, self._sampling_rate, + self._min_sampling_rate, get_log_level(), platform.python_version(), platform.system().lower(), @@ -168,8 +184,6 @@ def initialize( "log_level": log_level, } - instance._sampling_rate = instance._determine_sampling_rate(sampling_rate) - effective_api_key = api_key or os.environ.get("TUSK_API_KEY") if not env: @@ -183,6 +197,11 @@ def initialize( logger.debug("Already initialized, skipping...") return instance + sampling_config = instance._determine_sampling_config(sampling_rate) + instance._sampling_rate = sampling_config.base_rate + instance._sampling_mode = sampling_config.mode + instance._min_sampling_rate = sampling_config.min_rate + # Start coverage collection after the _initialized guard so repeated # initialize() calls don't stop/restart coverage and lose accumulated data. from .coverage_server import start_coverage_collection @@ -280,8 +299,6 @@ def initialize( instance._td_span_processor = TdSpanProcessor( exporter=instance.span_exporter, mode=instance.mode, - sampling_rate=instance._sampling_rate, - app_ready=instance.app_ready, environment=env, ) instance._td_span_processor.start() @@ -306,6 +323,7 @@ def initialize( install_hooks() instance._init_auto_instrumentations() + instance._start_adaptive_sampling_control_loop() # Create env vars snapshot if enabled (matches Node SDK behavior) instance.create_env_vars_snapshot() @@ -318,38 +336,222 @@ def initialize( return instance - def _determine_sampling_rate(self, init_param: float | None) -> float: - """Determine the sampling rate from various sources (precedence order).""" - # 1. Init param takes precedence + def _determine_sampling_config(self, init_param: float | None) -> ResolvedSamplingConfig: + """Determine the effective sampling config from init params, env, and file config.""" + recording_config = self.file_config.recording if self.file_config else None + config_sampling = recording_config.sampling if recording_config else None + + mode: SamplingMode = "fixed" + if config_sampling and config_sampling.mode in {"fixed", "adaptive"}: + mode = "adaptive" if config_sampling.mode == "adaptive" else "fixed" + elif config_sampling and config_sampling.mode: + logger.warning( + "Invalid sampling mode from config file: %s. Must be 'fixed' or 'adaptive'. Ignoring.", + config_sampling.mode, + ) + + base_rate: float | None = None if init_param is not None: validated = validate_sampling_rate(init_param, "init params") if validated is not None: logger.debug(f"Using sampling rate from init params: {validated}") - return validated + base_rate = validated - # 2. Environment variable - env_rate = os.environ.get("TUSK_SAMPLING_RATE") - if env_rate is not None: - try: - parsed = float(env_rate) - validated = validate_sampling_rate(parsed, "TUSK_SAMPLING_RATE env var") - if validated is not None: - logger.debug(f"Using sampling rate from env var: {validated}") - return validated - except ValueError: - logger.warning(f"Invalid TUSK_SAMPLING_RATE env var: {env_rate}") - - # 3. Config file - if self.file_config and self.file_config.recording and self.file_config.recording.sampling_rate is not None: - config_rate = self.file_config.recording.sampling_rate - validated = validate_sampling_rate(config_rate, "config file") + if base_rate is None: + for env_key in ("TUSK_RECORDING_SAMPLING_RATE", "TUSK_SAMPLING_RATE"): + env_rate = os.environ.get(env_key) + if env_rate is None: + continue + + try: + parsed = float(env_rate) + validated = validate_sampling_rate(parsed, f"{env_key} env var") + if validated is not None: + logger.debug(f"Using sampling rate from {env_key} env var: {validated}") + base_rate = validated + break + except ValueError: + logger.warning(f"Invalid {env_key} env var: {env_rate}") + + if base_rate is None and config_sampling and config_sampling.base_rate is not None: + validated = validate_sampling_rate(config_sampling.base_rate, "config file recording.sampling.base_rate") + if validated is not None: + logger.debug(f"Using sampling rate from config file recording.sampling.base_rate: {validated}") + base_rate = validated + + if base_rate is None and recording_config and recording_config.sampling_rate is not None: + validated = validate_sampling_rate(recording_config.sampling_rate, "config file recording.sampling_rate") if validated is not None: - logger.debug(f"Using sampling rate from config file: {validated}") - return validated + logger.debug(f"Using sampling rate from config file recording.sampling_rate: {validated}") + base_rate = validated + + if base_rate is None: + logger.debug("Using default sampling rate: 1.0") + base_rate = 1.0 + + min_rate = 0.0 + if mode == "adaptive": + validated_min_rate = validate_sampling_rate( + config_sampling.min_rate if config_sampling else None, + "config file recording.sampling.min_rate", + ) + min_rate = validated_min_rate if validated_min_rate is not None else 0.001 + min_rate = min(base_rate, min_rate) + + return ResolvedSamplingConfig( + mode=mode, + base_rate=base_rate, + min_rate=min_rate, + ) + + def _determine_sampling_rate(self, init_param: float | None) -> float: + """Backward-compatible helper that returns only the effective base sampling rate.""" + return self._determine_sampling_config(init_param).base_rate + + def _start_adaptive_sampling_control_loop(self) -> None: + if self.mode != TuskDriftMode.RECORD or self._sampling_mode != "adaptive": + return + + self._adaptive_sampling_controller = AdaptiveSamplingController( + ResolvedSamplingConfig( + mode="adaptive", + base_rate=self._sampling_rate, + min_rate=self._min_sampling_rate, + ) + ) + self._effective_memory_limit_bytes = self._detect_effective_memory_limit_bytes() + self._adaptive_sampling_stop_event.clear() + + self._adaptive_sampling_thread = threading.Thread( + target=self._adaptive_sampling_loop, + daemon=True, + name="drift-adaptive-sampling", + ) + self._adaptive_sampling_thread.start() + self._safe_update_adaptive_sampling_health() + + def _adaptive_sampling_loop(self) -> None: + while not self._adaptive_sampling_stop_event.wait(timeout=2.0): + self._safe_update_adaptive_sampling_health() + + def _safe_update_adaptive_sampling_health(self) -> None: + try: + self._update_adaptive_sampling_health() + except Exception: + logger.error("Adaptive sampling health update failed; keeping previous controller state.", exc_info=True) + + def _update_adaptive_sampling_health(self) -> None: + if self._adaptive_sampling_controller is None: + return + + batch_processor = self._td_span_processor._batch_processor if self._td_span_processor else None + queue_fill_ratio = None + dropped_span_count = 0 + if batch_processor is not None and batch_processor.max_queue_size > 0: + queue_fill_ratio = batch_processor.queue_size / batch_processor.max_queue_size + dropped_span_count = batch_processor.dropped_span_count + + export_failure_count = 0 + export_circuit_open = False + if self.span_exporter is not None: + for adapter in self.span_exporter.get_adapters(): + spans_failed = getattr(adapter, "spans_failed", 0) + export_failure_count += int(spans_failed) + export_circuit_open = export_circuit_open or getattr(adapter, "circuit_state", "") == "open" + + self._adaptive_sampling_controller.update( + AdaptiveSamplingHealthSnapshot( + queue_fill_ratio=queue_fill_ratio, + dropped_span_count=dropped_span_count, + export_failure_count=export_failure_count, + export_circuit_open=export_circuit_open, + memory_pressure_ratio=self._get_memory_pressure_ratio(), + ) + ) + + def _detect_effective_memory_limit_bytes(self) -> int | None: + candidates = ( + "/sys/fs/cgroup/memory.max", + "/sys/fs/cgroup/memory/memory.limit_in_bytes", + ) + for path in candidates: + parsed = self._read_numeric_control_file(path) + if parsed is None: + continue + if parsed <= 0 or parsed > 1_000_000_000_000_000: + continue + return parsed + return None + + def _get_memory_pressure_ratio(self) -> float | None: + if self._effective_memory_limit_bytes is None or self._effective_memory_limit_bytes <= 0: + return None + + cgroup_current = self._read_numeric_control_file("/sys/fs/cgroup/memory.current") + if cgroup_current is not None: + return cgroup_current / self._effective_memory_limit_bytes + + cgroup_v1_current = self._read_numeric_control_file("/sys/fs/cgroup/memory/memory.usage_in_bytes") + if cgroup_v1_current is not None: + return cgroup_v1_current / self._effective_memory_limit_bytes + + current_rss_bytes = self._read_current_rss_bytes() + if current_rss_bytes is not None: + return current_rss_bytes / self._effective_memory_limit_bytes + + return None + + @staticmethod + def _parse_proc_status_rss_bytes(raw_status: str) -> int | None: + for line in raw_status.splitlines(): + if not line.startswith("VmRSS:"): + continue + + parts = line.split() + if len(parts) < 3 or parts[2].lower() != "kb": + return None + + return int(parts[1]) * 1024 + + return None + + @staticmethod + def _parse_proc_statm_rss_bytes(raw_statm: str, page_size: int) -> int | None: + fields = raw_statm.split() + if len(fields) < 2: + return None + + return int(fields[1]) * page_size + + def _read_current_rss_bytes(self) -> int | None: + try: + proc_status_path = Path("/proc/self/status") + if proc_status_path.exists(): + parsed = self._parse_proc_status_rss_bytes(proc_status_path.read_text()) + if parsed is not None: + return parsed + except Exception: + pass + + try: + proc_statm_path = Path("/proc/self/statm") + if proc_statm_path.exists(): + return self._parse_proc_statm_rss_bytes(proc_statm_path.read_text(), int(os.sysconf("SC_PAGE_SIZE"))) + except Exception: + pass - # 4. Default - logger.debug("Using default sampling rate: 1.0") - return 1.0 + return None + + def _read_numeric_control_file(self, path: str) -> int | None: + try: + if not os.path.exists(path): + return None + raw_value = Path(path).read_text().strip() + if not raw_value or raw_value == "max": + return None + return int(raw_value) + except Exception: + return None def _detect_mode(self) -> TuskDriftMode: """Detect the SDK mode from environment variable.""" @@ -648,10 +850,6 @@ def mark_app_as_ready(self) -> None: self.app_ready = True - # Update span processor with app_ready flag - if self._td_span_processor: - self._td_span_processor.update_app_ready(True) - if self.mode == TuskDriftMode.REPLAY: logger.debug("Replay mode active - ready to serve mocked responses") elif self.mode == TuskDriftMode.RECORD: @@ -801,6 +999,34 @@ async def send_unpatched_dependency_alert( except Exception as e: logger.debug(f"Failed to send unpatched dependency alert: {e}") + def should_record_root_request(self, *, is_pre_app_start: bool) -> RootSamplingDecision: + if self._adaptive_sampling_controller is not None: + return self._adaptive_sampling_controller.get_decision(is_pre_app_start=is_pre_app_start) + + if is_pre_app_start: + return RootSamplingDecision( + should_record=True, + reason="pre_app_start", + mode="fixed", + state="fixed", + base_rate=self._sampling_rate, + min_rate=self._min_sampling_rate, + effective_rate=1.0, + admission_multiplier=1.0, + ) + + should_record = should_sample(self._sampling_rate, True) + return RootSamplingDecision( + should_record=should_record, + reason="sampled" if should_record else "not_sampled", + mode="fixed", + state="fixed", + base_rate=self._sampling_rate, + min_rate=self._min_sampling_rate, + effective_rate=self._sampling_rate, + admission_multiplier=1.0, + ) + def get_sampling_rate(self) -> float: """Get the current sampling rate.""" return self._sampling_rate @@ -835,6 +1061,11 @@ def shutdown(self) -> None: from .coverage_server import stop_coverage_collection + self._adaptive_sampling_stop_event.set() + if self._adaptive_sampling_thread is not None: + self._adaptive_sampling_thread.join(timeout=5.0) + self._adaptive_sampling_thread = None + # Shutdown OpenTelemetry tracer provider if self._td_span_processor is not None: self._td_span_processor.shutdown() diff --git a/drift/core/mode_utils.py b/drift/core/mode_utils.py index 69d9d3a..f671a26 100644 --- a/drift/core/mode_utils.py +++ b/drift/core/mode_utils.py @@ -180,18 +180,18 @@ def should_record_inbound_http_request( Returns: Tuple of (should_record, skip_reason): - should_record: True if request should be recorded - - skip_reason: If False, explains why ("dropped" or "not_sampled"), None otherwise + - skip_reason: If False, explains why ("dropped", "not_sampled", + "load_shed", or "critical_pause"), None otherwise """ if transform_engine and transform_engine.should_drop_inbound_request(method, target, headers): return False, "dropped" if not is_pre_app_start: from .drift_sdk import TuskDrift - from .sampling import should_sample sdk = TuskDrift.get_instance() - sampling_rate = sdk.get_sampling_rate() - if not should_sample(sampling_rate, is_app_ready=True): - return False, "not_sampled" + decision = sdk.should_record_root_request(is_pre_app_start=is_pre_app_start) + if not decision.should_record: + return False, decision.reason return True, None diff --git a/drift/core/no_recording.py b/drift/core/no_recording.py new file mode 100644 index 0000000..33dc256 --- /dev/null +++ b/drift/core/no_recording.py @@ -0,0 +1,22 @@ +"""Context helpers for suppressing child span creation.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar + +_recording_suppressed: ContextVar[bool] = ContextVar("td_recording_suppressed", default=False) + + +def is_recording_suppressed() -> bool: + return _recording_suppressed.get() + + +@contextmanager +def suppress_recording() -> Iterator[None]: + token = _recording_suppressed.set(True) + try: + yield + finally: + _recording_suppressed.reset(token) diff --git a/drift/core/tracing/span_utils.py b/drift/core/tracing/span_utils.py index e246833..bc41d7e 100644 --- a/drift/core/tracing/span_utils.py +++ b/drift/core/tracing/span_utils.py @@ -19,6 +19,7 @@ from opentelemetry.trace import SpanKind as OTelSpanKind from opentelemetry.trace import Status, StatusCode +from ..no_recording import is_recording_suppressed from ..types import TuskDriftMode from .td_attributes import TdSpanAttributes @@ -135,6 +136,10 @@ def create_span(options: CreateSpanOptions) -> SpanInfo | None: Returns None if span creation fails. """ try: + if is_recording_suppressed(): + logger.debug(f"[SpanUtils] Skipping span creation for '{options.name}' - recording suppressed") + return None + # Import here to avoid circular dependency from ..drift_sdk import TuskDrift diff --git a/drift/core/tracing/td_span_processor.py b/drift/core/tracing/td_span_processor.py index ef1e6e7..98adc6d 100644 --- a/drift/core/tracing/td_span_processor.py +++ b/drift/core/tracing/td_span_processor.py @@ -47,21 +47,21 @@ class TdSpanProcessor(SpanProcessor): This processor implements OpenTelemetry's SpanProcessor interface and serves as the bridge between OTel's tracing system and Drift's export infrastructure. + Root-request admission sampling happens earlier in inbound instrumentations, + so this processor only handles ended spans that were already allowed through. When a span ends: 1. Convert to CleanSpanData using otel_converter - 2. Apply sampling logic - 3. Apply trace blocking logic - 4. Validate protobuf serialization - 5. Forward to batch processor for export + 2. Apply trace blocking logic + 3. Validate protobuf serialization + 4. Handle REPLAY-mode inbound span forwarding + 5. Forward RECORD-mode spans to the batch processor for export """ def __init__( self, exporter: TdSpanExporter, mode: TuskDriftMode, - sampling_rate: float = 1.0, - app_ready: bool = False, environment: str | None = None, ) -> None: """Initialize the TdSpanProcessor. @@ -69,14 +69,10 @@ def __init__( Args: exporter: The TdSpanExporter to use for span export mode: SDK mode (RECORD, REPLAY, DISABLED) - sampling_rate: Sampling rate (0.0-1.0) - app_ready: Whether the application is ready environment: Environment name to include on spans """ self._exporter = exporter self._mode = mode - self._sampling_rate = sampling_rate - self._app_ready = app_ready self._environment = environment # We'll import and create batch processor lazily to avoid circular imports @@ -244,23 +240,3 @@ def force_flush(self, timeout_millis: int = 30000) -> bool: except Exception as e: logger.error(f"Error during force_flush: {e}") return False - - def update_app_ready(self, app_ready: bool) -> None: - """Update the app_ready flag. - - This is called when the application marks itself as ready. - - Args: - app_ready: Whether the application is ready - """ - self._app_ready = app_ready - logger.debug(f"TdSpanProcessor app_ready updated to {app_ready}") - - def update_sampling_rate(self, sampling_rate: float) -> None: - """Update the sampling rate. - - Args: - sampling_rate: New sampling rate (0.0-1.0) - """ - self._sampling_rate = sampling_rate - logger.debug(f"TdSpanProcessor sampling_rate updated to {sampling_rate}") diff --git a/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml b/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml index 14976cb..52a4ec4 100644 --- a/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/django/e2e-tests/docker-compose.yml b/drift/instrumentation/django/e2e-tests/docker-compose.yml index 801a3c7..72b9b09 100644 --- a/drift/instrumentation/django/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/django/e2e-tests/docker-compose.yml @@ -28,6 +28,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/django/middleware.py b/drift/instrumentation/django/middleware.py index 670679d..6514855 100644 --- a/drift/instrumentation/django/middleware.py +++ b/drift/instrumentation/django/middleware.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from django.http import HttpRequest, HttpResponse from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request +from ...core.no_recording import suppress_recording from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils from ...core.types import ( @@ -190,7 +191,8 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) -> ) if not should_record: logger.debug(f"[Django] Skipping request ({skip_reason}), path={path}") - return self.get_response(request) + with suppress_recording(): + return self.get_response(request) start_time_ns = time.time_ns() span_name = f"{method} {path}" diff --git a/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml b/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml index cf2e18c..9767b72 100644 --- a/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/fastapi/instrumentation.py b/drift/instrumentation/fastapi/instrumentation.py index d280544..5e1df77 100644 --- a/drift/instrumentation/fastapi/instrumentation.py +++ b/drift/instrumentation/fastapi/instrumentation.py @@ -27,6 +27,7 @@ from ...core.drift_sdk import TuskDrift from ...core.json_schema_helper import JsonSchemaHelper, SchemaMerge from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request +from ...core.no_recording import suppress_recording from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils from ...core.types import ( @@ -267,7 +268,8 @@ async def _record_request( ) if not should_record: logger.debug(f"[FastAPI] Skipping request ({skip_reason}), path={raw_path}") - return await original_call(app, scope, receive, send) + with suppress_recording(): + return await original_call(app, scope, receive, send) start_time_ns = time.time_ns() diff --git a/drift/instrumentation/flask/e2e-tests/docker-compose.yml b/drift/instrumentation/flask/e2e-tests/docker-compose.yml index 8c73754..5d21955 100644 --- a/drift/instrumentation/flask/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/flask/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml index c0d45a7..1671f4b 100644 --- a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml @@ -18,6 +18,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/httpx/e2e-tests/docker-compose.yml b/drift/instrumentation/httpx/e2e-tests/docker-compose.yml index ae57669..6471395 100644 --- a/drift/instrumentation/httpx/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/httpx/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml b/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml index 34e6e0a..95a6b6a 100644 --- a/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml @@ -38,6 +38,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml b/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml index 608fa98..0e10cc8 100644 --- a/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml @@ -38,6 +38,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/redis/e2e-tests/docker-compose.yml b/drift/instrumentation/redis/e2e-tests/docker-compose.yml index 84b269c..64772ea 100644 --- a/drift/instrumentation/redis/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/redis/e2e-tests/docker-compose.yml @@ -31,6 +31,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/instrumentation/requests/e2e-tests/docker-compose.yml b/drift/instrumentation/requests/e2e-tests/docker-compose.yml index 997da3a..0107d39 100644 --- a/drift/instrumentation/requests/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/requests/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/urllib/e2e-tests/docker-compose.yml b/drift/instrumentation/urllib/e2e-tests/docker-compose.yml index 4beb9b9..1436145 100644 --- a/drift/instrumentation/urllib/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/urllib/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml b/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml index 10b39a7..c61e044 100644 --- a/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml +++ b/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml @@ -27,6 +27,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} - USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1} - MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081} diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py index e5f4aed..4428e83 100644 --- a/drift/instrumentation/wsgi/handler.py +++ b/drift/instrumentation/wsgi/handler.py @@ -9,7 +9,7 @@ import json import logging import time -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from typing import TYPE_CHECKING, Any from opentelemetry import context as otel_context @@ -31,6 +31,7 @@ from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request +from ...core.no_recording import suppress_recording from ...core.tracing import TdSpanAttributes from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils from ...core.types import ( @@ -53,6 +54,30 @@ ) +class SuppressedResponseIterable(Iterable[bytes]): + """Keep no-record suppression active while a skipped WSGI response is consumed.""" + + def __init__(self, response: Iterable[bytes]): + self._response = response + + def __iter__(self) -> Iterator[bytes]: + with suppress_recording(): + iterator = iter(self._response) + while True: + try: + with suppress_recording(): + chunk = next(iterator) + except StopIteration: + return + yield chunk + + def close(self) -> None: + close_method = getattr(self._response, "close", None) + if close_method is not None: + with suppress_recording(): + close_method() + + def handle_wsgi_request( app: WSGIApplication, environ: WSGIEnvironment, @@ -225,7 +250,9 @@ def _create_and_handle_request( ) if not should_record: logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}") - return original_wsgi_app(app, environ, start_response) + with suppress_recording(): + response = original_wsgi_app(app, environ, start_response) + return SuppressedResponseIterable(response) # Capture request body request_body = capture_request_body(environ) diff --git a/drift/stack-tests/django-postgres/docker-compose.yml b/drift/stack-tests/django-postgres/docker-compose.yml index c20e94f..6877e01 100644 --- a/drift/stack-tests/django-postgres/docker-compose.yml +++ b/drift/stack-tests/django-postgres/docker-compose.yml @@ -39,6 +39,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/stack-tests/django-redis/docker-compose.yml b/drift/stack-tests/django-redis/docker-compose.yml index 1570b8a..e1c34db 100644 --- a/drift/stack-tests/django-redis/docker-compose.yml +++ b/drift/stack-tests/django-redis/docker-compose.yml @@ -32,6 +32,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/drift/stack-tests/fastapi-postgres/docker-compose.yml b/drift/stack-tests/fastapi-postgres/docker-compose.yml index 3497e5e..86ae38c 100644 --- a/drift/stack-tests/fastapi-postgres/docker-compose.yml +++ b/drift/stack-tests/fastapi-postgres/docker-compose.yml @@ -38,6 +38,7 @@ services: - BENCHMARKS=${BENCHMARKS:-} - BENCHMARK_DURATION=${BENCHMARK_DURATION:-10} - BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3} + - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-} - TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-} working_dir: /app volumes: diff --git a/tests/unit/test_adaptive_sampling.py b/tests/unit/test_adaptive_sampling.py new file mode 100644 index 0000000..1f4d451 --- /dev/null +++ b/tests/unit/test_adaptive_sampling.py @@ -0,0 +1,120 @@ +import math +import threading + +from drift.core.adaptive_sampling import ( + AdaptiveSamplingController, + AdaptiveSamplingHealthSnapshot, + ResolvedSamplingConfig, +) + + +def test_pre_app_start_always_records(): + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.0, min_rate=0.0), + random_fn=lambda: 0.99, + now_fn=lambda: 0.0, + ) + + decision = controller.get_decision(is_pre_app_start=True) + + assert decision.should_record is True + assert decision.reason == "pre_app_start" + assert decision.effective_rate == 1.0 + + +def test_controller_load_sheds_and_pauses_on_drops(): + now = {"value": 0.0} + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.3, + now_fn=lambda: now["value"], + ) + + controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.9)) + load_shed_decision = controller.get_decision(is_pre_app_start=False) + assert load_shed_decision.state == "hot" + assert load_shed_decision.effective_rate < 0.5 + assert load_shed_decision.should_record is False + assert load_shed_decision.reason == "load_shed" + + now["value"] = 1.0 + controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.1, dropped_span_count=1)) + paused_decision = controller.get_decision(is_pre_app_start=False) + assert paused_decision.state == "critical_pause" + assert paused_decision.should_record is False + assert paused_decision.reason == "critical_pause" + + +def test_elapsed_time_uses_zero_timestamp_as_real_value(): + now = {"value": 0.0} + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: now["value"], + ) + + controller.update(AdaptiveSamplingHealthSnapshot(export_failure_count=1)) + now["value"] = 0.5 + controller.update(AdaptiveSamplingHealthSnapshot(export_failure_count=1)) + + expected_decay = math.exp(-(0.5 * 1000.0) / 30000.0) + assert math.isclose(controller._recent_failure_signal, expected_decay, rel_tol=1e-6) + + +def test_get_decision_waits_for_controller_lock(): + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: 0.0, + ) + started = threading.Event() + finished = threading.Event() + result = {} + + def worker() -> None: + started.set() + result["decision"] = controller.get_decision(is_pre_app_start=False) + finished.set() + + thread = threading.Thread(target=worker) + controller._lock.acquire() + try: + thread.start() + assert started.wait(timeout=1.0) + assert not finished.wait(timeout=0.05) + finally: + controller._lock.release() + + assert finished.wait(timeout=1.0) + thread.join(timeout=1.0) + assert not thread.is_alive() + assert result["decision"].effective_rate == 0.5 + + +def test_update_waits_for_controller_lock(): + controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: 0.0, + ) + started = threading.Event() + finished = threading.Event() + + def worker() -> None: + started.set() + controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.9)) + finished.set() + + thread = threading.Thread(target=worker) + controller._lock.acquire() + try: + thread.start() + assert started.wait(timeout=1.0) + assert not finished.wait(timeout=0.05) + finally: + controller._lock.release() + + assert finished.wait(timeout=1.0) + thread.join(timeout=1.0) + assert not thread.is_alive() + assert controller.get_decision(is_pre_app_start=False).state == "hot" diff --git a/tests/unit/test_config_loading.py b/tests/unit/test_config_loading.py index 9f1570f..8d6d1b5 100644 --- a/tests/unit/test_config_loading.py +++ b/tests/unit/test_config_loading.py @@ -214,6 +214,38 @@ def test_handles_partial_config(self): finally: os.chdir(original_cwd) + def test_loads_nested_sampling_config(self): + """Should load recording.sampling config alongside legacy fields.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + (project_root / "pyproject.toml").touch() + + tusk_dir = project_root / ".tusk" + tusk_dir.mkdir() + (tusk_dir / "config.yaml").write_text( + """ +recording: + sampling: + mode: adaptive + base_rate: 0.25 + min_rate: 0.05 +""" + ) + + original_cwd = os.getcwd() + try: + os.chdir(project_root) + config = load_tusk_config() + + assert config is not None + assert config.recording is not None + assert config.recording.sampling is not None + assert config.recording.sampling.mode == "adaptive" + assert config.recording.sampling.base_rate == 0.25 + assert config.recording.sampling.min_rate == 0.05 + finally: + os.chdir(original_cwd) + def test_handles_invalid_yaml(self): """Should return None when YAML is invalid.""" with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/unit/test_drift_sdk.py b/tests/unit/test_drift_sdk.py index 954ae31..adb16d0 100644 --- a/tests/unit/test_drift_sdk.py +++ b/tests/unit/test_drift_sdk.py @@ -6,6 +6,8 @@ import pytest +from drift.core.adaptive_sampling import AdaptiveSamplingController, ResolvedSamplingConfig +from drift.core.config import RecordingConfig, SamplingConfig, TuskFileConfig from drift.core.drift_sdk import TuskDrift from drift.core.types import TuskDriftMode @@ -19,7 +21,13 @@ def reset_singleton(self): TuskDrift._instance = None TuskDrift._initialized = False # Clear environment variables - env_vars = ["TUSK_DRIFT_MODE", "TUSK_API_KEY", "TUSK_SAMPLING_RATE", "ENV"] + env_vars = [ + "TUSK_DRIFT_MODE", + "TUSK_API_KEY", + "TUSK_RECORDING_SAMPLING_RATE", + "TUSK_SAMPLING_RATE", + "ENV", + ] original_env = {k: os.environ.get(k) for k in env_vars} for var in env_vars: if var in os.environ: @@ -120,9 +128,10 @@ def reset_singleton(self): """Reset singleton state before each test.""" TuskDrift._instance = None TuskDrift._initialized = False - # Clear sampling rate env var - if "TUSK_SAMPLING_RATE" in os.environ: - del os.environ["TUSK_SAMPLING_RATE"] + # Clear sampling rate env vars + for env_var in ("TUSK_RECORDING_SAMPLING_RATE", "TUSK_SAMPLING_RATE"): + if env_var in os.environ: + del os.environ[env_var] yield TuskDrift._instance = None TuskDrift._initialized = False @@ -136,10 +145,31 @@ def test_uses_init_param_sampling_rate(self, reset_singleton): assert result == 0.5 - def test_uses_env_var_sampling_rate(self, reset_singleton): - """Should use sampling rate from env var if init param not provided.""" + def test_uses_recording_env_var_sampling_rate(self, reset_singleton): + """Should use the canonical recording env var if init param not provided.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(None) + + assert result == 0.25 + + def test_uses_legacy_sampling_env_var_as_alias(self, reset_singleton): + """Should fall back to the legacy env var for backward compatibility.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_SAMPLING_RATE"] = "0.2" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(None) + + assert result == 0.2 + + def test_recording_env_var_takes_precedence_over_legacy_alias(self, reset_singleton): + """Should prefer the canonical env var when both env vars are set.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_SAMPLING_RATE"] = "0.1" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(None) @@ -149,13 +179,37 @@ def test_uses_env_var_sampling_rate(self, reset_singleton): def test_init_param_takes_precedence_over_env_var(self, reset_singleton): """Should prefer init param over env var.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "0.25" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(0.75) assert result == 0.75 + def test_invalid_init_param_falls_back_to_env_var(self, reset_singleton): + """Should use env var when init param is present but invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(2.0) + + assert result == 0.25 + + def test_invalid_init_param_falls_back_to_config_file(self, reset_singleton): + """Should use config file when init param is present but invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + instance = TuskDrift.get_instance() + instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(base_rate=0.4), + ) + ) + + result = instance._determine_sampling_rate(2.0) + + assert result == 0.4 + def test_defaults_to_1_0(self, reset_singleton): """Should default to 1.0 (100%) sampling rate.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" @@ -165,16 +219,42 @@ def test_defaults_to_1_0(self, reset_singleton): assert result == 1.0 - def test_rejects_invalid_env_var_sampling_rate(self, reset_singleton): - """Should reject invalid env var and use default.""" + def test_rejects_invalid_recording_env_var_sampling_rate(self, reset_singleton): + """Should reject an invalid canonical env var and use default.""" os.environ["TUSK_DRIFT_MODE"] = "DISABLED" - os.environ["TUSK_SAMPLING_RATE"] = "invalid" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid" instance = TuskDrift.get_instance() result = instance._determine_sampling_rate(None) assert result == 1.0 + def test_invalid_env_var_falls_back_to_config_file(self, reset_singleton): + """Should use config file when env var is present but invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid" + instance = TuskDrift.get_instance() + instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(base_rate=0.4), + ) + ) + + result = instance._determine_sampling_rate(None) + + assert result == 0.4 + + def test_invalid_recording_env_var_falls_back_to_legacy_alias(self, reset_singleton): + """Should use the legacy alias when the canonical env var is invalid.""" + os.environ["TUSK_DRIFT_MODE"] = "DISABLED" + os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid" + os.environ["TUSK_SAMPLING_RATE"] = "0.4" + instance = TuskDrift.get_instance() + + result = instance._determine_sampling_rate(None) + + assert result == 0.4 + class TestTuskDriftInitialize: """Tests for TuskDrift.initialize method.""" @@ -185,7 +265,13 @@ def reset_singleton(self): TuskDrift._instance = None TuskDrift._initialized = False # Clear environment variables - env_vars = ["TUSK_DRIFT_MODE", "TUSK_API_KEY", "TUSK_SAMPLING_RATE", "ENV"] + env_vars = [ + "TUSK_DRIFT_MODE", + "TUSK_API_KEY", + "TUSK_RECORDING_SAMPLING_RATE", + "TUSK_SAMPLING_RATE", + "ENV", + ] for var in env_vars: if var in os.environ: del os.environ[var] @@ -252,6 +338,47 @@ def test_idempotent_initialization(self, reset_singleton, mocker): # TracerProvider should only be created once assert mock_provider.call_count == 1 + def test_second_initialize_does_not_mutate_live_adaptive_sampling_state(self, reset_singleton, mocker): + """Should keep sampling fields aligned with the live controller on repeated initialize calls.""" + mocker.patch("drift.core.drift_sdk.install_hooks") + mocker.patch("drift.core.drift_sdk.atexit") + mocker.patch("drift.core.drift_sdk.TracerProvider") + mocker.patch("drift.core.drift_sdk.trace") + mocker.patch.object(TuskDrift, "_start_adaptive_sampling_control_loop") + os.environ["TUSK_DRIFT_MODE"] = "RECORD" + + instance = TuskDrift.get_instance() + instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + ) + ) + + initialized_instance = TuskDrift.initialize(env="test") + initialized_instance._adaptive_sampling_controller = AdaptiveSamplingController( + ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1), + random_fn=lambda: 0.0, + now_fn=lambda: 0.0, + ) + + initialized_instance.file_config = TuskFileConfig( + recording=RecordingConfig( + sampling=SamplingConfig(mode="fixed", base_rate=0.2, min_rate=None), + ) + ) + + second_instance = TuskDrift.initialize(env="test", sampling_rate=0.9) + + assert second_instance is initialized_instance + assert second_instance._sampling_rate == 0.5 + assert second_instance._sampling_mode == "adaptive" + assert second_instance._min_sampling_rate == 0.1 + + decision = second_instance.should_record_root_request(is_pre_app_start=False) + assert decision.mode == "adaptive" + assert decision.base_rate == 0.5 + assert decision.min_rate == 0.1 + class TestTuskDriftMarkAppAsReady: """Tests for TuskDrift.mark_app_as_ready method.""" @@ -412,6 +539,105 @@ def test_shutdown_cleans_up_resources(self, reset_singleton, mocker): mock_tracer_provider.shutdown.assert_called_once() +class TestTuskDriftAdaptiveSampling: + """Tests for adaptive sampling health monitoring.""" + + @pytest.fixture(autouse=True) + def reset_singleton(self): + """Reset singleton state before each test.""" + TuskDrift._instance = None + TuskDrift._initialized = False + yield + TuskDrift._instance = None + TuskDrift._initialized = False + + def test_safe_update_logs_and_swallows_health_update_exceptions(self, reset_singleton, mocker): + """Should log and continue when health updates fail.""" + instance = TuskDrift.get_instance() + mocker.patch.object(instance, "_update_adaptive_sampling_health", side_effect=RuntimeError("boom")) + log_error = mocker.patch("drift.core.drift_sdk.logger.error") + + instance._safe_update_adaptive_sampling_health() + + log_error.assert_called_once() + assert "Adaptive sampling health update failed" in log_error.call_args.args[0] + + def test_adaptive_sampling_loop_continues_after_update_exception(self, reset_singleton, mocker): + """Should keep polling after a single health update failure.""" + instance = TuskDrift.get_instance() + stop_event = mocker.MagicMock() + stop_event.wait.side_effect = [False, False, True] + instance._adaptive_sampling_stop_event = stop_event + log_error = mocker.patch("drift.core.drift_sdk.logger.error") + + update_health = mocker.patch.object( + instance, + "_update_adaptive_sampling_health", + side_effect=[RuntimeError("boom"), None], + ) + + instance._adaptive_sampling_loop() + + assert update_health.call_count == 2 + log_error.assert_called_once() + assert "Adaptive sampling health update failed" in log_error.call_args.args[0] + + +class TestTuskDriftMemoryPressure: + """Tests for memory pressure measurement helpers.""" + + @pytest.fixture(autouse=True) + def reset_singleton(self): + """Reset singleton state before each test.""" + TuskDrift._instance = None + TuskDrift._initialized = False + yield + TuskDrift._instance = None + TuskDrift._initialized = False + + def test_parse_proc_status_rss_bytes(self, reset_singleton): + """Should parse current RSS from /proc/self/status.""" + raw_status = "Name:\tpython\nVmRSS:\t1234 kB\nThreads:\t8\n" + + assert TuskDrift._parse_proc_status_rss_bytes(raw_status) == 1234 * 1024 + + def test_read_current_rss_bytes_falls_back_to_proc_statm(self, reset_singleton, mocker): + """Should use /proc/self/statm when /proc/self/status is unavailable.""" + instance = TuskDrift.get_instance() + + mocker.patch( + "drift.core.drift_sdk.Path.exists", + autospec=True, + side_effect=lambda path: str(path) == "/proc/self/statm", + ) + mocker.patch( + "drift.core.drift_sdk.Path.read_text", + autospec=True, + side_effect=lambda path: "100 25 0 0 0 0 0\n" if str(path) == "/proc/self/statm" else "", + ) + mocker.patch("drift.core.drift_sdk.os.sysconf", return_value=4096) + + assert instance._read_current_rss_bytes() == 25 * 4096 + + def test_get_memory_pressure_ratio_uses_current_rss_fallback(self, reset_singleton, mocker): + """Should use current RSS fallback when cgroup current usage is unavailable.""" + instance = TuskDrift.get_instance() + instance._effective_memory_limit_bytes = 1024 + mocker.patch.object(instance, "_read_numeric_control_file", return_value=None) + mocker.patch.object(instance, "_read_current_rss_bytes", return_value=256) + + assert instance._get_memory_pressure_ratio() == 0.25 + + def test_get_memory_pressure_ratio_returns_none_without_current_measurement(self, reset_singleton, mocker): + """Should return None when no current memory measurement is available.""" + instance = TuskDrift.get_instance() + instance._effective_memory_limit_bytes = 1024 + mocker.patch.object(instance, "_read_numeric_control_file", return_value=None) + mocker.patch.object(instance, "_read_current_rss_bytes", return_value=None) + + assert instance._get_memory_pressure_ratio() is None + + class TestTuskDriftGetTracer: """Tests for TuskDrift.get_tracer method.""" diff --git a/tests/unit/test_mode_utils.py b/tests/unit/test_mode_utils.py index c60cb0f..3f9b001 100644 --- a/tests/unit/test_mode_utils.py +++ b/tests/unit/test_mode_utils.py @@ -311,11 +311,9 @@ def test_returns_true_when_no_drop_and_sampled(self, mocker): """Should return (True, None) when not dropped and sampled.""" mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") mock_sdk = mocker.MagicMock() - mock_sdk.get_sampling_rate.return_value = 1.0 + mock_sdk.should_record_root_request.return_value.should_record = True mock_drift.get_instance.return_value = mock_sdk - mocker.patch("drift.core.sampling.should_sample", return_value=True) - result, reason = should_record_inbound_http_request( method="GET", target="/api/users", @@ -361,11 +359,10 @@ def test_returns_false_when_not_sampled(self, mocker): """Should return (False, 'not_sampled') when sampling decides to skip.""" mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") mock_sdk = mocker.MagicMock() - mock_sdk.get_sampling_rate.return_value = 0.0 + mock_sdk.should_record_root_request.return_value.should_record = False + mock_sdk.should_record_root_request.return_value.reason = "not_sampled" mock_drift.get_instance.return_value = mock_sdk - mocker.patch("drift.core.sampling.should_sample", return_value=False) - result, reason = should_record_inbound_http_request( method="GET", target="/api/users", @@ -377,12 +374,33 @@ def test_returns_false_when_not_sampled(self, mocker): assert result is False assert reason == "not_sampled" + def test_returns_controller_reason_when_adaptive_sampling_skips(self, mocker): + """Should preserve adaptive controller reasons for debug logging.""" + mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") + mock_sdk = mocker.MagicMock() + mock_sdk.should_record_root_request.return_value.should_record = False + mock_sdk.should_record_root_request.return_value.reason = "critical_pause" + mock_drift.get_instance.return_value = mock_sdk + + result, reason = should_record_inbound_http_request( + method="GET", + target="/api/users", + headers={}, + transform_engine=None, + is_pre_app_start=False, + ) + + assert result is False + assert reason == "critical_pause" + def test_drop_check_happens_before_sampling(self, mocker): """Should check drop rules before sampling.""" mock_transform = mocker.MagicMock() mock_transform.should_drop_inbound_request.return_value = True - mock_sample = mocker.patch("drift.core.sampling.should_sample") + mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") + mock_sdk = mocker.MagicMock() + mock_drift.get_instance.return_value = mock_sdk result, reason = should_record_inbound_http_request( method="GET", @@ -392,7 +410,6 @@ def test_drop_check_happens_before_sampling(self, mocker): is_pre_app_start=False, ) - # should_sample should never be called if dropped - mock_sample.assert_not_called() + mock_sdk.should_record_root_request.assert_not_called() assert result is False assert reason == "dropped" diff --git a/tests/unit/test_sampling.py b/tests/unit/test_sampling.py index f6ee591..f166ec9 100644 --- a/tests/unit/test_sampling.py +++ b/tests/unit/test_sampling.py @@ -90,7 +90,7 @@ def test_rate_above_one_returns_none(self): def test_custom_source_in_warning(self): """Should include custom source in warning message.""" # Just verify it doesn't raise with custom source - result = validate_sampling_rate(-0.5, source="env var TUSK_SAMPLING_RATE") + result = validate_sampling_rate(-0.5, source="env var TUSK_RECORDING_SAMPLING_RATE") assert result is None def test_converts_to_float(self): diff --git a/tests/unit/test_span_utils.py b/tests/unit/test_span_utils.py index c23d0b2..1e1bf36 100644 --- a/tests/unit/test_span_utils.py +++ b/tests/unit/test_span_utils.py @@ -7,6 +7,7 @@ from opentelemetry.trace import SpanKind as OTelSpanKind from opentelemetry.trace import Status, StatusCode +from drift.core.no_recording import suppress_recording from drift.core.tracing.span_utils import ( AddSpanAttributesOptions, CreateSpanOptions, @@ -205,6 +206,23 @@ def test_returns_none_on_exception(self, mocker): assert result is None + def test_returns_none_when_recording_is_suppressed(self, mocker): + """Should not create spans when no-record context is active.""" + mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift") + mock_sdk = mocker.MagicMock() + mock_sdk.get_tracer.return_value = mocker.MagicMock() + mock_drift.get_instance.return_value = mock_sdk + + options = CreateSpanOptions( + name="test-span", + kind=OTelSpanKind.SERVER, + ) + + with suppress_recording(): + result = SpanUtils.create_span(options) + + assert result is None + class TestSpanUtilsWithSpan: """Tests for SpanUtils.with_span context manager.""" diff --git a/tests/unit/test_td_span_processor.py b/tests/unit/test_td_span_processor.py index eae2f61..5a34c3c 100644 --- a/tests/unit/test_td_span_processor.py +++ b/tests/unit/test_td_span_processor.py @@ -23,8 +23,7 @@ def test_initializes_with_required_params(self, mocker): assert processor._exporter is mock_exporter assert processor._mode == TuskDriftMode.RECORD - assert processor._sampling_rate == 1.0 - assert processor._app_ready is False + assert processor._environment is None assert processor._started is False def test_initializes_with_optional_params(self, mocker): @@ -34,14 +33,10 @@ def test_initializes_with_optional_params(self, mocker): processor = TdSpanProcessor( exporter=mock_exporter, mode=TuskDriftMode.REPLAY, - sampling_rate=0.5, - app_ready=True, environment="production", ) assert processor._mode == TuskDriftMode.REPLAY - assert processor._sampling_rate == 0.5 - assert processor._app_ready is True assert processor._environment == "production" @@ -392,35 +387,3 @@ def test_force_flush_handles_exception(self, mocker): result = processor.force_flush() assert result is False - - -class TestTdSpanProcessorUpdateMethods: - """Tests for TdSpanProcessor update methods.""" - - def test_update_app_ready(self, mocker): - """Should update app_ready flag.""" - mock_exporter = mocker.MagicMock() - processor = TdSpanProcessor( - exporter=mock_exporter, - mode=TuskDriftMode.RECORD, - ) - - assert processor._app_ready is False - - processor.update_app_ready(True) - - assert processor._app_ready is True - - def test_update_sampling_rate(self, mocker): - """Should update sampling rate.""" - mock_exporter = mocker.MagicMock() - processor = TdSpanProcessor( - exporter=mock_exporter, - mode=TuskDriftMode.RECORD, - ) - - assert processor._sampling_rate == 1.0 - - processor.update_sampling_rate(0.5) - - assert processor._sampling_rate == 0.5 diff --git a/tests/unit/test_wsgi_handler.py b/tests/unit/test_wsgi_handler.py new file mode 100644 index 0000000..0850852 --- /dev/null +++ b/tests/unit/test_wsgi_handler.py @@ -0,0 +1,83 @@ +"""Tests for WSGI handler request lifecycle behavior.""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator +from typing import Any + +from drift.instrumentation.wsgi.handler import _create_and_handle_request + + +class StreamingResponse(Iterable[bytes]): + def __init__(self, observed: list[tuple[str, bool]]) -> None: + self._observed = observed + self._yielded = False + + def __iter__(self) -> Iterator[bytes]: + from drift.core.no_recording import is_recording_suppressed + + self._observed.append(("iter", is_recording_suppressed())) + return self + + def __next__(self) -> bytes: + from drift.core.no_recording import is_recording_suppressed + + self._observed.append(("next", is_recording_suppressed())) + if self._yielded: + raise StopIteration + self._yielded = True + return b"chunk" + + def close(self) -> None: + from drift.core.no_recording import is_recording_suppressed + + self._observed.append(("close", is_recording_suppressed())) + + +def test_skipped_wsgi_request_keeps_suppression_during_streaming_iteration_and_close(mocker) -> None: + observed: list[tuple[str, bool]] = [] + response = StreamingResponse(observed) + + mocker.patch( + "drift.instrumentation.wsgi.handler.should_record_inbound_http_request", + return_value=(False, "not_sampled"), + ) + + def original_wsgi_app(_app: Any, _environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: + from drift.core.no_recording import is_recording_suppressed + + observed.append(("call", is_recording_suppressed())) + return response + + def app(_environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]: + return response + + wrapped_response = _create_and_handle_request( + app=app, + environ={ + "REQUEST_METHOD": "GET", + "PATH_INFO": "/stream", + "QUERY_STRING": "", + }, + start_response=lambda status, headers, exc_info=None: None, + original_wsgi_app=original_wsgi_app, + framework_name="wsgi", + instrumentation_name="WsgiInstrumentation", + transform_engine=None, + sdk=object(), + is_pre_app_start=False, + replay_token=None, + ) + + assert list(wrapped_response) == [b"chunk"] + close_method = getattr(wrapped_response, "close", None) + assert close_method is not None + close_method() + + assert observed == [ + ("call", True), + ("iter", True), + ("next", True), + ("next", True), + ("close", True), + ]
sampling_ratesampling.mode"fixed" | "adaptive""fixed"Selects constant sampling or adaptive load shedding.
sampling.base_rate float 1.0The sampling rate (0.0 - 1.0). 1.0 means 100% of requests are recorded, 0.0 means 0% of requests are recorded.The base sampling rate (0.0 - 1.0). This is the preferred config key and can be overridden by TUSK_RECORDING_SAMPLING_RATE or the sampling_rate init parameter.
sampling.min_ratefloat0.001 in adaptive modeThe minimum steady-state sampling rate for adaptive mode. In critical conditions the SDK can still temporarily pause recording.
sampling_ratefloatNoneLegacy fallback for the base sampling rate. Still supported for backward compatibility, but recording.sampling.base_rate is preferred.
export_spans