Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions python/python/tests/test_scalar_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4461,8 +4461,8 @@ def test_vector_filter_fts_search(tmp_path):
nearest=vector_query,
filter=PhraseQuery("text", "text"),
)
result = scanner.to_table()
assert [299, 300] == result["id"].to_pylist()
with pytest.raises(ValueError):
scanner.to_table()

# Case 6: search with prefilter=false, search_filter=phrase("text")
scanner = ds.scanner(
Expand All @@ -4473,5 +4473,5 @@ def test_vector_filter_fts_search(tmp_path):
"search_filter": PhraseQuery("text", "text"),
},
)
result = scanner.to_table()
assert [300] == result["id"].to_pylist()
with pytest.raises(ValueError):
scanner.to_table()
10 changes: 4 additions & 6 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2898,9 +2898,8 @@ def test_fts_filter_vector_search(tmp_path):
filter=PhraseQuery("text", "text"),
)

result = scanner.to_table()
ids_result = result["id"].to_pylist()
assert [299, 300] == ids_result
with pytest.raises(ValueError):
scanner.to_table()

# Case 6: search with prefilter=false, search_filter=phrase("text")
scanner = dataset.scanner(
Expand All @@ -2912,6 +2911,5 @@ def test_fts_filter_vector_search(tmp_path):
},
)

result = scanner.to_table()
ids_result = result["id"].to_pylist()
assert [300] == ids_result
with pytest.raises(ValueError):
scanner.to_table()
1 change: 1 addition & 0 deletions rust/lance-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
half = { workspace = true }
jsonb ={ workspace = true }
num-traits = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions rust/lance-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod json;
pub mod list;
pub mod memory;
pub mod scalar;
pub mod stream;
pub mod r#struct;

/// Arrow extension metadata key for extension name
Expand Down
305 changes: 305 additions & 0 deletions rust/lance-arrow/src/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! Utilities for working with streams of [`RecordBatch`].

use arrow_array::RecordBatch;
use arrow_schema::{ArrowError, SchemaRef};
use futures::stream::{self, Stream, StreamExt};
use std::pin::Pin;

/// Rechunks a stream of [`RecordBatch`] so that each output batch has
/// approximately `target_bytes` of array data.
///
/// Small input batches are accumulated (by concatenation) until at least
/// `min_bytes` of data has been collected. If the resulting batch exceeds
/// `max_bytes`, it is sliced into roughly equal pieces of ~`max_bytes`
/// (assuming uniform row sizes).
pub fn rechunk_stream_by_size<S, E>(
input: S,
input_schema: SchemaRef,
min_bytes: usize,
max_bytes: usize,
) -> impl Stream<Item = Result<RecordBatch, E>>
where
S: Stream<Item = Result<RecordBatch, E>>,
E: From<ArrowError>,
{
stream::try_unfold(
RechunkState {
input: Box::pin(input),
accumulated: Vec::new(),
acc_bytes: 0,
done: false,
input_schema,
min_bytes,
max_bytes,
},
|mut state| async move {
if state.done && state.accumulated.is_empty() {
return Ok(None);
}

// Pull batches until we reach the byte target or exhaust input.
while !state.done && state.acc_bytes < state.min_bytes {
match state.input.next().await {
Some(Ok(batch)) => {
state.acc_bytes += batch.get_array_memory_size();
state.accumulated.push(batch);
}
Some(Err(e)) => return Err(e),
None => {
state.done = true;
}
}
}

if state.accumulated.is_empty() {
return Ok(None);
}

// Fast path: if the first accumulated batch already meets the
// byte threshold, deliver it directly instead of concatenating
// everything together (which would just get sliced back apart).
if state.accumulated.len() > 1
&& state.accumulated[0].get_array_memory_size() >= state.min_bytes
{
let b = state.accumulated.remove(0);
state.acc_bytes -= b.get_array_memory_size();
return Ok(Some((b, state)));
}

let batch = if state.accumulated.len() == 1 {
state.accumulated.pop().unwrap()
} else {
let b =
arrow_select::concat::concat_batches(&state.input_schema, &state.accumulated)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here appears to be quite sensitive to max_bytes. Could our data in state.accumulated be repeatedly concatenated and sliced?

For example, if max_bytes is 1MiB and we got 64MiB data, the data inside state.accumulated with be 64 -> 63 -> 62 ... -> 2 -> 1. Should we maintain an offset about the sliced data and make sure we only concat once on the raw input?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I've modified the code so it only concatenates if the first slice is not large enough (smaller than min_bytes). Since min_bytes should be sufficiently less than max_bytes I think we should be good in most cases. In the event we have some really large outlier row and, as a result, slice inappropriately, we may still concatenate but I think that's enough of an outlier for the moment.

.map_err(E::from)?;
state.accumulated.clear();
b
};
state.acc_bytes = 0;

// Slice the batch into ~max_bytes pieces assuming uniform row sizes.
let batch_bytes = batch.get_array_memory_size();
let num_rows = batch.num_rows();
if batch_bytes <= state.max_bytes || num_rows <= 1 {
Ok(Some((batch, state)))
} else {
let rows_per_chunk =
(state.max_bytes as u64 * num_rows as u64 / batch_bytes as u64).max(1) as usize;
let mut slices = Vec::new();
let mut offset = 0;
while offset < num_rows {
let len = rows_per_chunk.min(num_rows - offset);
slices.push(batch.slice(offset, len));
offset += len;
}

let first = slices.remove(0);

// Stash leftover slices for subsequent iterations.
for a in &slices {
state.acc_bytes += a.get_array_memory_size();
}
state.accumulated = slices;

Ok(Some((first, state)))
}
},
)
}

