diff --git a/packages/elf-service/src/search/ranking/retrieval.rs b/packages/elf-service/src/search/ranking/retrieval.rs index d8a22f82..78a1b1e0 100644 --- a/packages/elf-service/src/search/ranking/retrieval.rs +++ b/packages/elf-service/src/search/ranking/retrieval.rs @@ -1,5 +1,8 @@ +mod merge; mod payload; +pub use self::merge::merge_retrieval_candidates; + use std::{ cmp::Ordering, collections::{HashMap, HashSet}, @@ -8,10 +11,7 @@ use std::{ use qdrant_client::qdrant::ScoredPoint; use uuid::Uuid; -use crate::search::{ - ChunkCandidate, ChunkRow, NoteMeta, RetrievalSourceCandidates, RetrievalSourceKind, - ranking::policy::ResolvedRetrievalSourcesPolicy, -}; +use crate::search::{ChunkCandidate, ChunkRow, NoteMeta}; pub fn collect_chunk_candidates( points: &[ScoredPoint], @@ -71,157 +71,6 @@ pub fn collect_chunk_candidates( out } -pub fn retrieval_source_weight( - policy: &ResolvedRetrievalSourcesPolicy, - source: RetrievalSourceKind, -) -> f32 { - match source { - RetrievalSourceKind::Fusion => policy.fusion_weight, - RetrievalSourceKind::StructuredField => policy.structured_field_weight, - RetrievalSourceKind::Recursive => policy.recursive_weight, - } -} - -pub fn retrieval_source_priority( - policy: &ResolvedRetrievalSourcesPolicy, - source: RetrievalSourceKind, -) -> u32 { - match source { - RetrievalSourceKind::StructuredField => policy.structured_field_priority, - RetrievalSourceKind::Fusion => policy.fusion_priority, - RetrievalSourceKind::Recursive => policy.recursive_priority, - } -} - -pub fn retrieval_source_kind_order(source: RetrievalSourceKind) -> u8 { - match source { - RetrievalSourceKind::StructuredField => 0, - RetrievalSourceKind::Fusion => 1, - RetrievalSourceKind::Recursive => 2, - } -} - -pub fn merge_retrieval_candidates( - sources: Vec, - policy: &ResolvedRetrievalSourcesPolicy, - candidate_k: u32, -) -> Vec { - if candidate_k == 0 { - return Vec::new(); - } - - #[derive(Debug)] - struct MergedRetrievalCandidate { - candidate: ChunkCandidate, - source_ranks: HashMap, - combined_score: f32, - } - - let mut by_chunk: HashMap = HashMap::new(); - let mut source_totals: HashMap = HashMap::new(); - - for source in sources { - let mut seen_for_source = HashSet::new(); - - for candidate in &source.candidates { - if seen_for_source.insert(candidate.chunk_id) { - *source_totals.entry(source.source).or_insert(0) += 1; - } - } - for candidate in source.candidates { - let chunk_id = candidate.chunk_id; - let rank = candidate.retrieval_rank; - - match by_chunk.get_mut(&chunk_id) { - Some(existing) => { - let entry = existing.source_ranks.entry(source.source).or_insert(rank); - - *entry = (*entry).min(rank); - }, - None => { - let mut source_ranks = HashMap::new(); - - source_ranks.insert(source.source, rank); - by_chunk.insert( - chunk_id, - MergedRetrievalCandidate { candidate, source_ranks, combined_score: 0.0 }, - ); - }, - } - } - } - - if by_chunk.is_empty() { - return Vec::new(); - } - - for total in source_totals.values_mut() { - *total = (*total).max(1); - } - - let mut source_order: Vec = source_totals.keys().copied().collect(); - - source_order.sort_by(|left, right| { - retrieval_source_priority(policy, *left) - .cmp(&retrieval_source_priority(policy, *right)) - .then_with(|| { - retrieval_source_kind_order(*left).cmp(&retrieval_source_kind_order(*right)) - }) - }); - - let mut merged: Vec = by_chunk.into_values().collect(); - - for candidate in &mut merged { - let mut combined_score = 0.0_f32; - - for (source, rank) in &candidate.source_ranks { - let total = source_totals.get(source).copied().unwrap_or(1); - - combined_score += - retrieval_source_weight(policy, *source) * rank_normalize(*rank, total); - } - - candidate.combined_score = combined_score; - } - - merged.sort_by(|left, right| { - cmp_f32_desc(left.combined_score, right.combined_score) - .then_with(|| right.source_ranks.len().cmp(&left.source_ranks.len())) - .then_with(|| { - for source in &source_order { - let lhs = left.source_ranks.get(source).copied(); - let rhs = right.source_ranks.get(source).copied(); - let ord = rank_asc(lhs, rhs); - - if ord != Ordering::Equal { - return ord; - } - } - - Ordering::Equal - }) - .then_with(|| left.candidate.chunk_id.cmp(&right.candidate.chunk_id)) - }); - - let mut out = Vec::new(); - - for (idx, mut candidate) in merged.into_iter().take(candidate_k as usize).enumerate() { - candidate.candidate.retrieval_rank = idx as u32 + 1; - candidate.candidate.retrieval_score = Some(candidate.combined_score); - - out.push(candidate.candidate); - } - - out -} - -pub fn rank_asc(left: Option, right: Option) -> Ordering { - let lhs = left.unwrap_or(u32::MAX); - let rhs = right.unwrap_or(u32::MAX); - - lhs.cmp(&rhs) -} - pub fn candidate_matches_note( note_meta: &HashMap, candidate: &ChunkCandidate, diff --git a/packages/elf-service/src/search/ranking/retrieval/merge.rs b/packages/elf-service/src/search/ranking/retrieval/merge.rs new file mode 100644 index 00000000..9e94f3ac --- /dev/null +++ b/packages/elf-service/src/search/ranking/retrieval/merge.rs @@ -0,0 +1,162 @@ +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, +}; + +use uuid::Uuid; + +use crate::search::{ + ChunkCandidate, RetrievalSourceCandidates, RetrievalSourceKind, + ranking::policy::ResolvedRetrievalSourcesPolicy, +}; + +#[derive(Debug)] +struct MergedRetrievalCandidate { + candidate: ChunkCandidate, + source_ranks: HashMap, + combined_score: f32, +} + +pub fn merge_retrieval_candidates( + sources: Vec, + policy: &ResolvedRetrievalSourcesPolicy, + candidate_k: u32, +) -> Vec { + if candidate_k == 0 { + return Vec::new(); + } + + let mut by_chunk: HashMap = HashMap::new(); + let mut source_totals: HashMap = HashMap::new(); + + for source in sources { + let mut seen_for_source = HashSet::new(); + + for candidate in &source.candidates { + if seen_for_source.insert(candidate.chunk_id) { + *source_totals.entry(source.source).or_insert(0) += 1; + } + } + for candidate in source.candidates { + let chunk_id = candidate.chunk_id; + let rank = candidate.retrieval_rank; + + match by_chunk.get_mut(&chunk_id) { + Some(existing) => { + let entry = existing.source_ranks.entry(source.source).or_insert(rank); + + *entry = (*entry).min(rank); + }, + None => { + let mut source_ranks = HashMap::new(); + + source_ranks.insert(source.source, rank); + by_chunk.insert( + chunk_id, + MergedRetrievalCandidate { candidate, source_ranks, combined_score: 0.0 }, + ); + }, + } + } + } + + if by_chunk.is_empty() { + return Vec::new(); + } + + for total in source_totals.values_mut() { + *total = (*total).max(1); + } + + let mut source_order: Vec = source_totals.keys().copied().collect(); + + source_order.sort_by(|left, right| { + retrieval_source_priority(policy, *left) + .cmp(&retrieval_source_priority(policy, *right)) + .then_with(|| { + retrieval_source_kind_order(*left).cmp(&retrieval_source_kind_order(*right)) + }) + }); + + let mut merged: Vec = by_chunk.into_values().collect(); + + for candidate in &mut merged { + let mut combined_score = 0.0_f32; + + for (source, rank) in &candidate.source_ranks { + let total = source_totals.get(source).copied().unwrap_or(1); + + combined_score += + retrieval_source_weight(policy, *source) * super::rank_normalize(*rank, total); + } + + candidate.combined_score = combined_score; + } + + merged.sort_by(|left, right| { + super::cmp_f32_desc(left.combined_score, right.combined_score) + .then_with(|| right.source_ranks.len().cmp(&left.source_ranks.len())) + .then_with(|| { + for source in &source_order { + let lhs = left.source_ranks.get(source).copied(); + let rhs = right.source_ranks.get(source).copied(); + let ord = rank_asc(lhs, rhs); + + if ord != Ordering::Equal { + return ord; + } + } + + Ordering::Equal + }) + .then_with(|| left.candidate.chunk_id.cmp(&right.candidate.chunk_id)) + }); + + let mut out = Vec::new(); + + for (idx, mut candidate) in merged.into_iter().take(candidate_k as usize).enumerate() { + candidate.candidate.retrieval_rank = idx as u32 + 1; + candidate.candidate.retrieval_score = Some(candidate.combined_score); + + out.push(candidate.candidate); + } + + out +} + +fn retrieval_source_weight( + policy: &ResolvedRetrievalSourcesPolicy, + source: RetrievalSourceKind, +) -> f32 { + match source { + RetrievalSourceKind::Fusion => policy.fusion_weight, + RetrievalSourceKind::StructuredField => policy.structured_field_weight, + RetrievalSourceKind::Recursive => policy.recursive_weight, + } +} + +fn retrieval_source_priority( + policy: &ResolvedRetrievalSourcesPolicy, + source: RetrievalSourceKind, +) -> u32 { + match source { + RetrievalSourceKind::StructuredField => policy.structured_field_priority, + RetrievalSourceKind::Fusion => policy.fusion_priority, + RetrievalSourceKind::Recursive => policy.recursive_priority, + } +} + +fn retrieval_source_kind_order(source: RetrievalSourceKind) -> u8 { + match source { + RetrievalSourceKind::StructuredField => 0, + RetrievalSourceKind::Fusion => 1, + RetrievalSourceKind::Recursive => 2, + } +} + +fn rank_asc(left: Option, right: Option) -> Ordering { + let lhs = left.unwrap_or(u32::MAX); + let rhs = right.unwrap_or(u32::MAX); + + lhs.cmp(&rhs) +} diff --git a/packages/elf-service/src/search/tests.rs b/packages/elf-service/src/search/tests.rs index 7d0faa35..6420b4d3 100644 --- a/packages/elf-service/src/search/tests.rs +++ b/packages/elf-service/src/search/tests.rs @@ -227,6 +227,43 @@ fn merge_retrieval_candidates_prefers_dual_source_signal_on_tie() { assert_eq!(first.chunk_id, shared_chunk_id); } +#[test] +fn merge_retrieval_candidates_uses_configured_source_priority_on_tie() { + let fusion_chunk_id = Uuid::from_u128(1); + let recursive_chunk_id = Uuid::from_u128(2); + let mut fusion_candidate = test_chunk_candidate(Uuid::new_v4(), 1); + let mut recursive_candidate = test_chunk_candidate(Uuid::new_v4(), 1); + + fusion_candidate.chunk_id = fusion_chunk_id; + recursive_candidate.chunk_id = recursive_chunk_id; + + let policy = ranking::ResolvedRetrievalSourcesPolicy { + fusion_weight: 1.0, + structured_field_weight: 0.0, + recursive_weight: 1.0, + fusion_priority: 10, + structured_field_priority: 20, + recursive_priority: 0, + }; + let merged = ranking::merge_retrieval_candidates( + vec![ + RetrievalSourceCandidates { + source: RetrievalSourceKind::Fusion, + candidates: vec![fusion_candidate], + }, + RetrievalSourceCandidates { + source: RetrievalSourceKind::Recursive, + candidates: vec![recursive_candidate], + }, + ], + &policy, + 2, + ); + + assert_eq!(merged[0].chunk_id, recursive_chunk_id); + assert_eq!(merged[1].chunk_id, fusion_chunk_id); +} + #[test] fn retrieval_weight_for_rank_uses_first_matching_segment_or_last() { let segments = vec![