diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 00000000000..681311eb9cf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 61c277226dd..6ba87667f10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10868,6 +10868,7 @@ dependencies = [ "codspeed-divan-compat", "itertools 0.14.0", "num-traits", + "parking_lot", "prost 0.14.3", "rand 0.10.0", "rstest", @@ -10897,6 +10898,7 @@ dependencies = [ name = "vortex-sequence" version = "0.1.0" dependencies = [ + "codspeed-divan-compat", "itertools 0.14.0", "num-traits", "prost 0.14.3", diff --git a/encodings/runend/Cargo.toml b/encodings/runend/Cargo.toml index 01a5b8d7a3e..a45f5528d69 100644 --- a/encodings/runend/Cargo.toml +++ b/encodings/runend/Cargo.toml @@ -18,6 +18,7 @@ arbitrary = { workspace = true, optional = true } arrow-array = { workspace = true, optional = true } itertools = { workspace = true } num-traits = { workspace = true } +parking_lot = { workspace = true } prost = { workspace = true } vortex-array = { workspace = true } vortex-buffer = { workspace = true } @@ -52,3 +53,15 @@ harness = false [[bench]] name = "run_end_decode" harness = false + +[[bench]] +name = "run_end_scalar_fn" +harness = false + +[[bench]] +name = "run_end_take" +harness = false + +[[bench]] +name = "run_end_filter" +harness = false diff --git a/encodings/runend/benches/run_end_filter.rs b/encodings/runend/benches/run_end_filter.rs new file mode 100644 index 00000000000..f4caba3562b --- /dev/null +++ b/encodings/runend/benches/run_end_filter.rs @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::cast_possible_truncation, clippy::unwrap_used)] + +use std::fmt; + +use divan::Bencher; +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand::seq::SliceRandom; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::PrimitiveArray; +use vortex_buffer::Buffer; +use vortex_mask::Mask; +use vortex_runend::_benchmarking::RunEndFilterMode; +use vortex_runend::_benchmarking::override_run_end_filter_mode; +use vortex_runend::RunEnd; + +fn main() { + divan::main(); +} + +const LEN: usize = 1_048_576; +const TRUE_COUNT: usize = 32_768; +const LONG_SLICE_COUNT: usize = 8; +const SHORT_SLICE_LEN: usize = 8; +const CLUSTER_COUNT: usize = 8; +const LONG_RUN_HEAVY_SLICE_COUNT: usize = 4; + +#[derive(Clone, Copy, Debug)] +enum MaskShape { + Random, + FewLongSlices, + ManyShortSlices, + ClusteredFewRuns, + LongRunHeavy, +} + +impl fmt::Display for MaskShape { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Random => write!(f, "random"), + Self::FewLongSlices => write!(f, "few_long_slices"), + Self::ManyShortSlices => write!(f, "many_short_slices"), + Self::ClusteredFewRuns => write!(f, "clustered_few_runs"), + Self::LongRunHeavy => write!(f, "long_run_heavy"), + } + } +} + +#[derive(Clone, Copy, Debug)] +struct BenchArgs { + run_length: usize, + mask_shape: MaskShape, +} + +impl fmt::Display for BenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_runs_{}", self.mask_shape, self.run_length) + } +} + +const BENCH_ARGS: &[BenchArgs] = &[ + BenchArgs { + run_length: 16, + mask_shape: MaskShape::Random, + }, + BenchArgs { + run_length: 256, + mask_shape: MaskShape::Random, + }, + BenchArgs { + run_length: 4096, + mask_shape: MaskShape::Random, + }, + BenchArgs { + run_length: 16, + mask_shape: MaskShape::FewLongSlices, + }, + BenchArgs { + run_length: 256, + mask_shape: MaskShape::FewLongSlices, + }, + BenchArgs { + run_length: 4096, + mask_shape: MaskShape::FewLongSlices, + }, + BenchArgs { + run_length: 16, + mask_shape: MaskShape::ManyShortSlices, + }, + BenchArgs { + run_length: 256, + mask_shape: MaskShape::ManyShortSlices, + }, + BenchArgs { + run_length: 4096, + mask_shape: MaskShape::ManyShortSlices, + }, + BenchArgs { + run_length: 16, + mask_shape: MaskShape::ClusteredFewRuns, + }, + BenchArgs { + run_length: 256, + mask_shape: MaskShape::ClusteredFewRuns, + }, + BenchArgs { + run_length: 4096, + mask_shape: MaskShape::ClusteredFewRuns, + }, + BenchArgs { + run_length: 4096, + mask_shape: MaskShape::LongRunHeavy, + }, + BenchArgs { + run_length: 16_384, + mask_shape: MaskShape::LongRunHeavy, + }, + BenchArgs { + run_length: 65_536, + mask_shape: MaskShape::LongRunHeavy, + }, +]; + +#[divan::bench(args = BENCH_ARGS)] +fn filter_auto(bencher: Bencher, args: BenchArgs) { + filter_with_mode(bencher, args, RunEndFilterMode::Auto); +} + +#[divan::bench(args = BENCH_ARGS)] +fn filter_force_take(bencher: Bencher, args: BenchArgs) { + filter_with_mode(bencher, args, RunEndFilterMode::Take); +} + +#[divan::bench(args = BENCH_ARGS)] +fn filter_force_encoded(bencher: Bencher, args: BenchArgs) { + filter_with_mode(bencher, args, RunEndFilterMode::Encoded); +} + +fn filter_with_mode(bencher: Bencher, args: BenchArgs, filter_mode: RunEndFilterMode) { + let array = run_end_fixture(args.run_length); + let mask = mask_fixture(args.mask_shape, args.run_length); + + bencher + .with_inputs(|| { + ( + array.clone(), + mask.clone(), + LEGACY_SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(array, mask, ctx)| { + let _filter_mode_guard = override_run_end_filter_mode(filter_mode); + let result = array + .filter(mask.clone()) + .unwrap() + .execute::(ctx) + .unwrap(); + divan::black_box(result); + }); +} + +fn run_end_fixture(run_length: usize) -> ArrayRef { + let run_count = LEN.div_ceil(run_length); + let ends = (0..run_count) + .map(|run_idx| ((run_idx + 1) * run_length).min(LEN) as u32) + .collect::>() + .into_array(); + let values = + PrimitiveArray::from_iter((0..run_count).map(|run_idx| run_idx as i32)).into_array(); + + RunEnd::new(ends, values).into_array() +} + +fn mask_fixture(mask_shape: MaskShape, run_length: usize) -> Mask { + match mask_shape { + MaskShape::Random => random_mask(run_length), + MaskShape::FewLongSlices => few_long_slices_mask(run_length), + MaskShape::ManyShortSlices => many_short_slices_mask(run_length), + MaskShape::ClusteredFewRuns => clustered_few_runs_mask(run_length), + MaskShape::LongRunHeavy => long_run_heavy_mask(run_length), + } +} + +fn random_mask(run_length: usize) -> Mask { + let mut rng = StdRng::seed_from_u64(run_length as u64); + let mut indices = (0..LEN).collect::>(); + indices.shuffle(&mut rng); + indices.truncate(TRUE_COUNT); + indices.sort_unstable(); + + Mask::from_indices(LEN, indices) +} + +fn few_long_slices_mask(run_length: usize) -> Mask { + let slice_len = TRUE_COUNT / LONG_SLICE_COUNT; + let spacing = LEN / LONG_SLICE_COUNT; + let misalignment = (run_length / 2).min(slice_len / 2); + let slices = (0..LONG_SLICE_COUNT) + .map(|slice_idx| { + let start = slice_idx * spacing + misalignment; + (start, start + slice_len) + }) + .collect(); + + Mask::from_slices(LEN, slices) +} + +fn many_short_slices_mask(run_length: usize) -> Mask { + let slice_count = TRUE_COUNT / SHORT_SLICE_LEN; + let spacing = LEN / slice_count; + let misalignment = (run_length / 4).min(spacing - SHORT_SLICE_LEN); + let slices = (0..slice_count) + .map(|slice_idx| { + let start = slice_idx * spacing + misalignment; + (start, start + SHORT_SLICE_LEN) + }) + .collect(); + + Mask::from_slices(LEN, slices) +} + +fn clustered_few_runs_mask(run_length: usize) -> Mask { + let run_count = LEN.div_ceil(run_length); + let runs_to_keep = TRUE_COUNT.div_ceil(run_length); + let cluster_count = runs_to_keep.min(CLUSTER_COUNT); + let base_cluster_runs = runs_to_keep / cluster_count; + let extra_cluster_runs = runs_to_keep % cluster_count; + let spacing = run_count / cluster_count; + + let mut next_start_run = 0usize; + let slices = (0..cluster_count) + .map(|cluster_idx| { + let cluster_runs = base_cluster_runs + usize::from(cluster_idx < extra_cluster_runs); + let start_run = next_start_run; + let end_run = (start_run + cluster_runs).min(run_count); + next_start_run += spacing; + + let start = start_run * run_length; + let end = (end_run * run_length).min(LEN); + (start, end) + }) + .collect(); + + Mask::from_slices(LEN, slices) +} + +fn long_run_heavy_mask(run_length: usize) -> Mask { + let run_count = LEN.div_ceil(run_length); + let slice_count = LONG_RUN_HEAVY_SLICE_COUNT.min(run_count); + let slice_len = (run_length * 3) / 4; + let misalignment = (run_length - slice_len).min(13); + let spacing = run_count / slice_count; + + let slices = (0..slice_count) + .map(|slice_idx| { + let start_run = slice_idx * spacing; + let start = start_run * run_length + misalignment; + let end = (start + slice_len).min(LEN); + (start, end) + }) + .collect(); + + Mask::from_slices(LEN, slices) +} diff --git a/encodings/runend/benches/run_end_scalar_fn.rs b/encodings/runend/benches/run_end_scalar_fn.rs new file mode 100644 index 00000000000..3fe7cf4dc73 --- /dev/null +++ b/encodings/runend/benches/run_end_scalar_fn.rs @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::cast_possible_truncation, clippy::unwrap_used)] + +use std::fmt; + +use divan::Bencher; +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::Nullability; +use vortex_array::scalar::Scalar; +use vortex_buffer::Buffer; +use vortex_runend::RunEnd; + +fn main() { + divan::main(); +} + +const LEN: usize = 1_048_576; + +#[derive(Clone, Copy, Debug)] +enum OutputKind { + Utf8, + Binary, +} + +impl fmt::Display for OutputKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Utf8 => write!(f, "utf8"), + Self::Binary => write!(f, "binary"), + } + } +} + +#[derive(Clone, Copy, Debug)] +struct BenchArgs { + run_length: usize, + output_kind: OutputKind, +} + +impl fmt::Display for BenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_runs_{}", self.output_kind, self.run_length) + } +} + +const BENCH_ARGS: &[BenchArgs] = &[ + BenchArgs { + run_length: 16, + output_kind: OutputKind::Utf8, + }, + BenchArgs { + run_length: 256, + output_kind: OutputKind::Utf8, + }, + BenchArgs { + run_length: 4096, + output_kind: OutputKind::Utf8, + }, + BenchArgs { + run_length: 16, + output_kind: OutputKind::Binary, + }, + BenchArgs { + run_length: 256, + output_kind: OutputKind::Binary, + }, + BenchArgs { + run_length: 4096, + output_kind: OutputKind::Binary, + }, +]; + +#[divan::bench(args = BENCH_ARGS)] +fn zip_constants(bencher: Bencher, args: BenchArgs) { + let mask = bool_run_end_fixture(args.run_length); + let (if_true, if_false) = constants(args.output_kind); + + bencher + .with_inputs(|| { + ( + mask.clone(), + if_true.clone(), + if_false.clone(), + LEGACY_SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(mask, if_true, if_false, ctx)| { + let result = mask + .zip(if_true.clone(), if_false.clone()) + .unwrap() + .execute::(ctx) + .unwrap(); + divan::black_box(result); + }); +} + +fn bool_run_end_fixture(run_length: usize) -> ArrayRef { + let run_count = LEN.div_ceil(run_length); + let ends = (0..run_count) + .map(|run_idx| ((run_idx + 1) * run_length).min(LEN) as u32) + .collect::>() + .into_array(); + let values = + BoolArray::from_iter((0..run_count).map(|run_idx| run_idx.is_multiple_of(2))).into_array(); + + RunEnd::new(ends, values).into_array() +} + +fn constants(output_kind: OutputKind) -> (ArrayRef, ArrayRef) { + match output_kind { + OutputKind::Utf8 => ( + ConstantArray::new( + Scalar::utf8( + "runend branch with a long utf8 payload", + Nullability::NonNullable, + ), + LEN, + ) + .into_array(), + ConstantArray::new( + Scalar::utf8( + "runend branch with a different utf8 payload", + Nullability::NonNullable, + ), + LEN, + ) + .into_array(), + ), + OutputKind::Binary => ( + ConstantArray::new( + Scalar::binary(vec![0xAA; 48], Nullability::NonNullable), + LEN, + ) + .into_array(), + ConstantArray::new( + Scalar::binary(vec![0x55; 64], Nullability::NonNullable), + LEN, + ) + .into_array(), + ), + } +} diff --git a/encodings/runend/benches/run_end_take.rs b/encodings/runend/benches/run_end_take.rs new file mode 100644 index 00000000000..85fccb8cf80 --- /dev/null +++ b/encodings/runend/benches/run_end_take.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::cast_possible_truncation, clippy::unwrap_used)] + +use std::fmt; + +use divan::Bencher; +use rand::RngExt; +use rand::SeedableRng; +use rand::rngs::StdRng; +use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; +use vortex_array::RecursiveCanonical; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::PrimitiveArray; +use vortex_buffer::Buffer; +use vortex_runend::RunEnd; + +fn main() { + divan::main(); +} + +const LEN: usize = 1_048_576; +const TAKE_LEN: usize = 32_768; + +#[derive(Clone, Copy, Debug)] +enum IndexPattern { + SortedSparse, + Contiguous, + RandomUnsorted, + Nullable, +} + +impl fmt::Display for IndexPattern { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::SortedSparse => write!(f, "sorted_sparse"), + Self::Contiguous => write!(f, "contiguous"), + Self::RandomUnsorted => write!(f, "random_unsorted"), + Self::Nullable => write!(f, "nullable"), + } + } +} + +#[derive(Clone, Copy, Debug)] +struct BenchArgs { + run_length: usize, + index_pattern: IndexPattern, +} + +impl fmt::Display for BenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_runs_{}", self.index_pattern, self.run_length) + } +} + +const BENCH_ARGS: &[BenchArgs] = &[ + BenchArgs { + run_length: 16, + index_pattern: IndexPattern::SortedSparse, + }, + BenchArgs { + run_length: 256, + index_pattern: IndexPattern::SortedSparse, + }, + BenchArgs { + run_length: 4096, + index_pattern: IndexPattern::SortedSparse, + }, + BenchArgs { + run_length: 16, + index_pattern: IndexPattern::Contiguous, + }, + BenchArgs { + run_length: 256, + index_pattern: IndexPattern::Contiguous, + }, + BenchArgs { + run_length: 4096, + index_pattern: IndexPattern::Contiguous, + }, + BenchArgs { + run_length: 16, + index_pattern: IndexPattern::RandomUnsorted, + }, + BenchArgs { + run_length: 256, + index_pattern: IndexPattern::RandomUnsorted, + }, + BenchArgs { + run_length: 4096, + index_pattern: IndexPattern::RandomUnsorted, + }, + BenchArgs { + run_length: 16, + index_pattern: IndexPattern::Nullable, + }, + BenchArgs { + run_length: 256, + index_pattern: IndexPattern::Nullable, + }, + BenchArgs { + run_length: 4096, + index_pattern: IndexPattern::Nullable, + }, +]; + +#[divan::bench(args = BENCH_ARGS)] +fn take_indices(bencher: Bencher, args: BenchArgs) { + let array = run_end_fixture(args.run_length); + let indices = indices_fixture(args.index_pattern); + + bencher + .with_inputs(|| { + ( + array.clone(), + indices.clone(), + LEGACY_SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(array, indices, ctx)| { + let result = array + .take(indices.clone()) + .unwrap() + .execute::(ctx) + .unwrap(); + divan::black_box(result); + }); +} + +fn run_end_fixture(run_length: usize) -> vortex_array::ArrayRef { + let run_count = LEN.div_ceil(run_length); + let ends = (0..run_count) + .map(|run_idx| ((run_idx + 1) * run_length).min(LEN) as u32) + .collect::>() + .into_array(); + let values = + PrimitiveArray::from_iter((0..run_count).map(|run_idx| run_idx as i32)).into_array(); + + RunEnd::new(ends, values).into_array() +} + +fn indices_fixture(index_pattern: IndexPattern) -> vortex_array::ArrayRef { + match index_pattern { + IndexPattern::SortedSparse => { + let stride = LEN / TAKE_LEN; + PrimitiveArray::from_iter((0..TAKE_LEN).map(|idx| (idx * stride) as u32)).into_array() + } + IndexPattern::Contiguous => { + let start = (LEN - TAKE_LEN) / 2; + PrimitiveArray::from_iter((start..start + TAKE_LEN).map(|idx| idx as u32)).into_array() + } + IndexPattern::RandomUnsorted => { + let mut rng = StdRng::seed_from_u64(0); + PrimitiveArray::from_iter((0..TAKE_LEN).map(|_| rng.random_range(0..LEN as u32))) + .into_array() + } + IndexPattern::Nullable => { + let stride = LEN / TAKE_LEN; + PrimitiveArray::from_option_iter( + (0..TAKE_LEN).map(|idx| (!idx.is_multiple_of(8)).then_some((idx * stride) as u32)), + ) + .into_array() + } + } +} diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index e0ef5382d40..633a8dbc2f7 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -2,9 +2,13 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::cmp::min; -use std::ops::AddAssign; +use std::sync::atomic::AtomicU8; +use std::sync::atomic::Ordering; use num_traits::AsPrimitive; +use num_traits::NumCast; +use parking_lot::Mutex; +use parking_lot::MutexGuard; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; @@ -19,10 +23,31 @@ use vortex_buffer::buffer_mut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; +use vortex_mask::MaskValues; +use crate::_benchmarking::RunEndFilterMode; use crate::RunEnd; use crate::compute::take::take_indices_unchecked; -const FILTER_TAKE_THRESHOLD: f64 = 0.1; + +const FILTER_TAKE_MIN_TRUE_COUNT: usize = 25; +const FILTER_ENCODED_DENSITY_SHIFT: usize = 3; +const FILTER_ENCODED_MIN_TRUES_PER_RUN: usize = 32; +const FILTER_ENCODED_MAX_SLICE_COUNT: usize = 32; +const FILTER_ENCODED_MIN_AVG_SLICE_LEN: usize = 256; + +static FILTER_MODE_OVERRIDE: AtomicU8 = AtomicU8::new(RunEndFilterMode::Auto.as_u8()); +static FILTER_MODE_OVERRIDE_LOCK: Mutex<()> = Mutex::new(()); + +pub(crate) fn override_run_end_filter_mode(mode: RunEndFilterMode) -> impl Drop { + let lock = FILTER_MODE_OVERRIDE_LOCK.lock(); + let previous_mode = current_filter_mode(); + FILTER_MODE_OVERRIDE.store(mode.as_u8(), Ordering::SeqCst); + + RunEndFilterModeGuard { + previous_mode, + _lock: lock, + } +} impl FilterKernel for RunEnd { fn filter( @@ -34,70 +59,136 @@ impl FilterKernel for RunEnd { .values() .vortex_expect("FilterKernel precondition: mask is Mask::Values"); - let runs_ratio = mask_values.true_count() as f64 / array.ends().len() as f64; - - if runs_ratio < FILTER_TAKE_THRESHOLD || mask_values.true_count() < 25 { - Ok(Some(take_indices_unchecked( + match select_filter_path(array, mask_values) { + FilterPath::Take => Ok(Some(take_indices_unchecked( &array, mask_values.indices(), &Validity::NonNullable, - )?)) - } else { - let primitive_run_ends = array.ends().clone().execute::(ctx)?; - let (run_ends, values_mask) = - match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |P| { - filter_run_end_primitive( - primitive_run_ends.as_slice::

(), - array.offset() as u64, - array.len() as u64, - mask_values.bit_buffer(), - )? - }); - let values = array.values().filter(values_mask)?; - - // SAFETY: guaranteed by implementation of filter_run_end_primitive - unsafe { - Ok(Some( - RunEnd::new_unchecked( - run_ends.into_array(), - values, - 0, - mask_values.true_count(), - ) - .into_array(), - )) + )?)), + FilterPath::Encoded => { + let primitive_run_ends = array.ends().clone().execute::(ctx)?; + let (run_ends, values_mask) = + match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |P| { + filter_run_end_primitive( + primitive_run_ends.as_slice::