/// Internal state for [`rechunk_stream`].
///
/// Kept as a named struct so the `try_unfold` closure stays readable.
struct RechunkState<S> {
input: Pin<Box<S>>,
accumulated: Vec<RecordBatch>,
acc_bytes: usize,
done: bool,
input_schema: SchemaRef,
min_bytes: usize,
max_bytes: usize,
}

#[cfg(test)]
mod tests {
use super::*;

use std::sync::Arc;

use arrow_array::Int32Array;
use arrow_schema::{DataType, Field, Schema};
use futures::executor::block_on;

fn make_batch(num_rows: usize) -> RecordBatch {
let schema = test_schema();
let values: Vec<i32> = (0..num_rows as i32).collect();
RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
}

fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]))
}

fn collect_rechunked(
batches: Vec<RecordBatch>,
min_bytes: usize,
max_bytes: usize,
) -> Vec<RecordBatch> {
let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>));
let rechunked = rechunk_stream_by_size(input, test_schema(), min_bytes, max_bytes);
block_on(rechunked.collect::<Vec<_>>())
.into_iter()
.map(|r| r.unwrap())
.collect()
}

fn total_rows(batches: &[RecordBatch]) -> usize {
batches.iter().map(|b| b.num_rows()).sum()
}

#[test]
fn test_empty_stream() {
let result = collect_rechunked(vec![], 100, 200);
assert!(result.is_empty());
}

#[test]
fn test_single_batch_passthrough() {
let batch = make_batch(100);
let bytes = batch.get_array_memory_size();
// Batch is between min and max — should pass through as-is.
let result = collect_rechunked(vec![batch], bytes / 2, bytes * 2);
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 100);
}

#[test]
fn test_small_batches_concatenated() {
let one_batch_bytes = make_batch(10).get_array_memory_size();
let batches: Vec<_> = (0..8).map(|_| make_batch(10)).collect();
// min = 5 batches worth, max = 10 batches worth.
let result = collect_rechunked(batches, one_batch_bytes * 5, one_batch_bytes * 10);
assert_eq!(total_rows(&result), 80);
// Should have been concatenated into fewer batches than the 8 inputs.
assert!(
result.len() < 8,
"expected fewer output batches, got {}",
result.len()
);
}

#[test]
fn test_large_batch_sliced() {
let batch = make_batch(1000);
let bytes = batch.get_array_memory_size();
let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);
assert_eq!(total_rows(&result), 1000);
assert!(
result.len() >= 4,
"expected at least 4 slices, got {}",
result.len()
);
}

#[test]
fn test_sliced_leftovers_are_not_recombined() {
// Key test for the fast-path optimisation. When a large batch is
// sliced, leftover slices should be delivered one-at-a-time without
// being concatenated back together. We verify this by checking that
// every output buffer pointer falls inside the original batch's
// allocation (i.e. they are all zero-copy slices, not fresh copies).
let batch = make_batch(1000);
let bytes = batch.get_array_memory_size();
let orig_data = batch.column(0).to_data();
let orig_buf = &orig_data.buffers()[0];
let orig_start = orig_buf.as_ptr() as usize;
let orig_end = orig_start + orig_buf.len();

let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4);

assert_eq!(total_rows(&result), 1000);
assert!(result.len() >= 4);

for (i, b) in result.iter().enumerate() {
let ptr = b.column(0).to_data().buffers()[0].as_ptr() as usize;
assert!(
ptr >= orig_start && ptr < orig_end,
"slice {i} buffer at {ptr:#x} is outside the original allocation \
[{orig_start:#x}, {orig_end:#x}) — it was re-concatenated"
);
}
}

#[test]
fn test_flush_remainder_on_stream_end() {
// Data below min_bytes should still be flushed when the stream ends.
let batch = make_batch(10);
let bytes = batch.get_array_memory_size();
let result = collect_rechunked(vec![batch], bytes * 100, bytes * 200);
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 10);
}

#[test]
fn test_large_then_small_batches() {
// After a large batch is fully drained, subsequent small batches
// should be accumulated normally.
let large = make_batch(1000);
let small_bytes = make_batch(10).get_array_memory_size();
let batches = vec![
large,
make_batch(10),
make_batch(10),
make_batch(10),
make_batch(10),
make_batch(10),
];
let result = collect_rechunked(batches, small_bytes * 3, small_bytes * 100);
assert_eq!(total_rows(&result), 1050);
// The large batch should appear (possibly sliced) followed by
// concatenated small batches, so we should have fewer output batches
// than the 6 inputs.
assert!(result.len() < 6);
}

#[test]
fn test_row_preservation_across_slicing() {
// Verify that every input row appears exactly once in the output
// and in the correct order after slicing.
let batch = make_batch(237); // odd count to exercise remainder slice
let bytes = batch.get_array_memory_size();
let result = collect_rechunked(vec![batch], bytes / 8, bytes / 5);

assert_eq!(total_rows(&result), 237);

let values: Vec<i32> = result
.iter()
.flat_map(|b| {
b.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.iter()
.copied()
})
.collect();
let expected: Vec<i32> = (0..237).collect();
assert_eq!(values, expected);
}

#[test]
fn test_error_propagation() {
let input = stream::iter(vec![
Ok(make_batch(10)),
Err(ArrowError::ComputeError("boom".into())),
Ok(make_batch(10)),
]);
let rechunked = rechunk_stream_by_size(input, test_schema(), 1, usize::MAX);
let results: Vec<Result<RecordBatch, ArrowError>> = block_on(rechunked.collect());
assert!(results.iter().any(|r| r.is_err()));
}
}
Loading