Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/experimentation/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@
variant key."""

EXPOSURE_HOURLY_BUCKET_MAX_WINDOW = timedelta(hours=72)

CONTROL_VARIANT_KEY = "control"

# Below these per-variant floors a metric shows "collecting data" rather than
# inference; sample-ratio is only checked once there is enough traffic to judge.
RESULTS_MIN_IDENTITIES_PER_VARIANT = 50
RESULTS_MIN_CONVERSIONS_PER_VARIANT = 5
SRM_MIN_TOTAL_IDENTITIES = 100
33 changes: 33 additions & 0 deletions api/experimentation/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from datetime import datetime

from experimentation.stats import Inference, VariantStats
from experimentation.types import ExposureGranularity


Expand Down Expand Up @@ -34,3 +35,35 @@ class ExposuresTimeseries:
class ExposuresSummary:
excluded_identities: int
timeseries: ExposuresTimeseries


@dataclass(frozen=True)
class MetricSpec:
metric_id: int
event: str
aggregation: str
lower_is_better: bool


@dataclass(frozen=True)
class ResultsAggregates:
"""Sufficient statistics gathered from the warehouse for one experiment:
the specs they were computed from, per-variant identity counts, and per
metric the per-variant ``VariantStats``. Bundled so the keys can't drift."""

specs: list[MetricSpec]
exposure_counts: dict[str, int]
metric_stats: dict[int, dict[str, VariantStats]]


@dataclass(frozen=True)
class MetricResult:
metric_id: int
variants: dict[str, VariantStats]
inference: dict[str, Inference | None]


@dataclass(frozen=True)
class ResultsSummary:
srm_p_value: float | None
metrics: list[MetricResult]
208 changes: 208 additions & 0 deletions api/experimentation/services.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
from dataclasses import replace
from functools import lru_cache

import structlog
Expand All @@ -12,24 +13,39 @@
from audit.models import AuditLog
from audit.related_object_type import RelatedObjectType
from experimentation.constants import (
CONTROL_VARIANT_KEY,
EXPERIMENT_FLAG,
EXPOSURE_EVENT_NAME,
EXPOSURE_HOURLY_BUCKET_MAX_WINDOW,
RESULTS_MIN_CONVERSIONS_PER_VARIANT,
RESULTS_MIN_IDENTITIES_PER_VARIANT,
SRM_MIN_TOTAL_IDENTITIES,
WAREHOUSE_CONNECTION_FLAG,
)
from experimentation.dataclasses import (
ExposureBucket,
ExposuresSummary,
ExposuresTimeseries,
ExposuresTimeseriesPoint,
MetricResult,
MetricSpec,
ResultsAggregates,
ResultsSummary,
WarehouseEventStats,
)
from experimentation.models import (
VALID_STATUS_TRANSITIONS,
ExperimentStatus,
MetricAggregation,
WarehouseConnectionStatus,
WarehouseType,
)
from experimentation.stats import (
Inference,
VariantStats,
compare_to_control,
srm_p_value,
)
from integrations.flagsmith.client import get_openfeature_client

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -227,6 +243,198 @@ def get_exposure_buckets(
]


# First-exposed identities per variant (same dedup/quarantine as the exposures
# query), each carrying the per-metric value used to derive sufficient
# statistics. Metric events are attributed from first exposure onwards.
_RESULTS_EXPOSURES_CTE = """
WITH exposures AS (
SELECT
identifier,
if(uniqExact(value) > 1, '', any(value)) AS variant,
uniqExact(value) > 1 AS quarantined,
min(timestamp) AS first_exposure
FROM events
WHERE environment_key = %(environment_key)s
AND event = %(exposure_event)s
AND feature_name = %(feature_name)s
AND timestamp >= %(window_start)s
AND timestamp < %(window_end)s
GROUP BY identifier
)"""


def _metric_value_expression(index: int, aggregation: str) -> str:
# Post-exposure attribution lives here, not in the JOIN ON: ClickHouse
# rejects an ON clause mixing left and right columns in an inequality.
condition = (
f"m.event = %(metric_{index}_event)s AND m.timestamp >= e.first_exposure"
)
value = "toFloat64OrZero(m.value)"
if aggregation == MetricAggregation.OCCURRENCE:
return f"countIf({condition}) > 0"
if aggregation == MetricAggregation.COUNT:
return f"countIf({condition})"
if aggregation == MetricAggregation.SUM:
return f"sumIf({value}, {condition})"
if aggregation == MetricAggregation.MEAN:
# per-identity average, zero when the identity has no matching events
return f"if(countIf({condition}) > 0, avgIf({value}, {condition}), 0)"
raise ValueError(f"Unsupported metric aggregation: {aggregation}")