(), + array.offset(), + array.len(), + mask_values.bit_buffer(), + )? + }); + let values = array.values().filter(values_mask)?; + + // SAFETY: guaranteed by implementation of filter_run_end_primitive + unsafe { + Ok(Some( + RunEnd::new_unchecked( + run_ends.into_array(), + values, + 0, + mask_values.true_count(), + ) + .into_array(), + )) + } } } } } -// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425 -fn filter_run_end_primitive + AsPrimitive>( +impl RunEndFilterMode { + const fn as_u8(self) -> u8 { + match self { + Self::Auto => 0, + Self::Take => 1, + Self::Encoded => 2, + } + } + + const fn from_u8(value: u8) -> Self { + match value { + 1 => Self::Take, + 2 => Self::Encoded, + _ => Self::Auto, + } + } +} + +struct RunEndFilterModeGuard { + previous_mode: RunEndFilterMode, + _lock: MutexGuard<'static, ()>, +} + +impl Drop for RunEndFilterModeGuard { + fn drop(&mut self) { + FILTER_MODE_OVERRIDE.store(self.previous_mode.as_u8(), Ordering::SeqCst); + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum FilterPath { + Take, + Encoded, +} + +fn current_filter_mode() -> RunEndFilterMode { + RunEndFilterMode::from_u8(FILTER_MODE_OVERRIDE.load(Ordering::SeqCst)) +} + +fn select_filter_path(array: ArrayView<'_, RunEnd>, mask_values: &MaskValues) -> FilterPath { + match current_filter_mode() { + RunEndFilterMode::Auto => auto_filter_path(array, mask_values), + RunEndFilterMode::Take => FilterPath::Take, + RunEndFilterMode::Encoded => FilterPath::Encoded, + } +} + +fn auto_filter_path(array: ArrayView<'_, RunEnd>, mask_values: &MaskValues) -> FilterPath { + let len = array.len(); + let run_count = array.ends().len(); + let true_count = mask_values.true_count(); + let slice_count = mask_values.slices().len(); + let average_slice_len = true_count.div_ceil(slice_count); + + if true_count < FILTER_TAKE_MIN_TRUE_COUNT { + return FilterPath::Take; + } + + let dense_selection = true_count.saturating_mul(1 << FILTER_ENCODED_DENSITY_SHIFT) >= len; + let localized_selection = true_count + >= run_count.saturating_mul(FILTER_ENCODED_MIN_TRUES_PER_RUN) + && slice_count <= FILTER_ENCODED_MAX_SLICE_COUNT + && average_slice_len >= FILTER_ENCODED_MIN_AVG_SLICE_LEN; + + if dense_selection || localized_selection { + FilterPath::Encoded + } else { + FilterPath::Take + } +} + +fn filter_run_end_primitive>( run_ends: &[R], - offset: u64, - length: u64, + offset: usize, + length: usize, mask: &BitBuffer, ) -> VortexResult<(PrimitiveArray, Mask)> { let mut new_run_ends = buffer_mut![R::zero(); run_ends.len()]; - let mut start = 0u64; + let mut start = 0usize; let mut j = 0; - let mut count = R::zero(); + let mut count = 0u64; let new_mask: Mask = BitBuffer::collect_bool(run_ends.len(), |i| { - let mut keep = false; let end = min(run_ends[i].as_() - offset, length); + let run_true_count = mask.slice(start..end).true_count() as u64; + let keep = run_true_count != 0; - // Safety: predicate must be the same length as the array the ends have been taken from - for pred in (start..end).map(|i| unsafe { - mask.value_unchecked(i.try_into().vortex_expect("index must fit in usize")) - }) { - count += >::from(pred); - keep |= pred + count += run_true_count; + if keep { + new_run_ends[j] = NumCast::from(count) + .vortex_expect("filtered run end count must fit in run-end type"); + j += 1; } - // this is to avoid branching - new_run_ends[j] = count; - j += keep as usize; start = end; keep @@ -113,12 +204,22 @@ fn filter_run_end_primitive + AsPrimitiv #[cfg(test)] mod tests { + #![allow(clippy::cast_possible_truncation)] + + use vortex_array::ArrayRef; use vortex_array::IntoArray; + use vortex_array::LEGACY_SESSION; + use vortex_array::VortexSessionExecute; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; + use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_mask::Mask; + use super::FilterPath; + use super::override_run_end_filter_mode; + use super::select_filter_path; + use crate::_benchmarking::RunEndFilterMode; use crate::RunEnd; use crate::RunEndArray; @@ -127,6 +228,78 @@ mod tests { .unwrap() } + fn run_end_fixture(run_length: usize, len: usize) -> ArrayRef { + let run_count = len.div_ceil(run_length); + let ends = (0..run_count) + .map(|run_idx| ((run_idx + 1) * run_length).min(len) as u32) + .collect::>() + .into_array(); + let values = + PrimitiveArray::from_iter((0..run_count).map(|run_idx| run_idx as i32)).into_array(); + + RunEnd::new(ends, values).into_array() + } + + fn run_end_offset_fixture( + run_length: usize, + total_len: usize, + offset: usize, + len: usize, + ) -> VortexResult { + let run_count = total_len.div_ceil(run_length); + let ends = (0..run_count) + .map(|run_idx| ((run_idx + 1) * run_length).min(total_len) as u32) + .collect::>() + .into_array(); + let values = + PrimitiveArray::from_iter((0..run_count).map(|run_idx| run_idx as i32)).into_array(); + + Ok(RunEnd::try_new_offset_length(ends, values, offset, len)?.into_array()) + } + + fn sparse_random_mask(len: usize, true_count: usize) -> Mask { + let mut indices = (0..true_count) + .map(|idx| (idx * 7_919) % len) + .collect::>(); + indices.sort_unstable(); + + Mask::from_indices(len, indices) + } + + fn sparse_clustered_mask(len: usize) -> Mask { + Mask::from_slices(len, vec![(1_024, 1_536), (8_192, 8_704)]) + } + + fn sparse_clustered_mask_for_slice(len: usize) -> Mask { + Mask::from_slices(len, vec![(1_024, 1_536), (6_144, 6_656)]) + } + + fn filter_path(array: &ArrayRef, mask: &Mask) -> FilterPath { + select_filter_path( + array.as_::(), + mask.values() + .expect("heuristic tests require a partial filter mask"), + ) + } + + fn filter_with_mode( + array: &ArrayRef, + mask: Mask, + mode: RunEndFilterMode, + ) -> VortexResult { + let _guard = override_run_end_filter_mode(mode); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + array.filter(mask)?.execute::(&mut ctx) + } + + fn assert_encoded_filter_matches_take(array: &ArrayRef, mask: Mask) -> VortexResult<()> { + let take_filtered = filter_with_mode(array, mask.clone(), RunEndFilterMode::Take)?; + let encoded_filtered = filter_with_mode(array, mask, RunEndFilterMode::Encoded)?; + + assert_arrays_eq!(encoded_filtered, take_filtered); + Ok(()) + } + #[test] fn filter_sliced_run_end() -> VortexResult<()> { let arr = ree_array().slice(2..7).unwrap(); @@ -141,4 +314,73 @@ mod tests { ); Ok(()) } + + #[test] + fn heuristic_prefers_take_for_sparse_random_mask() -> VortexResult<()> { + let array = run_end_fixture(1_024, 16_384); + let mask = sparse_random_mask(array.len(), 1_024); + + assert_eq!(filter_path(&array, &mask), FilterPath::Take); + Ok(()) + } + + #[test] + fn heuristic_prefers_encoded_for_sparse_clustered_mask() -> VortexResult<()> { + let array = run_end_fixture(1_024, 16_384); + let mask = sparse_clustered_mask(array.len()); + + assert_eq!(filter_path(&array, &mask), FilterPath::Encoded); + Ok(()) + } + + #[test] + fn heuristic_prefers_take_for_sparse_random_mask_on_slice() -> VortexResult<()> { + let array = run_end_offset_fixture(1_024, 16_384, 1_024, 14_336)?; + let mask = sparse_random_mask(array.len(), 1_024); + + assert_eq!(filter_path(&array, &mask), FilterPath::Take); + Ok(()) + } + + #[test] + fn heuristic_prefers_encoded_for_sparse_clustered_mask_on_slice() -> VortexResult<()> { + let array = run_end_offset_fixture(1_024, 16_384, 1_024, 14_336)?; + let mask = sparse_clustered_mask_for_slice(array.len()); + + assert_eq!(filter_path(&array, &mask), FilterPath::Encoded); + Ok(()) + } + + #[test] + fn encoded_filter_matches_take_on_partial_word_boundaries() -> VortexResult<()> { + let array = run_end_fixture(65, 260); + let mask = Mask::from_slices( + array.len(), + vec![(3, 64), (67, 129), (133, 194), (197, 259)], + ); + + assert_encoded_filter_matches_take(&array, mask) + } + + #[test] + fn encoded_filter_matches_take_on_clustered_masks() -> VortexResult<()> { + let array = run_end_fixture(1_024, 16_384); + let mask = Mask::from_slices( + array.len(), + vec![(13, 513), (4_109, 4_733), (9_001, 10_129)], + ); + + assert_encoded_filter_matches_take(&array, mask) + } + + #[test] + fn encoded_filter_matches_take_on_very_short_runs() -> VortexResult<()> { + let array = run_end_fixture(1, 64); + let mask = Mask::from_slices( + array.len(), + vec![(1, 3), (5, 8), (13, 14), (21, 25), (34, 35)], + ); + + assert_encoded_filter_matches_take(&array, mask) + } } diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index 511b473c53b..95abb67bedb 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -18,38 +18,33 @@ use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; use crate::RunEnd; use crate::RunEndData; +const SORTED_TAKE_MERGE_MIN_VALID_COUNT: usize = 64; +const UNSORTED_TAKE_SORT_MERGE_MIN_VALID_COUNT: usize = 8_192; +const UNSORTED_TAKE_SORT_MERGE_RUN_RATIO: usize = 16; + impl TakeExecute for RunEnd { - #[expect( - clippy::cast_possible_truncation, - reason = "index cast to usize inside macro" - )] fn take( array: ArrayView<'_, Self>, indices: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { let primitive_indices = indices.clone().execute::(ctx)?; + let indices_validity = primitive_indices.validity()?; let checked_indices = match_each_integer_ptype!(primitive_indices.ptype(), |P| { - primitive_indices - .as_slice::

