From e1e1d09aba64873d6f9d65c7c960e73b33bb618b Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 16:47:49 -0400 Subject: [PATCH] Bring over apache/arrow-rs/9683, integrate into sorts, add heuristic to choose between sort implementations. --- .../physical-plan/src/aggregates/row_hash.rs | 1 + datafusion/physical-plan/src/sorts/mod.rs | 1 + datafusion/physical-plan/src/sorts/radix.rs | 118 ++++++++ datafusion/physical-plan/src/sorts/sort.rs | 267 +++++++++++++++++- datafusion/physical-plan/src/sorts/stream.rs | 12 +- 5 files changed, 390 insertions(+), 9 deletions(-) create mode 100644 datafusion/physical-plan/src/sorts/radix.rs diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 056a7f171a516..a762e2ddd0b4b 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1162,6 +1162,7 @@ impl GroupedHashAggregateStream { emit, self.spill_state.spill_expr.clone(), self.batch_size, + false, ); let spillfile = self .spill_state diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index a73872a175b9b..a83b44e6a2544 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -22,6 +22,7 @@ mod cursor; mod merge; mod multi_level_merge; pub mod partial_sort; +mod radix; pub mod sort; pub mod sort_preserving_merge; mod stream; diff --git a/datafusion/physical-plan/src/sorts/radix.rs b/datafusion/physical-plan/src/sorts/radix.rs new file mode 100644 index 0000000000000..377bad33a9d20 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/radix.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO: replace with arrow_row::radix::radix_sort_to_indices once +// available in arrow-rs (see https://github.com/apache/arrow-rs/pull/9683) + +//! MSD radix sort on row-encoded keys. + +use arrow::array::UInt32Array; +use arrow::row::{RowConverter, Rows, SortField}; +use arrow_ord::sort::SortColumn; +use std::sync::Arc; + +/// 256-bucket histogram + scatter costs more than comparison sort at small n. +const FALLBACK_THRESHOLD: usize = 64; + +/// 8 bytes covers the discriminating prefix of most key layouts; deeper +/// recursion hits diminishing returns as buckets become sparse. +const MAX_DEPTH: usize = 8; + +/// Sort row indices using MSD radix sort on row-encoded keys. +/// +/// Returns a `UInt32Array` of row indices in sorted order. +pub(crate) fn radix_sort_to_indices( + sort_columns: &[SortColumn], +) -> arrow::error::Result { + let sort_fields: Vec = sort_columns + .iter() + .map(|col| { + SortField::new_with_options( + col.values.data_type().clone(), + col.options.unwrap_or_default(), + ) + }) + .collect(); + + let arrays: Vec<_> = sort_columns + .iter() + .map(|col| Arc::clone(&col.values)) + .collect(); + + let converter = RowConverter::new(sort_fields)?; + let rows = converter.convert_columns(&arrays)?; + + let n = rows.num_rows(); + let mut indices: Vec = (0..n as u32).collect(); + let mut temp = vec![0u32; n]; + msd_radix_sort(&mut indices, &mut temp, &rows, 0); + Ok(UInt32Array::from(indices)) +} + +fn msd_radix_sort(indices: &mut [u32], temp: &mut [u32], rows: &Rows, byte_pos: usize) { + let n = indices.len(); + + if n <= FALLBACK_THRESHOLD || byte_pos >= MAX_DEPTH { + indices.sort_unstable_by(|&a, &b| { + let ra = unsafe { rows.row_unchecked(a as usize) }; + let rb = unsafe { rows.row_unchecked(b as usize) }; + ra.cmp(&rb) + }); + return; + } + + let mut counts = [0u32; 256]; + for &idx in &*indices { + let row = unsafe { rows.row_unchecked(idx as usize) }; + let byte = row.data().get(byte_pos).copied().unwrap_or(0); + counts[byte as usize] += 1; + } + + // Skip scatter when all rows share the same byte + if counts.iter().filter(|&&c| c > 0).count() == 1 { + msd_radix_sort(indices, temp, rows, byte_pos + 1); + return; + } + + let mut offsets = [0u32; 257]; + for i in 0..256 { + offsets[i + 1] = offsets[i] + counts[i]; + } + + let temp = &mut temp[..n]; + let mut write_pos = offsets; + for &idx in &*indices { + let row = unsafe { rows.row_unchecked(idx as usize) }; + let byte = row.data().get(byte_pos).copied().unwrap_or(0) as usize; + temp[write_pos[byte] as usize] = idx; + write_pos[byte] += 1; + } + indices.copy_from_slice(temp); + + for bucket in 0..256 { + let start = offsets[bucket] as usize; + let end = offsets[bucket + 1] as usize; + if end - start > 1 { + msd_radix_sort( + &mut indices[start..end], + &mut temp[start..end], + rows, + byte_pos + 1, + ); + } + } +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 583bfa29b04ad..ddd97c330e330 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -55,7 +55,7 @@ use crate::{ use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::SpillCompression; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ @@ -220,6 +220,8 @@ struct ExternalSorter { /// the data will be concatenated and sorted in place rather than /// sort/merged. sort_in_place_threshold_bytes: usize, + /// Whether to use radix sort (decided once from expression types). + use_radix: bool, // ======================================================================== // STATE BUFFERS: @@ -294,6 +296,12 @@ impl ExternalSorter { ) .with_compression_type(spill_compression); + let sort_data_types: Vec = expr + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>()?; + let use_radix = use_radix_sort(&sort_data_types.iter().collect::>()); + Ok(Self { schema, in_mem_batches: vec![], @@ -308,6 +316,7 @@ impl ExternalSorter { batch_size, sort_spill_reservation_bytes, sort_in_place_threshold_bytes, + use_radix, }) } @@ -735,13 +744,15 @@ impl ExternalSorter { let schema = batch.schema(); let expressions = self.expr.clone(); let batch_size = self.batch_size; + let use_radix = self.use_radix; let output_row_metrics = metrics.output_rows().clone(); let stream = futures::stream::once(async move { let schema = batch.schema(); // Sort the batch immediately and get all output batches - let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; + let sorted_batches = + sort_batch_chunked(&batch, &expressions, batch_size, use_radix)?; // Resize the reservation to match the actual sorted output size. // Using try_resize avoids a release-then-reacquire cycle, which @@ -886,6 +897,24 @@ pub fn sort_batch( .map(|expr| expr.evaluate_to_sort_column(batch)) .collect::>>()?; + if fetch.is_none() + && use_radix_sort( + &sort_columns + .iter() + .map(|c| c.values.data_type()) + .collect::>(), + ) + { + let indices = super::radix::radix_sort_to_indices(&sort_columns)?; + let columns = take_arrays(batch.columns(), &indices, None)?; + let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); + return Ok(RecordBatch::try_new_with_options( + batch.schema(), + columns, + &options, + )?); + } + let indices = lexsort_to_indices(&sort_columns, fetch)?; let columns = take_arrays(batch.columns(), &indices, None)?; @@ -897,6 +926,36 @@ pub fn sort_batch( )?) } +/// Returns true if radix sort should be used for the given sort column types. +/// +/// Radix sort is faster for most multi-column sorts but falls back to +/// lexsort when: +/// - All sort columns are dictionary-typed (long shared row prefixes +/// waste radix passes before falling back to comparison sort) +/// - Any sort column is a nested type (encoding cost is high and lexsort +/// short-circuits comparison on leading columns) +pub(super) fn use_radix_sort(data_types: &[&DataType]) -> bool { + if data_types.is_empty() { + return false; + } + + let mut all_dict = true; + for dt in data_types { + match dt { + DataType::Dictionary(_, _) => {} + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Struct(_) + | DataType::Map(_, _) + | DataType::Union(_, _) => return false, + _ => all_dict = false, + } + } + + !all_dict +} + /// Sort a batch and return the result as multiple batches of size `batch_size`. /// This is useful when you want to avoid creating one large sorted batch in memory, /// and instead want to process the sorted data in smaller chunks. @@ -904,8 +963,15 @@ pub fn sort_batch_chunked( batch: &RecordBatch, expressions: &LexOrdering, batch_size: usize, + use_radix: bool, ) -> Result> { - IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect() + IncrementalSortIterator::new( + batch.clone(), + expressions.clone(), + batch_size, + use_radix, + ) + .collect() } /// Sort execution plan. @@ -2594,7 +2660,7 @@ mod tests { [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); // Sort with batch_size = 250 - let result_batches = sort_batch_chunked(&batch, &expressions, 250)?; + let result_batches = sort_batch_chunked(&batch, &expressions, 250, false)?; // Verify 4 batches are returned assert_eq!(result_batches.len(), 4); @@ -2647,7 +2713,7 @@ mod tests { [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); // Sort with batch_size = 100 - let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + let result_batches = sort_batch_chunked(&batch, &expressions, 100, false)?; // Should return exactly 1 batch assert_eq!(result_batches.len(), 1); @@ -2679,7 +2745,7 @@ mod tests { [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); // Sort with batch_size = 100 - let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + let result_batches = sort_batch_chunked(&batch, &expressions, 100, false)?; // Should return exactly 10 batches of 100 rows each assert_eq!(result_batches.len(), 10); @@ -2706,7 +2772,7 @@ mod tests { let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); - let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + let result_batches = sort_batch_chunked(&batch, &expressions, 100, false)?; // Empty input produces no output batches (0 chunks) assert_eq!(result_batches.len(), 0); @@ -2929,4 +2995,191 @@ mod tests { assert_eq!(desc.self_filters()[0].len(), 1); Ok(()) } + + #[test] + fn test_sort_batch_radix_multi_column() { + let a1: ArrayRef = Arc::new(Int32Array::from(vec![2, 1, 2, 1])); + let a2: ArrayRef = Arc::new(Int32Array::from(vec![4, 3, 2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![a1, a2]).unwrap(); + + let expressions = LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + ]) + .unwrap(); + + // No fetch -> should take the radix path + let sorted = sort_batch(&batch, &expressions, None).unwrap(); + let col_a = sorted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let col_b = sorted + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col_a.values(), &[1, 1, 2, 2]); + assert_eq!(col_b.values(), &[1, 3, 2, 4]); + } + + #[test] + fn test_sort_batch_lexsort_with_fetch() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![5, 3, 1, 4, 2])); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new(schema, vec![a]).unwrap(); + + let expressions = LexOrdering::new(vec![PhysicalSortExpr::new_default( + Arc::new(Column::new("a", 0)), + )]) + .unwrap(); + + // With fetch -> should use lexsort path + let sorted = sort_batch(&batch, &expressions, Some(2)).unwrap(); + assert_eq!(sorted.num_rows(), 2); + let col = sorted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[1, 2]); + } + + #[test] + fn test_use_radix_sort_heuristic() { + // Primitive columns -> radix + assert!(use_radix_sort(&[&DataType::Int32])); + + // All dictionary -> lexsort + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + assert!(!use_radix_sort(&[&dict_type])); + + // List column -> lexsort + let list_type = + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); + assert!(!use_radix_sort(&[&list_type])); + + // Mixed dict + primitive -> radix + assert!(use_radix_sort(&[&dict_type, &DataType::Int32])); + + // Empty -> no radix + assert!(!use_radix_sort(&[])); + } + + #[test] + fn test_sort_batch_radix_with_nulls_and_options() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(3), + None, + Some(1), + None, + Some(2), + ])); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let batch = RecordBatch::try_new(schema, vec![a]).unwrap(); + + // Descending, nulls first + let expressions = LexOrdering::new(vec![PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + SortOptions { + descending: true, + nulls_first: true, + }, + )]) + .unwrap(); + + let sorted = sort_batch(&batch, &expressions, None).unwrap(); + let col = sorted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + // nulls first, then descending: NULL, NULL, 3, 2, 1 + assert!(col.is_null(0)); + assert!(col.is_null(1)); + assert_eq!(col.value(2), 3); + assert_eq!(col.value(3), 2); + assert_eq!(col.value(4), 1); + } + + #[test] + fn test_sort_batch_radix_matches_lexsort() { + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(0xCAFE); + + for _ in 0..50 { + let len = rng.random_range(10..500); + let a1: ArrayRef = Arc::new(Int32Array::from( + (0..len) + .map(|_| { + if rng.random_bool(0.1) { + None + } else { + Some(rng.random_range(-100..100)) + } + }) + .collect::>(), + )); + let a2: ArrayRef = Arc::new(StringArray::from( + (0..len) + .map(|_| { + if rng.random_bool(0.1) { + None + } else { + Some( + ["alpha", "beta", "gamma", "delta", "epsilon"] + [rng.random_range(0..5)], + ) + } + }) + .collect::>(), + )); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new(schema, vec![a1, a2]).unwrap(); + + let desc = rng.random_bool(0.5); + let nf = rng.random_bool(0.5); + let opts = SortOptions { + descending: desc, + nulls_first: nf, + }; + let expressions = LexOrdering::new(vec![ + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), opts), + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), opts), + ]) + .unwrap(); + + // fetch=Some(len) forces the lexsort path while returning all rows + let lexsort_result = + sort_batch(&batch, &expressions, Some(len as usize)).unwrap(); + // fetch=None takes the radix path for these column types + let radix_result = sort_batch(&batch, &expressions, None).unwrap(); + + assert_eq!( + radix_result.num_rows(), + lexsort_result.num_rows(), + "row count mismatch" + ); + + for col_idx in 0..batch.num_columns() { + assert_eq!( + radix_result.column(col_idx).as_ref(), + lexsort_result.column(col_idx).as_ref(), + "column {col_idx} mismatch" + ); + } + } + } } diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index ff7f259dd1347..9d297c4ea679f 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -299,6 +299,7 @@ pub(crate) struct IncrementalSortIterator { batch: RecordBatch, expressions: LexOrdering, batch_size: usize, + use_radix: bool, indices: Option, cursor: usize, } @@ -308,11 +309,13 @@ impl IncrementalSortIterator { batch: RecordBatch, expressions: LexOrdering, batch_size: usize, + use_radix: bool, ) -> Self { Self { batch, expressions, batch_size, + use_radix, cursor: 0, indices: None, } @@ -339,7 +342,12 @@ impl Iterator for IncrementalSortIterator { Err(e) => return Some(Err(e)), }; - let indices = match lexsort_to_indices(&sort_columns, None) { + let indices = if self.use_radix { + super::radix::radix_sort_to_indices(&sort_columns) + } else { + lexsort_to_indices(&sort_columns, None) + }; + let indices = match indices { Ok(indices) => indices, Err(e) => return Some(Err(e.into())), }; @@ -414,7 +422,7 @@ mod tests { .unwrap(); let mut total_rows = 0; - IncrementalSortIterator::new(batch.clone(), expressions, batch_size).try_for_each( + IncrementalSortIterator::new(batch.clone(), expressions, batch_size, false).try_for_each( |result| { let chunk = result?; total_rows += chunk.num_rows();