diff --git a/api/experimentation/constants.py b/api/experimentation/constants.py index 07b0eb643718..8622d72229d7 100644 --- a/api/experimentation/constants.py +++ b/api/experimentation/constants.py @@ -8,3 +8,11 @@ variant key.""" EXPOSURE_HOURLY_BUCKET_MAX_WINDOW = timedelta(hours=72) + +CONTROL_VARIANT_KEY = "control" + +# Below these per-variant floors a metric shows "collecting data" rather than +# inference; sample-ratio is only checked once there is enough traffic to judge. +RESULTS_MIN_IDENTITIES_PER_VARIANT = 50 +RESULTS_MIN_CONVERSIONS_PER_VARIANT = 5 +SRM_MIN_TOTAL_IDENTITIES = 100 diff --git a/api/experimentation/dataclasses.py b/api/experimentation/dataclasses.py index ebaa23932fb9..8b6fdc557400 100644 --- a/api/experimentation/dataclasses.py +++ b/api/experimentation/dataclasses.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from datetime import datetime +from experimentation.stats import Inference, VariantStats from experimentation.types import ExposureGranularity @@ -34,3 +35,35 @@ class ExposuresTimeseries: class ExposuresSummary: excluded_identities: int timeseries: ExposuresTimeseries + + +@dataclass(frozen=True) +class MetricSpec: + metric_id: int + event: str + aggregation: str + lower_is_better: bool + + +@dataclass(frozen=True) +class ResultsAggregates: + """Sufficient statistics gathered from the warehouse for one experiment: + the specs they were computed from, per-variant identity counts, and per + metric the per-variant ``VariantStats``. Bundled so the keys can't drift.""" + + specs: list[MetricSpec] + exposure_counts: dict[str, int] + metric_stats: dict[int, dict[str, VariantStats]] + + +@dataclass(frozen=True) +class MetricResult: + metric_id: int + variants: dict[str, VariantStats] + inference: dict[str, Inference | None] + + +@dataclass(frozen=True) +class ResultsSummary: + srm_p_value: float | None + metrics: list[MetricResult] diff --git a/api/experimentation/services.py b/api/experimentation/services.py index c9a98572b99d..d91bdbcf877e 100644 --- a/api/experimentation/services.py +++ b/api/experimentation/services.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from dataclasses import replace from functools import lru_cache import structlog @@ -12,9 +13,13 @@ from audit.models import AuditLog from audit.related_object_type import RelatedObjectType from experimentation.constants import ( + CONTROL_VARIANT_KEY, EXPERIMENT_FLAG, EXPOSURE_EVENT_NAME, EXPOSURE_HOURLY_BUCKET_MAX_WINDOW, + RESULTS_MIN_CONVERSIONS_PER_VARIANT, + RESULTS_MIN_IDENTITIES_PER_VARIANT, + SRM_MIN_TOTAL_IDENTITIES, WAREHOUSE_CONNECTION_FLAG, ) from experimentation.dataclasses import ( @@ -22,14 +27,25 @@ ExposuresSummary, ExposuresTimeseries, ExposuresTimeseriesPoint, + MetricResult, + MetricSpec, + ResultsAggregates, + ResultsSummary, WarehouseEventStats, ) from experimentation.models import ( VALID_STATUS_TRANSITIONS, ExperimentStatus, + MetricAggregation, WarehouseConnectionStatus, WarehouseType, ) +from experimentation.stats import ( + Inference, + VariantStats, + compare_to_control, + srm_p_value, +) from integrations.flagsmith.client import get_openfeature_client if typing.TYPE_CHECKING: @@ -227,6 +243,198 @@ def get_exposure_buckets( ] +# First-exposed identities per variant (same dedup/quarantine as the exposures +# query), each carrying the per-metric value used to derive sufficient +# statistics. Metric events are attributed from first exposure onwards. +_RESULTS_EXPOSURES_CTE = """ +WITH exposures AS ( + SELECT + identifier, + if(uniqExact(value) > 1, '', any(value)) AS variant, + uniqExact(value) > 1 AS quarantined, + min(timestamp) AS first_exposure + FROM events + WHERE environment_key = %(environment_key)s + AND event = %(exposure_event)s + AND feature_name = %(feature_name)s + AND timestamp >= %(window_start)s + AND timestamp < %(window_end)s + GROUP BY identifier +)""" + + +def _metric_value_expression(index: int, aggregation: str) -> str: + # Post-exposure attribution lives here, not in the JOIN ON: ClickHouse + # rejects an ON clause mixing left and right columns in an inequality. + condition = ( + f"m.event = %(metric_{index}_event)s AND m.timestamp >= e.first_exposure" + ) + value = "toFloat64OrZero(m.value)" + if aggregation == MetricAggregation.OCCURRENCE: + return f"countIf({condition}) > 0" + if aggregation == MetricAggregation.COUNT: + return f"countIf({condition})" + if aggregation == MetricAggregation.SUM: + return f"sumIf({value}, {condition})" + if aggregation == MetricAggregation.MEAN: + # per-identity average, zero when the identity has no matching events + return f"if(countIf({condition}) > 0, avgIf({value}, {condition}), 0)" + raise ValueError(f"Unsupported metric aggregation: {aggregation}") + + +def _build_results_query(specs: Sequence[MetricSpec]) -> str: + if not specs: + return ( + _RESULTS_EXPOSURES_CTE + + "\nSELECT variant, count() AS n" + + "\nFROM exposures\nWHERE quarantined = 0\nGROUP BY variant" + ) + value_expressions = ",\n ".join( + f"{_metric_value_expression(i, spec.aggregation)} AS m{i}" + for i, spec in enumerate(specs) + ) + outer_aggregates = ",\n ".join( + f"sum(m{i}) AS m{i}_sum, sum(m{i} * m{i}) AS m{i}_sum_squares" + for i in range(len(specs)) + ) + return ( + _RESULTS_EXPOSURES_CTE + + f""", +unit_values AS ( + SELECT + e.variant AS variant, + {value_expressions} + FROM exposures AS e + LEFT JOIN events AS m + ON m.identifier = e.identifier + AND m.environment_key = %(environment_key)s + AND m.event IN %(metric_events)s + AND m.timestamp < %(window_end)s + WHERE e.quarantined = 0 + GROUP BY e.identifier, e.variant +) +SELECT variant, count() AS n, + {outer_aggregates} +FROM unit_values +GROUP BY variant""" + ) + + +def get_metric_variant_stats( + *, + environment_key: str, + feature_name: str, + window_start: datetime, + window_end: datetime, + specs: Sequence[MetricSpec], +) -> ResultsAggregates: + """Run the warehouse query, returning per-variant identity counts and, per + metric, per-variant sufficient statistics.""" + params: dict[str, object] = { + "environment_key": environment_key, + "exposure_event": EXPOSURE_EVENT_NAME, + "feature_name": feature_name, + "window_start": window_start, + "window_end": window_end, + } + if specs: + params["metric_events"] = [spec.event for spec in specs] + for index, spec in enumerate(specs): + params[f"metric_{index}_event"] = spec.event + + rows = _get_clickhouse_client().execute(_build_results_query(specs), params) + + exposure_counts: dict[str, int] = {} + metric_stats: dict[int, dict[str, VariantStats]] = { + spec.metric_id: {} for spec in specs + } + for row in rows: + # Columns are emitted in query order: variant, n, then a (sum, + # sum_squares) pair per spec — consumed positionally in that order. + columns = iter(row) + variant = next(columns) + n = int(next(columns)) + exposure_counts[variant] = n + for spec in specs: + metric_stats[spec.metric_id][variant] = VariantStats( + n=n, + sum=float(next(columns)), + sum_squares=float(next(columns)), + ) + return ResultsAggregates( + specs=list(specs), + exposure_counts=exposure_counts, + metric_stats=metric_stats, + ) + + +def build_results_summary( + aggregates: ResultsAggregates, + *, + expected_shares: dict[str, float], +) -> ResultsSummary: + exposure_counts = aggregates.exposure_counts + total = sum(exposure_counts.values()) + if expected_shares and total >= SRM_MIN_TOTAL_IDENTITIES: + srm = srm_p_value( + [exposure_counts.get(variant, 0) for variant in expected_shares], + list(expected_shares.values()), + ) + else: + srm = None + return ResultsSummary( + srm_p_value=srm, + metrics=[ + MetricResult( + metric_id=spec.metric_id, + variants=aggregates.metric_stats.get(spec.metric_id, {}), + inference=_metric_inference( + spec, aggregates.metric_stats.get(spec.metric_id, {}) + ), + ) + for spec in aggregates.specs + ], + ) + + +def _metric_inference( + spec: MetricSpec, + variants: dict[str, VariantStats], +) -> dict[str, Inference | None]: + control = variants.get(CONTROL_VARIANT_KEY) + return { + variant_key: _infer_treatment(spec, control, treatment) + for variant_key, treatment in variants.items() + if variant_key != CONTROL_VARIANT_KEY + } + + +def _infer_treatment( + spec: MetricSpec, + control: VariantStats | None, + treatment: VariantStats, +) -> Inference | None: + # Product floor for showing a result at all; compare_to_control applies its + # own independent guards (e.g. zero control mean) on top of this. + if ( + control is None + or control.n < RESULTS_MIN_IDENTITIES_PER_VARIANT + or treatment.n < RESULTS_MIN_IDENTITIES_PER_VARIANT + ): + return None + if spec.aggregation == MetricAggregation.OCCURRENCE and ( + control.sum < RESULTS_MIN_CONVERSIONS_PER_VARIANT + or treatment.sum < RESULTS_MIN_CONVERSIONS_PER_VARIANT + ): + return None + inference = compare_to_control(control, treatment) + if inference is not None and spec.lower_is_better: + # "Winning" means moving the metric the good way; for a lower-is-better + # metric that's a fall, so the chance of winning is the chance lift < 0. + inference = replace(inference, chance_to_win=1.0 - inference.chance_to_win) + return inference + + def _resolve_audit_log_author( user: FFAdminUser, ) -> dict[str, int | None]: diff --git a/api/tests/unit/experimentation/test_services.py b/api/tests/unit/experimentation/test_services.py index 88c673b55246..163a18e23e80 100644 --- a/api/tests/unit/experimentation/test_services.py +++ b/api/tests/unit/experimentation/test_services.py @@ -1,3 +1,4 @@ +from dataclasses import asdict from datetime import datetime, timezone import pytest @@ -12,13 +13,17 @@ ExposuresSummary, ExposuresTimeseries, ExposuresTimeseriesPoint, + MetricSpec, + ResultsAggregates, WarehouseEventStats, ) from experimentation.models import ( + MetricAggregation, WarehouseConnection, WarehouseConnectionStatus, WarehouseType, ) +from experimentation.stats import VariantStats def test_get_clickhouse_client__configured_url__builds_client_with_timeouts( @@ -538,3 +543,358 @@ def test_refresh_warehouse_connection_status__already_connected__is_noop( # Then assert result.status == WarehouseConnectionStatus.CONNECTED assert log.events == [] + + +def _spec( + metric_id: int = 7, + event: str = "purchase", + aggregation: str = MetricAggregation.OCCURRENCE, + lower_is_better: bool = False, +) -> MetricSpec: + return MetricSpec( + metric_id=metric_id, + event=event, + aggregation=aggregation, + lower_is_better=lower_is_better, + ) + + +def _aggregates( + specs: list[MetricSpec], + exposure_counts: dict[str, int], + metric_stats: dict[int, dict[str, VariantStats]], +) -> ResultsAggregates: + return ResultsAggregates( + specs=specs, + exposure_counts=exposure_counts, + metric_stats=metric_stats, + ) + + +def test_get_metric_variant_stats__metrics__queries_and_maps_rows( + mocker: MockerFixture, +) -> None: + # Given the warehouse returns per-variant counts and two metrics' sums + rows = [ + ("control", 1000, 100.0, 100.0, 5000.0, 30000.0), + ("variant_a", 1000, 120.0, 120.0, 5200.0, 31000.0), + ] + mock_client = mocker.Mock() + mock_client.execute.return_value = rows + mocker.patch( + "experimentation.services._get_clickhouse_client", + return_value=mock_client, + ) + specs = [ + _spec(metric_id=7, event="purchase", aggregation=MetricAggregation.OCCURRENCE), + _spec(metric_id=9, event="revenue", aggregation=MetricAggregation.SUM), + ] + window_start = datetime(2026, 6, 1, tzinfo=timezone.utc) + window_end = datetime(2026, 6, 10, tzinfo=timezone.utc) + + # When + aggregates = services.get_metric_variant_stats( + environment_key="env-key-123", + feature_name="my-feature", + window_start=window_start, + window_end=window_end, + specs=specs, + ) + + # Then per-variant counts and sufficient statistics are mapped per metric + assert aggregates.exposure_counts == {"control": 1000, "variant_a": 1000} + assert aggregates.metric_stats[7]["variant_a"] == VariantStats( + n=1000, sum=120.0, sum_squares=120.0 + ) + assert aggregates.metric_stats[9]["control"] == VariantStats( + n=1000, sum=5000.0, sum_squares=30000.0 + ) + # And the query joins post-exposure metric events and excludes quarantined + sql, params = mock_client.execute.call_args.args + assert "LEFT JOIN events AS m" in sql + assert "m.timestamp >= e.first_exposure" in sql + assert "timestamp < %(window_end)s" in sql + assert "WHERE e.quarantined = 0" in sql + assert ( + "countIf(m.event = %(metric_0_event)s AND m.timestamp >= e.first_exposure)" + " > 0 AS m0" in sql + ) + assert ( + "sumIf(toFloat64OrZero(m.value), m.event = %(metric_1_event)s" + " AND m.timestamp >= e.first_exposure) AS m1" in sql + ) + assert "sum(m0) AS m0_sum, sum(m0 * m0) AS m0_sum_squares" in sql + assert params["metric_events"] == ["purchase", "revenue"] + assert params["metric_0_event"] == "purchase" + assert params["metric_1_event"] == "revenue" + assert params["window_end"] == window_end + + +def test_get_metric_variant_stats__no_metrics__counts_variants_only( + mocker: MockerFixture, +) -> None: + # Given an experiment with no attached metrics + mock_client = mocker.Mock() + mock_client.execute.return_value = [("control", 1000), ("variant_a", 900)] + mocker.patch( + "experimentation.services._get_clickhouse_client", + return_value=mock_client, + ) + + # When + aggregates = services.get_metric_variant_stats( + environment_key="env-key-123", + feature_name="my-feature", + window_start=datetime(2026, 6, 1, tzinfo=timezone.utc), + window_end=datetime(2026, 6, 10, tzinfo=timezone.utc), + specs=[], + ) + + # Then only the per-variant counts are returned, with no metric join + assert aggregates.exposure_counts == {"control": 1000, "variant_a": 900} + assert aggregates.metric_stats == {} + sql, params = mock_client.execute.call_args.args + assert "SELECT variant, count() AS n" in sql + assert "LEFT JOIN" not in sql + assert "metric_events" not in params + + +@pytest.mark.parametrize( + "aggregation, expected", + [ + ( + MetricAggregation.OCCURRENCE, + "countIf(m.event = %(metric_0_event)s AND m.timestamp >= e.first_exposure) > 0", + ), + ( + MetricAggregation.COUNT, + "countIf(m.event = %(metric_0_event)s AND m.timestamp >= e.first_exposure)", + ), + ( + MetricAggregation.SUM, + "sumIf(toFloat64OrZero(m.value), m.event = %(metric_0_event)s AND m.timestamp >= e.first_exposure)", + ), + ( + MetricAggregation.MEAN, + "if(countIf(m.event = %(metric_0_event)s AND m.timestamp >= e.first_exposure) > 0, " + "avgIf(toFloat64OrZero(m.value), m.event = %(metric_0_event)s AND m.timestamp >= e.first_exposure), 0)", + ), + ], + ids=["occurrence", "count", "sum", "mean"], +) +def test_metric_value_expression__aggregation__builds_clause( + aggregation: str, + expected: str, +) -> None: + # Given an aggregation + # When / Then it maps to its per-identity unit-value expression + assert services._metric_value_expression(0, aggregation) == expected + + +def test_metric_value_expression__unknown_aggregation__raises() -> None: + # Given an aggregation the query builder does not support + # When / Then it refuses rather than silently emitting the wrong clause + with pytest.raises(ValueError, match="Unsupported metric aggregation"): + services._metric_value_expression(0, "median") + + +def test_build_results_summary__healthy_arms__infers_each_treatment() -> None: + # Given a 10% control and a 12% treatment, both well above the floor + control = VariantStats(n=1000, sum=100.0, sum_squares=100.0) + treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0) + aggregates = _aggregates( + [_spec(metric_id=7)], + {"control": 1000, "variant_a": 1000}, + {7: {"control": control, "variant_a": treatment}}, + ) + + # When + summary = services.build_results_summary( + aggregates, + expected_shares={"control": 0.5, "variant_a": 0.5}, + ) + + # Then the treatment is compared to control and the raw stats are kept + assert summary.metrics[0].variants == { + "control": control, + "variant_a": treatment, + } + inference = summary.metrics[0].inference["variant_a"] + assert inference is not None + assert inference.lift == pytest.approx(0.2) + assert inference.chance_to_win == pytest.approx(0.90379, abs=1e-4) + + +def test_build_results_summary__below_identity_floor__inference_none() -> None: + # Given arms below the minimum identities per variant + arm = VariantStats(n=40, sum=4.0, sum_squares=4.0) + aggregates = _aggregates( + [_spec(metric_id=7)], + {"control": 40, "variant_a": 40}, + {7: {"control": arm, "variant_a": arm}}, + ) + + # When + summary = services.build_results_summary(aggregates, expected_shares={}) + + # Then inference is withheld + assert summary.metrics[0].inference["variant_a"] is None + + +def test_build_results_summary__occurrence_below_conversion_floor__inference_none() -> ( + None +): + # Given enough identities but too few conversions on an occurrence metric + control = VariantStats(n=100, sum=10.0, sum_squares=10.0) + treatment = VariantStats(n=100, sum=3.0, sum_squares=3.0) + aggregates = _aggregates( + [_spec(metric_id=7, aggregation=MetricAggregation.OCCURRENCE)], + {"control": 100, "variant_a": 100}, + {7: {"control": control, "variant_a": treatment}}, + ) + + # When + summary = services.build_results_summary(aggregates, expected_shares={}) + + # Then inference is withheld + assert summary.metrics[0].inference["variant_a"] is None + + +def test_build_results_summary__lower_is_better__flips_chance_to_win() -> None: + # Given a value metric where a fall is the win + control = VariantStats(n=1000, sum=100.0, sum_squares=100.0) + treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0) + aggregates = _aggregates( + [_spec(metric_id=7, aggregation=MetricAggregation.SUM, lower_is_better=True)], + {"control": 1000, "variant_a": 1000}, + {7: {"control": control, "variant_a": treatment}}, + ) + + # When + summary = services.build_results_summary(aggregates, expected_shares={}) + + # Then the rise counts against the treatment + inference = summary.metrics[0].inference["variant_a"] + assert inference is not None + assert inference.lift == pytest.approx(0.2) + assert inference.chance_to_win == pytest.approx(1 - 0.90379, abs=1e-4) + + +def test_build_results_summary__zero_control_mean__inference_none() -> None: + # Given a control with no value: the relative lift is undefined + control = VariantStats(n=100, sum=0.0, sum_squares=0.0) + treatment = VariantStats(n=100, sum=50.0, sum_squares=50.0) + aggregates = _aggregates( + [_spec(metric_id=7, aggregation=MetricAggregation.COUNT)], + {"control": 100, "variant_a": 100}, + {7: {"control": control, "variant_a": treatment}}, + ) + + # When + summary = services.build_results_summary(aggregates, expected_shares={}) + + # Then inference is withheld by the kernel's own guard + assert summary.metrics[0].inference["variant_a"] is None + + +def test_build_results_summary__no_control_variant__inference_none() -> None: + # Given a metric with stats for a treatment but no control + treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0) + aggregates = _aggregates( + [_spec(metric_id=7)], + {"variant_a": 1000}, + {7: {"variant_a": treatment}}, + ) + + # When + summary = services.build_results_summary(aggregates, expected_shares={}) + + # Then inference is withheld + assert summary.metrics[0].inference["variant_a"] is None + + +def test_build_results_summary__balanced_traffic__srm_reports_no_mismatch() -> None: + # Given a balanced split above the SRM gate + aggregates = _aggregates([], {"control": 5000, "variant_a": 5000}, {}) + + # When + summary = services.build_results_summary( + aggregates, + expected_shares={"control": 0.5, "variant_a": 0.5}, + ) + + # Then + assert summary.srm_p_value == pytest.approx(1.0) + assert summary.metrics == [] + + +def test_build_results_summary__imbalanced_traffic__srm_below_threshold() -> None: + # Given a 60/40 split against an expected 50/50 + aggregates = _aggregates([], {"control": 6000, "variant_a": 4000}, {}) + + # When + summary = services.build_results_summary( + aggregates, + expected_shares={"control": 0.5, "variant_a": 0.5}, + ) + + # Then the mismatch is flagged + assert summary.srm_p_value is not None + assert summary.srm_p_value < 0.001 + + +@pytest.mark.parametrize( + "exposure_counts, expected_shares", + [ + ({"control": 40, "variant_a": 40}, {"control": 0.5, "variant_a": 0.5}), + ({"control": 5000, "variant_a": 5000}, {}), + ], + ids=["below_gate", "no_expected_shares"], +) +def test_build_results_summary__srm_not_computable__srm_none( + exposure_counts: dict[str, int], + expected_shares: dict[str, float], +) -> None: + # Given too little traffic, or no configured split to compare against + aggregates = _aggregates([], exposure_counts, {}) + + # When + summary = services.build_results_summary( + aggregates, expected_shares=expected_shares + ) + + # Then SRM is not reported + assert summary.srm_p_value is None + + +def test_build_results_summary__computed__serialises_to_wire_shape() -> None: + # Given a computed summary + control = VariantStats(n=1000, sum=100.0, sum_squares=100.0) + treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0) + aggregates = _aggregates( + [_spec(metric_id=7)], + {"control": 1000, "variant_a": 1000}, + {7: {"control": control, "variant_a": treatment}}, + ) + summary = services.build_results_summary( + aggregates, + expected_shares={"control": 0.5, "variant_a": 0.5}, + ) + + # When + payload = asdict(summary) + + # Then the payload nests raw stats and per-treatment inference + assert payload["srm_p_value"] == pytest.approx(1.0) + assert payload["metrics"][0]["metric_id"] == 7 + assert payload["metrics"][0]["variants"]["control"] == { + "n": 1000, + "sum": 100.0, + "sum_squares": 100.0, + } + assert set(payload["metrics"][0]["inference"]["variant_a"]) == { + "lift", + "ci_low", + "ci_high", + "chance_to_win", + }