() - .iter() - .copied() - .map(|idx| { - let usize_idx = idx as usize; - if usize_idx >= array.len() { - vortex_bail!(OutOfBounds: usize_idx, 0, array.len()); - } - Ok(usize_idx) - }) - .collect::>>()? + check_indices( + primitive_indices.as_slice::

(), + array.len(), + &indices_validity, + )? }); - let indices_validity = primitive_indices.validity()?; take_indices_unchecked(&array, &checked_indices, &indices_validity).map(Some) } } @@ -61,31 +56,224 @@ pub fn take_indices_unchecked>( validity: &Validity, ) -> VortexResult { let ends = array.ends().to_primitive(); - let ends_len = ends.len(); + let validity_mask = validity.to_mask(indices.len()); - // TODO(joe): use the validity mask to skip search sorted. let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| { let end_slices = ends.as_slice::(); - let physical_indices_vec: Vec = indices + let physical_indices_vec = + collect_physical_indices(end_slices, indices, array.offset(), &validity_mask)?; + let buffer = Buffer::from(physical_indices_vec); + + Ok::(PrimitiveArray::new( + buffer, + validity.clone(), + )) + }); + + array.values().take(physical_indices?.into_array()) +} + +fn check_indices>( + indices: &[T], + len: usize, + validity: &Validity, +) -> VortexResult> { + match validity.to_mask(indices.len()).bit_buffer() { + AllOr::All => indices + .iter() + .copied() + .map(|idx| check_index(idx.as_(), len)) + .collect(), + AllOr::None => Ok(vec![0; indices.len()]), + AllOr::Some(mask) => indices .iter() - .map(|idx| idx.as_() + array.offset()) - .map(|idx| { - match ::from(idx) { - Some(idx) => end_slices.search_sorted(&idx, SearchSortedSide::Right), - None => { - // The idx is too large for I, therefore it's out of bounds. - Ok(SearchResult::NotFound(ends_len)) - } + .copied() + .enumerate() + .map(|(position, idx)| { + if mask.value(position) { + check_index(idx.as_(), len) + } else { + Ok(0) } }) - .map(|result| result.map(|r| r.to_ends_index(ends_len) as u64)) - .collect::>>()?; - let buffer = Buffer::from(physical_indices_vec); + .collect(), + } +} - PrimitiveArray::new(buffer, validity.clone()) - }); +fn check_index(index: usize, len: usize) -> VortexResult { + if index >= len { + vortex_bail!(OutOfBounds: index, 0, len); + } + + Ok(index) +} + +fn collect_physical_indices + NumCast + PartialOrd, T: AsPrimitive>( + end_slices: &[E], + indices: &[T], + offset: usize, + validity_mask: &Mask, +) -> VortexResult> { + let valid_count = validity_mask.true_count(); + if valid_count == 0 { + return Ok(vec![0; indices.len()]); + } + + if !should_try_sorted_merge(valid_count) { + return search_physical_indices(end_slices, indices, offset, validity_mask); + } + + let mut physical_indices = vec![0; indices.len()]; + if try_fill_physical_indices_sorted( + end_slices, + indices, + offset, + validity_mask, + &mut physical_indices, + ) { + return Ok(physical_indices); + } + + if !should_sort_merge(valid_count, end_slices.len()) { + return search_physical_indices(end_slices, indices, offset, validity_mask); + } + + let mut indexed_indices = collect_logical_indices(indices, offset, validity_mask); + indexed_indices.sort_unstable_by_key(|&(logical_index, _)| logical_index); + fill_physical_indices(end_slices, indexed_indices, &mut physical_indices); + + Ok(physical_indices) +} + +fn should_try_sorted_merge(valid_count: usize) -> bool { + valid_count >= SORTED_TAKE_MERGE_MIN_VALID_COUNT +} + +fn should_sort_merge(valid_count: usize, run_count: usize) -> bool { + valid_count >= UNSORTED_TAKE_SORT_MERGE_MIN_VALID_COUNT + && valid_count >= run_count.saturating_mul(UNSORTED_TAKE_SORT_MERGE_RUN_RATIO) +} + +fn try_fill_physical_indices_sorted, T: AsPrimitive>( + end_slices: &[E], + indices: &[T], + offset: usize, + validity_mask: &Mask, + physical_indices: &mut [u64], +) -> bool { + let mut previous = None; + let mut run_index = 0usize; + + let mut record_index = |position: usize, logical_index: usize| { + if previous.is_some_and(|prev| logical_index < prev) { + return false; + } + previous = Some(logical_index); + physical_indices[position] = advance_to_run(end_slices, &mut run_index, logical_index); + true + }; + + match validity_mask.bit_buffer() { + AllOr::All => indices + .iter() + .copied() + .enumerate() + .all(|(position, idx)| record_index(position, idx.as_() + offset)), + AllOr::None => true, + AllOr::Some(mask) => indices + .iter() + .copied() + .enumerate() + .filter(|(position, _)| mask.value(*position)) + .all(|(position, idx)| record_index(position, idx.as_() + offset)), + } +} + +fn search_physical_indices>( + end_slices: &[E], + indices: &[T], + offset: usize, + validity_mask: &Mask, +) -> VortexResult> { + let ends_len = end_slices.len(); + let mut physical_indices = vec![0; indices.len()]; + + let mut record_index = |position: usize, logical_index: usize| -> VortexResult<()> { + physical_indices[position] = match E::from(logical_index) { + Some(logical_index) => end_slices + .search_sorted(&logical_index, SearchSortedSide::Right)? + .to_ends_index(ends_len) as u64, + None => SearchResult::NotFound(ends_len).to_ends_index(ends_len) as u64, + }; + + Ok(()) + }; - array.values().take(physical_indices.into_array()) + match validity_mask.bit_buffer() { + AllOr::All => { + for (position, idx) in indices.iter().copied().enumerate() { + record_index(position, idx.as_() + offset)?; + } + } + AllOr::None => {} + AllOr::Some(mask) => { + for (position, idx) in indices.iter().copied().enumerate() { + if mask.value(position) { + record_index(position, idx.as_() + offset)?; + } + } + } + } + + Ok(physical_indices) +} + +fn collect_logical_indices>( + indices: &[T], + offset: usize, + validity_mask: &Mask, +) -> Vec<(usize, usize)> { + match validity_mask.bit_buffer() { + AllOr::All => indices + .iter() + .copied() + .enumerate() + .map(|(position, idx)| (idx.as_() + offset, position)) + .collect(), + AllOr::None => Vec::new(), + AllOr::Some(mask) => indices + .iter() + .copied() + .enumerate() + .filter(|(position, _)| mask.value(*position)) + .map(|(position, idx)| (idx.as_() + offset, position)) + .collect(), + } +} + +fn fill_physical_indices>( + end_slices: &[E], + logical_indices: impl IntoIterator, + physical_indices: &mut [u64], +) { + let mut run_index = 0usize; + + for (logical_index, position) in logical_indices { + physical_indices[position] = advance_to_run(end_slices, &mut run_index, logical_index); + } +} + +fn advance_to_run>( + end_slices: &[E], + run_index: &mut usize, + logical_index: usize, +) -> u64 { + while *run_index < end_slices.len() && logical_index >= end_slices[*run_index].as_() { + *run_index += 1; + } + + debug_assert!(*run_index < end_slices.len()); + *run_index as u64 } #[cfg(test)] @@ -96,10 +284,16 @@ mod test { use vortex_array::IntoArray; use vortex_array::LEGACY_SESSION; use vortex_array::VortexSessionExecute; + use vortex_array::arrays::BoolArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::compute::conformance::take::test_take_conformance; + use vortex_array::validity::Validity; use vortex_buffer::buffer; + use vortex_error::VortexExpect; + use vortex_error::VortexResult; + use vortex_mask::AllOr; + use vortex_mask::Mask; use crate::RunEnd; use crate::RunEndArray; @@ -151,6 +345,86 @@ mod test { assert_arrays_eq!(taken, expected.into_array()); } + #[test] + fn ree_take_null_out_of_bounds_is_ignored() -> VortexResult<()> { + let indices = PrimitiveArray::new( + buffer![0u32, 100, 7], + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ) + .into_array(); + let taken = ree_array().take(indices)?; + + let expected = PrimitiveArray::from_option_iter([Some(1i32), None, Some(2)]).into_array(); + assert_arrays_eq!(taken, expected); + Ok(()) + } + + #[test] + fn ree_take_duplicate_indices() -> VortexResult<()> { + let taken = ree_array().take(buffer![0u32, 0, 8, 8].into_array())?; + + let expected = PrimitiveArray::from_iter([1i32, 1, 5, 5]).into_array(); + assert_arrays_eq!(taken, expected); + Ok(()) + } + + #[test] + fn ree_take_sorted_filter_indices() -> VortexResult<()> { + let indices = match Mask::from_indices(ree_array().len(), vec![0, 1, 6, 7, 8]).indices() { + AllOr::Some(indices) => PrimitiveArray::from_iter( + indices + .iter() + .map(|&idx| u32::try_from(idx).vortex_expect("mask index must fit in u32")), + ) + .into_array(), + AllOr::All | AllOr::None => unreachable!(), + }; + let taken = ree_array().take(indices)?; + + let expected = PrimitiveArray::from_iter([1i32, 1, 2, 2, 5]).into_array(); + assert_arrays_eq!(taken, expected); + Ok(()) + } + + #[test] + fn sliced_take_sorted_filter_indices() -> VortexResult<()> { + let sliced = ree_array().slice(2..10)?; + let indices = match Mask::from_indices(sliced.len(), vec![0, 3, 4, 5, 7]).indices() { + AllOr::Some(indices) => PrimitiveArray::from_iter( + indices + .iter() + .map(|&idx| u32::try_from(idx).vortex_expect("mask index must fit in u32")), + ) + .into_array(), + AllOr::All | AllOr::None => unreachable!(), + }; + let taken = sliced.take(indices)?; + + let expected = PrimitiveArray::from_iter([1i32, 4, 2, 2, 5]).into_array(); + assert_arrays_eq!(taken, expected); + Ok(()) + } + + #[test] + fn sorted_merge_threshold_is_conservative_for_small_takes() { + assert!(!super::should_try_sorted_merge( + super::SORTED_TAKE_MERGE_MIN_VALID_COUNT.saturating_sub(1) + )); + assert!(super::should_try_sorted_merge( + super::SORTED_TAKE_MERGE_MIN_VALID_COUNT + )); + } + + #[test] + fn unsorted_sort_merge_threshold_scales_with_run_count() { + assert!(!super::should_sort_merge( + super::UNSORTED_TAKE_SORT_MERGE_MIN_VALID_COUNT.saturating_sub(1), + 1, + )); + assert!(!super::should_sort_merge(32_768, 4_096)); + assert!(super::should_sort_merge(32_768, 256)); + } + #[rstest] #[case(ree_array())] #[case(RunEnd::encode( diff --git a/encodings/runend/src/lib.rs b/encodings/runend/src/lib.rs index cff09767d38..05db107aec6 100644 --- a/encodings/runend/src/lib.rs +++ b/encodings/runend/src/lib.rs @@ -24,6 +24,22 @@ pub mod _benchmarking { pub use compute::take::take_indices_unchecked; use super::*; + + /// Benchmark-only override for the RunEnd filter dispatcher. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub enum RunEndFilterMode { + /// Use the normal dispatcher heuristic. + Auto, + /// Always force the `take(indices)` filter path. + Take, + /// Always force the preserve-`RunEnd` encoded filter path. + Encoded, + } + + /// Override the RunEnd filter dispatcher for the lifetime of the returned guard. + pub fn override_run_end_filter_mode(mode: RunEndFilterMode) -> impl Drop { + compute::filter::override_run_end_filter_mode(mode) + } } use vortex_array::aggregate_fn::AggregateFnVTable; diff --git a/encodings/runend/src/rules.rs b/encodings/runend/src/rules.rs index 2c8a6607156..a8c2f5040dc 100644 --- a/encodings/runend/src/rules.rs +++ b/encodings/runend/src/rules.rs @@ -54,8 +54,10 @@ impl ArrayParentReduceRule for RunEndScalarFnRule { } } - // TODO(ngates): relax this constraint and implement run-end decoding for all vector types. - if !matches!(parent.dtype(), DType::Bool(_) | DType::Primitive(..)) { + if !matches!( + parent.dtype(), + DType::Bool(_) | DType::Primitive(..) | DType::Utf8(_) | DType::Binary(_) + ) { return Ok(None); } @@ -91,3 +93,174 @@ impl ArrayParentReduceRule for RunEndScalarFnRule { )) } } + +#[cfg(test)] +mod tests { + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::LEGACY_SESSION; + use vortex_array::RecursiveCanonical; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::BoolArray; + use vortex_array::assert_arrays_eq; + use vortex_array::builtins::ArrayBuiltins; + use vortex_array::dtype::FieldNames; + use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::dtype::StructFields; + use vortex_array::optimizer::ArrayOptimizer; + use vortex_array::scalar::Scalar; + use vortex_buffer::buffer; + + use super::*; + use crate::RunEnd; + + fn bool_mask_fixture() -> ArrayRef { + RunEnd::new( + buffer![256u32, 512, 768, 1024].into_array(), + BoolArray::from_iter([true, false, true, false]).into_array(), + ) + .into_array() + } + + #[test] + fn pushes_down_utf8_zip_to_runend() { + let mask = bool_mask_fixture(); + let if_true = ConstantArray::new( + Scalar::utf8("runend-true-branch", Nullability::NonNullable), + mask.len(), + ) + .into_array(); + let if_false = ConstantArray::new( + Scalar::utf8("runend-false-branch", Nullability::NonNullable), + mask.len(), + ) + .into_array(); + + let optimized = mask.zip(if_true, if_false).unwrap().optimize().unwrap(); + + assert!(optimized.is::()); + assert_eq!(optimized.dtype(), &DType::Utf8(Nullability::NonNullable)); + + let actual = optimized + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .unwrap() + .0 + .into_array(); + let expected = vortex_array::arrays::VarBinViewArray::from_iter_str((0..1024).map(|idx| { + if idx < 256 || (512..768).contains(&idx) { + "runend-true-branch" + } else { + "runend-false-branch" + } + })) + .into_array(); + assert_arrays_eq!(actual, expected); + } + + #[test] + fn pushes_down_binary_zip_to_runend() { + let mask = bool_mask_fixture(); + let if_true = vec![0xAA; 8]; + let if_false = vec![0x55; 12]; + let optimized = mask + .zip( + ConstantArray::new( + Scalar::binary(if_true.clone(), Nullability::NonNullable), + mask.len(), + ) + .into_array(), + ConstantArray::new( + Scalar::binary(if_false.clone(), Nullability::NonNullable), + mask.len(), + ) + .into_array(), + ) + .unwrap() + .optimize() + .unwrap(); + + assert!(optimized.is::()); + assert_eq!(optimized.dtype(), &DType::Binary(Nullability::NonNullable)); + + let actual = optimized + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .unwrap() + .0 + .into_array(); + let expected = vortex_array::arrays::VarBinViewArray::from_iter_bin((0..1024).map(|idx| { + if idx < 256 || (512..768).contains(&idx) { + if_true.clone() + } else { + if_false.clone() + } + })) + .into_array(); + assert_arrays_eq!(actual, expected); + } + + #[test] + fn pushes_down_sliced_nullable_utf8_zip_to_runend() { + let mask = bool_mask_fixture() + .slice(128..896) + .unwrap() + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .unwrap(); + let optimized = mask + .zip( + ConstantArray::new( + Scalar::utf8("slice-true-branch", Nullability::Nullable), + mask.len(), + ) + .into_array(), + ConstantArray::new(Scalar::null(DType::Utf8(Nullability::Nullable)), mask.len()) + .into_array(), + ) + .unwrap() + .optimize() + .unwrap(); + + assert!(optimized.is::()); + assert_eq!(optimized.dtype(), &DType::Utf8(Nullability::Nullable)); + assert_eq!(optimized.as_::().offset(), 128); + + let actual = optimized + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + .unwrap() + .0 + .into_array(); + let expected = vortex_array::arrays::VarBinViewArray::from_iter( + (128..896) + .map(|idx| (idx < 256 || (512..768).contains(&idx)).then_some("slice-true-branch")), + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + assert_arrays_eq!(actual, expected); + } + + #[test] + fn keeps_struct_zip_on_fallback_path() { + let mask = bool_mask_fixture(); + let struct_dtype = DType::Struct( + StructFields::new( + FieldNames::from(["value"]), + vec![DType::Primitive(PType::I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); + let if_true = ConstantArray::new( + Scalar::struct_(struct_dtype.clone(), vec![Scalar::from(1i32)]), + mask.len(), + ) + .into_array(); + let if_false = ConstantArray::new( + Scalar::struct_(struct_dtype, vec![Scalar::from(2i32)]), + mask.len(), + ) + .into_array(); + + let optimized = mask.zip(if_true, if_false).unwrap().optimize().unwrap(); + + assert!(!optimized.is::()); + } +} diff --git a/encodings/sequence/Cargo.toml b/encodings/sequence/Cargo.toml index 06f695cb989..65eb49f9228 100644 --- a/encodings/sequence/Cargo.toml +++ b/encodings/sequence/Cargo.toml @@ -24,9 +24,18 @@ vortex-proto = { workspace = true } vortex-session = { workspace = true } [dev-dependencies] +divan = { workspace = true } itertools = { workspace = true } rstest = { workspace = true } vortex-array = { path = "../../vortex-array", features = ["_test-harness"] } [lints] workspace = true + +[[bench]] +name = "sequence_compare" +harness = false + +[[bench]] +name = "sequence_scalar_fn" +harness = false diff --git a/encodings/sequence/benches/sequence_compare.rs b/encodings/sequence/benches/sequence_compare.rs new file mode 100644 index 00000000000..3f031d7f0e9 --- /dev/null +++ b/encodings/sequence/benches/sequence_compare.rs @@ -0,0 +1,635 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used)] + +use std::fmt; + +use divan::Bencher; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::Nullability::NonNullable; +use vortex_array::scalar_fn::fns::operators::Operator; +use vortex_sequence::Sequence; +use vortex_session::VortexSession; + +fn main() { + divan::main(); +} + +const LEN: usize = 1_000_000; + +#[derive(Clone, Copy, Debug)] +enum SequenceShape { + Ascending, + Descending, + Constant, +} + +impl SequenceShape { + fn build(self) -> vortex_sequence::SequenceArray { + match self { + Self::Ascending => Sequence::try_new_typed(10i64, 3, NonNullable, LEN).unwrap(), + Self::Descending => { + Sequence::try_new_typed(3_000_000i64, -3, NonNullable, LEN).unwrap() + } + Self::Constant => Sequence::try_new_typed(42i64, 0, NonNullable, LEN).unwrap(), + } + } +} + +impl fmt::Display for SequenceShape { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Ascending => write!(f, "ascending"), + Self::Descending => write!(f, "descending"), + Self::Constant => write!(f, "constant"), + } + } +} + +#[derive(Clone, Copy, Debug)] +enum ConstantCase { + AtStart, + AtMiddle, + AtEnd, + BelowRange, + AboveRange, +} + +impl ConstantCase { + fn value(self, shape: SequenceShape) -> i64 { + match (shape, self) { + (SequenceShape::Ascending, Self::AtStart) => 10, + (SequenceShape::Ascending, Self::AtMiddle) => 10 + 3 * (LEN / 2) as i64, + (SequenceShape::Ascending, Self::AtEnd) => 10 + 3 * (LEN - 1) as i64, + (SequenceShape::Ascending, Self::BelowRange) => 9, + (SequenceShape::Ascending, Self::AboveRange) => 10 + 3 * LEN as i64, + (SequenceShape::Descending, Self::AtStart) => 3_000_000, + (SequenceShape::Descending, Self::AtMiddle) => 3_000_000 - 3 * (LEN / 2) as i64, + (SequenceShape::Descending, Self::AtEnd) => 3_000_000 - 3 * (LEN - 1) as i64, + (SequenceShape::Descending, Self::BelowRange) => 3_000_000 - 3 * LEN as i64, + (SequenceShape::Descending, Self::AboveRange) => 3_000_001, + (SequenceShape::Constant, Self::AtStart | Self::AtMiddle | Self::AtEnd) => 42, + (SequenceShape::Constant, Self::BelowRange) => 41, + (SequenceShape::Constant, Self::AboveRange) => 43, + } + } +} + +impl fmt::Display for ConstantCase { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AtStart => write!(f, "start"), + Self::AtMiddle => write!(f, "middle"), + Self::AtEnd => write!(f, "end"), + Self::BelowRange => write!(f, "below_range"), + Self::AboveRange => write!(f, "above_range"), + } + } +} + +#[derive(Clone, Copy, Debug)] +struct BenchArgs { + shape: SequenceShape, + operator: Operator, + constant_case: ConstantCase, +} + +impl fmt::Display for BenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_{}_{}", self.shape, self.operator, self.constant_case) + } +} + +#[derive(Clone, Copy, Debug)] +struct EqControlArgs { + shape: SequenceShape, + constant_case: ConstantCase, +} + +impl fmt::Display for EqControlArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_eq_{}", self.shape, self.constant_case) + } +} + +#[derive(Clone, Copy, Debug)] +struct FallbackControlArgs { + shape: SequenceShape, + operator: Operator, + constant_case: ConstantCase, +} + +impl fmt::Display for FallbackControlArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}_{}_{}_fallback", + self.shape, self.operator, self.constant_case + ) + } +} + +const BENCH_ARGS: &[BenchArgs] = &[ + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lte, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lte, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lte, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lte, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lte, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gt, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gt, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gt, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gt, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gt, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lte, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lte, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lte, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lte, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Lte, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gt, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gt, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gt, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gt, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gt, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lte, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lte, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lte, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lte, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Lte, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gt, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gt, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gt, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gt, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gt, + constant_case: ConstantCase::AboveRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::AtStart, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::AtMiddle, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::AtEnd, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::BelowRange, + }, + BenchArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::AboveRange, + }, +]; + +const EQ_CONTROL_ARGS: &[EqControlArgs] = &[ + EqControlArgs { + shape: SequenceShape::Ascending, + constant_case: ConstantCase::AtStart, + }, + EqControlArgs { + shape: SequenceShape::Ascending, + constant_case: ConstantCase::AtMiddle, + }, + EqControlArgs { + shape: SequenceShape::Ascending, + constant_case: ConstantCase::AtEnd, + }, + EqControlArgs { + shape: SequenceShape::Ascending, + constant_case: ConstantCase::BelowRange, + }, + EqControlArgs { + shape: SequenceShape::Ascending, + constant_case: ConstantCase::AboveRange, + }, + EqControlArgs { + shape: SequenceShape::Descending, + constant_case: ConstantCase::AtStart, + }, + EqControlArgs { + shape: SequenceShape::Descending, + constant_case: ConstantCase::AtMiddle, + }, + EqControlArgs { + shape: SequenceShape::Descending, + constant_case: ConstantCase::AtEnd, + }, + EqControlArgs { + shape: SequenceShape::Descending, + constant_case: ConstantCase::BelowRange, + }, + EqControlArgs { + shape: SequenceShape::Descending, + constant_case: ConstantCase::AboveRange, + }, + EqControlArgs { + shape: SequenceShape::Constant, + constant_case: ConstantCase::AtStart, + }, + EqControlArgs { + shape: SequenceShape::Constant, + constant_case: ConstantCase::AtMiddle, + }, + EqControlArgs { + shape: SequenceShape::Constant, + constant_case: ConstantCase::AtEnd, + }, + EqControlArgs { + shape: SequenceShape::Constant, + constant_case: ConstantCase::BelowRange, + }, + EqControlArgs { + shape: SequenceShape::Constant, + constant_case: ConstantCase::AboveRange, + }, +]; + +const FALLBACK_CONTROL_ARGS: &[FallbackControlArgs] = &[ + FallbackControlArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::AtMiddle, + }, + FallbackControlArgs { + shape: SequenceShape::Ascending, + operator: Operator::Lt, + constant_case: ConstantCase::AboveRange, + }, + FallbackControlArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::AtMiddle, + }, + FallbackControlArgs { + shape: SequenceShape::Ascending, + operator: Operator::Gte, + constant_case: ConstantCase::AboveRange, + }, + FallbackControlArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::AtMiddle, + }, + FallbackControlArgs { + shape: SequenceShape::Descending, + operator: Operator::Lt, + constant_case: ConstantCase::AboveRange, + }, + FallbackControlArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::AtMiddle, + }, + FallbackControlArgs { + shape: SequenceShape::Descending, + operator: Operator::Gte, + constant_case: ConstantCase::AboveRange, + }, + FallbackControlArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::AtMiddle, + }, + FallbackControlArgs { + shape: SequenceShape::Constant, + operator: Operator::Lt, + constant_case: ConstantCase::AboveRange, + }, + FallbackControlArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::AtMiddle, + }, + FallbackControlArgs { + shape: SequenceShape::Constant, + operator: Operator::Gte, + constant_case: ConstantCase::AboveRange, + }, +]; + +#[divan::bench(args = BENCH_ARGS)] +fn compare_to_constant(bencher: Bencher, args: BenchArgs) { + let sequence = args.shape.build(); + let rhs = ConstantArray::new(args.constant_case.value(args.shape), LEN).into_array(); + let session = VortexSession::empty(); + + bencher + .with_inputs(|| { + ( + sequence.clone().into_array(), + rhs.clone(), + session.create_execution_ctx(), + ) + }) + .bench_refs(|(lhs, rhs, ctx)| { + lhs.clone() + .binary(rhs.clone(), args.operator) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} + +#[divan::bench(args = EQ_CONTROL_ARGS)] +fn compare_eq_control(bencher: Bencher, args: EqControlArgs) { + let sequence = args.shape.build(); + let rhs = ConstantArray::new(args.constant_case.value(args.shape), LEN).into_array(); + let session = VortexSession::empty(); + + bencher + .with_inputs(|| { + ( + sequence.clone().into_array(), + rhs.clone(), + session.create_execution_ctx(), + ) + }) + .bench_refs(|(lhs, rhs, ctx)| { + lhs.clone() + .binary(rhs.clone(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} + +#[divan::bench(args = FALLBACK_CONTROL_ARGS)] +fn compare_non_constant_control(bencher: Bencher, args: FallbackControlArgs) { + let sequence = args.shape.build(); + let rhs = PrimitiveArray::from_iter((0..LEN).map(|_| args.constant_case.value(args.shape))) + .into_array(); + let session = VortexSession::empty(); + + bencher + .with_inputs(|| { + ( + sequence.clone().into_array(), + rhs.clone(), + session.create_execution_ctx(), + ) + }) + .bench_refs(|(lhs, rhs, ctx)| { + lhs.clone() + .binary(rhs.clone(), args.operator) + .unwrap() + .execute::(ctx) + .unwrap() + }); +} diff --git a/encodings/sequence/benches/sequence_scalar_fn.rs b/encodings/sequence/benches/sequence_scalar_fn.rs new file mode 100644 index 00000000000..91dfec12db3 --- /dev/null +++ b/encodings/sequence/benches/sequence_scalar_fn.rs @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used)] + +use std::fmt; + +use divan::Bencher; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ConstantArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::Nullability::NonNullable; +use vortex_array::scalar_fn::fns::operators::Operator; +use vortex_sequence::Sequence; +use vortex_session::VortexSession; + +fn main() { + divan::main(); +} + +const LEN: usize = 1_000_000; + +#[derive(Clone, Copy, Debug)] +enum SequenceShape { + Ascending, + Descending, +} + +impl SequenceShape { + fn build(self) -> vortex_sequence::SequenceArray { + match self { + Self::Ascending => Sequence::try_new_typed(10i64, 3, NonNullable, LEN).unwrap(), + Self::Descending => { + Sequence::try_new_typed(3_000_000i64, -3, NonNullable, LEN).unwrap() + } + } + } + + fn midpoint_value(self) -> i64 { + match self { + Self::Ascending => 10 + 3 * (LEN / 2) as i64, + Self::Descending => 3_000_000 - 3 * (LEN / 2) as i64, + } + } +} + +impl fmt::Display for SequenceShape { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Ascending => write!(f, "ascending"), + Self::Descending => write!(f, "descending"), + } + } +} + +#[derive(Clone, Copy, Debug)] +enum AffineExpr { + SeqPlusConst, + ConstPlusSeq, + SeqMinusConst, + ConstMinusSeq, + SeqTimesConst, + ConstTimesSeq, +} + +impl AffineExpr { + fn constant(self) -> i64 { + match self { + Self::SeqPlusConst | Self::ConstPlusSeq => 17, + Self::SeqMinusConst => 11, + Self::ConstMinusSeq => 10_000_000, + Self::SeqTimesConst | Self::ConstTimesSeq => 2, + } + } + + fn operator(self) -> Operator { + match self { + Self::SeqPlusConst | Self::ConstPlusSeq => Operator::Add, + Self::SeqMinusConst | Self::ConstMinusSeq => Operator::Sub, + Self::SeqTimesConst | Self::ConstTimesSeq => Operator::Mul, + } + } + + fn midpoint_value(self, shape: SequenceShape) -> i64 { + let midpoint = shape.midpoint_value(); + match self { + Self::SeqPlusConst | Self::ConstPlusSeq => midpoint + self.constant(), + Self::SeqMinusConst => midpoint - self.constant(), + Self::ConstMinusSeq => self.constant() - midpoint, + Self::SeqTimesConst | Self::ConstTimesSeq => midpoint * self.constant(), + } + } +} + +impl fmt::Display for AffineExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::SeqPlusConst => write!(f, "seq_plus_const"), + Self::ConstPlusSeq => write!(f, "const_plus_seq"), + Self::SeqMinusConst => write!(f, "seq_minus_const"), + Self::ConstMinusSeq => write!(f, "const_minus_seq"), + Self::SeqTimesConst => write!(f, "seq_times_const"), + Self::ConstTimesSeq => write!(f, "const_times_seq"), + } + } +} + +#[derive(Clone, Copy, Debug)] +struct BenchArgs { + shape: SequenceShape, + affine: AffineExpr, +} + +impl fmt::Display for BenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}_{}", self.shape, self.affine) + } +} + +const BENCH_ARGS: &[BenchArgs] = &[ + BenchArgs { + shape: SequenceShape::Ascending, + affine: AffineExpr::SeqPlusConst, + }, + BenchArgs { + shape: SequenceShape::Ascending, + affine: AffineExpr::ConstPlusSeq, + }, + BenchArgs { + shape: SequenceShape::Ascending, + affine: AffineExpr::SeqMinusConst, + }, + BenchArgs { + shape: SequenceShape::Ascending, + affine: AffineExpr::ConstMinusSeq, + }, + BenchArgs { + shape: SequenceShape::Ascending, + affine: AffineExpr::SeqTimesConst, + }, + BenchArgs { + shape: SequenceShape::Ascending, + affine: AffineExpr::ConstTimesSeq, + }, + BenchArgs { + shape: SequenceShape::Descending, + affine: AffineExpr::SeqPlusConst, + }, + BenchArgs { + shape: SequenceShape::Descending, + affine: AffineExpr::ConstPlusSeq, + }, + BenchArgs { + shape: SequenceShape::Descending, + affine: AffineExpr::SeqMinusConst, + }, + BenchArgs { + shape: SequenceShape::Descending, + affine: AffineExpr::ConstMinusSeq, + }, + BenchArgs { + shape: SequenceShape::Descending, + affine: AffineExpr::SeqTimesConst, + }, + BenchArgs { + shape: SequenceShape::Descending, + affine: AffineExpr::ConstTimesSeq, + }, +]; + +#[divan::bench(args = BENCH_ARGS)] +fn affine_compare_to_constant(bencher: Bencher, args: BenchArgs) { + let sequence = args.shape.build().into_array(); + let affine_constant = ConstantArray::new(args.affine.constant(), LEN).into_array(); + let compare_constant = + ConstantArray::new(args.affine.midpoint_value(args.shape), LEN).into_array(); + let session = VortexSession::empty(); + + bencher + .with_inputs(|| { + ( + sequence.clone(), + affine_constant.clone(), + compare_constant.clone(), + session.create_execution_ctx(), + ) + }) + .bench_refs(|(sequence, affine_constant, compare_constant, ctx)| { + apply_affine(sequence.clone(), affine_constant.clone(), args.affine) + .binary(compare_constant.clone(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap(); + }); +} + +#[divan::bench(args = BENCH_ARGS)] +fn affine_transform(bencher: Bencher, args: BenchArgs) { + let sequence = args.shape.build().into_array(); + let affine_constant = ConstantArray::new(args.affine.constant(), LEN).into_array(); + let session = VortexSession::empty(); + + bencher + .with_inputs(|| { + ( + sequence.clone(), + affine_constant.clone(), + session.create_execution_ctx(), + ) + }) + .bench_refs(|(sequence, affine_constant, ctx)| { + apply_affine(sequence.clone(), affine_constant.clone(), args.affine) + .execute::(ctx) + .unwrap(); + }); +} + +fn apply_affine( + sequence: vortex_array::ArrayRef, + affine_constant: vortex_array::ArrayRef, + affine: AffineExpr, +) -> vortex_array::ArrayRef { + match affine { + AffineExpr::SeqPlusConst | AffineExpr::SeqMinusConst | AffineExpr::SeqTimesConst => { + sequence.binary(affine_constant, affine.operator()).unwrap() + } + AffineExpr::ConstPlusSeq | AffineExpr::ConstMinusSeq | AffineExpr::ConstTimesSeq => { + affine_constant.binary(sequence, affine.operator()).unwrap() + } + } +} diff --git a/encodings/sequence/src/compute/compare.rs b/encodings/sequence/src/compute/compare.rs index c0c9d1367d6..edd70ac1d1a 100644 --- a/encodings/sequence/src/compute/compare.rs +++ b/encodings/sequence/src/compute/compare.rs @@ -1,20 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::cmp::Ordering; +use std::ops::Range; + use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::BoolArray; -use vortex_array::arrays::ConstantArray; -use vortex_array::dtype::NativePType; +use vortex_array::dtype::IntegerPType; use vortex_array::dtype::Nullability; use vortex_array::match_each_integer_ptype; use vortex_array::scalar::PValue; -use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::binary::CompareKernel; use vortex_array::scalar_fn::fns::operators::CompareOperator; use vortex_buffer::BitBuffer; +use vortex_buffer::BitBufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -29,8 +31,7 @@ impl CompareKernel for Sequence { operator: CompareOperator, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - // TODO(joe): support other operators (NotEq, Lt, Lte, Gt, Gte) in encoded space. - if operator != CompareOperator::Eq { + if operator == CompareOperator::NotEq { return Ok(None); } @@ -38,16 +39,15 @@ impl CompareKernel for Sequence { return Ok(None); }; - // Check if there exists an integer solution to const = base + (0..len) * multiplier. - let set_idx = find_intersection_scalar( - lhs.base(), - lhs.multiplier(), - lhs.len(), - constant - .as_primitive() - .pvalue() - .vortex_expect("null constant handled in adaptor"), - ); + let intercept = constant + .as_primitive() + .pvalue() + .vortex_expect("null constant handled in adaptor"); + let Ok(true_range) = + find_true_range_scalar(lhs.base(), lhs.multiplier(), lhs.len(), intercept, operator) + else { + return Ok(None); + }; let nullability = lhs.dtype().nullability() | rhs.dtype().nullability(); let validity = match nullability { @@ -55,18 +55,84 @@ impl CompareKernel for Sequence { Nullability::Nullable => vortex_array::validity::Validity::AllValid, }; - if let Ok(set_idx) = set_idx { - let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx)); - Ok(Some(BoolArray::new(buffer, validity).into_array())) - } else { - Ok(Some( - ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).into_array(), - )) - } + Ok(Some( + BoolArray::new(range_to_bit_buffer(lhs.len(), true_range), validity).into_array(), + )) + } +} + +fn range_to_bit_buffer(len: usize, true_range: Range) -> BitBuffer { + if true_range.start == true_range.end { + return BitBuffer::new_unset(len); + } + if true_range.start == 0 && true_range.end == len { + return BitBuffer::new_set(len); } + + let mut buffer = BitBufferMut::new_unset(len); + buffer.fill_range(true_range.start, true_range.end, true); + buffer.freeze() +} + +fn empty_range() -> Range { + 0..0 +} + +fn full_range(len: usize) -> Range { + 0..len +} + +fn prefix_range(end: usize) -> Range { + 0..end +} + +fn suffix_range(start: usize, len: usize) -> Range { + start..len +} + +fn singleton_range(index: usize) -> Range { + index..index + 1 } -/// Find the index where `base + idx * multiplier == intercept`, if one exists. +fn comparison_matches(ordering: Ordering, operator: CompareOperator) -> bool { + match operator { + CompareOperator::Eq => ordering.is_eq(), + CompareOperator::NotEq => ordering.is_ne(), + CompareOperator::Gt => ordering.is_gt(), + CompareOperator::Gte => ordering.is_ge(), + CompareOperator::Lt => ordering.is_lt(), + CompareOperator::Lte => ordering.is_le(), + } +} + +fn constant_true_range(len: usize, ordering: Ordering, operator: CompareOperator) -> Range { + if comparison_matches(ordering, operator) { + full_range(len) + } else { + empty_range() + } +} + +fn usize_to_u128(value: usize) -> VortexResult { + u128::try_from(value).map_err(|_| vortex_err!("Cannot represent {value} as u128")) +} + +fn usize_to_i128(value: usize) -> VortexResult { + i128::try_from(value).map_err(|_| vortex_err!("Cannot represent {value} as i128")) +} + +fn ceil_div_positive_u128(lhs: u128, rhs: u128) -> u128 { + debug_assert!(rhs > 0); + if lhs == 0 { 0 } else { ((lhs - 1) / rhs) + 1 } +} + +fn ceil_div_positive_i128(lhs: i128, rhs: i128) -> i128 { + debug_assert!(lhs >= 0); + debug_assert!(rhs > 0); + if lhs == 0 { 0 } else { ((lhs - 1) / rhs) + 1 } +} + +/// Find the first index where `base + idx * multiplier == intercept`, if one exists. /// /// # Errors /// Return `VortexError` if: @@ -88,61 +154,321 @@ pub(crate) fn find_intersection_scalar( }) } -fn find_intersection( +fn find_intersection( base: P, multiplier: P, len: usize, intercept: P, ) -> VortexResult { - if len == 0 { - vortex_bail!("len == 0") + let true_range = find_true_range(base, multiplier, len, intercept, CompareOperator::Eq)?; + if true_range.start == true_range.end { + vortex_bail!("{intercept} does not intersect the sequence"); } + Ok(true_range.start) +} - let count = P::from_usize(len - 1).vortex_expect("idx must fit into type"); - let end_element = base + (multiplier * count); +fn find_true_range_scalar( + base: PValue, + multiplier: PValue, + len: usize, + intercept: PValue, + operator: CompareOperator, +) -> VortexResult> { + match_each_integer_ptype!(base.ptype(), |P| { + let intercept = intercept.cast::

()?; + let base = base.cast::

()?; + let multiplier = multiplier.cast::

()?; + find_true_range(base, multiplier, len, intercept, operator) + }) +} - // Handle ascending vs descending sequences - let (min_val, max_val) = if multiplier.is_ge(P::zero()) { - (base, end_element) +fn find_true_range( + base: P, + multiplier: P, + len: usize, + intercept: P, + operator: CompareOperator, +) -> VortexResult> { + if len == 0 { + vortex_bail!("len == 0"); + } + + if P::PTYPE.is_signed_int() { + signed_true_range( + base.to_i128() + .ok_or_else(|| vortex_err!("Cannot represent {base} as i128"))?, + multiplier + .to_i128() + .ok_or_else(|| vortex_err!("Cannot represent {multiplier} as i128"))?, + len, + intercept + .to_i128() + .ok_or_else(|| vortex_err!("Cannot represent {intercept} as i128"))?, + operator, + ) } else { - (end_element, base) - }; + unsigned_true_range( + base.to_u128() + .ok_or_else(|| vortex_err!("Cannot represent {base} as u128"))?, + multiplier + .to_u128() + .ok_or_else(|| vortex_err!("Cannot represent {multiplier} as u128"))?, + len, + intercept + .to_u128() + .ok_or_else(|| vortex_err!("Cannot represent {intercept} as u128"))?, + operator, + ) + } +} - // Check if intercept is in range - if !intercept.is_ge(min_val) || !intercept.is_le(max_val) { - vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range") +#[allow(clippy::manual_is_multiple_of)] +fn unsigned_true_range( + base: u128, + multiplier: u128, + len: usize, + intercept: u128, + operator: CompareOperator, +) -> VortexResult> { + if multiplier == 0 { + return Ok(constant_true_range(len, base.cmp(&intercept), operator)); } - // Handle zero multiplier (constant sequence) - if multiplier == P::zero() { - if intercept == base { - return Ok(0); - } else { - vortex_bail!("{intercept} != {base} with zero multiplier") + let last = base + multiplier * usize_to_u128(len - 1)?; + + let true_range = match operator { + CompareOperator::Eq => { + if intercept < base || intercept > last { + empty_range() + } else { + let diff = intercept - base; + if diff % multiplier == 0 { + singleton_range( + usize::try_from(diff / multiplier) + .map_err(|_| vortex_err!("index does not fit into usize"))?, + ) + } else { + empty_range() + } + } } - } + CompareOperator::Lt => { + let end = if intercept <= base { + 0 + } else if intercept > last { + len + } else { + usize::try_from(ceil_div_positive_u128(intercept - base, multiplier)) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + prefix_range(end) + } + CompareOperator::Lte => { + let end = if intercept < base { + 0 + } else if intercept >= last { + len + } else { + usize::try_from(((intercept - base) / multiplier) + 1) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + prefix_range(end) + } + CompareOperator::Gt => { + let start = if intercept < base { + 0 + } else if intercept >= last { + len + } else { + usize::try_from(((intercept - base) / multiplier) + 1) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + suffix_range(start, len) + } + CompareOperator::Gte => { + let start = if intercept <= base { + 0 + } else if intercept > last { + len + } else { + usize::try_from(ceil_div_positive_u128(intercept - base, multiplier)) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + suffix_range(start, len) + } + CompareOperator::NotEq => vortex_bail!("NotEq cannot be represented as a single range"), + }; - // Check if (intercept - base) is evenly divisible by multiplier - let diff = intercept - base; - if diff % multiplier != P::zero() { - vortex_bail!("{diff} % {multiplier} != 0") + Ok(true_range) +} + +#[allow(clippy::manual_is_multiple_of)] +fn signed_true_range( + base: i128, + multiplier: i128, + len: usize, + intercept: i128, + operator: CompareOperator, +) -> VortexResult> { + if multiplier == 0 { + return Ok(constant_true_range(len, base.cmp(&intercept), operator)); } - let idx = diff / multiplier; - idx.to_usize() - .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize")) + let last = base + multiplier * usize_to_i128(len - 1)?; + + let true_range = if multiplier > 0 { + let max = last; + match operator { + CompareOperator::Eq => { + if intercept < base || intercept > max { + empty_range() + } else { + let diff = intercept - base; + if diff % multiplier == 0 { + singleton_range( + usize::try_from(diff / multiplier) + .map_err(|_| vortex_err!("index does not fit into usize"))?, + ) + } else { + empty_range() + } + } + } + CompareOperator::Lt => { + let end = if intercept <= base { + 0 + } else if intercept > max { + len + } else { + usize::try_from(ceil_div_positive_i128(intercept - base, multiplier)) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + prefix_range(end) + } + CompareOperator::Lte => { + let end = if intercept < base { + 0 + } else if intercept >= max { + len + } else { + usize::try_from(((intercept - base) / multiplier) + 1) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + prefix_range(end) + } + CompareOperator::Gt => { + let start = if intercept < base { + 0 + } else if intercept >= max { + len + } else { + usize::try_from(((intercept - base) / multiplier) + 1) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + suffix_range(start, len) + } + CompareOperator::Gte => { + let start = if intercept <= base { + 0 + } else if intercept > max { + len + } else { + usize::try_from(ceil_div_positive_i128(intercept - base, multiplier)) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + suffix_range(start, len) + } + CompareOperator::NotEq => { + vortex_bail!("NotEq cannot be represented as a single range") + } + } + } else { + let min = last; + let step = -multiplier; + match operator { + CompareOperator::Eq => { + if intercept < min || intercept > base { + empty_range() + } else { + let diff = base - intercept; + if diff % step == 0 { + singleton_range( + usize::try_from(diff / step) + .map_err(|_| vortex_err!("index does not fit into usize"))?, + ) + } else { + empty_range() + } + } + } + CompareOperator::Lt => { + let start = if base < intercept { + 0 + } else if min >= intercept { + len + } else { + usize::try_from(((base - intercept) / step) + 1) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + suffix_range(start, len) + } + CompareOperator::Lte => { + let start = if base <= intercept { + 0 + } else if min > intercept { + len + } else { + usize::try_from(ceil_div_positive_i128(base - intercept, step)) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + suffix_range(start, len) + } + CompareOperator::Gt => { + let end = if base <= intercept { + 0 + } else if min > intercept { + len + } else { + usize::try_from(ceil_div_positive_i128(base - intercept, step)) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + prefix_range(end) + } + CompareOperator::Gte => { + let end = if base < intercept { + 0 + } else if min >= intercept { + len + } else { + usize::try_from(((base - intercept) / step) + 1) + .map_err(|_| vortex_err!("cut-point does not fit into usize"))? + }; + prefix_range(end) + } + CompareOperator::NotEq => { + vortex_bail!("NotEq cannot be represented as a single range") + } + } + }; + + Ok(true_range) } #[cfg(test)] mod tests { + use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; + use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::Nullability::NonNullable; use vortex_array::dtype::Nullability::Nullable; + use vortex_array::scalar_fn::fns::binary::CompareKernel; + use vortex_array::scalar_fn::fns::operators::CompareOperator; use vortex_array::scalar_fn::fns::operators::Operator; + use vortex_session::VortexSession; use crate::Sequence; @@ -181,4 +507,174 @@ mod tests { let expected = BoolArray::from_iter([false, false, false, false]); assert_arrays_eq!(result, expected); } + + #[test] + fn test_compare_range_ascending() { + let lhs = Sequence::try_new_typed(2i64, 3, NonNullable, 5).unwrap(); + + let lt = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(8i64, lhs.len()).into_array(), + Operator::Lt, + ) + .unwrap(); + assert_arrays_eq!(lt, BoolArray::from_iter([true, true, false, false, false])); + + let lte = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(8i64, lhs.len()).into_array(), + Operator::Lte, + ) + .unwrap(); + assert_arrays_eq!(lte, BoolArray::from_iter([true, true, true, false, false])); + + let gt = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(8i64, lhs.len()).into_array(), + Operator::Gt, + ) + .unwrap(); + assert_arrays_eq!(gt, BoolArray::from_iter([false, false, false, true, true])); + + let gte = lhs + .into_array() + .binary(ConstantArray::new(8i64, 5).into_array(), Operator::Gte) + .unwrap(); + assert_arrays_eq!(gte, BoolArray::from_iter([false, false, true, true, true])); + } + + #[test] + fn test_compare_range_descending() { + let lhs = Sequence::try_new_typed(14i64, -3, NonNullable, 5).unwrap(); + + let lt = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(8i64, lhs.len()).into_array(), + Operator::Lt, + ) + .unwrap(); + assert_arrays_eq!(lt, BoolArray::from_iter([false, false, false, true, true])); + + let lte = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(8i64, lhs.len()).into_array(), + Operator::Lte, + ) + .unwrap(); + assert_arrays_eq!(lte, BoolArray::from_iter([false, false, true, true, true])); + + let gt = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(8i64, lhs.len()).into_array(), + Operator::Gt, + ) + .unwrap(); + assert_arrays_eq!(gt, BoolArray::from_iter([true, true, false, false, false])); + + let gte = lhs + .into_array() + .binary(ConstantArray::new(8i64, 5).into_array(), Operator::Gte) + .unwrap(); + assert_arrays_eq!(gte, BoolArray::from_iter([true, true, true, false, false])); + } + + #[test] + fn test_compare_constant_sequence_matches_all() { + let lhs = Sequence::try_new_typed(7i64, 0, NonNullable, 4).unwrap(); + + let eq = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(7i64, lhs.len()).into_array(), + Operator::Eq, + ) + .unwrap(); + assert_arrays_eq!(eq, BoolArray::from_iter([true, true, true, true])); + + let lte = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(7i64, lhs.len()).into_array(), + Operator::Lte, + ) + .unwrap(); + assert_arrays_eq!(lte, BoolArray::from_iter([true, true, true, true])); + + let gt = lhs + .clone() + .into_array() + .binary( + ConstantArray::new(6i64, lhs.len()).into_array(), + Operator::Gt, + ) + .unwrap(); + assert_arrays_eq!(gt, BoolArray::from_iter([true, true, true, true])); + + let lt = lhs + .into_array() + .binary(ConstantArray::new(7i64, 4).into_array(), Operator::Lt) + .unwrap(); + assert_arrays_eq!(lt, BoolArray::from_iter([false, false, false, false])); + } + + #[test] + fn test_compare_nullable_range() { + let lhs = Sequence::try_new_typed(2i64, 3, Nullable, 4).unwrap(); + let rhs = ConstantArray::new(5i64, lhs.len()); + let result = lhs + .into_array() + .binary(rhs.into_array(), Operator::Gte) + .unwrap(); + let expected = BoolArray::from_iter([Some(false), Some(true), Some(true), Some(true)]); + assert_arrays_eq!(result, expected); + } + + #[test] + fn test_compare_swapped_operands() { + let rhs = Sequence::try_new_typed(2i64, 3, NonNullable, 5).unwrap(); + let lhs = ConstantArray::new(8i64, rhs.len()); + let result = lhs + .into_array() + .binary(rhs.into_array(), Operator::Gt) + .unwrap(); + let expected = BoolArray::from_iter([true, true, false, false, false]); + assert_arrays_eq!(result, expected); + } + + #[test] + fn test_compare_unsigned_sequence_range() { + let lhs = Sequence::try_new_typed(2u64, 3, NonNullable, 5).unwrap(); + let rhs = ConstantArray::new(8u64, lhs.len()); + let result = lhs + .into_array() + .binary(rhs.into_array(), Operator::Gte) + .unwrap(); + let expected = BoolArray::from_iter([false, false, true, true, true]); + assert_arrays_eq!(result, expected); + } + + #[test] + fn test_compare_non_constant_rhs_returns_none() { + let lhs = Sequence::try_new_typed(2i64, 3, NonNullable, 5).unwrap(); + let rhs = PrimitiveArray::from_iter([8i64; 5]).into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = Sequence::compare(lhs.as_view(), &rhs, CompareOperator::Lt, &mut ctx).unwrap(); + + assert!(result.is_none()); + } } diff --git a/encodings/sequence/src/rules.rs b/encodings/sequence/src/rules.rs index f87c08e8109..55618a576ff 100644 --- a/encodings/sequence/src/rules.rs +++ b/encodings/sequence/src/rules.rs @@ -1,15 +1,259 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::CheckedAdd; +use num_traits::CheckedMul; +use num_traits::CheckedSub; +use num_traits::Zero; +use vortex_array::ArrayRef; +use vortex_array::ArrayView; +use vortex_array::IntoArray; +use vortex_array::arrays::Constant; +use vortex_array::arrays::ScalarFnVTable; +use vortex_array::arrays::scalar_fn::AnyScalarFn; use vortex_array::arrays::slice::SliceReduceAdaptor; +use vortex_array::dtype::DType; +use vortex_array::dtype::IntegerPType; +use vortex_array::dtype::PType; +use vortex_array::match_each_integer_ptype; +use vortex_array::optimizer::rules::ArrayParentReduceRule; use vortex_array::optimizer::rules::ParentRuleSet; +use vortex_array::scalar::PValue; +use vortex_array::scalar_fn::fns::binary::Binary; use vortex_array::scalar_fn::fns::cast::CastReduceAdaptor; use vortex_array::scalar_fn::fns::list_contains::ListContainsElementReduceAdaptor; +use vortex_array::scalar_fn::fns::operators::Operator; +use vortex_error::VortexResult; use crate::Sequence; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(Sequence)), + ParentRuleSet::lift(&SequenceAffineScalarFnRule), ParentRuleSet::lift(&ListContainsElementReduceAdaptor(Sequence)), ParentRuleSet::lift(&SliceReduceAdaptor(Sequence)), ]); + +#[derive(Debug)] +struct SequenceAffineScalarFnRule; + +impl ArrayParentReduceRule for SequenceAffineScalarFnRule { + type Parent = AnyScalarFn; + + fn reduce_parent( + &self, + sequence: ArrayView<'_, Sequence>, + parent: ArrayView<'_, ScalarFnVTable>, + child_idx: usize, + ) -> VortexResult> { + if parent.nchildren() != 2 || child_idx > 1 { + return Ok(None); + } + + let Some(operator) = parent.scalar_fn().as_opt::().copied() else { + return Ok(None); + }; + + let DType::Primitive(result_ptype, nullability) = parent.dtype() else { + return Ok(None); + }; + if !result_ptype.is_int() { + return Ok(None); + } + + let Some(sibling) = parent.iter_children().nth(child_idx ^ 1) else { + return Ok(None); + }; + let Some(constant) = sibling.as_opt::() else { + return Ok(None); + }; + let Some(constant_value) = constant + .scalar() + .as_primitive_opt() + .and_then(|c| c.pvalue()) + else { + return Ok(None); + }; + + let Some((base, multiplier)) = affine_sequence_parts( + sequence.base(), + sequence.multiplier(), + constant_value, + *result_ptype, + operator, + child_idx == 0, + ) else { + return Ok(None); + }; + + Ok(Sequence::try_new( + base, + multiplier, + *result_ptype, + *nullability, + sequence.len(), + ) + .ok() + .map(|sequence| sequence.into_array())) + } +} + +fn affine_sequence_parts( + base: PValue, + multiplier: PValue, + constant: PValue, + result_ptype: PType, + operator: Operator, + sequence_on_lhs: bool, +) -> Option<(PValue, PValue)> { + match_each_integer_ptype!(result_ptype, |P| { + let base = base.cast::

