Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 13 additions & 127 deletions apps/elf-eval/src/app/trace_compare.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::{collections::HashMap, path::Path};
mod analysis;

use std::path::Path;

use color_eyre::{Result, eyre};
use time::format_description::well_known::Rfc3339;
Expand All @@ -9,13 +11,12 @@ use crate::app::{
metrics::{self},
types::{
TraceCompareCandidateRow, TraceCompareChurn, TraceCompareGuardrails, TraceCompareOutput,
TraceComparePolicies, TraceComparePolicy, TraceCompareRegressionAttribution,
TraceCompareStageDelta, TraceCompareStageRow, TraceCompareSummary, TraceCompareTrace,
TraceCompareTraceRow, TraceCompareVariant,
TraceComparePolicies, TraceComparePolicy, TraceCompareStageRow, TraceCompareSummary,
TraceCompareTrace, TraceCompareTraceRow, TraceCompareVariant,
},
};
use elf_config::Config;
use elf_service::search::{self, TraceReplayCandidate, TraceReplayContext};
use elf_service::search::{self, TraceReplayContext};
use elf_storage::db::Db;

pub(super) async fn trace_compare(
Expand Down Expand Up @@ -85,124 +86,6 @@ pub(super) async fn trace_compare(
})
}

fn decode_trace_replay_candidates(
rows: Vec<TraceCompareCandidateRow>,
) -> Vec<TraceReplayCandidate> {
rows.into_iter()
.map(|row| {
let decoded =
serde_json::from_value::<TraceReplayCandidate>(row.candidate_snapshot.clone())
.ok()
.filter(|value| value.note_id != Uuid::nil() && value.chunk_id != Uuid::nil());

decoded.unwrap_or_else(|| TraceReplayCandidate {
note_id: row.note_id,
chunk_id: row.chunk_id,
chunk_index: row.chunk_index,
snippet: row.snippet,
retrieval_rank: u32::try_from(row.retrieval_rank).unwrap_or(0),
retrieval_score: None,
rerank_score: row.rerank_score,
note_scope: row.note_scope,
note_importance: row.note_importance,
note_updated_at: row.note_updated_at,
note_hit_count: row.note_hit_count,
note_last_hit_at: row.note_last_hit_at,
diversity_selected: None,
diversity_selected_rank: None,
diversity_selected_reason: None,
diversity_skipped_reason: None,
diversity_nearest_selected_note_id: None,
diversity_similarity: None,
diversity_mmr_score: None,
diversity_missing_embedding: None,
})
})
.collect()
}

fn build_trace_compare_stage_deltas(
stage_rows: &[TraceCompareStageRow],
a_selected_count: u32,
b_selected_count: u32,
) -> Vec<TraceCompareStageDelta> {
if stage_rows.is_empty() {
return vec![TraceCompareStageDelta {
stage_order: 1,
stage_name: "selection.final".to_string(),
baseline_item_count: 0,
a_item_count: a_selected_count,
b_item_count: b_selected_count,
item_count_delta: b_selected_count as i64 - a_selected_count as i64,
baseline_stats: None,
}];
}

let mut out = Vec::with_capacity(stage_rows.len());

for row in stage_rows {
let baseline_item_count = row.item_count.max(0) as u32;
let (a_item_count, b_item_count) = if row.stage_name == "selection.final" {
(a_selected_count, b_selected_count)
} else {
(baseline_item_count, baseline_item_count)
};
let baseline_stats = row.stage_payload.get("stats").cloned();

out.push(TraceCompareStageDelta {
stage_order: row.stage_order.max(0) as u32,
stage_name: row.stage_name.clone(),
baseline_item_count,
a_item_count,
b_item_count,
item_count_delta: b_item_count as i64 - a_item_count as i64,
baseline_stats,
});
}

out
}

fn build_trace_compare_regression_attribution(
churn: &TraceCompareChurn,
guardrails: &TraceCompareGuardrails,
stage_deltas: &[TraceCompareStageDelta],
) -> TraceCompareRegressionAttribution {
let stage_by_name: HashMap<&str, &TraceCompareStageDelta> =
stage_deltas.iter().map(|stage| (stage.stage_name.as_str(), stage)).collect();

if guardrails.retrieval_top3_retention_delta < 0.0 {
let recall_count = stage_by_name
.get("recall.candidates")
.map(|stage| stage.baseline_item_count)
.unwrap_or(0);

return TraceCompareRegressionAttribution {
primary_stage: "selection.final".to_string(),
evidence: format!(
"retrieval_top3_retention dropped by {:.4} (a={:.4}, b={:.4}); recall baseline item_count={recall_count}",
guardrails.retrieval_top3_retention_delta,
guardrails.a_retrieval_top3_retention,
guardrails.b_retrieval_top3_retention
),
};
}
if churn.set_churn_at_k > 0.0 || churn.positional_churn_at_k > 0.0 {
return TraceCompareRegressionAttribution {
primary_stage: "rerank.score".to_string(),
evidence: format!(
"top-k churn changed without retrieval-top3 regression (set_churn_at_k={:.4}, positional_churn_at_k={:.4})",
churn.set_churn_at_k, churn.positional_churn_at_k
),
};
}

TraceCompareRegressionAttribution {
primary_stage: "not_applicable".to_string(),
evidence: "No regression signal detected.".to_string(),
}
}

async fn compare_trace_id(
db: &Db,
config_a: &Config,
Expand All @@ -226,7 +109,7 @@ async fn compare_trace_id(
.created_at
.format(&Rfc3339)
.map_err(|err| eyre::eyre!("Failed to format trace created_at: {err}"))?;
let candidates = decode_trace_replay_candidates(candidate_rows);
let candidates = analysis::decode_trace_replay_candidates(candidate_rows);
let top_k = args.top_k.unwrap_or(context.top_k).max(1);
let items_a =
search::replay_ranking_from_candidates(config_a, &context, None, &candidates, top_k)
Expand All @@ -251,13 +134,16 @@ async fn compare_trace_id(
b_retrieval_top3_retention: b_retention,
retrieval_top3_retention_delta: b_retention - a_retention,
};
let stage_deltas = build_trace_compare_stage_deltas(
let stage_deltas = analysis::build_trace_compare_stage_deltas(
stage_rows.as_slice(),
items_a.len() as u32,
items_b.len() as u32,
);
let regression_attribution =
build_trace_compare_regression_attribution(&churn, &guardrails, stage_deltas.as_slice());
let regression_attribution = analysis::build_trace_compare_regression_attribution(
&churn,
&guardrails,
stage_deltas.as_slice(),
);

Ok(TraceCompareTrace {
trace_id: context.trace_id,
Expand Down
Loading