Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 4 additions & 155 deletions packages/elf-service/src/search/ranking/retrieval.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
mod merge;
mod payload;

pub use self::merge::merge_retrieval_candidates;

use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
Expand All @@ -8,10 +11,7 @@ use std::{
use qdrant_client::qdrant::ScoredPoint;
use uuid::Uuid;

use crate::search::{
ChunkCandidate, ChunkRow, NoteMeta, RetrievalSourceCandidates, RetrievalSourceKind,
ranking::policy::ResolvedRetrievalSourcesPolicy,
};
use crate::search::{ChunkCandidate, ChunkRow, NoteMeta};

pub fn collect_chunk_candidates(
points: &[ScoredPoint],
Expand Down Expand Up @@ -71,157 +71,6 @@ pub fn collect_chunk_candidates(
out
}

pub fn retrieval_source_weight(
policy: &ResolvedRetrievalSourcesPolicy,
source: RetrievalSourceKind,
) -> f32 {
match source {
RetrievalSourceKind::Fusion => policy.fusion_weight,
RetrievalSourceKind::StructuredField => policy.structured_field_weight,
RetrievalSourceKind::Recursive => policy.recursive_weight,
}
}

pub fn retrieval_source_priority(
policy: &ResolvedRetrievalSourcesPolicy,
source: RetrievalSourceKind,
) -> u32 {
match source {
RetrievalSourceKind::StructuredField => policy.structured_field_priority,
RetrievalSourceKind::Fusion => policy.fusion_priority,
RetrievalSourceKind::Recursive => policy.recursive_priority,
}
}

pub fn retrieval_source_kind_order(source: RetrievalSourceKind) -> u8 {
match source {
RetrievalSourceKind::StructuredField => 0,
RetrievalSourceKind::Fusion => 1,
RetrievalSourceKind::Recursive => 2,
}
}

pub fn merge_retrieval_candidates(
sources: Vec<RetrievalSourceCandidates>,
policy: &ResolvedRetrievalSourcesPolicy,
candidate_k: u32,
) -> Vec<ChunkCandidate> {
if candidate_k == 0 {
return Vec::new();
}

#[derive(Debug)]
struct MergedRetrievalCandidate {
candidate: ChunkCandidate,
source_ranks: HashMap<RetrievalSourceKind, u32>,
combined_score: f32,
}

let mut by_chunk: HashMap<Uuid, MergedRetrievalCandidate> = HashMap::new();
let mut source_totals: HashMap<RetrievalSourceKind, u32> = HashMap::new();

for source in sources {
let mut seen_for_source = HashSet::new();

for candidate in &source.candidates {
if seen_for_source.insert(candidate.chunk_id) {
*source_totals.entry(source.source).or_insert(0) += 1;
}
}
for candidate in source.candidates {
let chunk_id = candidate.chunk_id;
let rank = candidate.retrieval_rank;

match by_chunk.get_mut(&chunk_id) {
Some(existing) => {
let entry = existing.source_ranks.entry(source.source).or_insert(rank);

*entry = (*entry).min(rank);
},
None => {
let mut source_ranks = HashMap::new();

source_ranks.insert(source.source, rank);
by_chunk.insert(
chunk_id,
MergedRetrievalCandidate { candidate, source_ranks, combined_score: 0.0 },
);
},
}
}
}

if by_chunk.is_empty() {
return Vec::new();
}

for total in source_totals.values_mut() {
*total = (*total).max(1);
}

let mut source_order: Vec<RetrievalSourceKind> = source_totals.keys().copied().collect();

source_order.sort_by(|left, right| {
retrieval_source_priority(policy, *left)
.cmp(&retrieval_source_priority(policy, *right))
.then_with(|| {
retrieval_source_kind_order(*left).cmp(&retrieval_source_kind_order(*right))
})
});

let mut merged: Vec<MergedRetrievalCandidate> = by_chunk.into_values().collect();

for candidate in &mut merged {
let mut combined_score = 0.0_f32;

for (source, rank) in &candidate.source_ranks {
let total = source_totals.get(source).copied().unwrap_or(1);

combined_score +=
retrieval_source_weight(policy, *source) * rank_normalize(*rank, total);
}

candidate.combined_score = combined_score;
}

merged.sort_by(|left, right| {
cmp_f32_desc(left.combined_score, right.combined_score)
.then_with(|| right.source_ranks.len().cmp(&left.source_ranks.len()))
.then_with(|| {
for source in &source_order {
let lhs = left.source_ranks.get(source).copied();
let rhs = right.source_ranks.get(source).copied();
let ord = rank_asc(lhs, rhs);

if ord != Ordering::Equal {
return ord;
}
}

Ordering::Equal
})
.then_with(|| left.candidate.chunk_id.cmp(&right.candidate.chunk_id))
});

let mut out = Vec::new();

for (idx, mut candidate) in merged.into_iter().take(candidate_k as usize).enumerate() {
candidate.candidate.retrieval_rank = idx as u32 + 1;
candidate.candidate.retrieval_score = Some(candidate.combined_score);

out.push(candidate.candidate);
}

out
}

pub fn rank_asc(left: Option<u32>, right: Option<u32>) -> Ordering {
let lhs = left.unwrap_or(u32::MAX);
let rhs = right.unwrap_or(u32::MAX);

lhs.cmp(&rhs)
}

pub fn candidate_matches_note(
note_meta: &HashMap<Uuid, NoteMeta>,
candidate: &ChunkCandidate,
Expand Down
162 changes: 162 additions & 0 deletions packages/elf-service/src/search/ranking/retrieval/merge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};

use uuid::Uuid;

use crate::search::{
ChunkCandidate, RetrievalSourceCandidates, RetrievalSourceKind,
ranking::policy::ResolvedRetrievalSourcesPolicy,
};

#[derive(Debug)]
struct MergedRetrievalCandidate {
candidate: ChunkCandidate,
source_ranks: HashMap<RetrievalSourceKind, u32>,
combined_score: f32,
}

pub fn merge_retrieval_candidates(
sources: Vec<RetrievalSourceCandidates>,
policy: &ResolvedRetrievalSourcesPolicy,
candidate_k: u32,
) -> Vec<ChunkCandidate> {
if candidate_k == 0 {
return Vec::new();
}

let mut by_chunk: HashMap<Uuid, MergedRetrievalCandidate> = HashMap::new();
let mut source_totals: HashMap<RetrievalSourceKind, u32> = HashMap::new();

for source in sources {
let mut seen_for_source = HashSet::new();

for candidate in &source.candidates {
if seen_for_source.insert(candidate.chunk_id) {
*source_totals.entry(source.source).or_insert(0) += 1;
}
}
for candidate in source.candidates {
let chunk_id = candidate.chunk_id;
let rank = candidate.retrieval_rank;

match by_chunk.get_mut(&chunk_id) {
Some(existing) => {
let entry = existing.source_ranks.entry(source.source).or_insert(rank);

*entry = (*entry).min(rank);
},
None => {
let mut source_ranks = HashMap::new();

source_ranks.insert(source.source, rank);
by_chunk.insert(
chunk_id,
MergedRetrievalCandidate { candidate, source_ranks, combined_score: 0.0 },
);
},
}
}
}

if by_chunk.is_empty() {
return Vec::new();
}

for total in source_totals.values_mut() {
*total = (*total).max(1);
}

let mut source_order: Vec<RetrievalSourceKind> = source_totals.keys().copied().collect();

source_order.sort_by(|left, right| {
retrieval_source_priority(policy, *left)
.cmp(&retrieval_source_priority(policy, *right))
.then_with(|| {
retrieval_source_kind_order(*left).cmp(&retrieval_source_kind_order(*right))
})
});

let mut merged: Vec<MergedRetrievalCandidate> = by_chunk.into_values().collect();

for candidate in &mut merged {
let mut combined_score = 0.0_f32;

for (source, rank) in &candidate.source_ranks {
let total = source_totals.get(source).copied().unwrap_or(1);

combined_score +=
retrieval_source_weight(policy, *source) * super::rank_normalize(*rank, total);
}

candidate.combined_score = combined_score;
}

merged.sort_by(|left, right| {
super::cmp_f32_desc(left.combined_score, right.combined_score)
.then_with(|| right.source_ranks.len().cmp(&left.source_ranks.len()))
.then_with(|| {
for source in &source_order {
let lhs = left.source_ranks.get(source).copied();
let rhs = right.source_ranks.get(source).copied();
let ord = rank_asc(lhs, rhs);

if ord != Ordering::Equal {
return ord;
}
}

Ordering::Equal
})
.then_with(|| left.candidate.chunk_id.cmp(&right.candidate.chunk_id))
});

let mut out = Vec::new();

for (idx, mut candidate) in merged.into_iter().take(candidate_k as usize).enumerate() {
candidate.candidate.retrieval_rank = idx as u32 + 1;
candidate.candidate.retrieval_score = Some(candidate.combined_score);

out.push(candidate.candidate);
}

out
}

fn retrieval_source_weight(
policy: &ResolvedRetrievalSourcesPolicy,
source: RetrievalSourceKind,
) -> f32 {
match source {
RetrievalSourceKind::Fusion => policy.fusion_weight,
RetrievalSourceKind::StructuredField => policy.structured_field_weight,
RetrievalSourceKind::Recursive => policy.recursive_weight,
}
}

fn retrieval_source_priority(
policy: &ResolvedRetrievalSourcesPolicy,
source: RetrievalSourceKind,
) -> u32 {
match source {
RetrievalSourceKind::StructuredField => policy.structured_field_priority,
RetrievalSourceKind::Fusion => policy.fusion_priority,
RetrievalSourceKind::Recursive => policy.recursive_priority,
}
}

fn retrieval_source_kind_order(source: RetrievalSourceKind) -> u8 {
match source {
RetrievalSourceKind::StructuredField => 0,
RetrievalSourceKind::Fusion => 1,
RetrievalSourceKind::Recursive => 2,
}
}

fn rank_asc(left: Option<u32>, right: Option<u32>) -> Ordering {
let lhs = left.unwrap_or(u32::MAX);
let rhs = right.unwrap_or(u32::MAX);

lhs.cmp(&rhs)
}
37 changes: 37 additions & 0 deletions packages/elf-service/src/search/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,43 @@ fn merge_retrieval_candidates_prefers_dual_source_signal_on_tie() {
assert_eq!(first.chunk_id, shared_chunk_id);
}

#[test]
fn merge_retrieval_candidates_uses_configured_source_priority_on_tie() {
let fusion_chunk_id = Uuid::from_u128(1);
let recursive_chunk_id = Uuid::from_u128(2);
let mut fusion_candidate = test_chunk_candidate(Uuid::new_v4(), 1);
let mut recursive_candidate = test_chunk_candidate(Uuid::new_v4(), 1);

fusion_candidate.chunk_id = fusion_chunk_id;
recursive_candidate.chunk_id = recursive_chunk_id;

let policy = ranking::ResolvedRetrievalSourcesPolicy {
fusion_weight: 1.0,
structured_field_weight: 0.0,
recursive_weight: 1.0,
fusion_priority: 10,
structured_field_priority: 20,
recursive_priority: 0,
};
let merged = ranking::merge_retrieval_candidates(
vec![
RetrievalSourceCandidates {
source: RetrievalSourceKind::Fusion,
candidates: vec![fusion_candidate],
},
RetrievalSourceCandidates {
source: RetrievalSourceKind::Recursive,
candidates: vec![recursive_candidate],
},
],
&policy,
2,
);

assert_eq!(merged[0].chunk_id, recursive_chunk_id);
assert_eq!(merged[1].chunk_id, fusion_chunk_id);
}

#[test]
fn retrieval_weight_for_rank_uses_first_matching_segment_or_last() {
let segments = vec![
Expand Down