().ok()?; + let multiplier = multiplier.cast::

().ok()?; + let constant = constant.cast::

().ok()?; + + affine_sequence_parts_typed(base, multiplier, constant, operator, sequence_on_lhs) + .map(|(base, multiplier)| (PValue::from(base), PValue::from(multiplier))) + }) +} + +fn affine_sequence_parts_typed

( + base: P, + multiplier: P, + constant: P, + operator: Operator, + sequence_on_lhs: bool, +) -> Option<(P, P)> +where + P: IntegerPType + CheckedAdd + CheckedSub + CheckedMul + Zero + Copy, + PValue: From

, +{ + match (operator, sequence_on_lhs) { + (Operator::Add, _) => Some((base.checked_add(&constant)?, multiplier)), + (Operator::Sub, true) => Some((base.checked_sub(&constant)?, multiplier)), + (Operator::Sub, false) => Some(( + constant.checked_sub(&base)?, + P::zero().checked_sub(&multiplier)?, + )), + (Operator::Mul, _) => Some(( + base.checked_mul(&constant)?, + multiplier.checked_mul(&constant)?, + )), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::arrays::ConstantArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; + use vortex_array::assert_arrays_eq; + use vortex_array::dtype::Nullability::NonNullable; + use vortex_array::optimizer::ArrayOptimizer; + use vortex_array::scalar_fn::fns::binary::Binary; + use vortex_array::scalar_fn::fns::operators::Operator; + + use crate::Sequence; + use crate::SequenceArray; + + #[rstest] + #[case::seq_plus_const( + Operator::Add, + true, + Sequence::try_new_typed(12i64, 3, NonNullable, 5).unwrap(), + )] + #[case::const_plus_seq( + Operator::Add, + false, + Sequence::try_new_typed(12i64, 3, NonNullable, 5).unwrap(), + )] + #[case::seq_minus_const( + Operator::Sub, + true, + Sequence::try_new_typed(2i64, 3, NonNullable, 5).unwrap(), + )] + #[case::const_minus_seq( + Operator::Sub, + false, + Sequence::try_new_typed(-2i64, -3, NonNullable, 5).unwrap(), + )] + #[case::seq_times_const( + Operator::Mul, + true, + Sequence::try_new_typed(35i64, 15, NonNullable, 5).unwrap(), + )] + #[case::const_times_seq( + Operator::Mul, + false, + Sequence::try_new_typed(35i64, 15, NonNullable, 5).unwrap(), + )] + fn rewrites_affine_binary_ops_to_sequence( + #[case] operator: Operator, + #[case] sequence_on_lhs: bool, + #[case] expected: SequenceArray, + ) { + let sequence = Sequence::try_new_typed(7i64, 3, NonNullable, 5) + .unwrap() + .into_array(); + let constant = ConstantArray::new(5i64, sequence.len()).into_array(); + + let optimized = optimize_binary(sequence, constant, operator, sequence_on_lhs); + + assert!(optimized.is::()); + assert_arrays_eq!(optimized, expected.into_array()); + } + + #[test] + fn falls_back_for_overflow_prone_const_minus_seq() { + let sequence = Sequence::try_new_typed(1i8, i8::MIN, NonNullable, 2) + .unwrap() + .into_array(); + let constant = ConstantArray::new(0i8, sequence.len()).into_array(); + + let optimized = optimize_binary(sequence, constant, Operator::Sub, false); + + assert!(!optimized.is::()); + assert_arrays_eq!( + optimized, + PrimitiveArray::from_iter([-1i8, 127]).into_array() + ); + } + + #[test] + fn keeps_division_on_the_fallback_path() { + let sequence = Sequence::try_new_typed(8i64, 4, NonNullable, 4) + .unwrap() + .into_array(); + let constant = ConstantArray::new(2i64, sequence.len()).into_array(); + + let optimized = optimize_binary(sequence, constant, Operator::Div, true); + + assert!(!optimized.is::()); + assert_arrays_eq!( + optimized, + PrimitiveArray::from_iter([4i64, 6, 8, 10]).into_array() + ); + } + + fn optimize_binary( + sequence: ArrayRef, + constant: ArrayRef, + operator: Operator, + sequence_on_lhs: bool, + ) -> ArrayRef { + let children = if sequence_on_lhs { + vec![sequence, constant] + } else { + vec![constant, sequence] + }; + + Binary + .try_new_array(children[0].len(), operator, children) + .unwrap() + .optimize() + .unwrap() + } +} diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index dbdfd5a2cd9..4b8d4801289 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -152,6 +152,10 @@ required-features = ["_test-harness"] name = "dict_mask" harness = false +[[bench]] +name = "masked_scalar_fn" +harness = false + [[bench]] name = "dict_unreferenced_mask" harness = false diff --git a/vortex-array/benches/masked_scalar_fn.rs b/vortex-array/benches/masked_scalar_fn.rs new file mode 100644 index 00000000000..ea84a0d0b85 --- /dev/null +++ b/vortex-array/benches/masked_scalar_fn.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![allow(clippy::unwrap_used)] + +use std::fmt; + +use divan::Bencher; +use vortex_array::Canonical; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::MaskedArray; +use vortex_array::arrays::dict_test::gen_primitive_for_dict; +use vortex_array::builders::dict::dict_encode; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::scalar_fn::fns::operators::Operator; +use vortex_array::validity::Validity; +use vortex_session::VortexSession; + +fn main() { + divan::main(); +} + +const LEN: usize = 1_000_000; + +#[derive(Clone, Copy, Debug)] +struct BenchArgs { + unique_count: usize, + invalid_stride: usize, +} + +impl fmt::Display for BenchArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "unique_{}_invalid_every_{}", + self.unique_count, self.invalid_stride + ) + } +} + +const BENCH_ARGS: &[BenchArgs] = &[ + BenchArgs { + unique_count: 32, + invalid_stride: 4, + }, + BenchArgs { + unique_count: 32, + invalid_stride: 16, + }, + BenchArgs { + unique_count: 512, + invalid_stride: 4, + }, + BenchArgs { + unique_count: 512, + invalid_stride: 16, + }, +]; + +#[divan::bench(args = BENCH_ARGS)] +fn compare_to_constant(bencher: Bencher, args: BenchArgs) { + let (masked, target) = masked_dict_fixture(args); + let compare_value = ConstantArray::new(target, masked.len()).into_array(); + let session = VortexSession::empty(); + + bencher + .with_inputs(|| { + ( + masked.clone(), + compare_value.clone(), + session.create_execution_ctx(), + ) + }) + .bench_refs(|(masked, compare_value, ctx)| { + let result = masked + .binary(compare_value.clone(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap(); + divan::black_box(result); + }); +} + +fn masked_dict_fixture(args: BenchArgs) -> (vortex_array::ArrayRef, i32) { + let primitive = gen_primitive_for_dict::(LEN, args.unique_count); + let target = primitive.as_slice::()[0]; + let dict = dict_encode(&primitive.into_array()).unwrap(); + let validity = Validity::from_iter((0..LEN).map(|idx| idx % args.invalid_stride != 0)); + let masked = MaskedArray::try_new(dict.into_array(), validity) + .unwrap() + .into_array(); + + (masked, target) +} diff --git a/vortex-array/src/arrays/masked/compute/rules.rs b/vortex-array/src/arrays/masked/compute/rules.rs index 3accb455c3f..1e6d13f89e8 100644 --- a/vortex-array/src/arrays/masked/compute/rules.rs +++ b/vortex-array/src/arrays/masked/compute/rules.rs @@ -1,16 +1,194 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::Constant; use crate::arrays::Masked; +use crate::arrays::ScalarFnArray; +use crate::arrays::ScalarFnVTable; use crate::arrays::dict::TakeReduceAdaptor; use crate::arrays::filter::FilterReduceAdaptor; +use crate::arrays::scalar_fn::AnyScalarFn; +use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::arrays::slice::SliceReduceAdaptor; +use crate::optimizer::ArrayOptimizer; +use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::fns::mask::Mask as MaskExpr; use crate::scalar_fn::fns::mask::MaskReduceAdaptor; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(Masked)), ParentRuleSet::lift(&MaskReduceAdaptor(Masked)), + ParentRuleSet::lift(&MaskedScalarFnPushDownRule), ParentRuleSet::lift(&SliceReduceAdaptor(Masked)), ParentRuleSet::lift(&TakeReduceAdaptor(Masked)), ]); + +#[derive(Debug)] +struct MaskedScalarFnPushDownRule; + +impl ArrayParentReduceRule for MaskedScalarFnPushDownRule { + type Parent = AnyScalarFn; + + fn reduce_parent( + &self, + array: ArrayView<'_, Masked>, + parent: ArrayView<'_, ScalarFnVTable>, + child_idx: usize, + ) -> VortexResult> { + let signature = parent.scalar_fn().signature(); + if signature.is_null_sensitive() || signature.is_fallible() { + return Ok(None); + } + + if !parent + .iter_children() + .enumerate() + .all(|(idx, child)| idx == child_idx || child.is::()) + { + return Ok(None); + } + + let pushed_child = ScalarFnArray::try_new( + parent.scalar_fn().clone(), + parent + .iter_children() + .enumerate() + .map(|(idx, child)| { + if idx == child_idx { + array.child().clone() + } else { + child.clone() + } + }) + .collect(), + parent.len(), + )? + .into_array() + .optimize()?; + + Ok(Some(MaskExpr.try_new_array( + parent.len(), + EmptyOptions, + [pushed_child, array.validity()?.to_array(parent.len())], + )?)) + } +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use super::*; + use crate::Canonical; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::ConstantArray; + use crate::arrays::Dict; + use crate::arrays::DictArray; + use crate::arrays::MaskedArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::ScalarFnVTable as ScalarFnArrayVTable; + use crate::assert_arrays_eq; + use crate::builtins::ArrayBuiltins; + use crate::scalar_fn::ScalarFnVTable as ScalarFnTrait; + use crate::scalar_fn::fns::binary::Binary; + use crate::scalar_fn::fns::is_null::IsNull; + use crate::scalar_fn::fns::operators::Operator; + use crate::validity::Validity; + + #[test] + fn pushes_down_compare_to_masked_dict() -> VortexResult<()> { + let masked = masked_dict_fixture()?; + let optimized = masked.binary( + ConstantArray::new(9i32, masked.len()).into_array(), + Operator::Eq, + )?; + + let encoded_result = if optimized.is::() { + optimized.clone() + } else { + assert!(root_scalar_fn_is::(&optimized)); + optimized.nth_child(0).unwrap() + }; + assert!(encoded_result.is::()); + assert!(!encoded_result.is::()); + + let expected = BoolArray::from_iter([ + Some(false), + None, + Some(false), + Some(true), + None, + Some(false), + ]) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_arrays_eq!( + optimized.execute::(&mut ctx)?.into_array(), + expected + ); + Ok(()) + } + + #[test] + fn keeps_is_null_on_fallback_path() -> VortexResult<()> { + let optimized = masked_dict_fixture()?.is_null()?; + + assert!(root_scalar_fn_is::(&optimized)); + assert!(optimized.nth_child(0).unwrap().is::()); + Ok(()) + } + + #[test] + fn keeps_fallible_binary_on_fallback_path() -> VortexResult<()> { + let masked = masked_dict_fixture()?; + let optimized = masked.binary( + ConstantArray::new(2i32, masked.len()).into_array(), + Operator::Div, + )?; + + assert!(root_scalar_fn_is::(&optimized)); + assert!(optimized.nth_child(0).unwrap().is::()); + Ok(()) + } + + #[test] + fn keeps_non_constant_sibling_on_fallback_path() -> VortexResult<()> { + let masked = masked_dict_fixture()?; + let sibling = PrimitiveArray::from_iter([7i32, 9, 11, 9, 7, 11]).into_array(); + let optimized = masked.binary(sibling, Operator::Eq)?; + + assert!(root_scalar_fn_is::(&optimized)); + assert!(optimized.nth_child(0).unwrap().is::()); + Ok(()) + } + + fn masked_dict_fixture() -> VortexResult { + let dict = DictArray::try_new( + PrimitiveArray::from_iter([0u8, 1, 2, 1, 0, 2]).into_array(), + PrimitiveArray::from_iter([7i32, 9, 11]).into_array(), + )? + .into_array(); + + Ok(MaskedArray::try_new( + dict, + Validity::from_iter([true, false, true, true, false, true]), + )? + .into_array()) + } + + fn root_scalar_fn_is(array: &ArrayRef) -> bool { + array + .as_opt::() + .is_some_and(|scalar_fn| scalar_fn.scalar_fn().is::()) + } +}