diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 0425f1b..f829229 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -110,24 +110,29 @@ Your Tusk Drift API key, required when using Tusk Cloud for storing and managing This will securely store your auth key for future replay sessions. -## TUSK_SAMPLING_RATE +## TUSK_RECORDING_SAMPLING_RATE -Controls what percentage of requests are recorded during trace collection. +Controls the base recording rate used during trace collection. - **Type:** Number between 0.0 and 1.0 -- **Default:** 1.0 (100% of requests) -- **Precedence:** This environment variable is overridden by the `sampling_rate` parameter in `TuskDrift.initialize()`, but takes precedence over the `sampling_rate` setting in `.tusk/config.yaml` +- **If unset:** Falls back to `.tusk/config.yaml` and then the default base rate of `1.0` +- **Precedence:** This environment variable is overridden by the `sampling_rate` parameter in `TuskDrift.initialize()`, but takes precedence over `recording.sampling.base_rate` and the legacy `recording.sampling_rate` setting in `.tusk/config.yaml` +- **Scope:** This only overrides the base rate. It does not change `recording.sampling.mode` or `recording.sampling.min_rate` **Examples:** ```bash # Record all requests (100%) -TUSK_SAMPLING_RATE=1.0 python app.py +TUSK_RECORDING_SAMPLING_RATE=1.0 python app.py # Record 10% of requests -TUSK_SAMPLING_RATE=0.1 python app.py +TUSK_RECORDING_SAMPLING_RATE=0.1 python app.py ``` +If `recording.sampling.mode: adaptive` is enabled in `.tusk/config.yaml`, this environment variable still only changes the base rate; adaptive load shedding remains active. + +`TUSK_RECORDING_SAMPLING_RATE` is the canonical variable, but `TUSK_SAMPLING_RATE` is still accepted as a backward-compatible alias. + For more details on sampling rate configuration methods and precedence, see the [Initialization Guide](./initialization.md#configure-sampling-rate). ## Rust Core Flags diff --git a/docs/initialization.md b/docs/initialization.md index f8c707e..c425f57 100644 --- a/docs/initialization.md +++ b/docs/initialization.md @@ -73,8 +73,8 @@ Create an initialization file or add the SDK initialization to your application
sampling_ratefloat1.0TUSK_SAMPLING_RATE env var and config file.NoneTUSK_RECORDING_SAMPLING_RATE and config file base-rate settings. Does not change recording.sampling.mode.sampling_rate |
+ sampling.mode |
+ "fixed" | "adaptive" |
+ "fixed" |
+ Selects constant sampling or adaptive load shedding. | +
sampling.base_rate |
float |
1.0 |
- The sampling rate (0.0 - 1.0). 1.0 means 100% of requests are recorded, 0.0 means 0% of requests are recorded. | +The base sampling rate (0.0 - 1.0). This is the preferred config key and can be overridden by TUSK_RECORDING_SAMPLING_RATE or the sampling_rate init parameter. |
+
sampling.min_rate |
+ float |
+ 0.001 in adaptive mode |
+ The minimum steady-state sampling rate for adaptive mode. In critical conditions the SDK can still temporarily pause recording. | +|
sampling_rate |
+ float |
+ None |
+ Legacy fallback for the base sampling rate. Still supported for backward compatibility, but recording.sampling.base_rate is preferred. |
|
export_spans |
diff --git a/docs/quickstart.md b/docs/quickstart.md
index e25aff5..6549f09 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -2,9 +2,18 @@
Let's walk through recording and replaying your first trace:
-## Step 1: Set sampling rate to 1.0
+## Step 1: Set fixed sampling to 1.0
-Set the `sampling_rate` in `.tusk/config.yaml` to 1.0 to ensure that all requests are recorded.
+Update `.tusk/config.yaml` so your first recording captures every request:
+
+```yaml
+recording:
+ sampling:
+ mode: fixed
+ base_rate: 1.0
+```
+
+Legacy `recording.sampling_rate: 1.0` still works, but `recording.sampling` is the preferred config shape.
## Step 2: Start server in record mode
diff --git a/drift/core/adaptive_sampling.py b/drift/core/adaptive_sampling.py
new file mode 100644
index 0000000..ec6710b
--- /dev/null
+++ b/drift/core/adaptive_sampling.py
@@ -0,0 +1,254 @@
+"""Adaptive sampling controller for inbound root-request admission."""
+
+from __future__ import annotations
+
+import logging
+import math
+import random
+import threading
+import time
+from dataclasses import dataclass
+from typing import Literal
+
+logger = logging.getLogger(__name__)
+
+SamplingMode = Literal["fixed", "adaptive"]
+AdaptiveSamplingState = Literal["fixed", "healthy", "warm", "hot", "critical_pause"]
+RootSamplingDecisionReason = Literal[
+ "pre_app_start",
+ "sampled",
+ "not_sampled",
+ "load_shed",
+ "critical_pause",
+]
+
+
+@dataclass
+class ResolvedSamplingConfig:
+ mode: SamplingMode
+ base_rate: float
+ min_rate: float
+
+
+@dataclass
+class AdaptiveSamplingHealthSnapshot:
+ queue_fill_ratio: float | None = None
+ dropped_span_count: int = 0
+ export_failure_count: int = 0
+ export_circuit_open: bool = False
+ memory_pressure_ratio: float | None = None
+
+
+@dataclass
+class RootSamplingDecision:
+ should_record: bool
+ reason: RootSamplingDecisionReason
+ mode: SamplingMode
+ state: AdaptiveSamplingState
+ base_rate: float
+ min_rate: float
+ effective_rate: float
+ admission_multiplier: float
+
+
+def _clamp(value: float, min_value: float, max_value: float) -> float:
+ return min(max_value, max(min_value, value))
+
+
+def _clamp01(value: float) -> float:
+ return _clamp(value, 0.0, 1.0)
+
+
+def _normalize_between(value: float | None, zero_point: float, one_point: float) -> float:
+ if value is None or one_point <= zero_point:
+ return 0.0
+ return _clamp01((value - zero_point) / (one_point - zero_point))
+
+
+class AdaptiveSamplingController:
+ def __init__(
+ self,
+ config: ResolvedSamplingConfig,
+ *,
+ random_fn=random.random,
+ now_fn=time.monotonic,
+ ) -> None:
+ self._config = config
+ self._random_fn = random_fn
+ self._now_fn = now_fn
+ self._lock = threading.RLock()
+
+ self._admission_multiplier = 1.0
+ self._state: AdaptiveSamplingState = "fixed" if config.mode == "fixed" else "healthy"
+ self._paused_until_s = 0.0
+ self._last_updated_at_s: float | None = None
+ self._last_decrease_at_s = 0.0
+
+ self._prev_dropped_span_count = 0
+ self._prev_export_failure_count = 0
+
+ self._queue_fill_ewma: float | None = None
+ self._recent_drop_signal = 0.0
+ self._recent_failure_signal = 0.0
+
+ def update(self, snapshot: AdaptiveSamplingHealthSnapshot) -> None:
+ with self._lock:
+ if self._config.mode != "adaptive":
+ self._state = "fixed"
+ self._admission_multiplier = 1.0
+ return
+
+ now_s = self._now_fn()
+ elapsed_s = 2.0 if self._last_updated_at_s is None else max(0.001, now_s - self._last_updated_at_s)
+ self._last_updated_at_s = now_s
+
+ decay = math.exp(-(elapsed_s * 1000.0) / 30000.0)
+ self._recent_drop_signal *= decay
+ self._recent_failure_signal *= decay
+
+ dropped_delta = max(0, snapshot.dropped_span_count - self._prev_dropped_span_count)
+ export_failure_delta = max(0, snapshot.export_failure_count - self._prev_export_failure_count)
+
+ self._prev_dropped_span_count = snapshot.dropped_span_count
+ self._prev_export_failure_count = snapshot.export_failure_count
+
+ self._recent_drop_signal += dropped_delta
+ self._recent_failure_signal += export_failure_delta
+
+ if snapshot.queue_fill_ratio is not None:
+ queue_fill_ratio = _clamp01(snapshot.queue_fill_ratio)
+ self._queue_fill_ewma = (
+ queue_fill_ratio
+ if self._queue_fill_ewma is None
+ else (0.25 * queue_fill_ratio) + (0.75 * self._queue_fill_ewma)
+ )
+
+ queue_pressure = _normalize_between(self._queue_fill_ewma, 0.20, 0.85)
+ memory_pressure = _normalize_between(snapshot.memory_pressure_ratio, 0.80, 0.92)
+ export_failure_pressure = _clamp01(self._recent_failure_signal / 5.0)
+ pressure = max(queue_pressure, memory_pressure, export_failure_pressure)
+
+ hard_brake = (
+ dropped_delta > 0 or snapshot.export_circuit_open or (snapshot.memory_pressure_ratio or 0.0) >= 0.92
+ )
+
+ previous_state = self._state
+ previous_multiplier = self._admission_multiplier
+
+ if hard_brake:
+ self._paused_until_s = now_s + 15.0
+ self._admission_multiplier = 0.0
+ self._state = "critical_pause"
+ self._last_decrease_at_s = now_s
+ self._log_transition(previous_state, previous_multiplier, pressure, snapshot)
+ return
+
+ if now_s < self._paused_until_s:
+ self._state = "critical_pause"
+ self._log_transition(previous_state, previous_multiplier, pressure, snapshot)
+ return
+
+ min_multiplier = self._get_min_multiplier()
+ if pressure >= 0.70:
+ self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.4)
+ self._state = "hot"
+ self._last_decrease_at_s = now_s
+ elif pressure >= 0.45:
+ self._admission_multiplier = max(min_multiplier, self._admission_multiplier * 0.7)
+ self._state = "warm"
+ self._last_decrease_at_s = now_s
+ else:
+ if pressure <= 0.20 and (now_s - self._last_decrease_at_s) >= 10.0:
+ self._admission_multiplier = min(1.0, self._admission_multiplier + 0.05)
+ self._state = "healthy"
+
+ self._log_transition(previous_state, previous_multiplier, pressure, snapshot)
+
+ def get_decision(self, *, is_pre_app_start: bool) -> RootSamplingDecision:
+ with self._lock:
+ if is_pre_app_start:
+ return RootSamplingDecision(
+ should_record=True,
+ reason="pre_app_start",
+ mode=self._config.mode,
+ state=self._state,
+ base_rate=self._config.base_rate,
+ min_rate=self._config.min_rate,
+ effective_rate=1.0,
+ admission_multiplier=1.0,
+ )
+
+ effective_rate = (
+ self.get_effective_sampling_rate()
+ if self._config.mode == "adaptive"
+ else _clamp01(self._config.base_rate)
+ )
+
+ if effective_rate <= 0.0:
+ return RootSamplingDecision(
+ should_record=False,
+ reason="critical_pause" if self._state == "critical_pause" else "not_sampled",
+ mode=self._config.mode,
+ state=self._state,
+ base_rate=self._config.base_rate,
+ min_rate=self._config.min_rate,
+ effective_rate=effective_rate,
+ admission_multiplier=self._admission_multiplier,
+ )
+
+ should_record = self._random_fn() < effective_rate
+ return RootSamplingDecision(
+ should_record=should_record,
+ reason=(
+ "sampled"
+ if should_record
+ else "load_shed"
+ if self._config.mode == "adaptive" and effective_rate < self._config.base_rate
+ else "not_sampled"
+ ),
+ mode=self._config.mode,
+ state=self._state,
+ base_rate=self._config.base_rate,
+ min_rate=self._config.min_rate,
+ effective_rate=effective_rate,
+ admission_multiplier=self._admission_multiplier if self._config.mode == "adaptive" else 1.0,
+ )
+
+ def get_effective_sampling_rate(self) -> float:
+ with self._lock:
+ if self._config.mode != "adaptive":
+ return _clamp01(self._config.base_rate)
+ if self._state == "critical_pause" and self._now_fn() < self._paused_until_s:
+ return 0.0
+ effective_rate = self._config.base_rate * self._admission_multiplier
+ return _clamp(
+ effective_rate,
+ min(self._config.base_rate, self._config.min_rate),
+ self._config.base_rate,
+ )
+
+ def _get_min_multiplier(self) -> float:
+ if self._config.base_rate <= 0.0 or self._config.min_rate <= 0.0:
+ return 0.0
+ return _clamp01(self._config.min_rate / self._config.base_rate)
+
+ def _log_transition(
+ self,
+ previous_state: AdaptiveSamplingState,
+ previous_multiplier: float,
+ pressure: float,
+ snapshot: AdaptiveSamplingHealthSnapshot,
+ ) -> None:
+ if previous_state == self._state and abs(previous_multiplier - self._admission_multiplier) < 0.05:
+ return
+
+ logger.info(
+ "Adaptive sampling updated (state=%s, multiplier=%.2f, effective_rate=%.4f, pressure=%.2f, queue_fill=%s, memory_pressure_ratio=%s, export_circuit_open=%s).",
+ self._state,
+ self._admission_multiplier,
+ self.get_effective_sampling_rate(),
+ pressure,
+ f"{self._queue_fill_ewma:.2f}" if self._queue_fill_ewma is not None else "n/a",
+ snapshot.memory_pressure_ratio if snapshot.memory_pressure_ratio is not None else "n/a",
+ snapshot.export_circuit_open,
+ )
diff --git a/drift/core/batch_processor.py b/drift/core/batch_processor.py
index 13c89e9..9338657 100644
--- a/drift/core/batch_processor.py
+++ b/drift/core/batch_processor.py
@@ -244,3 +244,8 @@ def queue_size(self) -> int:
def dropped_span_count(self) -> int:
"""Get the number of dropped spans."""
return self._dropped_spans
+
+ @property
+ def max_queue_size(self) -> int:
+ """Get the configured maximum queue size."""
+ return self._config.max_queue_size
diff --git a/drift/core/config.py b/drift/core/config.py
index 481c07d..16df032 100644
--- a/drift/core/config.py
+++ b/drift/core/config.py
@@ -66,11 +66,21 @@ class ComparisonConfig:
ignore_fields: list[str] = field(default_factory=list)
+@dataclass
+class SamplingConfig:
+ """Configuration for fixed vs adaptive sampling."""
+
+ mode: str | None = None
+ base_rate: float | None = None
+ min_rate: float | None = None
+
+
@dataclass
class RecordingConfig:
"""Configuration for recording behavior."""
sampling_rate: float | None = None
+ sampling: SamplingConfig | None = None
export_spans: bool | None = None
enable_env_var_recording: bool | None = None
enable_analytics: bool | None = None
@@ -144,8 +154,42 @@ def _parse_recording_config(data: dict[str, Any]) -> RecordingConfig:
)
sampling_rate = None
+ sampling = None
+ raw_sampling = data.get("sampling")
+ if isinstance(raw_sampling, dict):
+ base_rate = raw_sampling.get("base_rate")
+ if base_rate is not None and not isinstance(base_rate, (int, float)):
+ logger.warning(
+ f"Invalid 'sampling.base_rate' in config: expected number, got {type(base_rate).__name__}. "
+ "This value will be ignored."
+ )
+ base_rate = None
+
+ min_rate = raw_sampling.get("min_rate")
+ if min_rate is not None and not isinstance(min_rate, (int, float)):
+ logger.warning(
+ f"Invalid 'sampling.min_rate' in config: expected number, got {type(min_rate).__name__}. "
+ "This value will be ignored."
+ )
+ min_rate = None
+
+ mode = raw_sampling.get("mode")
+ if mode is not None and not isinstance(mode, str):
+ logger.warning(
+ f"Invalid 'sampling.mode' in config: expected string, got {type(mode).__name__}. "
+ "This value will be ignored."
+ )
+ mode = None
+
+ sampling = SamplingConfig(
+ mode=mode,
+ base_rate=float(base_rate) if base_rate is not None else None,
+ min_rate=float(min_rate) if min_rate is not None else None,
+ )
+
return RecordingConfig(
sampling_rate=sampling_rate,
+ sampling=sampling,
export_spans=data.get("export_spans"),
enable_env_var_recording=data.get("enable_env_var_recording"),
enable_analytics=data.get("enable_analytics"),
diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py
index 395f8b9..b93ac2c 100644
--- a/drift/core/drift_sdk.py
+++ b/drift/core/drift_sdk.py
@@ -8,6 +8,7 @@
import platform
import random
import stat
+import threading
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -19,6 +20,13 @@
from ..instrumentation.registry import install_hooks
from ..version import SDK_VERSION
+from .adaptive_sampling import (
+ AdaptiveSamplingController,
+ AdaptiveSamplingHealthSnapshot,
+ ResolvedSamplingConfig,
+ RootSamplingDecision,
+ SamplingMode,
+)
from .communication.communicator import CommunicatorConfig, ProtobufCommunicator
from .communication.types import MockRequestInput, MockResponseOutput
from .config import TuskConfig, TuskFileConfig, load_tusk_config
@@ -48,6 +56,12 @@ def __init__(self) -> None:
self.app_ready = False
self._sdk_instance_id = self._generate_sdk_instance_id()
self._sampling_rate: float = 1.0
+ self._sampling_mode: str = "fixed"
+ self._min_sampling_rate: float = 0.0
+ self._adaptive_sampling_controller: AdaptiveSamplingController | None = None
+ self._adaptive_sampling_thread: threading.Thread | None = None
+ self._adaptive_sampling_stop_event = threading.Event()
+ self._effective_memory_limit_bytes: int | None = None
self._transform_configs: dict[str, Any] | None = None
self._init_params: dict[str, Any] = {}
@@ -121,14 +135,16 @@ def _log_startup_summary(self, env: str, use_remote_export: bool) -> None:
)
logger.info(
- "SDK initialized successfully (version=%s, mode=%s, env=%s, service=%s, serviceId=%s, exportSpans=%s, samplingRate=%s, logLevel=%s, runtime=python %s, platform=%s/%s).",
+ "SDK initialized successfully (version=%s, mode=%s, env=%s, service=%s, serviceId=%s, exportSpans=%s, samplingMode=%s, samplingBaseRate=%s, samplingMinRate=%s, logLevel=%s, runtime=python %s, platform=%s/%s).",
SDK_VERSION,
self.mode,
env,
service_name,
service_id,
use_remote_export,
+ self._sampling_mode,
self._sampling_rate,
+ self._min_sampling_rate,
get_log_level(),
platform.python_version(),
platform.system().lower(),
@@ -168,8 +184,6 @@ def initialize(
"log_level": log_level,
}
- instance._sampling_rate = instance._determine_sampling_rate(sampling_rate)
-
effective_api_key = api_key or os.environ.get("TUSK_API_KEY")
if not env:
@@ -183,6 +197,11 @@ def initialize(
logger.debug("Already initialized, skipping...")
return instance
+ sampling_config = instance._determine_sampling_config(sampling_rate)
+ instance._sampling_rate = sampling_config.base_rate
+ instance._sampling_mode = sampling_config.mode
+ instance._min_sampling_rate = sampling_config.min_rate
+
# Start coverage collection after the _initialized guard so repeated
# initialize() calls don't stop/restart coverage and lose accumulated data.
from .coverage_server import start_coverage_collection
@@ -280,8 +299,6 @@ def initialize(
instance._td_span_processor = TdSpanProcessor(
exporter=instance.span_exporter,
mode=instance.mode,
- sampling_rate=instance._sampling_rate,
- app_ready=instance.app_ready,
environment=env,
)
instance._td_span_processor.start()
@@ -306,6 +323,7 @@ def initialize(
install_hooks()
instance._init_auto_instrumentations()
+ instance._start_adaptive_sampling_control_loop()
# Create env vars snapshot if enabled (matches Node SDK behavior)
instance.create_env_vars_snapshot()
@@ -318,38 +336,222 @@ def initialize(
return instance
- def _determine_sampling_rate(self, init_param: float | None) -> float:
- """Determine the sampling rate from various sources (precedence order)."""
- # 1. Init param takes precedence
+ def _determine_sampling_config(self, init_param: float | None) -> ResolvedSamplingConfig:
+ """Determine the effective sampling config from init params, env, and file config."""
+ recording_config = self.file_config.recording if self.file_config else None
+ config_sampling = recording_config.sampling if recording_config else None
+
+ mode: SamplingMode = "fixed"
+ if config_sampling and config_sampling.mode in {"fixed", "adaptive"}:
+ mode = "adaptive" if config_sampling.mode == "adaptive" else "fixed"
+ elif config_sampling and config_sampling.mode:
+ logger.warning(
+ "Invalid sampling mode from config file: %s. Must be 'fixed' or 'adaptive'. Ignoring.",
+ config_sampling.mode,
+ )
+
+ base_rate: float | None = None
if init_param is not None:
validated = validate_sampling_rate(init_param, "init params")
if validated is not None:
logger.debug(f"Using sampling rate from init params: {validated}")
- return validated
+ base_rate = validated
- # 2. Environment variable
- env_rate = os.environ.get("TUSK_SAMPLING_RATE")
- if env_rate is not None:
- try:
- parsed = float(env_rate)
- validated = validate_sampling_rate(parsed, "TUSK_SAMPLING_RATE env var")
- if validated is not None:
- logger.debug(f"Using sampling rate from env var: {validated}")
- return validated
- except ValueError:
- logger.warning(f"Invalid TUSK_SAMPLING_RATE env var: {env_rate}")
-
- # 3. Config file
- if self.file_config and self.file_config.recording and self.file_config.recording.sampling_rate is not None:
- config_rate = self.file_config.recording.sampling_rate
- validated = validate_sampling_rate(config_rate, "config file")
+ if base_rate is None:
+ for env_key in ("TUSK_RECORDING_SAMPLING_RATE", "TUSK_SAMPLING_RATE"):
+ env_rate = os.environ.get(env_key)
+ if env_rate is None:
+ continue
+
+ try:
+ parsed = float(env_rate)
+ validated = validate_sampling_rate(parsed, f"{env_key} env var")
+ if validated is not None:
+ logger.debug(f"Using sampling rate from {env_key} env var: {validated}")
+ base_rate = validated
+ break
+ except ValueError:
+ logger.warning(f"Invalid {env_key} env var: {env_rate}")
+
+ if base_rate is None and config_sampling and config_sampling.base_rate is not None:
+ validated = validate_sampling_rate(config_sampling.base_rate, "config file recording.sampling.base_rate")
+ if validated is not None:
+ logger.debug(f"Using sampling rate from config file recording.sampling.base_rate: {validated}")
+ base_rate = validated
+
+ if base_rate is None and recording_config and recording_config.sampling_rate is not None:
+ validated = validate_sampling_rate(recording_config.sampling_rate, "config file recording.sampling_rate")
if validated is not None:
- logger.debug(f"Using sampling rate from config file: {validated}")
- return validated
+ logger.debug(f"Using sampling rate from config file recording.sampling_rate: {validated}")
+ base_rate = validated
+
+ if base_rate is None:
+ logger.debug("Using default sampling rate: 1.0")
+ base_rate = 1.0
+
+ min_rate = 0.0
+ if mode == "adaptive":
+ validated_min_rate = validate_sampling_rate(
+ config_sampling.min_rate if config_sampling else None,
+ "config file recording.sampling.min_rate",
+ )
+ min_rate = validated_min_rate if validated_min_rate is not None else 0.001
+ min_rate = min(base_rate, min_rate)
+
+ return ResolvedSamplingConfig(
+ mode=mode,
+ base_rate=base_rate,
+ min_rate=min_rate,
+ )
+
+ def _determine_sampling_rate(self, init_param: float | None) -> float:
+ """Backward-compatible helper that returns only the effective base sampling rate."""
+ return self._determine_sampling_config(init_param).base_rate
+
+ def _start_adaptive_sampling_control_loop(self) -> None:
+ if self.mode != TuskDriftMode.RECORD or self._sampling_mode != "adaptive":
+ return
+
+ self._adaptive_sampling_controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(
+ mode="adaptive",
+ base_rate=self._sampling_rate,
+ min_rate=self._min_sampling_rate,
+ )
+ )
+ self._effective_memory_limit_bytes = self._detect_effective_memory_limit_bytes()
+ self._adaptive_sampling_stop_event.clear()
+
+ self._adaptive_sampling_thread = threading.Thread(
+ target=self._adaptive_sampling_loop,
+ daemon=True,
+ name="drift-adaptive-sampling",
+ )
+ self._adaptive_sampling_thread.start()
+ self._safe_update_adaptive_sampling_health()
+
+ def _adaptive_sampling_loop(self) -> None:
+ while not self._adaptive_sampling_stop_event.wait(timeout=2.0):
+ self._safe_update_adaptive_sampling_health()
+
+ def _safe_update_adaptive_sampling_health(self) -> None:
+ try:
+ self._update_adaptive_sampling_health()
+ except Exception:
+ logger.error("Adaptive sampling health update failed; keeping previous controller state.", exc_info=True)
+
+ def _update_adaptive_sampling_health(self) -> None:
+ if self._adaptive_sampling_controller is None:
+ return
+
+ batch_processor = self._td_span_processor._batch_processor if self._td_span_processor else None
+ queue_fill_ratio = None
+ dropped_span_count = 0
+ if batch_processor is not None and batch_processor.max_queue_size > 0:
+ queue_fill_ratio = batch_processor.queue_size / batch_processor.max_queue_size
+ dropped_span_count = batch_processor.dropped_span_count
+
+ export_failure_count = 0
+ export_circuit_open = False
+ if self.span_exporter is not None:
+ for adapter in self.span_exporter.get_adapters():
+ spans_failed = getattr(adapter, "spans_failed", 0)
+ export_failure_count += int(spans_failed)
+ export_circuit_open = export_circuit_open or getattr(adapter, "circuit_state", "") == "open"
+
+ self._adaptive_sampling_controller.update(
+ AdaptiveSamplingHealthSnapshot(
+ queue_fill_ratio=queue_fill_ratio,
+ dropped_span_count=dropped_span_count,
+ export_failure_count=export_failure_count,
+ export_circuit_open=export_circuit_open,
+ memory_pressure_ratio=self._get_memory_pressure_ratio(),
+ )
+ )
+
+ def _detect_effective_memory_limit_bytes(self) -> int | None:
+ candidates = (
+ "/sys/fs/cgroup/memory.max",
+ "/sys/fs/cgroup/memory/memory.limit_in_bytes",
+ )
+ for path in candidates:
+ parsed = self._read_numeric_control_file(path)
+ if parsed is None:
+ continue
+ if parsed <= 0 or parsed > 1_000_000_000_000_000:
+ continue
+ return parsed
+ return None
+
+ def _get_memory_pressure_ratio(self) -> float | None:
+ if self._effective_memory_limit_bytes is None or self._effective_memory_limit_bytes <= 0:
+ return None
+
+ cgroup_current = self._read_numeric_control_file("/sys/fs/cgroup/memory.current")
+ if cgroup_current is not None:
+ return cgroup_current / self._effective_memory_limit_bytes
+
+ cgroup_v1_current = self._read_numeric_control_file("/sys/fs/cgroup/memory/memory.usage_in_bytes")
+ if cgroup_v1_current is not None:
+ return cgroup_v1_current / self._effective_memory_limit_bytes
+
+ current_rss_bytes = self._read_current_rss_bytes()
+ if current_rss_bytes is not None:
+ return current_rss_bytes / self._effective_memory_limit_bytes
+
+ return None
+
+ @staticmethod
+ def _parse_proc_status_rss_bytes(raw_status: str) -> int | None:
+ for line in raw_status.splitlines():
+ if not line.startswith("VmRSS:"):
+ continue
+
+ parts = line.split()
+ if len(parts) < 3 or parts[2].lower() != "kb":
+ return None
+
+ return int(parts[1]) * 1024
+
+ return None
+
+ @staticmethod
+ def _parse_proc_statm_rss_bytes(raw_statm: str, page_size: int) -> int | None:
+ fields = raw_statm.split()
+ if len(fields) < 2:
+ return None
+
+ return int(fields[1]) * page_size
+
+ def _read_current_rss_bytes(self) -> int | None:
+ try:
+ proc_status_path = Path("/proc/self/status")
+ if proc_status_path.exists():
+ parsed = self._parse_proc_status_rss_bytes(proc_status_path.read_text())
+ if parsed is not None:
+ return parsed
+ except Exception:
+ pass
+
+ try:
+ proc_statm_path = Path("/proc/self/statm")
+ if proc_statm_path.exists():
+ return self._parse_proc_statm_rss_bytes(proc_statm_path.read_text(), int(os.sysconf("SC_PAGE_SIZE")))
+ except Exception:
+ pass
- # 4. Default
- logger.debug("Using default sampling rate: 1.0")
- return 1.0
+ return None
+
+ def _read_numeric_control_file(self, path: str) -> int | None:
+ try:
+ if not os.path.exists(path):
+ return None
+ raw_value = Path(path).read_text().strip()
+ if not raw_value or raw_value == "max":
+ return None
+ return int(raw_value)
+ except Exception:
+ return None
def _detect_mode(self) -> TuskDriftMode:
"""Detect the SDK mode from environment variable."""
@@ -648,10 +850,6 @@ def mark_app_as_ready(self) -> None:
self.app_ready = True
- # Update span processor with app_ready flag
- if self._td_span_processor:
- self._td_span_processor.update_app_ready(True)
-
if self.mode == TuskDriftMode.REPLAY:
logger.debug("Replay mode active - ready to serve mocked responses")
elif self.mode == TuskDriftMode.RECORD:
@@ -801,6 +999,34 @@ async def send_unpatched_dependency_alert(
except Exception as e:
logger.debug(f"Failed to send unpatched dependency alert: {e}")
+ def should_record_root_request(self, *, is_pre_app_start: bool) -> RootSamplingDecision:
+ if self._adaptive_sampling_controller is not None:
+ return self._adaptive_sampling_controller.get_decision(is_pre_app_start=is_pre_app_start)
+
+ if is_pre_app_start:
+ return RootSamplingDecision(
+ should_record=True,
+ reason="pre_app_start",
+ mode="fixed",
+ state="fixed",
+ base_rate=self._sampling_rate,
+ min_rate=self._min_sampling_rate,
+ effective_rate=1.0,
+ admission_multiplier=1.0,
+ )
+
+ should_record = should_sample(self._sampling_rate, True)
+ return RootSamplingDecision(
+ should_record=should_record,
+ reason="sampled" if should_record else "not_sampled",
+ mode="fixed",
+ state="fixed",
+ base_rate=self._sampling_rate,
+ min_rate=self._min_sampling_rate,
+ effective_rate=self._sampling_rate,
+ admission_multiplier=1.0,
+ )
+
def get_sampling_rate(self) -> float:
"""Get the current sampling rate."""
return self._sampling_rate
@@ -835,6 +1061,11 @@ def shutdown(self) -> None:
from .coverage_server import stop_coverage_collection
+ self._adaptive_sampling_stop_event.set()
+ if self._adaptive_sampling_thread is not None:
+ self._adaptive_sampling_thread.join(timeout=5.0)
+ self._adaptive_sampling_thread = None
+
# Shutdown OpenTelemetry tracer provider
if self._td_span_processor is not None:
self._td_span_processor.shutdown()
diff --git a/drift/core/mode_utils.py b/drift/core/mode_utils.py
index 69d9d3a..f671a26 100644
--- a/drift/core/mode_utils.py
+++ b/drift/core/mode_utils.py
@@ -180,18 +180,18 @@ def should_record_inbound_http_request(
Returns:
Tuple of (should_record, skip_reason):
- should_record: True if request should be recorded
- - skip_reason: If False, explains why ("dropped" or "not_sampled"), None otherwise
+ - skip_reason: If False, explains why ("dropped", "not_sampled",
+ "load_shed", or "critical_pause"), None otherwise
"""
if transform_engine and transform_engine.should_drop_inbound_request(method, target, headers):
return False, "dropped"
if not is_pre_app_start:
from .drift_sdk import TuskDrift
- from .sampling import should_sample
sdk = TuskDrift.get_instance()
- sampling_rate = sdk.get_sampling_rate()
- if not should_sample(sampling_rate, is_app_ready=True):
- return False, "not_sampled"
+ decision = sdk.should_record_root_request(is_pre_app_start=is_pre_app_start)
+ if not decision.should_record:
+ return False, decision.reason
return True, None
diff --git a/drift/core/no_recording.py b/drift/core/no_recording.py
new file mode 100644
index 0000000..33dc256
--- /dev/null
+++ b/drift/core/no_recording.py
@@ -0,0 +1,22 @@
+"""Context helpers for suppressing child span creation."""
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+from contextlib import contextmanager
+from contextvars import ContextVar
+
+_recording_suppressed: ContextVar[bool] = ContextVar("td_recording_suppressed", default=False)
+
+
+def is_recording_suppressed() -> bool:
+ return _recording_suppressed.get()
+
+
+@contextmanager
+def suppress_recording() -> Iterator[None]:
+ token = _recording_suppressed.set(True)
+ try:
+ yield
+ finally:
+ _recording_suppressed.reset(token)
diff --git a/drift/core/tracing/span_utils.py b/drift/core/tracing/span_utils.py
index e246833..bc41d7e 100644
--- a/drift/core/tracing/span_utils.py
+++ b/drift/core/tracing/span_utils.py
@@ -19,6 +19,7 @@
from opentelemetry.trace import SpanKind as OTelSpanKind
from opentelemetry.trace import Status, StatusCode
+from ..no_recording import is_recording_suppressed
from ..types import TuskDriftMode
from .td_attributes import TdSpanAttributes
@@ -135,6 +136,10 @@ def create_span(options: CreateSpanOptions) -> SpanInfo | None:
Returns None if span creation fails.
"""
try:
+ if is_recording_suppressed():
+ logger.debug(f"[SpanUtils] Skipping span creation for '{options.name}' - recording suppressed")
+ return None
+
# Import here to avoid circular dependency
from ..drift_sdk import TuskDrift
diff --git a/drift/core/tracing/td_span_processor.py b/drift/core/tracing/td_span_processor.py
index ef1e6e7..98adc6d 100644
--- a/drift/core/tracing/td_span_processor.py
+++ b/drift/core/tracing/td_span_processor.py
@@ -47,21 +47,21 @@ class TdSpanProcessor(SpanProcessor):
This processor implements OpenTelemetry's SpanProcessor interface and serves
as the bridge between OTel's tracing system and Drift's export infrastructure.
+ Root-request admission sampling happens earlier in inbound instrumentations,
+ so this processor only handles ended spans that were already allowed through.
When a span ends:
1. Convert to CleanSpanData using otel_converter
- 2. Apply sampling logic
- 3. Apply trace blocking logic
- 4. Validate protobuf serialization
- 5. Forward to batch processor for export
+ 2. Apply trace blocking logic
+ 3. Validate protobuf serialization
+ 4. Handle REPLAY-mode inbound span forwarding
+ 5. Forward RECORD-mode spans to the batch processor for export
"""
def __init__(
self,
exporter: TdSpanExporter,
mode: TuskDriftMode,
- sampling_rate: float = 1.0,
- app_ready: bool = False,
environment: str | None = None,
) -> None:
"""Initialize the TdSpanProcessor.
@@ -69,14 +69,10 @@ def __init__(
Args:
exporter: The TdSpanExporter to use for span export
mode: SDK mode (RECORD, REPLAY, DISABLED)
- sampling_rate: Sampling rate (0.0-1.0)
- app_ready: Whether the application is ready
environment: Environment name to include on spans
"""
self._exporter = exporter
self._mode = mode
- self._sampling_rate = sampling_rate
- self._app_ready = app_ready
self._environment = environment
# We'll import and create batch processor lazily to avoid circular imports
@@ -244,23 +240,3 @@ def force_flush(self, timeout_millis: int = 30000) -> bool:
except Exception as e:
logger.error(f"Error during force_flush: {e}")
return False
-
- def update_app_ready(self, app_ready: bool) -> None:
- """Update the app_ready flag.
-
- This is called when the application marks itself as ready.
-
- Args:
- app_ready: Whether the application is ready
- """
- self._app_ready = app_ready
- logger.debug(f"TdSpanProcessor app_ready updated to {app_ready}")
-
- def update_sampling_rate(self, sampling_rate: float) -> None:
- """Update the sampling rate.
-
- Args:
- sampling_rate: New sampling rate (0.0-1.0)
- """
- self._sampling_rate = sampling_rate
- logger.debug(f"TdSpanProcessor sampling_rate updated to {sampling_rate}")
diff --git a/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml b/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml
index 14976cb..52a4ec4 100644
--- a/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/aiohttp/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/django/e2e-tests/docker-compose.yml b/drift/instrumentation/django/e2e-tests/docker-compose.yml
index 801a3c7..72b9b09 100644
--- a/drift/instrumentation/django/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/django/e2e-tests/docker-compose.yml
@@ -28,6 +28,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/django/middleware.py b/drift/instrumentation/django/middleware.py
index 670679d..6514855 100644
--- a/drift/instrumentation/django/middleware.py
+++ b/drift/instrumentation/django/middleware.py
@@ -17,6 +17,7 @@
if TYPE_CHECKING:
from django.http import HttpRequest, HttpResponse
from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request
+from ...core.no_recording import suppress_recording
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils
from ...core.types import (
@@ -190,7 +191,8 @@ def _record_request(self, request: HttpRequest, sdk, is_pre_app_start: bool) ->
)
if not should_record:
logger.debug(f"[Django] Skipping request ({skip_reason}), path={path}")
- return self.get_response(request)
+ with suppress_recording():
+ return self.get_response(request)
start_time_ns = time.time_ns()
span_name = f"{method} {path}"
diff --git a/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml b/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml
index cf2e18c..9767b72 100644
--- a/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/fastapi/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/fastapi/instrumentation.py b/drift/instrumentation/fastapi/instrumentation.py
index d280544..5e1df77 100644
--- a/drift/instrumentation/fastapi/instrumentation.py
+++ b/drift/instrumentation/fastapi/instrumentation.py
@@ -27,6 +27,7 @@
from ...core.drift_sdk import TuskDrift
from ...core.json_schema_helper import JsonSchemaHelper, SchemaMerge
from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request
+from ...core.no_recording import suppress_recording
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanInfo, SpanUtils
from ...core.types import (
@@ -267,7 +268,8 @@ async def _record_request(
)
if not should_record:
logger.debug(f"[FastAPI] Skipping request ({skip_reason}), path={raw_path}")
- return await original_call(app, scope, receive, send)
+ with suppress_recording():
+ return await original_call(app, scope, receive, send)
start_time_ns = time.time_ns()
diff --git a/drift/instrumentation/flask/e2e-tests/docker-compose.yml b/drift/instrumentation/flask/e2e-tests/docker-compose.yml
index 8c73754..5d21955 100644
--- a/drift/instrumentation/flask/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/flask/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml
index c0d45a7..1671f4b 100644
--- a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml
@@ -18,6 +18,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/drift/instrumentation/httpx/e2e-tests/docker-compose.yml b/drift/instrumentation/httpx/e2e-tests/docker-compose.yml
index ae57669..6471395 100644
--- a/drift/instrumentation/httpx/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/httpx/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml b/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml
index 34e6e0a..95a6b6a 100644
--- a/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/psycopg/e2e-tests/docker-compose.yml
@@ -38,6 +38,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml b/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml
index 608fa98..0e10cc8 100644
--- a/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/psycopg2/e2e-tests/docker-compose.yml
@@ -38,6 +38,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/drift/instrumentation/redis/e2e-tests/docker-compose.yml b/drift/instrumentation/redis/e2e-tests/docker-compose.yml
index 84b269c..64772ea 100644
--- a/drift/instrumentation/redis/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/redis/e2e-tests/docker-compose.yml
@@ -31,6 +31,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/drift/instrumentation/requests/e2e-tests/docker-compose.yml b/drift/instrumentation/requests/e2e-tests/docker-compose.yml
index 997da3a..0107d39 100644
--- a/drift/instrumentation/requests/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/requests/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/urllib/e2e-tests/docker-compose.yml b/drift/instrumentation/urllib/e2e-tests/docker-compose.yml
index 4beb9b9..1436145 100644
--- a/drift/instrumentation/urllib/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/urllib/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml b/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml
index 10b39a7..c61e044 100644
--- a/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml
+++ b/drift/instrumentation/urllib3/e2e-tests/docker-compose.yml
@@ -27,6 +27,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
- USE_MOCK_EXTERNALS=${USE_MOCK_EXTERNALS:-1}
- MOCK_SERVER_BASE_URL=${MOCK_SERVER_BASE_URL:-http://mock-upstream:8081}
diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py
index e5f4aed..4428e83 100644
--- a/drift/instrumentation/wsgi/handler.py
+++ b/drift/instrumentation/wsgi/handler.py
@@ -9,7 +9,7 @@
import json
import logging
import time
-from collections.abc import Iterable
+from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING, Any
from opentelemetry import context as otel_context
@@ -31,6 +31,7 @@
from ...core.mode_utils import handle_record_mode, should_record_inbound_http_request
+from ...core.no_recording import suppress_recording
from ...core.tracing import TdSpanAttributes
from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils
from ...core.types import (
@@ -53,6 +54,30 @@
)
+class SuppressedResponseIterable(Iterable[bytes]):
+ """Keep no-record suppression active while a skipped WSGI response is consumed."""
+
+ def __init__(self, response: Iterable[bytes]):
+ self._response = response
+
+ def __iter__(self) -> Iterator[bytes]:
+ with suppress_recording():
+ iterator = iter(self._response)
+ while True:
+ try:
+ with suppress_recording():
+ chunk = next(iterator)
+ except StopIteration:
+ return
+ yield chunk
+
+ def close(self) -> None:
+ close_method = getattr(self._response, "close", None)
+ if close_method is not None:
+ with suppress_recording():
+ close_method()
+
+
def handle_wsgi_request(
app: WSGIApplication,
environ: WSGIEnvironment,
@@ -225,7 +250,9 @@ def _create_and_handle_request(
)
if not should_record:
logger.debug(f"[WSGI] Skipping request ({skip_reason}), path={path}")
- return original_wsgi_app(app, environ, start_response)
+ with suppress_recording():
+ response = original_wsgi_app(app, environ, start_response)
+ return SuppressedResponseIterable(response)
# Capture request body
request_body = capture_request_body(environ)
diff --git a/drift/stack-tests/django-postgres/docker-compose.yml b/drift/stack-tests/django-postgres/docker-compose.yml
index c20e94f..6877e01 100644
--- a/drift/stack-tests/django-postgres/docker-compose.yml
+++ b/drift/stack-tests/django-postgres/docker-compose.yml
@@ -39,6 +39,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/drift/stack-tests/django-redis/docker-compose.yml b/drift/stack-tests/django-redis/docker-compose.yml
index 1570b8a..e1c34db 100644
--- a/drift/stack-tests/django-redis/docker-compose.yml
+++ b/drift/stack-tests/django-redis/docker-compose.yml
@@ -32,6 +32,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/drift/stack-tests/fastapi-postgres/docker-compose.yml b/drift/stack-tests/fastapi-postgres/docker-compose.yml
index 3497e5e..86ae38c 100644
--- a/drift/stack-tests/fastapi-postgres/docker-compose.yml
+++ b/drift/stack-tests/fastapi-postgres/docker-compose.yml
@@ -38,6 +38,7 @@ services:
- BENCHMARKS=${BENCHMARKS:-}
- BENCHMARK_DURATION=${BENCHMARK_DURATION:-10}
- BENCHMARK_WARMUP=${BENCHMARK_WARMUP:-3}
+ - TUSK_RECORDING_SAMPLING_RATE=${TUSK_RECORDING_SAMPLING_RATE:-}
- TUSK_SAMPLING_RATE=${TUSK_SAMPLING_RATE:-}
working_dir: /app
volumes:
diff --git a/tests/unit/test_adaptive_sampling.py b/tests/unit/test_adaptive_sampling.py
new file mode 100644
index 0000000..1f4d451
--- /dev/null
+++ b/tests/unit/test_adaptive_sampling.py
@@ -0,0 +1,120 @@
+import math
+import threading
+
+from drift.core.adaptive_sampling import (
+ AdaptiveSamplingController,
+ AdaptiveSamplingHealthSnapshot,
+ ResolvedSamplingConfig,
+)
+
+
+def test_pre_app_start_always_records():
+ controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(mode="adaptive", base_rate=0.0, min_rate=0.0),
+ random_fn=lambda: 0.99,
+ now_fn=lambda: 0.0,
+ )
+
+ decision = controller.get_decision(is_pre_app_start=True)
+
+ assert decision.should_record is True
+ assert decision.reason == "pre_app_start"
+ assert decision.effective_rate == 1.0
+
+
+def test_controller_load_sheds_and_pauses_on_drops():
+ now = {"value": 0.0}
+ controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1),
+ random_fn=lambda: 0.3,
+ now_fn=lambda: now["value"],
+ )
+
+ controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.9))
+ load_shed_decision = controller.get_decision(is_pre_app_start=False)
+ assert load_shed_decision.state == "hot"
+ assert load_shed_decision.effective_rate < 0.5
+ assert load_shed_decision.should_record is False
+ assert load_shed_decision.reason == "load_shed"
+
+ now["value"] = 1.0
+ controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.1, dropped_span_count=1))
+ paused_decision = controller.get_decision(is_pre_app_start=False)
+ assert paused_decision.state == "critical_pause"
+ assert paused_decision.should_record is False
+ assert paused_decision.reason == "critical_pause"
+
+
+def test_elapsed_time_uses_zero_timestamp_as_real_value():
+ now = {"value": 0.0}
+ controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1),
+ random_fn=lambda: 0.0,
+ now_fn=lambda: now["value"],
+ )
+
+ controller.update(AdaptiveSamplingHealthSnapshot(export_failure_count=1))
+ now["value"] = 0.5
+ controller.update(AdaptiveSamplingHealthSnapshot(export_failure_count=1))
+
+ expected_decay = math.exp(-(0.5 * 1000.0) / 30000.0)
+ assert math.isclose(controller._recent_failure_signal, expected_decay, rel_tol=1e-6)
+
+
+def test_get_decision_waits_for_controller_lock():
+ controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1),
+ random_fn=lambda: 0.0,
+ now_fn=lambda: 0.0,
+ )
+ started = threading.Event()
+ finished = threading.Event()
+ result = {}
+
+ def worker() -> None:
+ started.set()
+ result["decision"] = controller.get_decision(is_pre_app_start=False)
+ finished.set()
+
+ thread = threading.Thread(target=worker)
+ controller._lock.acquire()
+ try:
+ thread.start()
+ assert started.wait(timeout=1.0)
+ assert not finished.wait(timeout=0.05)
+ finally:
+ controller._lock.release()
+
+ assert finished.wait(timeout=1.0)
+ thread.join(timeout=1.0)
+ assert not thread.is_alive()
+ assert result["decision"].effective_rate == 0.5
+
+
+def test_update_waits_for_controller_lock():
+ controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1),
+ random_fn=lambda: 0.0,
+ now_fn=lambda: 0.0,
+ )
+ started = threading.Event()
+ finished = threading.Event()
+
+ def worker() -> None:
+ started.set()
+ controller.update(AdaptiveSamplingHealthSnapshot(queue_fill_ratio=0.9))
+ finished.set()
+
+ thread = threading.Thread(target=worker)
+ controller._lock.acquire()
+ try:
+ thread.start()
+ assert started.wait(timeout=1.0)
+ assert not finished.wait(timeout=0.05)
+ finally:
+ controller._lock.release()
+
+ assert finished.wait(timeout=1.0)
+ thread.join(timeout=1.0)
+ assert not thread.is_alive()
+ assert controller.get_decision(is_pre_app_start=False).state == "hot"
diff --git a/tests/unit/test_config_loading.py b/tests/unit/test_config_loading.py
index 9f1570f..8d6d1b5 100644
--- a/tests/unit/test_config_loading.py
+++ b/tests/unit/test_config_loading.py
@@ -214,6 +214,38 @@ def test_handles_partial_config(self):
finally:
os.chdir(original_cwd)
+ def test_loads_nested_sampling_config(self):
+ """Should load recording.sampling config alongside legacy fields."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ project_root = Path(tmpdir)
+ (project_root / "pyproject.toml").touch()
+
+ tusk_dir = project_root / ".tusk"
+ tusk_dir.mkdir()
+ (tusk_dir / "config.yaml").write_text(
+ """
+recording:
+ sampling:
+ mode: adaptive
+ base_rate: 0.25
+ min_rate: 0.05
+"""
+ )
+
+ original_cwd = os.getcwd()
+ try:
+ os.chdir(project_root)
+ config = load_tusk_config()
+
+ assert config is not None
+ assert config.recording is not None
+ assert config.recording.sampling is not None
+ assert config.recording.sampling.mode == "adaptive"
+ assert config.recording.sampling.base_rate == 0.25
+ assert config.recording.sampling.min_rate == 0.05
+ finally:
+ os.chdir(original_cwd)
+
def test_handles_invalid_yaml(self):
"""Should return None when YAML is invalid."""
with tempfile.TemporaryDirectory() as tmpdir:
diff --git a/tests/unit/test_drift_sdk.py b/tests/unit/test_drift_sdk.py
index 954ae31..adb16d0 100644
--- a/tests/unit/test_drift_sdk.py
+++ b/tests/unit/test_drift_sdk.py
@@ -6,6 +6,8 @@
import pytest
+from drift.core.adaptive_sampling import AdaptiveSamplingController, ResolvedSamplingConfig
+from drift.core.config import RecordingConfig, SamplingConfig, TuskFileConfig
from drift.core.drift_sdk import TuskDrift
from drift.core.types import TuskDriftMode
@@ -19,7 +21,13 @@ def reset_singleton(self):
TuskDrift._instance = None
TuskDrift._initialized = False
# Clear environment variables
- env_vars = ["TUSK_DRIFT_MODE", "TUSK_API_KEY", "TUSK_SAMPLING_RATE", "ENV"]
+ env_vars = [
+ "TUSK_DRIFT_MODE",
+ "TUSK_API_KEY",
+ "TUSK_RECORDING_SAMPLING_RATE",
+ "TUSK_SAMPLING_RATE",
+ "ENV",
+ ]
original_env = {k: os.environ.get(k) for k in env_vars}
for var in env_vars:
if var in os.environ:
@@ -120,9 +128,10 @@ def reset_singleton(self):
"""Reset singleton state before each test."""
TuskDrift._instance = None
TuskDrift._initialized = False
- # Clear sampling rate env var
- if "TUSK_SAMPLING_RATE" in os.environ:
- del os.environ["TUSK_SAMPLING_RATE"]
+ # Clear sampling rate env vars
+ for env_var in ("TUSK_RECORDING_SAMPLING_RATE", "TUSK_SAMPLING_RATE"):
+ if env_var in os.environ:
+ del os.environ[env_var]
yield
TuskDrift._instance = None
TuskDrift._initialized = False
@@ -136,10 +145,31 @@ def test_uses_init_param_sampling_rate(self, reset_singleton):
assert result == 0.5
- def test_uses_env_var_sampling_rate(self, reset_singleton):
- """Should use sampling rate from env var if init param not provided."""
+ def test_uses_recording_env_var_sampling_rate(self, reset_singleton):
+ """Should use the canonical recording env var if init param not provided."""
os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
- os.environ["TUSK_SAMPLING_RATE"] = "0.25"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25"
+ instance = TuskDrift.get_instance()
+
+ result = instance._determine_sampling_rate(None)
+
+ assert result == 0.25
+
+ def test_uses_legacy_sampling_env_var_as_alias(self, reset_singleton):
+ """Should fall back to the legacy env var for backward compatibility."""
+ os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
+ os.environ["TUSK_SAMPLING_RATE"] = "0.2"
+ instance = TuskDrift.get_instance()
+
+ result = instance._determine_sampling_rate(None)
+
+ assert result == 0.2
+
+ def test_recording_env_var_takes_precedence_over_legacy_alias(self, reset_singleton):
+ """Should prefer the canonical env var when both env vars are set."""
+ os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25"
+ os.environ["TUSK_SAMPLING_RATE"] = "0.1"
instance = TuskDrift.get_instance()
result = instance._determine_sampling_rate(None)
@@ -149,13 +179,37 @@ def test_uses_env_var_sampling_rate(self, reset_singleton):
def test_init_param_takes_precedence_over_env_var(self, reset_singleton):
"""Should prefer init param over env var."""
os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
- os.environ["TUSK_SAMPLING_RATE"] = "0.25"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25"
instance = TuskDrift.get_instance()
result = instance._determine_sampling_rate(0.75)
assert result == 0.75
+ def test_invalid_init_param_falls_back_to_env_var(self, reset_singleton):
+ """Should use env var when init param is present but invalid."""
+ os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "0.25"
+ instance = TuskDrift.get_instance()
+
+ result = instance._determine_sampling_rate(2.0)
+
+ assert result == 0.25
+
+ def test_invalid_init_param_falls_back_to_config_file(self, reset_singleton):
+ """Should use config file when init param is present but invalid."""
+ os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
+ instance = TuskDrift.get_instance()
+ instance.file_config = TuskFileConfig(
+ recording=RecordingConfig(
+ sampling=SamplingConfig(base_rate=0.4),
+ )
+ )
+
+ result = instance._determine_sampling_rate(2.0)
+
+ assert result == 0.4
+
def test_defaults_to_1_0(self, reset_singleton):
"""Should default to 1.0 (100%) sampling rate."""
os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
@@ -165,16 +219,42 @@ def test_defaults_to_1_0(self, reset_singleton):
assert result == 1.0
- def test_rejects_invalid_env_var_sampling_rate(self, reset_singleton):
- """Should reject invalid env var and use default."""
+ def test_rejects_invalid_recording_env_var_sampling_rate(self, reset_singleton):
+ """Should reject an invalid canonical env var and use default."""
os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
- os.environ["TUSK_SAMPLING_RATE"] = "invalid"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid"
instance = TuskDrift.get_instance()
result = instance._determine_sampling_rate(None)
assert result == 1.0
+ def test_invalid_env_var_falls_back_to_config_file(self, reset_singleton):
+ """Should use config file when env var is present but invalid."""
+ os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid"
+ instance = TuskDrift.get_instance()
+ instance.file_config = TuskFileConfig(
+ recording=RecordingConfig(
+ sampling=SamplingConfig(base_rate=0.4),
+ )
+ )
+
+ result = instance._determine_sampling_rate(None)
+
+ assert result == 0.4
+
+ def test_invalid_recording_env_var_falls_back_to_legacy_alias(self, reset_singleton):
+ """Should use the legacy alias when the canonical env var is invalid."""
+ os.environ["TUSK_DRIFT_MODE"] = "DISABLED"
+ os.environ["TUSK_RECORDING_SAMPLING_RATE"] = "invalid"
+ os.environ["TUSK_SAMPLING_RATE"] = "0.4"
+ instance = TuskDrift.get_instance()
+
+ result = instance._determine_sampling_rate(None)
+
+ assert result == 0.4
+
class TestTuskDriftInitialize:
"""Tests for TuskDrift.initialize method."""
@@ -185,7 +265,13 @@ def reset_singleton(self):
TuskDrift._instance = None
TuskDrift._initialized = False
# Clear environment variables
- env_vars = ["TUSK_DRIFT_MODE", "TUSK_API_KEY", "TUSK_SAMPLING_RATE", "ENV"]
+ env_vars = [
+ "TUSK_DRIFT_MODE",
+ "TUSK_API_KEY",
+ "TUSK_RECORDING_SAMPLING_RATE",
+ "TUSK_SAMPLING_RATE",
+ "ENV",
+ ]
for var in env_vars:
if var in os.environ:
del os.environ[var]
@@ -252,6 +338,47 @@ def test_idempotent_initialization(self, reset_singleton, mocker):
# TracerProvider should only be created once
assert mock_provider.call_count == 1
+ def test_second_initialize_does_not_mutate_live_adaptive_sampling_state(self, reset_singleton, mocker):
+ """Should keep sampling fields aligned with the live controller on repeated initialize calls."""
+ mocker.patch("drift.core.drift_sdk.install_hooks")
+ mocker.patch("drift.core.drift_sdk.atexit")
+ mocker.patch("drift.core.drift_sdk.TracerProvider")
+ mocker.patch("drift.core.drift_sdk.trace")
+ mocker.patch.object(TuskDrift, "_start_adaptive_sampling_control_loop")
+ os.environ["TUSK_DRIFT_MODE"] = "RECORD"
+
+ instance = TuskDrift.get_instance()
+ instance.file_config = TuskFileConfig(
+ recording=RecordingConfig(
+ sampling=SamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1),
+ )
+ )
+
+ initialized_instance = TuskDrift.initialize(env="test")
+ initialized_instance._adaptive_sampling_controller = AdaptiveSamplingController(
+ ResolvedSamplingConfig(mode="adaptive", base_rate=0.5, min_rate=0.1),
+ random_fn=lambda: 0.0,
+ now_fn=lambda: 0.0,
+ )
+
+ initialized_instance.file_config = TuskFileConfig(
+ recording=RecordingConfig(
+ sampling=SamplingConfig(mode="fixed", base_rate=0.2, min_rate=None),
+ )
+ )
+
+ second_instance = TuskDrift.initialize(env="test", sampling_rate=0.9)
+
+ assert second_instance is initialized_instance
+ assert second_instance._sampling_rate == 0.5
+ assert second_instance._sampling_mode == "adaptive"
+ assert second_instance._min_sampling_rate == 0.1
+
+ decision = second_instance.should_record_root_request(is_pre_app_start=False)
+ assert decision.mode == "adaptive"
+ assert decision.base_rate == 0.5
+ assert decision.min_rate == 0.1
+
class TestTuskDriftMarkAppAsReady:
"""Tests for TuskDrift.mark_app_as_ready method."""
@@ -412,6 +539,105 @@ def test_shutdown_cleans_up_resources(self, reset_singleton, mocker):
mock_tracer_provider.shutdown.assert_called_once()
+class TestTuskDriftAdaptiveSampling:
+ """Tests for adaptive sampling health monitoring."""
+
+ @pytest.fixture(autouse=True)
+ def reset_singleton(self):
+ """Reset singleton state before each test."""
+ TuskDrift._instance = None
+ TuskDrift._initialized = False
+ yield
+ TuskDrift._instance = None
+ TuskDrift._initialized = False
+
+ def test_safe_update_logs_and_swallows_health_update_exceptions(self, reset_singleton, mocker):
+ """Should log and continue when health updates fail."""
+ instance = TuskDrift.get_instance()
+ mocker.patch.object(instance, "_update_adaptive_sampling_health", side_effect=RuntimeError("boom"))
+ log_error = mocker.patch("drift.core.drift_sdk.logger.error")
+
+ instance._safe_update_adaptive_sampling_health()
+
+ log_error.assert_called_once()
+ assert "Adaptive sampling health update failed" in log_error.call_args.args[0]
+
+ def test_adaptive_sampling_loop_continues_after_update_exception(self, reset_singleton, mocker):
+ """Should keep polling after a single health update failure."""
+ instance = TuskDrift.get_instance()
+ stop_event = mocker.MagicMock()
+ stop_event.wait.side_effect = [False, False, True]
+ instance._adaptive_sampling_stop_event = stop_event
+ log_error = mocker.patch("drift.core.drift_sdk.logger.error")
+
+ update_health = mocker.patch.object(
+ instance,
+ "_update_adaptive_sampling_health",
+ side_effect=[RuntimeError("boom"), None],
+ )
+
+ instance._adaptive_sampling_loop()
+
+ assert update_health.call_count == 2
+ log_error.assert_called_once()
+ assert "Adaptive sampling health update failed" in log_error.call_args.args[0]
+
+
+class TestTuskDriftMemoryPressure:
+ """Tests for memory pressure measurement helpers."""
+
+ @pytest.fixture(autouse=True)
+ def reset_singleton(self):
+ """Reset singleton state before each test."""
+ TuskDrift._instance = None
+ TuskDrift._initialized = False
+ yield
+ TuskDrift._instance = None
+ TuskDrift._initialized = False
+
+ def test_parse_proc_status_rss_bytes(self, reset_singleton):
+ """Should parse current RSS from /proc/self/status."""
+ raw_status = "Name:\tpython\nVmRSS:\t1234 kB\nThreads:\t8\n"
+
+ assert TuskDrift._parse_proc_status_rss_bytes(raw_status) == 1234 * 1024
+
+ def test_read_current_rss_bytes_falls_back_to_proc_statm(self, reset_singleton, mocker):
+ """Should use /proc/self/statm when /proc/self/status is unavailable."""
+ instance = TuskDrift.get_instance()
+
+ mocker.patch(
+ "drift.core.drift_sdk.Path.exists",
+ autospec=True,
+ side_effect=lambda path: str(path) == "/proc/self/statm",
+ )
+ mocker.patch(
+ "drift.core.drift_sdk.Path.read_text",
+ autospec=True,
+ side_effect=lambda path: "100 25 0 0 0 0 0\n" if str(path) == "/proc/self/statm" else "",
+ )
+ mocker.patch("drift.core.drift_sdk.os.sysconf", return_value=4096)
+
+ assert instance._read_current_rss_bytes() == 25 * 4096
+
+ def test_get_memory_pressure_ratio_uses_current_rss_fallback(self, reset_singleton, mocker):
+ """Should use current RSS fallback when cgroup current usage is unavailable."""
+ instance = TuskDrift.get_instance()
+ instance._effective_memory_limit_bytes = 1024
+ mocker.patch.object(instance, "_read_numeric_control_file", return_value=None)
+ mocker.patch.object(instance, "_read_current_rss_bytes", return_value=256)
+
+ assert instance._get_memory_pressure_ratio() == 0.25
+
+ def test_get_memory_pressure_ratio_returns_none_without_current_measurement(self, reset_singleton, mocker):
+ """Should return None when no current memory measurement is available."""
+ instance = TuskDrift.get_instance()
+ instance._effective_memory_limit_bytes = 1024
+ mocker.patch.object(instance, "_read_numeric_control_file", return_value=None)
+ mocker.patch.object(instance, "_read_current_rss_bytes", return_value=None)
+
+ assert instance._get_memory_pressure_ratio() is None
+
+
class TestTuskDriftGetTracer:
"""Tests for TuskDrift.get_tracer method."""
diff --git a/tests/unit/test_mode_utils.py b/tests/unit/test_mode_utils.py
index c60cb0f..3f9b001 100644
--- a/tests/unit/test_mode_utils.py
+++ b/tests/unit/test_mode_utils.py
@@ -311,11 +311,9 @@ def test_returns_true_when_no_drop_and_sampled(self, mocker):
"""Should return (True, None) when not dropped and sampled."""
mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift")
mock_sdk = mocker.MagicMock()
- mock_sdk.get_sampling_rate.return_value = 1.0
+ mock_sdk.should_record_root_request.return_value.should_record = True
mock_drift.get_instance.return_value = mock_sdk
- mocker.patch("drift.core.sampling.should_sample", return_value=True)
-
result, reason = should_record_inbound_http_request(
method="GET",
target="/api/users",
@@ -361,11 +359,10 @@ def test_returns_false_when_not_sampled(self, mocker):
"""Should return (False, 'not_sampled') when sampling decides to skip."""
mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift")
mock_sdk = mocker.MagicMock()
- mock_sdk.get_sampling_rate.return_value = 0.0
+ mock_sdk.should_record_root_request.return_value.should_record = False
+ mock_sdk.should_record_root_request.return_value.reason = "not_sampled"
mock_drift.get_instance.return_value = mock_sdk
- mocker.patch("drift.core.sampling.should_sample", return_value=False)
-
result, reason = should_record_inbound_http_request(
method="GET",
target="/api/users",
@@ -377,12 +374,33 @@ def test_returns_false_when_not_sampled(self, mocker):
assert result is False
assert reason == "not_sampled"
+ def test_returns_controller_reason_when_adaptive_sampling_skips(self, mocker):
+ """Should preserve adaptive controller reasons for debug logging."""
+ mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift")
+ mock_sdk = mocker.MagicMock()
+ mock_sdk.should_record_root_request.return_value.should_record = False
+ mock_sdk.should_record_root_request.return_value.reason = "critical_pause"
+ mock_drift.get_instance.return_value = mock_sdk
+
+ result, reason = should_record_inbound_http_request(
+ method="GET",
+ target="/api/users",
+ headers={},
+ transform_engine=None,
+ is_pre_app_start=False,
+ )
+
+ assert result is False
+ assert reason == "critical_pause"
+
def test_drop_check_happens_before_sampling(self, mocker):
"""Should check drop rules before sampling."""
mock_transform = mocker.MagicMock()
mock_transform.should_drop_inbound_request.return_value = True
- mock_sample = mocker.patch("drift.core.sampling.should_sample")
+ mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift")
+ mock_sdk = mocker.MagicMock()
+ mock_drift.get_instance.return_value = mock_sdk
result, reason = should_record_inbound_http_request(
method="GET",
@@ -392,7 +410,6 @@ def test_drop_check_happens_before_sampling(self, mocker):
is_pre_app_start=False,
)
- # should_sample should never be called if dropped
- mock_sample.assert_not_called()
+ mock_sdk.should_record_root_request.assert_not_called()
assert result is False
assert reason == "dropped"
diff --git a/tests/unit/test_sampling.py b/tests/unit/test_sampling.py
index f6ee591..f166ec9 100644
--- a/tests/unit/test_sampling.py
+++ b/tests/unit/test_sampling.py
@@ -90,7 +90,7 @@ def test_rate_above_one_returns_none(self):
def test_custom_source_in_warning(self):
"""Should include custom source in warning message."""
# Just verify it doesn't raise with custom source
- result = validate_sampling_rate(-0.5, source="env var TUSK_SAMPLING_RATE")
+ result = validate_sampling_rate(-0.5, source="env var TUSK_RECORDING_SAMPLING_RATE")
assert result is None
def test_converts_to_float(self):
diff --git a/tests/unit/test_span_utils.py b/tests/unit/test_span_utils.py
index c23d0b2..1e1bf36 100644
--- a/tests/unit/test_span_utils.py
+++ b/tests/unit/test_span_utils.py
@@ -7,6 +7,7 @@
from opentelemetry.trace import SpanKind as OTelSpanKind
from opentelemetry.trace import Status, StatusCode
+from drift.core.no_recording import suppress_recording
from drift.core.tracing.span_utils import (
AddSpanAttributesOptions,
CreateSpanOptions,
@@ -205,6 +206,23 @@ def test_returns_none_on_exception(self, mocker):
assert result is None
+ def test_returns_none_when_recording_is_suppressed(self, mocker):
+ """Should not create spans when no-record context is active."""
+ mock_drift = mocker.patch("drift.core.drift_sdk.TuskDrift")
+ mock_sdk = mocker.MagicMock()
+ mock_sdk.get_tracer.return_value = mocker.MagicMock()
+ mock_drift.get_instance.return_value = mock_sdk
+
+ options = CreateSpanOptions(
+ name="test-span",
+ kind=OTelSpanKind.SERVER,
+ )
+
+ with suppress_recording():
+ result = SpanUtils.create_span(options)
+
+ assert result is None
+
class TestSpanUtilsWithSpan:
"""Tests for SpanUtils.with_span context manager."""
diff --git a/tests/unit/test_td_span_processor.py b/tests/unit/test_td_span_processor.py
index eae2f61..5a34c3c 100644
--- a/tests/unit/test_td_span_processor.py
+++ b/tests/unit/test_td_span_processor.py
@@ -23,8 +23,7 @@ def test_initializes_with_required_params(self, mocker):
assert processor._exporter is mock_exporter
assert processor._mode == TuskDriftMode.RECORD
- assert processor._sampling_rate == 1.0
- assert processor._app_ready is False
+ assert processor._environment is None
assert processor._started is False
def test_initializes_with_optional_params(self, mocker):
@@ -34,14 +33,10 @@ def test_initializes_with_optional_params(self, mocker):
processor = TdSpanProcessor(
exporter=mock_exporter,
mode=TuskDriftMode.REPLAY,
- sampling_rate=0.5,
- app_ready=True,
environment="production",
)
assert processor._mode == TuskDriftMode.REPLAY
- assert processor._sampling_rate == 0.5
- assert processor._app_ready is True
assert processor._environment == "production"
@@ -392,35 +387,3 @@ def test_force_flush_handles_exception(self, mocker):
result = processor.force_flush()
assert result is False
-
-
-class TestTdSpanProcessorUpdateMethods:
- """Tests for TdSpanProcessor update methods."""
-
- def test_update_app_ready(self, mocker):
- """Should update app_ready flag."""
- mock_exporter = mocker.MagicMock()
- processor = TdSpanProcessor(
- exporter=mock_exporter,
- mode=TuskDriftMode.RECORD,
- )
-
- assert processor._app_ready is False
-
- processor.update_app_ready(True)
-
- assert processor._app_ready is True
-
- def test_update_sampling_rate(self, mocker):
- """Should update sampling rate."""
- mock_exporter = mocker.MagicMock()
- processor = TdSpanProcessor(
- exporter=mock_exporter,
- mode=TuskDriftMode.RECORD,
- )
-
- assert processor._sampling_rate == 1.0
-
- processor.update_sampling_rate(0.5)
-
- assert processor._sampling_rate == 0.5
diff --git a/tests/unit/test_wsgi_handler.py b/tests/unit/test_wsgi_handler.py
new file mode 100644
index 0000000..0850852
--- /dev/null
+++ b/tests/unit/test_wsgi_handler.py
@@ -0,0 +1,83 @@
+"""Tests for WSGI handler request lifecycle behavior."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable, Iterator
+from typing import Any
+
+from drift.instrumentation.wsgi.handler import _create_and_handle_request
+
+
+class StreamingResponse(Iterable[bytes]):
+ def __init__(self, observed: list[tuple[str, bool]]) -> None:
+ self._observed = observed
+ self._yielded = False
+
+ def __iter__(self) -> Iterator[bytes]:
+ from drift.core.no_recording import is_recording_suppressed
+
+ self._observed.append(("iter", is_recording_suppressed()))
+ return self
+
+ def __next__(self) -> bytes:
+ from drift.core.no_recording import is_recording_suppressed
+
+ self._observed.append(("next", is_recording_suppressed()))
+ if self._yielded:
+ raise StopIteration
+ self._yielded = True
+ return b"chunk"
+
+ def close(self) -> None:
+ from drift.core.no_recording import is_recording_suppressed
+
+ self._observed.append(("close", is_recording_suppressed()))
+
+
+def test_skipped_wsgi_request_keeps_suppression_during_streaming_iteration_and_close(mocker) -> None:
+ observed: list[tuple[str, bool]] = []
+ response = StreamingResponse(observed)
+
+ mocker.patch(
+ "drift.instrumentation.wsgi.handler.should_record_inbound_http_request",
+ return_value=(False, "not_sampled"),
+ )
+
+ def original_wsgi_app(_app: Any, _environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]:
+ from drift.core.no_recording import is_recording_suppressed
+
+ observed.append(("call", is_recording_suppressed()))
+ return response
+
+ def app(_environ: dict[str, Any], _start_response: Any) -> Iterable[bytes]:
+ return response
+
+ wrapped_response = _create_and_handle_request(
+ app=app,
+ environ={
+ "REQUEST_METHOD": "GET",
+ "PATH_INFO": "/stream",
+ "QUERY_STRING": "",
+ },
+ start_response=lambda status, headers, exc_info=None: None,
+ original_wsgi_app=original_wsgi_app,
+ framework_name="wsgi",
+ instrumentation_name="WsgiInstrumentation",
+ transform_engine=None,
+ sdk=object(),
+ is_pre_app_start=False,
+ replay_token=None,
+ )
+
+ assert list(wrapped_response) == [b"chunk"]
+ close_method = getattr(wrapped_response, "close", None)
+ assert close_method is not None
+ close_method()
+
+ assert observed == [
+ ("call", True),
+ ("iter", True),
+ ("next", True),
+ ("next", True),
+ ("close", True),
+ ]