From 73cabb3338685e4cd431c31dad8a46255d9319fb Mon Sep 17 00:00:00 2001 From: Hopelynconsult Date: Thu, 7 May 2026 19:21:27 +0300 Subject: [PATCH 1/2] feat(governance): add distributional drift detection (PSI, KS) Complement to the per-point anomaly detector (#35): the anomaly detector flags individual predictions whose features fall outside historical norms; this module compares the *distribution* of recent predictions (or inputs) against a reference baseline and flags drift even when no single prediction is anomalous. Two non-parametric tests: - Population Stability Index over reference quantile bins. PSI < 0.1 stable, 0.1-0.25 moderate, > 0.25 severe (industry-standard rule of thumb). - Two-sample Kolmogorov-Smirnov, with the asymptotic p-value computed from the standard Kolmogorov series so we don't pull in scipy at evaluation time. Both run per-feature; a DriftReport aggregates per-feature DriftResults so callers (CI gate, monitoring dashboards) decide their own aggregation policy. Designed to plug into the prediction-history JSONL emitted by the anomaly detector so drift can run as a scheduled CI step over the last N days of production predictions. - DriftResult / DriftReport dataclasses with JSON serialisation - detect_drift() one-shot entrypoint covering both methods - write_drift_report() for persistence alongside model cards - 13 tests covering identical/shifted distributions, both methods, per-feature severity, edge cases (constant reference, non-finite, empty windows), feature mismatch validation, and JSON round-trip --- src/climatevision/governance/__init__.py | 16 ++ .../governance/drift_detector.py | 249 ++++++++++++++++++ tests/test_drift_detector.py | 131 +++++++++ 3 files changed, 396 insertions(+) create mode 100644 src/climatevision/governance/drift_detector.py create mode 100644 tests/test_drift_detector.py diff --git a/src/climatevision/governance/__init__.py b/src/climatevision/governance/__init__.py index 0f6cc09..26f628e 100644 --- a/src/climatevision/governance/__init__.py +++ b/src/climatevision/governance/__init__.py @@ -5,6 +5,7 @@ - SHAP-based explainability for segmentation predictions - Regional bias and fairness auditing - Anomaly detection for inference inputs/outputs +- Distributional drift detection (PSI, KS) over prediction windows - Model audit trails and version tracking """ @@ -42,6 +43,14 @@ check_fairness_gate, SUPPORTED_REGIONS, ) +from .drift_detector import ( + DriftReport, + DriftResult, + detect_drift, + kolmogorov_smirnov, + population_stability_index, + write_drift_report, +) __all__ = [ # Explainability @@ -73,4 +82,11 @@ "RegionMetrics", "check_fairness_gate", "SUPPORTED_REGIONS", + # Drift detection + "DriftReport", + "DriftResult", + "detect_drift", + "kolmogorov_smirnov", + "population_stability_index", + "write_drift_report", ] diff --git a/src/climatevision/governance/drift_detector.py b/src/climatevision/governance/drift_detector.py new file mode 100644 index 0000000..b936ef7 --- /dev/null +++ b/src/climatevision/governance/drift_detector.py @@ -0,0 +1,249 @@ +""" +Distributional drift detection for ClimateVision inputs and predictions. + +The existing anomaly detector (``governance.anomaly_detector``) flags +*individual* predictions whose features fall outside historical norms. +This module is its complement: it compares the *distribution* of recent +predictions (or inputs) against a reference baseline and flags drift +even when no single prediction is anomalous. + +Two well-understood non-parametric tests are exposed: + +- **Population Stability Index (PSI)** — bins both windows on the + reference's quantiles and sums (p_i - q_i) * log(p_i / q_i). The + industry-standard rule of thumb: PSI < 0.1 stable, 0.1-0.25 moderate + drift, > 0.25 significant drift. +- **Kolmogorov-Smirnov (KS)** — supremum of the gap between the two + empirical CDFs, with a two-sample asymptotic p-value. + +Both run on a single feature at a time. Multi-feature drift is reported +as a list of per-feature ``DriftResult`` objects so callers can decide +how to aggregate (any-feature-drifts vs. average) without baking that +policy into the detector. + +Designed to plug into the prediction-history JSONL written by the +anomaly detector, so drift checks can run as a CI step over the last +N days of production predictions. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import List, Optional, Sequence, Union + +import numpy as np + +logger = logging.getLogger(__name__) + +DEFAULT_PSI_BINS = 10 +PSI_STABLE = 0.10 +PSI_MODERATE = 0.25 +DEFAULT_KS_SIGNIFICANCE = 0.05 + + +@dataclass +class DriftResult: + """Per-feature drift assessment.""" + + feature: str + method: str + statistic: float + threshold: float + drifted: bool + severity: str # "stable", "moderate", "severe" + p_value: Optional[float] = None + n_reference: int = 0 + n_current: int = 0 + + +@dataclass +class DriftReport: + """Multi-feature drift report covering one window comparison.""" + + reference_window: str + current_window: str + method: str + results: List[DriftResult] = field(default_factory=list) + + @property + def any_drifted(self) -> bool: + return any(r.drifted for r in self.results) + + @property + def severe_features(self) -> List[str]: + return [r.feature for r in self.results if r.severity == "severe"] + + def to_dict(self) -> dict: + return { + "reference_window": self.reference_window, + "current_window": self.current_window, + "method": self.method, + "any_drifted": self.any_drifted, + "severe_features": self.severe_features, + "results": [asdict(r) for r in self.results], + } + + +def _as_array(values: Sequence[float], name: str) -> np.ndarray: + arr = np.asarray(values, dtype=np.float64).ravel() + if arr.size == 0: + raise ValueError(f"{name} window is empty") + if not np.all(np.isfinite(arr)): + raise ValueError(f"{name} window contains non-finite values") + return arr + + +def population_stability_index( + reference: Sequence[float], + current: Sequence[float], + n_bins: int = DEFAULT_PSI_BINS, +) -> float: + """Compute PSI between a reference and a current sample. + + Bins are derived from quantiles of the reference distribution so the + reference always has roughly equal mass per bin, which is the canonical + PSI definition. Empty bins are floored to a small epsilon to keep the + log finite. + """ + ref = _as_array(reference, "reference") + cur = _as_array(current, "current") + + quantiles = np.linspace(0.0, 1.0, n_bins + 1) + edges = np.unique(np.quantile(ref, quantiles)) + if edges.size < 2: + # Reference is a constant — fall back to a single bin. + edges = np.array([ref.min() - 1e-6, ref.max() + 1e-6]) + + edges[0] = -np.inf + edges[-1] = np.inf + + ref_counts, _ = np.histogram(ref, bins=edges) + cur_counts, _ = np.histogram(cur, bins=edges) + + ref_pct = np.clip(ref_counts / ref_counts.sum(), 1e-6, None) + cur_pct = np.clip(cur_counts / cur_counts.sum(), 1e-6, None) + + return float(np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))) + + +def _psi_severity(psi: float) -> str: + if psi < PSI_STABLE: + return "stable" + if psi < PSI_MODERATE: + return "moderate" + return "severe" + + +def kolmogorov_smirnov( + reference: Sequence[float], + current: Sequence[float], +) -> tuple[float, float]: + """Two-sample KS statistic + asymptotic p-value. + + Avoids importing scipy by computing the supremum gap of the empirical + CDFs directly and using the standard Kolmogorov asymptotic series for + the p-value. + """ + ref = _as_array(reference, "reference") + cur = _as_array(current, "current") + + combined = np.sort(np.concatenate([ref, cur])) + cdf_ref = np.searchsorted(np.sort(ref), combined, side="right") / ref.size + cdf_cur = np.searchsorted(np.sort(cur), combined, side="right") / cur.size + statistic = float(np.max(np.abs(cdf_ref - cdf_cur))) + + n = ref.size + m = cur.size + en = float(np.sqrt(n * m / (n + m))) + lam = (en + 0.12 + 0.11 / en) * statistic + + # Kolmogorov asymptotic distribution: P(K > lam) summed series. + p = 0.0 + for k in range(1, 101): + term = 2 * (-1) ** (k - 1) * np.exp(-2 * (lam ** 2) * (k ** 2)) + p += term + if abs(term) < 1e-12: + break + p_value = float(min(max(p, 0.0), 1.0)) + return statistic, p_value + + +def detect_drift( + reference: dict[str, Sequence[float]], + current: dict[str, Sequence[float]], + *, + method: str = "psi", + reference_window: str = "baseline", + current_window: str = "current", + psi_bins: int = DEFAULT_PSI_BINS, + ks_significance: float = DEFAULT_KS_SIGNIFICANCE, +) -> DriftReport: + """Per-feature drift assessment over two windows. + + ``reference`` and ``current`` are dicts mapping feature name to a 1D + sample of values from the respective window. Features must match. + """ + if method not in {"psi", "ks"}: + raise ValueError(f"unknown method: {method!r}; expected 'psi' or 'ks'") + + missing = set(reference.keys()) ^ set(current.keys()) + if missing: + raise ValueError(f"feature mismatch between windows: {missing}") + + results: List[DriftResult] = [] + for feature in reference.keys(): + ref_values = reference[feature] + cur_values = current[feature] + if method == "psi": + psi = population_stability_index(ref_values, cur_values, n_bins=psi_bins) + severity = _psi_severity(psi) + results.append( + DriftResult( + feature=feature, + method="psi", + statistic=psi, + threshold=PSI_MODERATE, + drifted=psi >= PSI_STABLE, + severity=severity, + n_reference=len(ref_values), + n_current=len(cur_values), + ) + ) + else: + statistic, p_value = kolmogorov_smirnov(ref_values, cur_values) + severe = p_value < (ks_significance / 5) + severity = "severe" if severe else "moderate" if p_value < ks_significance else "stable" + results.append( + DriftResult( + feature=feature, + method="ks", + statistic=statistic, + threshold=ks_significance, + drifted=p_value < ks_significance, + severity=severity, + p_value=p_value, + n_reference=len(ref_values), + n_current=len(cur_values), + ) + ) + + return DriftReport( + reference_window=reference_window, + current_window=current_window, + method=method, + results=results, + ) + + +def write_drift_report( + report: DriftReport, path: Union[str, Path] +) -> Path: + """Persist a DriftReport to disk as JSON.""" + out = Path(path) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(report.to_dict(), indent=2)) + logger.info("Wrote drift report to %s", out) + return out diff --git a/tests/test_drift_detector.py b/tests/test_drift_detector.py new file mode 100644 index 0000000..ec99011 --- /dev/null +++ b/tests/test_drift_detector.py @@ -0,0 +1,131 @@ +"""Tests for governance.drift_detector.""" + +from __future__ import annotations + +import json + +import numpy as np +import pytest + +from climatevision.governance.drift_detector import ( + DriftReport, + DriftResult, + detect_drift, + kolmogorov_smirnov, + population_stability_index, + write_drift_report, +) + + +def _normal(mean: float, std: float, n: int = 5_000, seed: int = 0): + rng = np.random.default_rng(seed) + return rng.normal(mean, std, size=n) + + +def test_psi_zero_for_identical_samples(): + base = _normal(0, 1, seed=0) + same = _normal(0, 1, seed=1) + psi = population_stability_index(base, same) + assert psi < 0.05 + + +def test_psi_flags_shifted_distribution(): + base = _normal(0, 1, seed=0) + shifted = _normal(2, 1, seed=1) + psi = population_stability_index(base, shifted) + assert psi > 0.25 + + +def test_psi_handles_constant_reference(): + base = np.zeros(1000) + cur = np.ones(1000) + psi = population_stability_index(base, cur) + assert psi >= 0.0 + assert np.isfinite(psi) + + +def test_ks_pvalue_high_for_identical_distribution(): + base = _normal(0, 1, n=2000, seed=0) + same = _normal(0, 1, n=2000, seed=1) + _, p = kolmogorov_smirnov(base, same) + assert p > 0.05 + + +def test_ks_pvalue_low_for_shifted_distribution(): + base = _normal(0, 1, n=2000, seed=0) + shifted = _normal(1, 1, n=2000, seed=1) + statistic, p = kolmogorov_smirnov(base, shifted) + assert statistic > 0.1 + assert p < 0.01 + + +def test_detect_drift_psi_returns_per_feature_results(): + ref = { + "mean_confidence": _normal(0.5, 0.1, seed=0), + "positive_fraction": _normal(0.2, 0.05, seed=1), + } + cur = { + "mean_confidence": _normal(0.5, 0.1, seed=2), + "positive_fraction": _normal(0.4, 0.05, seed=3), + } + report = detect_drift(ref, cur, method="psi") + assert isinstance(report, DriftReport) + assert len(report.results) == 2 + by_feature = {r.feature: r for r in report.results} + assert by_feature["mean_confidence"].severity == "stable" + assert by_feature["positive_fraction"].severity in {"moderate", "severe"} + assert report.any_drifted + + +def test_detect_drift_ks_method(): + ref = {"x": _normal(0, 1, seed=0)} + cur = {"x": _normal(2, 1, seed=1)} + report = detect_drift(ref, cur, method="ks") + assert report.method == "ks" + assert report.results[0].drifted is True + assert report.results[0].p_value is not None + assert report.results[0].p_value < 0.05 + + +def test_detect_drift_rejects_unknown_method(): + with pytest.raises(ValueError, match="unknown method"): + detect_drift({"x": [1.0]}, {"x": [1.0]}, method="bogus") + + +def test_detect_drift_rejects_feature_mismatch(): + with pytest.raises(ValueError, match="feature mismatch"): + detect_drift({"a": [1.0, 2.0]}, {"b": [1.0, 2.0]}) + + +def test_severe_features_isolated(): + ref = { + "stable_feat": _normal(0, 1, seed=0), + "drift_feat": _normal(0, 1, seed=1), + } + cur = { + "stable_feat": _normal(0, 1, seed=2), + "drift_feat": _normal(5, 1, seed=3), + } + report = detect_drift(ref, cur, method="psi") + assert report.severe_features == ["drift_feat"] + + +def test_validation_rejects_non_finite(): + with pytest.raises(ValueError, match="non-finite"): + population_stability_index([np.nan, 1.0], [1.0, 2.0]) + + +def test_validation_rejects_empty_window(): + with pytest.raises(ValueError, match="empty"): + population_stability_index([], [1.0, 2.0]) + + +def test_write_drift_report_round_trips_json(tmp_path): + ref = {"x": _normal(0, 1, seed=0)} + cur = {"x": _normal(0, 1, seed=1)} + report = detect_drift(ref, cur, method="psi") + out = write_drift_report(report, tmp_path / "drift.json") + loaded = json.loads(out.read_text()) + assert loaded["method"] == "psi" + assert "any_drifted" in loaded + assert len(loaded["results"]) == 1 From 4fdb979b68ddaa2c630fa01ce04bccbe0e0b6f28 Mon Sep 17 00:00:00 2001 From: Linda Oraegbunam <108290852+obielin@users.noreply.github.com> Date: Thu, 7 May 2026 23:35:32 +0300 Subject: [PATCH 2/2] fix(governance): align PSI drifted flag with significant-drift threshold --- src/climatevision/governance/drift_detector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/climatevision/governance/drift_detector.py b/src/climatevision/governance/drift_detector.py index b936ef7..a89d34b 100644 --- a/src/climatevision/governance/drift_detector.py +++ b/src/climatevision/governance/drift_detector.py @@ -206,7 +206,7 @@ def detect_drift( method="psi", statistic=psi, threshold=PSI_MODERATE, - drifted=psi >= PSI_STABLE, + drifted=psi >= PSI_MODERATE, severity=severity, n_reference=len(ref_values), n_current=len(cur_values),