Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions rust/lance-index/src/scalar/btree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use lance_io::object_store::ObjectStore;
use log::{debug, warn};
use object_store::path::Path;
use rangemap::RangeInclusiveMap;
use roaring::RoaringBitmap;
use roaring::{RoaringBitmap, RoaringTreemap};
use serde::{Deserialize, Serialize, Serializer};
use snafu::location;
use tracing::{info, instrument};
Expand Down Expand Up @@ -1303,12 +1303,42 @@ impl BTreeIndex {
) -> Result<SendableRecordBatchStream> {
let value_column_index = new_data.schema().index_of(VALUE_COLUMN_NAME)?;

let new_input = Arc::new(OneShotExec::new(new_data));
// Collect new data row IDs so we can remove stale entries from the old
// index. When stable row IDs are used, the _rowid values are stable row
// IDs (not physical addresses), so fragment-based filtering alone is
// insufficient — an updated row keeps its stable row ID but moves to a
// new fragment. Without this dedup step, both the old (stale) and new
// entries would survive, causing duplicate row IDs in the merged index.
let new_schema = new_data.schema();
let new_batches: Vec<RecordBatch> = new_data.try_collect().await?;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will collect all our new data, seems bad.

let new_row_ids: RoaringTreemap = new_batches
.iter()
.flat_map(|batch| {
batch[ROW_ID]
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.expect("expected UInt64Array for row_id column")
.values()
.iter()
.copied()
})
.collect();
let new_input = Arc::new(OneShotExec::new(Box::pin(RecordBatchStreamAdapter::new(
new_schema,
futures::stream::iter(new_batches.into_iter().map(Ok)),
))));

let old_stream = self.into_data_stream().await?;
let old_stream = match valid_old_fragments {
Some(valid_frags) => filter_row_ids_by_fragments(old_stream, valid_frags),
None => old_stream,
};
// Remove old entries for row IDs that appear in the new data
let old_stream = if new_row_ids.is_empty() {
old_stream
} else {
filter_out_row_ids(old_stream, new_row_ids)
};
let old_input = Arc::new(OneShotExec::new(old_stream));
debug_assert_eq!(
old_input.schema().flattened_fields().len(),
Expand Down Expand Up @@ -1361,6 +1391,28 @@ fn filter_row_ids_by_fragments(
Box::pin(RecordBatchStreamAdapter::new(schema, filtered))
}

/// Filter a stream to remove rows whose row IDs are in the given set.
/// Used during index optimization to remove stale entries for updated rows.
fn filter_out_row_ids(
stream: SendableRecordBatchStream,
row_ids_to_remove: RoaringTreemap,
) -> SendableRecordBatchStream {
let schema = stream.schema();
let filtered = stream.map(move |batch_result| {
let batch = batch_result?;
let row_ids = batch[ROW_ID]
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.expect("expected UInt64Array for row_id column");
let mask: arrow_array::BooleanArray = row_ids
.iter()
.map(|id| id.map(|id| !row_ids_to_remove.contains(id)))
.collect();
Ok(arrow_select::filter::filter_record_batch(&batch, &mask)?)
});
Box::pin(RecordBatchStreamAdapter::new(schema, filtered))
}

fn wrap_bound(bound: &Bound<ScalarValue>) -> Bound<OrderableScalarValue> {
match bound {
Bound::Unbounded => Bound::Unbounded,
Expand Down
74 changes: 73 additions & 1 deletion rust/lance/tests/query/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use arrow_array::{
LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, StringViewArray,
};
use arrow_schema::DataType;
use lance::dataset::{InsertBuilder, UpdateBuilder, WriteParams};
use lance::Dataset;

use lance_datagen::{array, gen_batch, ArrayGeneratorExt, RowCount};
use lance_index::IndexType;
use lance_index::{DatasetIndexExt, IndexType};

use super::{test_filter, test_scan, test_take};
use crate::utils::DatasetTestCases;
Expand Down Expand Up @@ -403,3 +404,74 @@ async fn test_query_decimal(#[case] data_type: DataType) {
})
.await
}

/// Regression test: BTree index optimize after update should not panic with
/// "RowAddrTreeMap::from_sorted_iter called with non-sorted input".
///
/// Sequence: Write(SRID) → BTree(int_col) → Update(int_col) → OptimizeIndices →
/// filtered scan. The optimize merges old and new index data; the merged pages
/// must have properly sorted row addresses.
#[tokio::test]
async fn test_btree_optimize_after_update() {
use lance_index::scalar::ScalarIndexParams;

// Create 100 rows with an int_col and a category column for selective updates
let ids: Vec<i32> = (0..100).collect();
let values: Vec<i32> = (0..100).collect();
let categories: Vec<&str> = (0..100)
.map(|i| if i % 5 == 0 { "A" } else { "B" })
.collect();

let batch = RecordBatch::try_from_iter(vec![
("id", Arc::new(Int32Array::from(ids)) as ArrayRef),
("int_col", Arc::new(Int32Array::from(values)) as ArrayRef),
(
"category",
Arc::new(StringArray::from(categories)) as ArrayRef,
),
])
.unwrap();

// Write with stable row IDs
let mut ds = InsertBuilder::new("memory://")
.with_params(&WriteParams {
enable_stable_row_ids: true,
..Default::default()
})
.execute(vec![batch])
.await
.unwrap();

// Create BTree index on int_col
ds.create_index_builder(
&["int_col"],
IndexType::BTree,
&ScalarIndexParams::default(),
)
.await
.unwrap();

// Update int_col for category='A' rows (20% of data)
let result = UpdateBuilder::new(Arc::new(ds))
.update_where("category = 'A'")
.unwrap()
.set("int_col", "-1")
.unwrap()
.build()
.unwrap()
.execute()
.await
.unwrap();
ds = result.new_dataset.as_ref().clone();

// Optimize indices — merges old BTree pages with new data
ds.optimize_indices(&Default::default()).await.unwrap();

// Filtered scan should not panic
let mut scanner = ds.scan();
scanner.filter("int_col < 200").unwrap();
let result = scanner.try_into_batch().await.unwrap();

// All 100 rows should pass (20 updated to -1, 80 unchanged, all < 200)
assert_eq!(result.num_rows(), 100);
}