-
Notifications
You must be signed in to change notification settings - Fork 643
perf: speedup flat fts #6054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
westonpace
merged 8 commits into
lance-format:main
from
westonpace:perf/speedup-flat-fts
Mar 3, 2026
Merged
perf: speedup flat fts #6054
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
2bfccad
Various perf improvements to flat FTS search
westonpace 7dfee5e
Rework according to PR suggestions
westonpace 906b17e
Cleanup to get tests passing
westonpace 690d8b2
Clippy suggestions
westonpace 4facb4b
Revert change in func sig
westonpace 01e3db0
Clippy suggestions and test case fixes
westonpace 0e462e4
Address review comments
westonpace 2bdcdce
Missed issues during rebase
westonpace File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,305 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright The Lance Authors | ||
|
|
||
| //! Utilities for working with streams of [`RecordBatch`]. | ||
|
|
||
| use arrow_array::RecordBatch; | ||
| use arrow_schema::{ArrowError, SchemaRef}; | ||
| use futures::stream::{self, Stream, StreamExt}; | ||
| use std::pin::Pin; | ||
|
|
||
| /// Rechunks a stream of [`RecordBatch`] so that each output batch has | ||
| /// approximately `target_bytes` of array data. | ||
| /// | ||
| /// Small input batches are accumulated (by concatenation) until at least | ||
| /// `min_bytes` of data has been collected. If the resulting batch exceeds | ||
| /// `max_bytes`, it is sliced into roughly equal pieces of ~`max_bytes` | ||
| /// (assuming uniform row sizes). | ||
| pub fn rechunk_stream_by_size<S, E>( | ||
| input: S, | ||
| input_schema: SchemaRef, | ||
| min_bytes: usize, | ||
| max_bytes: usize, | ||
| ) -> impl Stream<Item = Result<RecordBatch, E>> | ||
| where | ||
| S: Stream<Item = Result<RecordBatch, E>>, | ||
| E: From<ArrowError>, | ||
| { | ||
| stream::try_unfold( | ||
| RechunkState { | ||
| input: Box::pin(input), | ||
| accumulated: Vec::new(), | ||
| acc_bytes: 0, | ||
| done: false, | ||
| input_schema, | ||
| min_bytes, | ||
| max_bytes, | ||
| }, | ||
| |mut state| async move { | ||
| if state.done && state.accumulated.is_empty() { | ||
| return Ok(None); | ||
| } | ||
|
|
||
| // Pull batches until we reach the byte target or exhaust input. | ||
| while !state.done && state.acc_bytes < state.min_bytes { | ||
| match state.input.next().await { | ||
| Some(Ok(batch)) => { | ||
| state.acc_bytes += batch.get_array_memory_size(); | ||
| state.accumulated.push(batch); | ||
| } | ||
| Some(Err(e)) => return Err(e), | ||
| None => { | ||
| state.done = true; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if state.accumulated.is_empty() { | ||
| return Ok(None); | ||
| } | ||
|
|
||
| // Fast path: if the first accumulated batch already meets the | ||
| // byte threshold, deliver it directly instead of concatenating | ||
| // everything together (which would just get sliced back apart). | ||
| if state.accumulated.len() > 1 | ||
| && state.accumulated[0].get_array_memory_size() >= state.min_bytes | ||
| { | ||
| let b = state.accumulated.remove(0); | ||
| state.acc_bytes -= b.get_array_memory_size(); | ||
| return Ok(Some((b, state))); | ||
| } | ||
|
|
||
| let batch = if state.accumulated.len() == 1 { | ||
| state.accumulated.pop().unwrap() | ||
| } else { | ||
| let b = | ||
| arrow_select::concat::concat_batches(&state.input_schema, &state.accumulated) | ||
| .map_err(E::from)?; | ||
| state.accumulated.clear(); | ||
| b | ||
| }; | ||
| state.acc_bytes = 0; | ||
|
|
||
| // Slice the batch into ~max_bytes pieces assuming uniform row sizes. | ||
| let batch_bytes = batch.get_array_memory_size(); | ||
| let num_rows = batch.num_rows(); | ||
| if batch_bytes <= state.max_bytes || num_rows <= 1 { | ||
| Ok(Some((batch, state))) | ||
| } else { | ||
| let rows_per_chunk = | ||
| (state.max_bytes as u64 * num_rows as u64 / batch_bytes as u64).max(1) as usize; | ||
| let mut slices = Vec::new(); | ||
| let mut offset = 0; | ||
| while offset < num_rows { | ||
| let len = rows_per_chunk.min(num_rows - offset); | ||
| slices.push(batch.slice(offset, len)); | ||
| offset += len; | ||
| } | ||
|
|
||
| let first = slices.remove(0); | ||
|
|
||
| // Stash leftover slices for subsequent iterations. | ||
| for a in &slices { | ||
| state.acc_bytes += a.get_array_memory_size(); | ||
| } | ||
| state.accumulated = slices; | ||
|
|
||
| Ok(Some((first, state))) | ||
| } | ||
| }, | ||
| ) | ||
| } | ||
|
|
||
| /// Internal state for [`rechunk_stream`]. | ||
| /// | ||
| /// Kept as a named struct so the `try_unfold` closure stays readable. | ||
| struct RechunkState<S> { | ||
| input: Pin<Box<S>>, | ||
| accumulated: Vec<RecordBatch>, | ||
| acc_bytes: usize, | ||
| done: bool, | ||
| input_schema: SchemaRef, | ||
| min_bytes: usize, | ||
| max_bytes: usize, | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|
|
||
| use std::sync::Arc; | ||
|
|
||
| use arrow_array::Int32Array; | ||
| use arrow_schema::{DataType, Field, Schema}; | ||
| use futures::executor::block_on; | ||
|
|
||
| fn make_batch(num_rows: usize) -> RecordBatch { | ||
| let schema = test_schema(); | ||
| let values: Vec<i32> = (0..num_rows as i32).collect(); | ||
| RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap() | ||
| } | ||
|
|
||
| fn test_schema() -> SchemaRef { | ||
| Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) | ||
| } | ||
|
|
||
| fn collect_rechunked( | ||
| batches: Vec<RecordBatch>, | ||
| min_bytes: usize, | ||
| max_bytes: usize, | ||
| ) -> Vec<RecordBatch> { | ||
| let input = stream::iter(batches.into_iter().map(Ok::<_, ArrowError>)); | ||
| let rechunked = rechunk_stream_by_size(input, test_schema(), min_bytes, max_bytes); | ||
| block_on(rechunked.collect::<Vec<_>>()) | ||
| .into_iter() | ||
| .map(|r| r.unwrap()) | ||
| .collect() | ||
| } | ||
|
|
||
| fn total_rows(batches: &[RecordBatch]) -> usize { | ||
| batches.iter().map(|b| b.num_rows()).sum() | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_empty_stream() { | ||
| let result = collect_rechunked(vec![], 100, 200); | ||
| assert!(result.is_empty()); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_single_batch_passthrough() { | ||
| let batch = make_batch(100); | ||
| let bytes = batch.get_array_memory_size(); | ||
| // Batch is between min and max — should pass through as-is. | ||
| let result = collect_rechunked(vec![batch], bytes / 2, bytes * 2); | ||
| assert_eq!(result.len(), 1); | ||
| assert_eq!(result[0].num_rows(), 100); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_small_batches_concatenated() { | ||
| let one_batch_bytes = make_batch(10).get_array_memory_size(); | ||
| let batches: Vec<_> = (0..8).map(|_| make_batch(10)).collect(); | ||
| // min = 5 batches worth, max = 10 batches worth. | ||
| let result = collect_rechunked(batches, one_batch_bytes * 5, one_batch_bytes * 10); | ||
| assert_eq!(total_rows(&result), 80); | ||
| // Should have been concatenated into fewer batches than the 8 inputs. | ||
| assert!( | ||
| result.len() < 8, | ||
| "expected fewer output batches, got {}", | ||
| result.len() | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_large_batch_sliced() { | ||
| let batch = make_batch(1000); | ||
| let bytes = batch.get_array_memory_size(); | ||
| let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4); | ||
| assert_eq!(total_rows(&result), 1000); | ||
| assert!( | ||
| result.len() >= 4, | ||
| "expected at least 4 slices, got {}", | ||
| result.len() | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_sliced_leftovers_are_not_recombined() { | ||
| // Key test for the fast-path optimisation. When a large batch is | ||
| // sliced, leftover slices should be delivered one-at-a-time without | ||
| // being concatenated back together. We verify this by checking that | ||
| // every output buffer pointer falls inside the original batch's | ||
| // allocation (i.e. they are all zero-copy slices, not fresh copies). | ||
| let batch = make_batch(1000); | ||
| let bytes = batch.get_array_memory_size(); | ||
| let orig_data = batch.column(0).to_data(); | ||
| let orig_buf = &orig_data.buffers()[0]; | ||
| let orig_start = orig_buf.as_ptr() as usize; | ||
| let orig_end = orig_start + orig_buf.len(); | ||
|
|
||
| let result = collect_rechunked(vec![batch], bytes / 8, bytes / 4); | ||
|
|
||
| assert_eq!(total_rows(&result), 1000); | ||
| assert!(result.len() >= 4); | ||
|
|
||
| for (i, b) in result.iter().enumerate() { | ||
| let ptr = b.column(0).to_data().buffers()[0].as_ptr() as usize; | ||
| assert!( | ||
| ptr >= orig_start && ptr < orig_end, | ||
| "slice {i} buffer at {ptr:#x} is outside the original allocation \ | ||
| [{orig_start:#x}, {orig_end:#x}) — it was re-concatenated" | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_flush_remainder_on_stream_end() { | ||
| // Data below min_bytes should still be flushed when the stream ends. | ||
| let batch = make_batch(10); | ||
| let bytes = batch.get_array_memory_size(); | ||
| let result = collect_rechunked(vec![batch], bytes * 100, bytes * 200); | ||
| assert_eq!(result.len(), 1); | ||
| assert_eq!(result[0].num_rows(), 10); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_large_then_small_batches() { | ||
| // After a large batch is fully drained, subsequent small batches | ||
| // should be accumulated normally. | ||
| let large = make_batch(1000); | ||
| let small_bytes = make_batch(10).get_array_memory_size(); | ||
| let batches = vec![ | ||
| large, | ||
| make_batch(10), | ||
| make_batch(10), | ||
| make_batch(10), | ||
| make_batch(10), | ||
| make_batch(10), | ||
| ]; | ||
| let result = collect_rechunked(batches, small_bytes * 3, small_bytes * 100); | ||
| assert_eq!(total_rows(&result), 1050); | ||
| // The large batch should appear (possibly sliced) followed by | ||
| // concatenated small batches, so we should have fewer output batches | ||
| // than the 6 inputs. | ||
| assert!(result.len() < 6); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_row_preservation_across_slicing() { | ||
| // Verify that every input row appears exactly once in the output | ||
| // and in the correct order after slicing. | ||
| let batch = make_batch(237); // odd count to exercise remainder slice | ||
| let bytes = batch.get_array_memory_size(); | ||
| let result = collect_rechunked(vec![batch], bytes / 8, bytes / 5); | ||
|
|
||
| assert_eq!(total_rows(&result), 237); | ||
|
|
||
| let values: Vec<i32> = result | ||
| .iter() | ||
| .flat_map(|b| { | ||
| b.column(0) | ||
| .as_any() | ||
| .downcast_ref::<Int32Array>() | ||
| .unwrap() | ||
| .values() | ||
| .iter() | ||
| .copied() | ||
| }) | ||
| .collect(); | ||
| let expected: Vec<i32> = (0..237).collect(); | ||
| assert_eq!(values, expected); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_error_propagation() { | ||
| let input = stream::iter(vec![ | ||
| Ok(make_batch(10)), | ||
| Err(ArrowError::ComputeError("boom".into())), | ||
| Ok(make_batch(10)), | ||
| ]); | ||
| let rechunked = rechunk_stream_by_size(input, test_schema(), 1, usize::MAX); | ||
| let results: Vec<Result<RecordBatch, ArrowError>> = block_on(rechunked.collect()); | ||
| assert!(results.iter().any(|r| r.is_err())); | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here appears to be quite sensitive to
max_bytes. Could our data instate.accumulatedbe repeatedly concatenated and sliced?For example, if
max_bytesis 1MiB and we got 64MiB data, the data insidestate.accumulatedwith be64 -> 63 -> 62 ... -> 2 -> 1. Should we maintain an offset about the sliced data and make sure we only concat once on the raw input?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I've modified the code so it only concatenates if the first slice is not large enough (smaller than min_bytes). Since min_bytes should be sufficiently less than max_bytes I think we should be good in most cases. In the event we have some really large outlier row and, as a result, slice inappropriately, we may still concatenate but I think that's enough of an outlier for the moment.