diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index 98831952..5286551d 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -40,6 +40,11 @@ struct VamanaSearchParameters { size_t search_buffer_capacity = Unspecify(); size_t prefetch_lookahead = Unspecify(); size_t prefetch_step = Unspecify(); + // Minimum filter hit rate to continue filtered search. + // If the hit rate after the first round falls below this threshold, + // stop and return empty results (caller can fall back to exact search). + // Default 0 means never give up. + float filter_stop = 0.0f; }; } // namespace detail diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 4b16cf4b..6c28ee95 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -118,6 +118,7 @@ class DynamicVamanaIndexImpl { // Selective search with IDSelector auto old_sp = impl_->get_search_parameters(); impl_->set_search_parameters(sp); + const float filter_stop = params ? params->filter_stop : 0.0f; auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) { for (auto i : range) { @@ -125,8 +126,13 @@ class DynamicVamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = impl_->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); do { - iterator.next(k); + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); @@ -136,6 +142,10 @@ class DynamicVamanaIndexImpl { } } } + if (should_stop_filtered_search(total_checked, found, filter_stop)) { + found = 0; + break; + } } while (found < k && !iterator.done()); // Pad results if not enough neighbors found diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index e0d7c68a..1ab1d4d0 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -431,6 +431,31 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { } } // namespace storage +// Predict how many more items need to be processed to reach the goal, +// based on the observed hit rate so far. +// If no hits yet, returns `hint` unchanged. +// The caller should cap the result to a max batch size if needed. +inline size_t +predict_further_processing(size_t processed, size_t hits, size_t goal, size_t hint) { + if (hits == 0 || hits >= goal) { + return hint; + } + float batch_size = static_cast(goal - hits) * processed / hits; + return std::max(static_cast(batch_size), size_t{1}); +} + +// Check if the filtered search should stop early based on the observed hit rate. +// Returns true if the hit rate is below the threshold, meaning the caller should +// give up and let the caller fall back to exact search. +inline bool +should_stop_filtered_search(size_t total_checked, size_t found, float filter_stop) { + if (filter_stop <= 0 || total_checked == 0 || found == 0) { + return false; + } + float hit_rate = static_cast(found) / total_checked; + return hit_rate < filter_stop; +} + inline svs::threads::ThreadPoolHandle default_threadpool() { return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()) ); diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 4cf58d7e..81cd8ae3 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -124,6 +124,7 @@ class VamanaIndexImpl { get_impl()->set_search_parameters(old_sp); }); get_impl()->set_search_parameters(sp); + const float filter_stop = params ? params->filter_stop : 0.0f; auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) { for (auto i : range) { @@ -131,8 +132,13 @@ class VamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = get_impl()->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); do { - iterator.next(k); + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); @@ -142,6 +148,10 @@ class VamanaIndexImpl { } } } + if (should_stop_filtered_search(total_checked, found, filter_stop)) { + found = 0; + break; + } } while (found < k && !iterator.done()); // Pad results if not enough neighbors found diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3..abd14296 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -501,6 +501,122 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { svs::runtime::v0::DynamicVamanaIndex::destroy(index); } +CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = + index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter); + CATCH_REQUIRE(status.ok()); + + // All returned labels must fall inside the filter range + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("FilterStopEarlyExit", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + // Set filter_stop = 0.5 (50%). With ~10% hit rate, search should give up + // and return unspecified results. + svs::runtime::v0::VamanaIndex::SearchParams search_params; + search_params.filter_stop = 0.5f; + + status = index->search( + nq, xq, k, distances.data(), result_labels.data(), &search_params, &filter + ); + CATCH_REQUIRE(status.ok()); + + // All results should be unspecified (early exit returned empty) + for (int i = 0; i < nq * k; ++i) { + CATCH_REQUIRE(!svs::runtime::v0::is_specified(result_labels[i])); + } + + // Now search without filter_stop — should find valid results + std::vector distances2(nq * k); + std::vector result_labels2(nq * k); + + status = index->search( + nq, xq, k, distances2.data(), result_labels2.data(), nullptr, &filter + ); + CATCH_REQUIRE(status.ok()); + + // Should have valid results in the filter range + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels2[i])) { + CATCH_REQUIRE(result_labels2[i] >= min_id); + CATCH_REQUIRE(result_labels2[i] < max_id); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") { const auto& test_data = get_test_data(); // Build index