def _build_results_query(specs: Sequence[MetricSpec]) -> str:
if not specs:
return (
_RESULTS_EXPOSURES_CTE
+ "\nSELECT variant, count() AS n"
+ "\nFROM exposures\nWHERE quarantined = 0\nGROUP BY variant"
)
value_expressions = ",\n ".join(
f"{_metric_value_expression(i, spec.aggregation)} AS m{i}"
for i, spec in enumerate(specs)
)
outer_aggregates = ",\n ".join(
f"sum(m{i}) AS m{i}_sum, sum(m{i} * m{i}) AS m{i}_sum_squares"
for i in range(len(specs))
)
return (
_RESULTS_EXPOSURES_CTE
+ f""",
unit_values AS (
SELECT
e.variant AS variant,
{value_expressions}
FROM exposures AS e
LEFT JOIN events AS m
ON m.identifier = e.identifier
AND m.environment_key = %(environment_key)s
AND m.event IN %(metric_events)s
AND m.timestamp < %(window_end)s
WHERE e.quarantined = 0
GROUP BY e.identifier, e.variant
)
SELECT variant, count() AS n,
{outer_aggregates}
FROM unit_values
GROUP BY variant"""
)


def get_metric_variant_stats(
*,
environment_key: str,
feature_name: str,
window_start: datetime,
window_end: datetime,
specs: Sequence[MetricSpec],
) -> ResultsAggregates:
"""Run the warehouse query, returning per-variant identity counts and, per
metric, per-variant sufficient statistics."""
params: dict[str, object] = {
"environment_key": environment_key,
"exposure_event": EXPOSURE_EVENT_NAME,
"feature_name": feature_name,
"window_start": window_start,
"window_end": window_end,
}
if specs:
params["metric_events"] = [spec.event for spec in specs]
for index, spec in enumerate(specs):
params[f"metric_{index}_event"] = spec.event

rows = _get_clickhouse_client().execute(_build_results_query(specs), params)

exposure_counts: dict[str, int] = {}
metric_stats: dict[int, dict[str, VariantStats]] = {
spec.metric_id: {} for spec in specs
}
for row in rows:
# Columns are emitted in query order: variant, n, then a (sum,
# sum_squares) pair per spec — consumed positionally in that order.
columns = iter(row)
variant = next(columns)
n = int(next(columns))
exposure_counts[variant] = n
for spec in specs:
metric_stats[spec.metric_id][variant] = VariantStats(
n=n,
sum=float(next(columns)),
sum_squares=float(next(columns)),
)
return ResultsAggregates(
specs=list(specs),
exposure_counts=exposure_counts,
metric_stats=metric_stats,
)


def build_results_summary(
aggregates: ResultsAggregates,
*,
expected_shares: dict[str, float],
) -> ResultsSummary:
exposure_counts = aggregates.exposure_counts
total = sum(exposure_counts.values())
if expected_shares and total >= SRM_MIN_TOTAL_IDENTITIES:
srm = srm_p_value(
[exposure_counts.get(variant, 0) for variant in expected_shares],
list(expected_shares.values()),
)
else:
srm = None
return ResultsSummary(
srm_p_value=srm,
metrics=[
MetricResult(
metric_id=spec.metric_id,
variants=aggregates.metric_stats.get(spec.metric_id, {}),
inference=_metric_inference(
spec, aggregates.metric_stats.get(spec.metric_id, {})
),
)
for spec in aggregates.specs
],
)


def _metric_inference(
spec: MetricSpec,
variants: dict[str, VariantStats],
) -> dict[str, Inference | None]:
control = variants.get(CONTROL_VARIANT_KEY)
return {
variant_key: _infer_treatment(spec, control, treatment)
for variant_key, treatment in variants.items()
if variant_key != CONTROL_VARIANT_KEY
}


def _infer_treatment(
spec: MetricSpec,
control: VariantStats | None,
treatment: VariantStats,
) -> Inference | None:
# Product floor for showing a result at all; compare_to_control applies its
# own independent guards (e.g. zero control mean) on top of this.
if (
control is None
or control.n < RESULTS_MIN_IDENTITIES_PER_VARIANT
or treatment.n < RESULTS_MIN_IDENTITIES_PER_VARIANT
):
return None
if spec.aggregation == MetricAggregation.OCCURRENCE and (
control.sum < RESULTS_MIN_CONVERSIONS_PER_VARIANT
or treatment.sum < RESULTS_MIN_CONVERSIONS_PER_VARIANT
):
return None
inference = compare_to_control(control, treatment)
if inference is not None and spec.lower_is_better:
# "Winning" means moving the metric the good way; for a lower-is-better
# metric that's a fall, so the chance of winning is the chance lift < 0.
inference = replace(inference, chance_to_win=1.0 - inference.chance_to_win)
return inference


def _resolve_audit_log_author(
user: FFAdminUser,
) -> dict[str, int | None]:
Expand Down
Loading
Loading