Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bindings/cpp/include/svs/runtime/vamana_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ struct VamanaSearchParameters {
size_t search_buffer_capacity = Unspecify<size_t>();
size_t prefetch_lookahead = Unspecify<size_t>();
size_t prefetch_step = Unspecify<size_t>();
// Minimum filter hit rate to continue filtered search.
// If the hit rate after the first round falls below this threshold,
// stop and return empty results (caller can fall back to exact search).
// Default 0 means never give up.
float filter_stop = 0.0f;
};
} // namespace detail

Expand Down
12 changes: 11 additions & 1 deletion bindings/cpp/src/dynamic_vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,21 @@ class DynamicVamanaIndexImpl {
// Selective search with IDSelector
auto old_sp = impl_->get_search_parameters();
impl_->set_search_parameters(sp);
const float filter_stop = params ? params->filter_stop : 0.0f;

auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) {
for (auto i : range) {
// For every query
auto query = queries.get_datum(i);
auto iterator = impl_->batch_iterator(query);
size_t found = 0;
size_t total_checked = 0;
auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size());
do {
iterator.next(k);
batch_size =
predict_further_processing(total_checked, found, k, batch_size);
iterator.next(batch_size);
total_checked += iterator.size();
for (auto& neighbor : iterator.results()) {
if (filter->is_member(neighbor.id())) {
result.set(neighbor, i, found);
Expand All @@ -136,6 +142,10 @@ class DynamicVamanaIndexImpl {
}
}
}
if (should_stop_filtered_search(total_checked, found, filter_stop)) {
found = 0;
break;
}
} while (found < k && !iterator.done());

// Pad results if not enough neighbors found
Expand Down
25 changes: 25 additions & 0 deletions bindings/cpp/src/svs_runtime_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,31 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) {
}
} // namespace storage

// Predict how many more items need to be processed to reach the goal,
// based on the observed hit rate so far.
// If no hits yet, returns `hint` unchanged.
// The caller should cap the result to a max batch size if needed.
inline size_t
predict_further_processing(size_t processed, size_t hits, size_t goal, size_t hint) {
if (hits == 0 || hits >= goal) {
return hint;
}
float batch_size = static_cast<float>(goal - hits) * processed / hits;
return std::max(static_cast<size_t>(batch_size), size_t{1});
}

// Check if the filtered search should stop early based on the observed hit rate.
// Returns true if the hit rate is below the threshold, meaning the caller should
// give up and let the caller fall back to exact search.
inline bool
should_stop_filtered_search(size_t total_checked, size_t found, float filter_stop) {
if (filter_stop <= 0 || total_checked == 0 || found == 0) {
return false;
}
float hit_rate = static_cast<float>(found) / total_checked;
return hit_rate < filter_stop;
}

inline svs::threads::ThreadPoolHandle default_threadpool() {
return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads())
);
Expand Down
12 changes: 11 additions & 1 deletion bindings/cpp/src/vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,21 @@ class VamanaIndexImpl {
get_impl()->set_search_parameters(old_sp);
});
get_impl()->set_search_parameters(sp);
const float filter_stop = params ? params->filter_stop : 0.0f;

auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) {
for (auto i : range) {
// For every query
auto query = queries.get_datum(i);
auto iterator = get_impl()->batch_iterator(query);
size_t found = 0;
size_t total_checked = 0;
auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size());
do {
iterator.next(k);
batch_size =
predict_further_processing(total_checked, found, k, batch_size);
iterator.next(batch_size);
total_checked += iterator.size();
for (auto& neighbor : iterator.results()) {
if (filter->is_member(neighbor.id())) {
result.set(neighbor, i, found);
Expand All @@ -142,6 +148,10 @@ class VamanaIndexImpl {
}
}
}
if (should_stop_filtered_search(total_checked, found, filter_stop)) {
found = 0;
break;
}
} while (found < k && !iterator.done());

// Pad results if not enough neighbors found
Expand Down
116 changes: 116 additions & 0 deletions bindings/cpp/tests/runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,122 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") {
svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") {
const auto& test_data = get_test_data();
// Build index
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(index != nullptr);

// Add data
std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

const int nq = 5;
const float* xq = test_data.data();
const int k = 5;

// 10% selectivity: accept only IDs 0-9 out of 100
size_t min_id = 0;
size_t max_id = test_n / 10;
test_utils::IDFilterRange filter(min_id, max_id);

std::vector<float> distances(nq * k);
std::vector<size_t> result_labels(nq * k);

status =
index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter);
CATCH_REQUIRE(status.ok());

// All returned labels must fall inside the filter range
for (int i = 0; i < nq * k; ++i) {
if (svs::runtime::v0::is_specified(result_labels[i])) {
CATCH_REQUIRE(result_labels[i] >= min_id);
CATCH_REQUIRE(result_labels[i] < max_id);
}
}

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("FilterStopEarlyExit", "[runtime][filtered_search]") {
const auto& test_data = get_test_data();
// Build index
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(index != nullptr);

// Add data
std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

const int nq = 5;
const float* xq = test_data.data();
const int k = 5;

// 10% selectivity: accept only IDs 0-9 out of 100
size_t min_id = 0;
size_t max_id = test_n / 10;
test_utils::IDFilterRange filter(min_id, max_id);

std::vector<float> distances(nq * k);
std::vector<size_t> result_labels(nq * k);

// Set filter_stop = 0.5 (50%). With ~10% hit rate, search should give up
// and return unspecified results.
svs::runtime::v0::VamanaIndex::SearchParams search_params;
search_params.filter_stop = 0.5f;

status = index->search(
nq, xq, k, distances.data(), result_labels.data(), &search_params, &filter
);
CATCH_REQUIRE(status.ok());

// All results should be unspecified (early exit returned empty)
for (int i = 0; i < nq * k; ++i) {
CATCH_REQUIRE(!svs::runtime::v0::is_specified(result_labels[i]));
}

// Now search without filter_stop — should find valid results
std::vector<float> distances2(nq * k);
std::vector<size_t> result_labels2(nq * k);

status = index->search(
nq, xq, k, distances2.data(), result_labels2.data(), nullptr, &filter
);
CATCH_REQUIRE(status.ok());

// Should have valid results in the filter range
for (int i = 0; i < nq * k; ++i) {
if (svs::runtime::v0::is_specified(result_labels2[i])) {
CATCH_REQUIRE(result_labels2[i] >= min_id);
CATCH_REQUIRE(result_labels2[i] < max_id);
}
}

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") {
const auto& test_data = get_test_data();
// Build index
Expand Down
Loading