diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 4d2e37924a..ad3774567c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -343,6 +343,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.shuffle.directRead.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Only applies to native shuffle (not JVM columnar shuffle). " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..c9db6cc8aa 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -26,9 +26,9 @@ use crate::{ }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; -use arrow::array::{Array, RecordBatch, UInt32Array}; +use arrow::array::{Array, ArrayRef, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; -use arrow::datatypes::DataType as ArrowDataType; +use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use datafusion::common::{DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::MemoryPool; @@ -39,7 +39,7 @@ use datafusion::{ physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream}, prelude::{SessionConfig, SessionContext}, }; -use datafusion_comet_proto::spark_operator::Operator; +use datafusion_comet_proto::spark_operator::{self, Operator}; use datafusion_spark::function::bitwise::bit_count::SparkBitCount; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::bitwise::bitwise_not::SparkBitwiseNot; @@ -72,6 +72,7 @@ use jni::{ sys::{jboolean, jdouble, jint, jlong}, JNIEnv, }; +use prost::Message; use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; @@ -82,8 +83,9 @@ use tokio::sync::mpsc; use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; -use crate::execution::operators::ScanExec; -use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; +use crate::execution::operators::{ScanExec, ShuffleScanExec}; +use crate::execution::shuffle::read_shuffle_block; +use crate::execution::shuffle::CompressionCodec; use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; @@ -151,6 +153,8 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, + /// The shuffle scan input sources for the DataFusion plan + pub shuffle_scans: Vec, /// The global reference of input sources for the DataFusion plan pub input_sources: Vec>, /// The record batch stream to pull results from @@ -311,6 +315,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( partition_count: partition_count as usize, root_op: None, scans: vec![], + shuffle_scans: vec![], input_sources, stream: None, batch_receiver: None, @@ -491,6 +496,10 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr exec_context.scans.iter_mut().try_for_each(|scan| { scan.get_next_batch()?; Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) }) } @@ -539,7 +548,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) .with_exec_id(exec_context_id); - let (scans, root_op) = planner.create_plan( + let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), exec_context.partition_count, @@ -548,6 +557,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.plan_creation_time += physical_plan_time; exec_context.scans = scans; + exec_context.shuffle_scans = shuffle_scans; if exec_context.explain_native { let formatted_plan_str = @@ -560,7 +570,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // so we should always execute partition 0. let stream = root_op.native_plan.execute(0, task_ctx)?; - if exec_context.scans.is_empty() { + if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { // No JVM data sources — spawn onto tokio so the executor // thread parks in blocking_recv instead of busy-polling. // @@ -867,6 +877,79 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( }) } +/// Casts any dictionary-encoded columns in a RecordBatch to their value types. +/// Used by the JNI decode path where the JVM expects plain Arrow types. +fn unpack_dictionary_columns(batch: RecordBatch) -> DataFusionResult { + let mut needs_cast = false; + for col in batch.columns() { + if matches!(col.data_type(), ArrowDataType::Dictionary(_, _)) { + needs_cast = true; + break; + } + } + if !needs_cast { + return Ok(batch); + } + + let mut new_columns: Vec = Vec::with_capacity(batch.num_columns()); + let mut new_fields: Vec = Vec::with_capacity(batch.num_columns()); + for (col, field) in batch.columns().iter().zip(batch.schema().fields()) { + match col.data_type() { + ArrowDataType::Dictionary(_, value_type) => { + new_columns.push(arrow::compute::cast(col.as_ref(), value_type)?); + new_fields.push(Arc::new(arrow::datatypes::Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ))); + } + _ => { + new_columns.push(Arc::clone(col)); + new_fields.push(Arc::clone(field)); + } + } + } + let schema = Arc::new(Schema::new(new_fields)); + Ok(RecordBatch::try_new(schema, new_columns)?) +} + +// Thread-local cache for the parsed shuffle schema. The schema bytes are +// identical for every call within a given shuffle reader, so we avoid +// re-parsing protobuf and re-allocating Field/Schema objects on each batch. +thread_local! { + static CACHED_SHUFFLE_SCHEMA: std::cell::RefCell, Arc)>> = + const { std::cell::RefCell::new(None) }; +} + +/// Parse shuffle schema from protobuf bytes, using a thread-local cache. +fn get_or_parse_shuffle_schema( + env: &mut JNIEnv, + schema_bytes: &JByteArray, +) -> CometResult> { + let schema_vec = env.convert_byte_array(schema_bytes)?; + + CACHED_SHUFFLE_SCHEMA.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some((ref cached_bytes, ref cached_schema)) = *cache { + if cached_bytes == &schema_vec { + return Ok(Arc::clone(cached_schema)); + } + } + + let shuffle_scan = spark_operator::ShuffleScan::decode(schema_vec.as_slice()) + .map_err(|e| CometError::Internal(format!("Failed to parse shuffle schema: {e}")))?; + let fields: Vec = shuffle_scan + .fields + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("c{i}"), to_arrow_datatype(dt), true)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + *cache = Some((schema_vec, Arc::clone(&schema))); + Ok(schema) + }) +} + #[no_mangle] /// Used by Comet native shuffle reader /// # Safety @@ -878,14 +961,23 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( length: jint, array_addrs: JLongArray, schema_addrs: JLongArray, + schema_bytes: JByteArray, tracing_enabled: jboolean, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { + let schema = get_or_parse_shuffle_schema(&mut env, &schema_bytes)?; + with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || { let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - let batch = read_ipc_compressed(slice)?; + let batch = read_shuffle_block(slice, &schema)?; + + // The raw shuffle format may preserve dictionary-encoded columns. + // The JVM side expects plain (non-dictionary) types, so cast any + // dictionary columns to their value types before FFI export. + let batch = unpack_dictionary_columns(batch)?; + prepare_output(&mut env, array_addrs, schema_addrs, batch, false) }) }) diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..ad3ec3f08b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -34,7 +34,9 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; +mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; +pub use shuffle_scan::ShuffleScanExec; /// Error returned during executing operators. #[derive(thiserror::Error, Debug)] diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 6ba1bb5d59..194fa6769a 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,8 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - operators::{ExecutionError, ScanExec}, - planner::{operator_registry::OperatorBuilder, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PhysicalPlanner, PlanCreationResult}, spark_plan::SparkPlan, }, extract_op, @@ -42,12 +41,13 @@ impl OperatorBuilder for ProjectionBuilder { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let project = extract_op!(spark_plan, Projection); let children = &spark_plan.children; assert_eq!(children.len(), 1); - let (scans, child) = planner.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + planner.create_plan(&children[0], inputs, partition_count)?; // Create projection expressions let exprs: Result, _> = project @@ -68,6 +68,7 @@ impl OperatorBuilder for ProjectionBuilder { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), )) } diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs new file mode 100644 index 0000000000..9f4102e601 --- /dev/null +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -0,0 +1,397 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometError, + execution::{ + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, + shuffle::codec::read_shuffle_block, + }, + jvm_bridge::{jni_call, JVMClasses}, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::{arrow_datafusion_err, Result as DataFusionResult}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::{ + execution::TaskContext, + physical_expr::*, + physical_plan::{ExecutionPlan, *}, +}; +use futures::Stream; +use jni::objects::{GlobalRef, JByteBuffer, JObject}; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use super::scan::InputBatch; + +/// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. +/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed +/// bytes from CometShuffleBlockIterator and decodes them using read_shuffle_block(). +#[derive(Debug, Clone)] +pub struct ShuffleScanExec { + /// The ID of the execution context that owns this subquery. + pub exec_context_id: i64, + /// The input source: a global reference to a JVM CometShuffleBlockIterator object. + pub input_source: Option>, + /// The data types of columns in the shuffle output. + pub data_types: Vec, + /// Schema of the shuffle output. + pub schema: SchemaRef, + /// The current input batch, populated by get_next_batch() before poll_next(). + pub batch: Arc>>, + /// Cache of plan properties. + cache: PlanProperties, + /// Metrics collector. + metrics: ExecutionPlanMetricsSet, + /// Baseline metrics. + baseline_metrics: BaselineMetrics, + /// Time spent decoding compressed shuffle blocks. + decode_time: Time, +} + +impl ShuffleScanExec { + pub fn new( + exec_context_id: i64, + input_source: Option>, + data_types: Vec, + ) -> Result { + let metrics_set = ExecutionPlanMetricsSet::default(); + let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); + let decode_time = MetricBuilder::new(&metrics_set).subset_time("decode_time", 0); + + let schema = schema_from_data_types(&data_types); + + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(Self { + exec_context_id, + input_source, + data_types, + batch: Arc::new(Mutex::new(None)), + cache, + metrics: metrics_set, + baseline_metrics, + schema, + decode_time, + }) + } + + /// Feeds input batch into this scan. Only used in unit tests. + pub fn set_input_batch(&mut self, input: InputBatch) { + *self.batch.try_lock().unwrap() = Some(input); + } + + /// Pull next input batch from JVM. Called externally before poll_next() + /// because JNI calls cannot happen from within poll_next on tokio threads. + pub fn get_next_batch(&mut self) -> Result<(), CometError> { + if self.input_source.is_none() { + // Unit test mode - no JNI calls needed. + return Ok(()); + } + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + + let mut current_batch = self.batch.try_lock().unwrap(); + if current_batch.is_none() { + let next_batch = Self::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + &self.data_types, + &self.decode_time, + &self.schema, + )?; + *current_batch = Some(next_batch); + } + + timer.stop(); + + Ok(()) + } + + /// Invokes JNI calls to get the next compressed shuffle block and decode it. + fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], + decode_time: &Time, + schema: &Arc, + ) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + return Ok(InputBatch::EOF); + } + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null shuffle block iterator object. Plan id: {exec_context_id}" + )))); + } + + let mut env = JVMClasses::get_env()?; + + // has_next() reads the next block and returns its length, or -1 if EOF + let block_length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length == -1 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer containing the compressed shuffle block + let buffer: JObject = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }; + + let byte_buffer = JByteBuffer::from(buffer); + let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let length = block_length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decode the compressed shuffle block + let mut timer = decode_time.timer(); + let batch = read_shuffle_block(slice, schema)?; + timer.stop(); + + let num_rows = batch.num_rows(); + + // Extract column arrays from the RecordBatch. + let columns: Vec = batch.columns().to_vec(); + + debug_assert_eq!( + columns.len(), + data_types.len(), + "Shuffle block column count mismatch: got {} but expected {}", + columns.len(), + data_types.len() + ); + + Ok(InputBatch::new(columns, Some(num_rows))) + } +} + +fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef { + let fields = data_types + .iter() + .enumerate() + .map(|(idx, dt)| Field::new(format!("col_{idx}"), dt.clone(), true)) + .collect::>(); + + Arc::new(Schema::new(fields)) +} + +impl ExecutionPlan for ShuffleScanExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _: Arc, + ) -> datafusion::common::Result { + Ok(Box::pin(ShuffleScanStream::new( + self.clone(), + self.schema(), + partition, + self.baseline_metrics.clone(), + ))) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn name(&self) -> &str { + "ShuffleScanExec" + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +impl DisplayAs for ShuffleScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let fields: Vec = self + .data_types + .iter() + .enumerate() + .map(|(idx, dt)| format!("col_{idx}: {dt}")) + .collect(); + write!(f, "ShuffleScanExec: schema=[{}]", fields.join(", "))?; + } + DisplayFormatType::TreeRender => unimplemented!(), + } + Ok(()) + } +} + +/// An async stream that feeds decoded shuffle batches into the DataFusion plan. +struct ShuffleScanStream { + /// The ShuffleScanExec producing input batches. + shuffle_scan: ShuffleScanExec, + /// Schema of the output. + schema: SchemaRef, + /// Metrics. + baseline_metrics: BaselineMetrics, +} + +impl ShuffleScanStream { + pub fn new( + shuffle_scan: ShuffleScanExec, + schema: SchemaRef, + _partition: usize, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + shuffle_scan, + schema, + baseline_metrics, + } + } +} + +impl Stream for ShuffleScanStream { + type Item = DataFusionResult; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + let mut scan_batch = self.shuffle_scan.batch.try_lock().unwrap(); + + let input_batch = &*scan_batch; + let input_batch = if let Some(batch) = input_batch { + batch + } else { + timer.stop(); + return Poll::Pending; + }; + + let result = match input_batch { + InputBatch::EOF => Poll::Ready(None), + InputBatch::Batch(columns, num_rows) => { + self.baseline_metrics.record_output(*num_rows); + let options = + arrow::array::RecordBatchOptions::new().with_row_count(Some(*num_rows)); + let maybe_batch = arrow::array::RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + columns.clone(), + &options, + ) + .map_err(|e| arrow_datafusion_err!(e)); + Poll::Ready(Some(maybe_batch)) + } + }; + + *scan_batch = None; + + timer.stop(); + + result + } +} + +impl RecordBatchStream for ShuffleScanStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +#[cfg(test)] +mod tests { + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + use crate::execution::shuffle::codec::read_shuffle_block; + + #[test] + #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) + fn test_read_compressed_block() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + // Write as compressed raw batch + let writer = + ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back (skip 16-byte header: 8 compressed_length + 8 field_count) + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &schema).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify data + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert_eq!(col0.value(1), 2); + assert_eq!(col0.value(2), 3); + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bd37755922..b5892d763c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -27,7 +27,7 @@ use crate::{ errors::ExpressionError, execution::{ expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, @@ -141,6 +141,8 @@ use url::Url; type PhyAggResult = Result, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +pub type PlanCreationResult = + Result<(Vec, Vec, Arc), ExecutionError>; struct JoinParameters { pub left: Arc, @@ -913,7 +915,7 @@ impl PhysicalPlanner { spark_plan: &'a Operator, inputs: &mut Vec>, partition_count: usize, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { // Try to use the modular registry first - this automatically handles any registered operator types if OperatorRegistry::global().can_handle(spark_plan) { return OperatorRegistry::global().create_plan( @@ -929,7 +931,8 @@ impl PhysicalPlanner { match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Filter(filter) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; @@ -940,12 +943,14 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])), )) } OpStruct::HashAgg(agg) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let group_exprs: PhyExprResult = agg .grouping_exprs @@ -996,6 +1001,7 @@ impl PhysicalPlanner { if agg.result_exprs.is_empty() { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), )) } else { @@ -1012,6 +1018,7 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, projection, @@ -1030,7 +1037,8 @@ impl PhysicalPlanner { "Invalid limit/offset combination: [{num}. {offset}]" ))); } - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let limit: Arc = if offset == 0 { Arc::new(LocalLimitExec::new( Arc::clone(&child.native_plan), @@ -1050,12 +1058,14 @@ impl PhysicalPlanner { }; Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])), )) } OpStruct::Sort(sort) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let exprs: Result, ExecutionError> = sort .sort_orders @@ -1079,6 +1089,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, sort_exec, @@ -1115,6 +1126,7 @@ impl PhysicalPlanner { if partition_files.partitioned_file.is_empty() { let empty_exec = Arc::new(EmptyExec::new(required_schema)); return Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, empty_exec, vec![])), )); @@ -1205,6 +1217,7 @@ impl PhysicalPlanner { common.encryption_enabled, )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1243,6 +1256,7 @@ impl PhysicalPlanner { &scan.csv_options.clone().unwrap(), )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1276,6 +1290,7 @@ impl PhysicalPlanner { Ok(( vec![scan.clone()], + vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), )) } @@ -1307,6 +1322,7 @@ impl PhysicalPlanner { )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new( spark_plan.plan_id, @@ -1317,7 +1333,8 @@ impl PhysicalPlanner { } OpStruct::ShuffleWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), @@ -1350,6 +1367,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, shuffle_writer, @@ -1359,7 +1377,8 @@ impl PhysicalPlanner { } OpStruct::ParquetWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let codec = match writer.compression.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), @@ -1396,6 +1415,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, parquet_writer, @@ -1405,7 +1425,8 @@ impl PhysicalPlanner { } OpStruct::Expand(expand) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let mut projections = vec![]; let mut projection = vec![]; @@ -1448,12 +1469,14 @@ impl PhysicalPlanner { let expand = Arc::new(ExpandExec::new(projections, input, schema)); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])), )) } OpStruct::Explode(explode) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; // Create the expression for the array to explode let child_expr = if let Some(child_expr) = &explode.child { @@ -1559,11 +1582,12 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, unnest_exec, vec![child])), )) } OpStruct::SortMergeJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1615,6 +1639,7 @@ impl PhysicalPlanner { )); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, coalesce_batches, @@ -1628,6 +1653,7 @@ impl PhysicalPlanner { } else { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, join, @@ -1640,7 +1666,7 @@ impl PhysicalPlanner { } } OpStruct::HashJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1670,6 +1696,7 @@ impl PhysicalPlanner { if join.build_side == BuildSide::BuildLeft as i32 { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, hash_join, @@ -1688,6 +1715,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, swapped_hash_join, @@ -1698,7 +1726,8 @@ impl PhysicalPlanner { } } OpStruct::Window(wnd) => { - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let input_schema = child.schema(); let sort_exprs: Result, ExecutionError> = wnd .order_by_list @@ -1736,9 +1765,37 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError("No input for shuffle scan".to_string())); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = + ShuffleScanExec::new(self.exec_context_id, input_source, data_types)?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new( + spark_plan.plan_id, + Arc::new(shuffle_scan), + vec![], + )), + )) + } _ => Err(GeneralError(format!( "Unsupported or unregistered operator type: {:?}", spark_plan.op_struct @@ -1756,12 +1813,15 @@ impl PhysicalPlanner { join_type: i32, condition: &Option, partition_count: usize, - ) -> Result<(JoinParameters, Vec), ExecutionError> { + ) -> Result<(JoinParameters, Vec, Vec), ExecutionError> { assert_eq!(children.len(), 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs, partition_count)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs, partition_count)?; + let (mut left_scans, mut left_shuffle_scans, left) = + self.create_plan(&children[0], inputs, partition_count)?; + let (mut right_scans, mut right_shuffle_scans, right) = + self.create_plan(&children[1], inputs, partition_count)?; left_scans.append(&mut right_scans); + left_shuffle_scans.append(&mut right_shuffle_scans); let left_join_exprs: Vec<_> = left_join_keys .iter() @@ -1882,6 +1942,7 @@ impl PhysicalPlanner { join_filter, }, left_scans, + left_shuffle_scans, )) } @@ -3670,7 +3731,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3744,7 +3806,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3791,7 +3854,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3876,7 +3940,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3900,7 +3965,8 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = + planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); @@ -4014,7 +4080,7 @@ mod tests { })), }; - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); @@ -4140,7 +4206,7 @@ mod tests { }; // Create a physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Start executing the plan in a separate thread @@ -4631,7 +4697,7 @@ mod tests { }; // Create the physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Create test data: Date32 and Int8 columns diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index b34a80df95..cad5df40c5 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -25,11 +25,8 @@ use std::{ use datafusion_comet_proto::spark_operator::Operator; use jni::objects::GlobalRef; -use super::PhysicalPlanner; -use crate::execution::{ - operators::{ExecutionError, ScanExec}, - spark_plan::SparkPlan, -}; +use super::{PhysicalPlanner, PlanCreationResult}; +use crate::execution::operators::ExecutionError; /// Trait for building physical operators from Spark protobuf operators pub trait OperatorBuilder: Send + Sync { @@ -40,7 +37,7 @@ pub trait OperatorBuilder: Send + Sync { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError>; + ) -> PlanCreationResult; } /// Enum to identify different operator types for registry dispatch @@ -100,7 +97,7 @@ impl OperatorRegistry { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let operator_type = get_operator_type(spark_operator).ok_or_else(|| { ExecutionError::GeneralError(format!( "Unsupported operator type: {:?}", @@ -153,5 +150,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option { OpStruct::Window(_) => Some(OperatorType::Window), OpStruct::Explode(_) => None, // Not yet in OperatorType enum OpStruct::CsvScan(_) => Some(OperatorType::CsvScan), + OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum } } diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 33e6989d4c..52cc320e6c 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -16,17 +16,19 @@ // under the License. use crate::errors::{CometError, CometResult}; -use arrow::array::RecordBatch; +use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, RecordBatch}; +use arrow::buffer::Buffer; +use arrow::datatypes::DataType; use arrow::datatypes::Schema; -use arrow::ipc::reader::StreamReader; -use arrow::ipc::writer::StreamWriter; +use arrow::record_batch::RecordBatchOptions; use bytes::Buf; use crc32fast::Hasher; use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::metrics::Time; use simd_adler32::Adler32; -use std::io::{Cursor, Seek, SeekFrom, Write}; +use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::sync::Arc; #[derive(Debug, Clone)] pub enum CompressionCodec { @@ -42,6 +44,132 @@ pub struct ShuffleBlockWriter { header_bytes: Vec, } +/// Recursively writes raw Arrow ArrayData buffers to the given writer. +/// Arrays must be normalized to zero offset before calling this function. +fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> Result<()> { + debug_assert_eq!(data.offset(), 0, "shuffle arrays must have offset 0"); + + // Write null_count + let null_count = data.null_count() as u32; + writer.write_all(&null_count.to_le_bytes())?; + + // Write validity bitmap + if null_count > 0 { + if let Some(bitmap) = data.nulls() { + debug_assert_eq!(bitmap.offset(), 0, "null bitmap must have offset 0"); + let bitmap_bytes = bitmap.buffer().as_slice(); + let len = bitmap_bytes.len() as u32; + writer.write_all(&len.to_le_bytes())?; + writer.write_all(bitmap_bytes)?; + } else { + writer.write_all(&0u32.to_le_bytes())?; + } + } else { + writer.write_all(&0u32.to_le_bytes())?; + } + + // Write buffers + let num_buffers = data.buffers().len() as u32; + writer.write_all(&num_buffers.to_le_bytes())?; + for buffer in data.buffers() { + let len: u32 = buffer.len().try_into().map_err(|_| { + DataFusionError::Execution(format!("Buffer length {} exceeds u32::MAX", buffer.len())) + })?; + writer.write_all(&len.to_le_bytes())?; + writer.write_all(buffer.as_slice())?; + } + + // Write children + let num_children = data.child_data().len() as u32; + writer.write_all(&num_children.to_le_bytes())?; + for child in data.child_data() { + let child_num_rows = child.len() as u32; + writer.write_all(&child_num_rows.to_le_bytes())?; + write_array_data(child, writer)?; + } + + Ok(()) +} + +/// Ensures an array has zero offset in both its data and null buffer by +/// producing a physical copy when necessary. This is required because our +/// raw buffer format writes buffers verbatim and assumes offset 0. +fn normalize_array(col: &ArrayRef) -> Result { + let needs_copy = col.offset() != 0 || col.nulls().is_some_and(|nulls| nulls.offset() != 0); + if needs_copy { + // Use MutableArrayData::extend for a direct memcpy rather than + // take() which builds an index array and does per-element lookups. + let data = col.to_data(); + let mut mutable = MutableArrayData::new(vec![&data], false, col.len()); + mutable.extend(0, 0, col.len()); + Ok(make_array(mutable.freeze())) + } else { + Ok(Arc::clone(col)) + } +} + +// Column encoding tags for the raw shuffle format. +const COL_TAG_PLAIN: u8 = 0; +const COL_TAG_DICTIONARY: u8 = 1; + +/// Encode a dictionary key DataType as a single byte. +fn encode_key_type(dt: &DataType) -> Result { + match dt { + DataType::Int8 => Ok(0), + DataType::Int16 => Ok(1), + DataType::Int32 => Ok(2), + DataType::Int64 => Ok(3), + DataType::UInt8 => Ok(4), + DataType::UInt16 => Ok(5), + DataType::UInt32 => Ok(6), + DataType::UInt64 => Ok(7), + _ => Err(DataFusionError::Execution(format!( + "unsupported dictionary key type: {dt:?}" + ))), + } +} + +/// Decode a dictionary key DataType from the byte written by encode_key_type. +fn decode_key_type(tag: u8) -> Result { + match tag { + 0 => Ok(DataType::Int8), + 1 => Ok(DataType::Int16), + 2 => Ok(DataType::Int32), + 3 => Ok(DataType::Int64), + 4 => Ok(DataType::UInt8), + 5 => Ok(DataType::UInt16), + 6 => Ok(DataType::UInt32), + 7 => Ok(DataType::UInt64), + _ => Err(DataFusionError::Execution(format!( + "unknown dictionary key type tag: {tag}" + ))), + } +} + +/// Writes a RecordBatch in raw buffer format. Dictionary arrays are preserved +/// (not cast to value type) so they compress better. Each column is prefixed +/// with a tag byte indicating plain (0) or dictionary (1) encoding. +fn write_raw_batch(batch: &RecordBatch, writer: &mut W) -> Result<()> { + let num_rows = batch.num_rows() as u32; + writer.write_all(&num_rows.to_le_bytes())?; + + for col in batch.columns() { + let col = normalize_array(col)?; + match col.data_type() { + DataType::Dictionary(key_type, _) => { + writer.write_all(&[COL_TAG_DICTIONARY])?; + writer.write_all(&[encode_key_type(key_type)?])?; + } + _ => { + writer.write_all(&[COL_TAG_PLAIN])?; + } + } + write_array_data(&col.to_data(), writer)?; + } + + Ok(()) +} + impl ShuffleBlockWriter { pub fn try_new(schema: &Schema, codec: CompressionCodec) -> Result { let header_bytes = Vec::with_capacity(20); @@ -71,8 +199,13 @@ impl ShuffleBlockWriter { }) } - /// Writes given record batch as Arrow IPC bytes into given writer. + /// Writes given record batch in raw buffer format into given writer. /// Returns number of bytes written. + /// + /// The batch is first serialized to an intermediate buffer, then compressed + /// in one shot. This avoids creating a streaming compression encoder per + /// batch (expensive for Zstd ~128KB state) and gives us the uncompressed + /// size for a pre-allocation hint on the read side. pub fn write_batch( &self, batch: &RecordBatch, @@ -89,57 +222,49 @@ impl ShuffleBlockWriter { // write header output.write_all(&self.header_bytes)?; - let output = match &self.codec { + // Serialize raw batch to intermediate buffer + let mut raw_buf = Vec::new(); + write_raw_batch(batch, &mut raw_buf)?; + + // Write uncompressed size hint (u32), then compressed data + let uncompressed_len = raw_buf.len() as u32; + output.write_all(&uncompressed_len.to_le_bytes())?; + + match &self.codec { CompressionCodec::None => { - let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - arrow_writer.into_inner()? + output.write_all(&raw_buf)?; } CompressionCodec::Lz4Frame => { - let mut wtr = lz4_flex::frame::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.finish().map_err(|e| { - DataFusionError::Execution(format!("lz4 compression error: {e}")) - })? + let compressed = lz4_flex::compress(&raw_buf); + output.write_all(&compressed)?; } - CompressionCodec::Zstd(level) => { - let encoder = zstd::Encoder::new(output, *level)?; - let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - let zstd_encoder = arrow_writer.into_inner()?; - zstd_encoder.finish()? + let compressed = zstd::bulk::compress(&raw_buf, *level)?; + output.write_all(&compressed)?; } - CompressionCodec::Snappy => { - let mut wtr = snap::write::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; + let mut wtr = snap::write::FrameEncoder::new(output.by_ref()); + wtr.write_all(&raw_buf)?; wtr.into_inner().map_err(|e| { DataFusionError::Execution(format!("snappy compression error: {e}")) - })? + })?; } }; - // fill ipc length + // fill block length let end_pos = output.stream_position()?; - let ipc_length = end_pos - start_pos - 8; + let block_length = end_pos - start_pos - 8; let max_size = i32::MAX as u64; - if ipc_length > max_size { + if block_length > max_size { return Err(DataFusionError::Execution(format!( - "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ + "Shuffle block size {block_length} exceeds maximum size of {max_size}. \ Try reducing batch size or increasing compression level" ))); } - // fill ipc length + // fill block length output.seek(SeekFrom::Start(start_pos))?; - output.write_all(&ipc_length.to_le_bytes())?; + output.write_all(&block_length.to_le_bytes())?; output.seek(SeekFrom::Start(end_pos))?; timer.stop(); @@ -148,38 +273,193 @@ impl ShuffleBlockWriter { } } -pub fn read_ipc_compressed(bytes: &[u8]) -> Result { - match &bytes[0..4] { +// --------------------------------------------------------------------------- +// Read-side helpers +// --------------------------------------------------------------------------- + +fn read_u32(cursor: &mut &[u8]) -> Result { + if cursor.len() < 4 { + return Err(DataFusionError::Execution( + "unexpected end of shuffle block data".to_string(), + )); + } + let (bytes, rest) = cursor.split_at(4); + *cursor = rest; + Ok(u32::from_le_bytes(bytes.try_into().unwrap())) +} + +fn read_bytes<'a>(cursor: &mut &'a [u8], len: usize) -> Result<&'a [u8]> { + if cursor.len() < len { + return Err(DataFusionError::Execution( + "unexpected end of shuffle block data".to_string(), + )); + } + let (bytes, rest) = cursor.split_at(len); + *cursor = rest; + Ok(bytes) +} + +/// Returns child data types for nested Arrow types. +fn get_child_types(data_type: &DataType) -> Vec { + match data_type { + DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) => { + vec![field.data_type().clone()] + } + DataType::Map(field, _) => { + // Map's single child is a struct with key/value fields + vec![field.data_type().clone()] + } + DataType::Struct(fields) => fields.iter().map(|f| f.data_type().clone()).collect(), + DataType::Dictionary(_, value_type) => vec![value_type.as_ref().clone()], + _ => vec![], + } +} + +/// Reconstructs ArrayData from raw buffer format (reverse of write_array_data). +fn read_array_data( + cursor: &mut &[u8], + data_type: &DataType, + num_rows: usize, +) -> Result { + let null_count = read_u32(cursor)? as usize; + + // Read validity bitmap + let bitmap_len = read_u32(cursor)? as usize; + let null_buffer = if bitmap_len > 0 { + let bytes = read_bytes(cursor, bitmap_len)?; + Some(Buffer::from(bytes)) + } else { + None + }; + + // Read buffers + let num_buffers = read_u32(cursor)? as usize; + let mut buffers = Vec::with_capacity(num_buffers); + for _ in 0..num_buffers { + let buf_len = read_u32(cursor)? as usize; + let bytes = read_bytes(cursor, buf_len)?; + buffers.push(Buffer::from(bytes)); + } + + // Read children + let num_children = read_u32(cursor)? as usize; + let child_types = get_child_types(data_type); + let mut child_data = Vec::with_capacity(num_children); + for i in 0..num_children { + let child_num_rows = read_u32(cursor)? as usize; + let child_type = child_types.get(i).ok_or_else(|| { + DataFusionError::Execution(format!( + "unexpected child index {i} for data type {data_type:?}" + )) + })?; + child_data.push(read_array_data(cursor, child_type, child_num_rows)?); + } + + // Build ArrayData without validation (data came from our own writer) + let mut builder = arrow::array::ArrayData::builder(data_type.clone()) + .len(num_rows) + .null_count(null_count); + + if let Some(nb) = null_buffer { + builder = builder.null_bit_buffer(Some(nb)); + } + + for buf in buffers { + builder = builder.add_buffer(buf); + } + + for child in child_data { + builder = builder.add_child_data(child); + } + + // SAFETY: data was written by write_array_data from valid Arrow arrays + Ok(unsafe { builder.build_unchecked() }) +} + +/// Read a raw batch from decompressed bytes, given the expected schema. +/// Columns that were written as dictionary-encoded are reconstructed as +/// DictionaryArray. The returned batch schema may differ from the input +/// schema (Dictionary vs plain types) — callers that need plain types +/// should cast dictionary columns afterward. +fn read_raw_batch(bytes: &[u8], schema: &Arc) -> Result { + let mut cursor = bytes; + + let num_rows = read_u32(&mut cursor)? as usize; + + let mut columns: Vec = Vec::with_capacity(schema.fields().len()); + let mut fields: Vec = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + // Read per-column tag + let tag = read_bytes(&mut cursor, 1)?[0]; + let data_type = match tag { + COL_TAG_PLAIN => field.data_type().clone(), + COL_TAG_DICTIONARY => { + let key_tag = read_bytes(&mut cursor, 1)?[0]; + let key_type = decode_key_type(key_tag)?; + DataType::Dictionary(Box::new(key_type), Box::new(field.data_type().clone())) + } + _ => { + return Err(DataFusionError::Execution(format!( + "unknown column tag: {tag}" + ))); + } + }; + let array_data = read_array_data(&mut cursor, &data_type, num_rows)?; + columns.push(make_array(array_data)); + fields.push(Arc::new(arrow::datatypes::Field::new( + field.name(), + data_type, + field.is_nullable(), + ))); + } + + let actual_schema = Arc::new(Schema::new(fields)); + let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); + let batch = RecordBatch::try_new_with_options(actual_schema, columns, &options)?; + Ok(batch) +} + +/// Reads and decompresses a shuffle block written in raw buffer format. +/// The `bytes` slice starts at the codec tag (after the 8-byte length and +/// 8-byte field_count header that the JVM reads). +/// +/// Format: `[codec_tag: 4 bytes][uncompressed_len: u32][compressed_data...]` +pub fn read_shuffle_block(bytes: &[u8], schema: &Arc) -> Result { + let codec_tag = &bytes[0..4]; + // Read uncompressed size hint for pre-allocation + let uncompressed_len = u32::from_le_bytes(bytes[4..8].try_into().unwrap()) as usize; + let data = &bytes[8..]; + + match codec_tag { b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) + let decoder = snap::read::FrameDecoder::new(data); + let decompressed = read_all_with_capacity(decoder, uncompressed_len)?; + read_raw_batch(&decompressed, schema) } b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) + let decompressed = lz4_flex::decompress(data, uncompressed_len) + .map_err(|e| DataFusionError::Execution(format!("lz4 decompression error: {e}")))?; + read_raw_batch(&decompressed, schema) } b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"NONE" => { - let mut reader = - unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) + let decompressed = zstd::bulk::decompress(data, uncompressed_len)?; + read_raw_batch(&decompressed, schema) } + b"NONE" => read_raw_batch(data, schema), other => Err(DataFusionError::Execution(format!( "Failed to decode batch: invalid compression codec: {other:?}" ))), } } -/// Checksum algorithms for writing IPC bytes. +/// Read all bytes from a reader into a pre-allocated Vec. +fn read_all_with_capacity(mut reader: R, capacity: usize) -> Result> { + let mut buf = Vec::with_capacity(capacity); + reader.read_to_end(&mut buf)?; + Ok(buf) +} + +/// Checksum algorithms for writing shuffle bytes. #[derive(Clone)] pub(crate) enum Checksum { /// CRC32 checksum algorithm. @@ -237,3 +517,410 @@ impl Checksum { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::*; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + fn make_test_batch() -> (Arc, RecordBatch) { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])), + ], + ) + .unwrap(); + (schema, batch) + } + + fn roundtrip(codec: CompressionCodec) { + let (schema, batch) = make_test_batch(); + let writer = ShuffleBlockWriter::try_new(&schema, codec).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + let bytes = buf.into_inner(); + // Skip 16-byte header: 8 compressed_length + 8 field_count + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &schema).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify Int32 column (nullable) + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert!(col0.is_null(1)); + assert_eq!(col0.value(2), 3); + + // Verify Float64 column (non-nullable) + let col1 = decoded + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col1.value(0), 1.0); + assert_eq!(col1.value(1), 2.0); + assert_eq!(col1.value(2), 3.0); + } + + #[test] + fn test_raw_roundtrip_primitives_none() { + roundtrip(CompressionCodec::None); + } + + #[test] + fn test_raw_roundtrip_primitives_lz4() { + roundtrip(CompressionCodec::Lz4Frame); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_raw_roundtrip_primitives_zstd() { + roundtrip(CompressionCodec::Zstd(1)); + } + + #[test] + fn test_raw_roundtrip_primitives_snappy() { + roundtrip(CompressionCodec::Snappy); + } + + /// Generic roundtrip helper: writes a batch with ShuffleBlockWriter, + /// reads it back with read_shuffle_block, and asserts equality for all + /// four compression codecs. + fn roundtrip_test(schema: Arc, batch: &RecordBatch) { + let codecs = vec![ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + ]; + for codec in codecs { + let writer = ShuffleBlockWriter::try_new(&schema, codec.clone()).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(batch, &mut buf, &ipc_time).unwrap(); + + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &schema).unwrap(); + assert_eq!(decoded.num_rows(), batch.num_rows()); + assert_eq!(decoded.num_columns(), batch.num_columns()); + for i in 0..batch.num_columns() { + assert_eq!( + batch.column(i).as_ref(), + decoded.column(i).as_ref(), + "column {i} mismatch with codec {:?}", + codec + ); + } + } + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_string_and_binary() { + let schema = Arc::new(Schema::new(vec![ + Field::new("s", DataType::Utf8, true), + Field::new("b", DataType::Binary, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![ + Some("hello"), + None, + Some("world"), + Some(""), + ])), + Arc::new(BinaryArray::from(vec![ + Some(b"abc" as &[u8]), + Some(b"\x00\x01\x02"), + None, + Some(b""), + ])), + ], + ) + .unwrap(); + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_boolean_and_null() { + let schema = Arc::new(Schema::new(vec![ + Field::new("bool", DataType::Boolean, true), + Field::new("n", DataType::Null, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(vec![ + Some(true), + None, + Some(false), + Some(true), + ])), + Arc::new(NullArray::new(4)), + ], + ) + .unwrap(); + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_decimal_date_timestamp() { + let schema = Arc::new(Schema::new(vec![ + Field::new("dec", DataType::Decimal128(18, 3), true), + Field::new("date", DataType::Date32, true), + Field::new( + "ts", + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), + true, + ), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new( + Decimal128Array::from(vec![Some(12345_i128), None, Some(-99999)]) + .with_precision_and_scale(18, 3) + .unwrap(), + ), + Arc::new(Date32Array::from(vec![Some(18000), None, Some(19000)])), + Arc::new(TimestampMicrosecondArray::from(vec![ + Some(1_000_000), + None, + Some(2_000_000), + ])), + ], + ) + .unwrap(); + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_nested_types() { + // List with nulls at the list level + let list_field = Field::new_list("l", Field::new("item", DataType::Int32, true), true); + + // Struct with nulls + let struct_field = Field::new( + "st", + DataType::Struct( + vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Utf8, true), + ] + .into(), + ), + true, + ); + + // Map + let map_field = Field::new( + "m", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Int32, true), + ] + .into(), + ), + false, + )), + false, + ), + true, + ); + + let schema = Arc::new(Schema::new(vec![list_field, struct_field, map_field])); + + // Build List + let list_arr = { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.append(false); // null list + builder.values().append_value(3); + builder.append(true); + builder.finish() + }; + + // Build Struct + let struct_arr = StructArray::from(vec![ + ( + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![Some(10), None, Some(30)])) as ArrayRef, + ), + ( + Arc::new(Field::new("y", DataType::Utf8, true)), + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])) as ArrayRef, + ), + ]); + // Apply null at row 1 + let struct_arr = StructArray::try_new( + struct_arr.fields().clone(), + struct_arr.columns().to_vec(), + Some(arrow::buffer::NullBuffer::from(vec![true, false, true])), + ) + .unwrap(); + + // Build Map + let map_arr = { + let key_builder = StringBuilder::new(); + let value_builder = Int32Builder::new(); + let mut builder = MapBuilder::new(None, key_builder, value_builder); + builder.keys().append_value("k1"); + builder.values().append_value(100); + builder.append(true).unwrap(); + builder.append(false).unwrap(); // null map entry + builder.keys().append_value("k2"); + builder.values().append_value(200); + builder.append(true).unwrap(); + builder.finish() + }; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(list_arr), Arc::new(struct_arr), Arc::new(map_arr)], + ) + .unwrap(); + + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_dictionary_preserved() { + // Dictionary should be preserved through write/read + let dict_schema = Arc::new(Schema::new(vec![Field::new( + "d", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + + let keys = Int32Array::from(vec![Some(0), Some(1), None, Some(0)]); + let values = StringArray::from(vec!["foo", "bar"]); + let dict_arr = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let batch = + RecordBatch::try_new(Arc::clone(&dict_schema), vec![Arc::new(dict_arr)]).unwrap(); + + // The read schema uses the value type (Utf8) since that's what Spark knows about. + // The reader reconstructs a Dictionary type from the per-column tag. + let read_schema = Arc::new(Schema::new(vec![Field::new("d", DataType::Utf8, true)])); + + let codecs = vec![ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + ]; + for codec in codecs { + let writer = ShuffleBlockWriter::try_new(&dict_schema, codec.clone()).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &read_schema).unwrap(); + assert_eq!(decoded.num_rows(), 4); + + // Result should be a DictionaryArray (preserved, not cast) + let col = decoded + .column(0) + .as_any() + .downcast_ref::>() + .expect("expected DictionaryArray to be preserved"); + // Verify values by casting to string for comparison + let cast_col = arrow::compute::cast(col, &DataType::Utf8).expect("cast dict to utf8"); + let str_col = cast_col.as_any().downcast_ref::().unwrap(); + assert_eq!(str_col.value(0), "foo"); + assert_eq!(str_col.value(1), "bar"); + assert!(str_col.is_null(2)); + assert_eq!(str_col.value(3), "foo"); + } + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_sliced_batch() { + // Test that arrays with non-zero offsets (from slicing) roundtrip correctly. + // This is important because the shuffle writer uses debug_assert for offset==0, + // but in release builds sliced arrays could silently produce wrong results. + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + let full_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + None, + Some(6), + ])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("bb"), + None, + Some("dddd"), + Some("eeeee"), + None, + ])), + ], + ) + .unwrap(); + + // Slice the batch to get arrays with non-zero offset + let sliced = full_batch.slice(2, 3); // rows: [Some(3), Some(4), None] and [None, Some("dddd"), Some("eeeee")] + assert_eq!(sliced.num_rows(), 3); + roundtrip_test(schema, &sliced); + } + + #[test] + fn test_empty_batch_returns_zero() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(Vec::::new()))], + ) + .unwrap(); + assert_eq!(batch.num_rows(), 0); + + let writer = ShuffleBlockWriter::try_new(&schema, CompressionCodec::None).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + let bytes_written = writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + assert_eq!(bytes_written, 0); + } +} diff --git a/native/core/src/execution/shuffle/metrics.rs b/native/core/src/execution/shuffle/metrics.rs index 33b51c3cd8..6c768bf92f 100644 --- a/native/core/src/execution/shuffle/metrics.rs +++ b/native/core/src/execution/shuffle/metrics.rs @@ -26,7 +26,7 @@ pub(super) struct ShufflePartitionerMetrics { /// Time to perform repartitioning pub(super) repart_time: Time, - /// Time encoding batches to IPC format + /// Time encoding batches to shuffle format pub(super) encode_time: Time, /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics. diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index 6018cff50f..19e23bb72c 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -23,6 +23,6 @@ mod shuffle_writer; pub mod spark_unsafe; mod writers; -pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter}; +pub use codec::{read_shuffle_block, CompressionCodec, ShuffleBlockWriter}; pub use comet_partitioning::CometPartitioning; pub use shuffle_writer::ShuffleWriterExec; diff --git a/native/core/src/execution/shuffle/partitioners/multi_partition.rs b/native/core/src/execution/shuffle/partitioners/multi_partition.rs index 9c366ad462..4da855b605 100644 --- a/native/core/src/execution/shuffle/partitioners/multi_partition.rs +++ b/native/core/src/execution/shuffle/partitioners/multi_partition.rs @@ -555,7 +555,7 @@ impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { .await } - /// Writes buffered shuffled record batches into Arrow IPC bytes. + /// Writes buffered shuffled record batches to the output shuffle file. fn shuffle_write(&mut self) -> datafusion::common::Result<()> { with_trace("shuffle_write", self.tracing_enabled, || { let start_time = Instant::now(); diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index fe1bf0fccf..2b1cba7049 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -265,7 +265,7 @@ async fn external_shuffle( #[cfg(test)] mod test { use super::*; - use crate::execution::shuffle::{read_ipc_compressed, ShuffleBlockWriter}; + use crate::execution::shuffle::{read_shuffle_block, ShuffleBlockWriter}; use arrow::array::{Array, StringArray, StringBuilder}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -285,8 +285,9 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn roundtrip_ipc() { + fn roundtrip_raw() { let batch = create_batch(8192); + let schema = batch.schema(); for codec in &[ CompressionCodec::None, CompressionCodec::Zstd(1), @@ -295,15 +296,14 @@ mod test { ] { let mut output = vec![]; let mut cursor = Cursor::new(&mut output); - let writer = - ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone()).unwrap(); + let writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); let length = writer .write_batch(&batch, &mut cursor, &Time::default()) .unwrap(); assert_eq!(length, output.len()); - let ipc_without_length_prefix = &output[16..]; - let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + let block_without_length_prefix = &output[16..]; + let batch2 = read_shuffle_block(block_without_length_prefix, &schema).unwrap(); assert_eq!(batch, batch2); } } @@ -587,7 +587,7 @@ mod test { } /// Test that batch coalescing in BufBatchWriter reduces output size by - /// writing fewer, larger IPC blocks instead of many small ones. + /// writing fewer, larger blocks instead of many small ones. #[test] #[cfg_attr(miri, ignore)] fn test_batch_coalescing_reduces_size() { @@ -651,7 +651,7 @@ mod test { buf_writer.flush(&encode_time, &write_time).unwrap(); } - // Coalesced output should be smaller due to fewer IPC schema blocks + // Coalesced output should be smaller due to fewer block headers assert!( coalesced_output.len() < uncoalesced_output.len(), "Coalesced output ({} bytes) should be smaller than uncoalesced ({} bytes)", @@ -659,9 +659,9 @@ mod test { uncoalesced_output.len() ); - // Verify both roundtrip correctly by reading all IPC blocks - let coalesced_rows = read_all_ipc_blocks(&coalesced_output); - let uncoalesced_rows = read_all_ipc_blocks(&uncoalesced_output); + // Verify both roundtrip correctly by reading all shuffle blocks + let coalesced_rows = read_all_shuffle_blocks(&coalesced_output, &schema); + let uncoalesced_rows = read_all_shuffle_blocks(&uncoalesced_output, &schema); assert_eq!( coalesced_rows, 5000, "Coalesced should contain all 5000 rows" @@ -672,22 +672,22 @@ mod test { ); } - /// Read all IPC blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, + /// Read all shuffle blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, /// returning the total number of rows. - fn read_all_ipc_blocks(data: &[u8]) -> usize { + fn read_all_shuffle_blocks(data: &[u8], schema: &Arc) -> usize { let mut offset = 0; let mut total_rows = 0; while offset < data.len() { - // First 8 bytes are the IPC length (little-endian u64) - let ipc_length = + // First 8 bytes are the block length (little-endian u64) + let block_length = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; // Skip the 8-byte length prefix; the next 8 bytes are field_count + codec header let block_start = offset + 8; - let block_end = block_start + ipc_length; - // read_ipc_compressed expects data starting after the 16-byte header + let block_end = block_start + block_length; + // read_shuffle_block expects data starting after the 16-byte header // (i.e., after length + field_count), at the codec tag - let ipc_data = &data[block_start + 8..block_end]; - let batch = read_ipc_compressed(ipc_data).unwrap(); + let block_data = &data[block_start + 8..block_end]; + let batch = read_shuffle_block(block_data, schema).unwrap(); total_rows += batch.num_rows(); offset = block_end; } diff --git a/native/core/src/execution/shuffle/writers/buf_batch_writer.rs b/native/core/src/execution/shuffle/writers/buf_batch_writer.rs index 8d056d7bb0..afbb3f09ed 100644 --- a/native/core/src/execution/shuffle/writers/buf_batch_writer.rs +++ b/native/core/src/execution/shuffle/writers/buf_batch_writer.rs @@ -27,7 +27,7 @@ use std::io::{Cursor, Seek, SeekFrom, Write}; /// Once the buffer exceeds the max size, the buffer will be flushed to the writer. /// /// Small batches are coalesced using Arrow's [`BatchCoalescer`] before serialization, -/// producing exactly `batch_size`-row output batches to reduce per-block IPC schema overhead. +/// producing exactly `batch_size`-row output batches to reduce per-block serialization overhead. /// The coalescer is lazily initialized on the first write. pub(crate) struct BufBatchWriter, W: Write> { shuffle_block_writer: S, diff --git a/native/core/src/jvm_bridge/mod.rs b/native/core/src/jvm_bridge/mod.rs index 00fe7b33c3..85c2ae7577 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/core/src/jvm_bridge/mod.rs @@ -174,11 +174,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod shuffle_block_iterator; use crate::{errors::CometError, JAVA_VM}; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. #[allow(dead_code)] // we need to keep references to Java items to prevent GC @@ -204,6 +206,8 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. + pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -257,6 +261,7 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/core/src/jvm_bridge/shuffle_block_iterator.rs b/native/core/src/jvm_bridge/shuffle_block_iterator.rs new file mode 100644 index 0000000000..c3bb5af5fb --- /dev/null +++ b/native/core/src/jvm_bridge/shuffle_block_iterator.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// A struct that holds all the JNI methods and fields for JVM `CometShuffleBlockIterator` class. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..344b9f0f21 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -52,6 +52,7 @@ message Operator { ParquetWriter parquet_writer = 113; Explode explode = 114; CsvScan csv_scan = 115; + ShuffleScan shuffle_scan = 116; } } @@ -85,6 +86,12 @@ message Scan { bool arrow_ffi_safe = 3; } +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} + // Common data shared by all partitions in split mode (sent once at planning) message NativeScanCommon { repeated SparkStructField required_schema = 1; diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java new file mode 100644 index 0000000000..02526a6e63 --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads the + * compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() and + * getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_shuffle_block which allocates new memory for the + * decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. Called by + * native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + 8-byte + * fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + return -1; + } + throw new EOFException("Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + + bytesToRead + + " exceeds maximum of " + + Integer.MAX_VALUE + + ". Try reducing shuffle batch size."); + } + + if (dataBuf.capacity() < bytesToRead) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit((int) bytesToRead); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + currentBlockLength = (int) bytesToRead; + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes (4-byte codec + * prefix + compressed IPC data). Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** Returns the length of the current block in bytes. Called by native code via JNI. */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 44ebf7e36e..e198ac99ff 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -67,7 +67,8 @@ class CometExecIterator( numParts: Int, partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) extends Iterator[ColumnarBatch] with Logging { @@ -78,8 +79,13 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - private val cometBatchIterators = inputs.map { iterator => - new CometBatchIterator(iterator, nativeUtil) + // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle + // scan indices, CometBatchIterator for regular scan indices. + private val inputIterators: Array[Object] = inputs.zipWithIndex.map { + case (_, idx) if shuffleBlockIterators.contains(idx) => + shuffleBlockIterators(idx).asInstanceOf[Object] + case (iterator, _) => + new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] }.toArray private val plan = { @@ -106,7 +112,7 @@ class CometExecIterator( nativeLib.createPlan( id, - cometBatchIterators, + inputIterators, protobufQueryPlan, protobufSparkConfigs, numParts, @@ -229,6 +235,7 @@ class CometExecIterator( currentBatch = null } nativeUtil.close() + shuffleBlockIterators.values.foreach(_.close()) nativeLib.releasePlan(plan) if (tracingEnabled) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..3571d9f6cd 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -54,7 +54,7 @@ class Native extends NativeBase { // scalastyle:off @native def createPlan( id: Long, - iterators: Array[CometBatchIterator], + iterators: Array[Object], plan: Array[Byte], configMapProto: Array[Byte], partitionCount: Int, @@ -177,6 +177,7 @@ class Native extends NativeBase { length: Int, arrayAddrs: Array[Long], schemaAddrs: Array[Long], + schemaBytes: Array[Byte], tracingEnabled: Boolean): Long /** diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index ca9dbdad7c..dde36d9789 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -22,8 +22,12 @@ package org.apache.comet.serde.operator import scala.jdk.CollectionConverters._ import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.ConfigEntry import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} @@ -86,15 +90,67 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { object CometExchangeSink extends CometSink[SparkPlan] { - /** - * Exchange data is FFI safe because there is no use of mutable buffers involved. - * - * Source of broadcast exchange batches is ArrowStreamReader. - * - * Source of shuffle exchange batches is NativeBatchDecoderIterator. - */ override def isFfiSafe: Boolean = true + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.exists(_.shuffleType == CometNativeShuffle) + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, "unsupported data types for shuffle direct read") + None + } + } + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = CometSinkPlaceHolder(nativeOp, op, op) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index ad0c4f2afe..cb8652507f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -64,7 +64,10 @@ private[spark] class CometExecRDD( nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -109,6 +112,12 @@ private[spark] class CometExecRDD( serializedPlan } + // Create shuffle block iterators for indices that have factories + val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => + val inputPart = partition.inputPartitions(idx) + idx -> factory(context, inputPart) + } + val it = new CometExecIterator( CometExec.newIterId, inputs, @@ -118,7 +127,8 @@ private[spark] class CometExecRDD( numPartitions, partition.index, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIters) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -167,7 +177,10 @@ object CometExecRDD { nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = { + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -181,6 +194,7 @@ object CometExecRDD { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index e95eb92d21..dfec6e4474 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -36,6 +36,8 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator import org.apache.spark.util.CompletionIterator import org.apache.comet.{CometConf, Native} +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.QueryPlanSerde.serializeDataType import org.apache.comet.vector.NativeUtil /** @@ -86,6 +88,17 @@ class CometBlockStoreShuffleReader[K, C]( fetchContinuousBlocksInBatch).toCompletionIterator } + /** Serialize the output schema as a ShuffleScan protobuf message. */ + private lazy val schemaBytes: Array[Byte] = { + import scala.jdk.CollectionConverters._ + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val scanTypes = dep.outputAttributes.flatMap { attr => + serializeDataType(attr.dataType) + } + scanBuilder.addAllFields(scanTypes.asJava) + scanBuilder.build().toByteArray + } + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { var currentReadIterator: NativeBatchDecoderIterator = null @@ -114,6 +127,7 @@ class CometBlockStoreShuffleReader[K, C]( dep.decodeTime, nativeLib, nativeUtil, + schemaBytes, tracingEnabled) currentReadIterator }) @@ -153,6 +167,18 @@ class CometBlockStoreShuffleReader[K, C]( } } + /** + * Returns the raw concatenated InputStream of all shuffle blocks, bypassing the decode step. + * Used by ShuffleScan direct read path. + */ + def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2) + new java.io.SequenceInputStream(new java.util.Enumeration[InputStream] { + override def hasMoreElements: Boolean = streams.hasNext + override def nextElement(): InputStream = streams.next() + }) + } + private def fetchContinuousBlocksInBatch: Boolean = { val conf = SparkEnv.get.conf val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index d65a6b21f4..02d707b5a4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -832,7 +832,8 @@ object CometShuffleExchangeExec shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, schema = Some(fromAttributes(outputAttributes)), - decodeTime = writeMetrics("decode_time")) + decodeTime = writeMetrics("decode_time"), + outputAttributes = outputAttributes) dependency } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index ba6fc588e2..6594982c85 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ class CometShuffledBatchRDD( var dependency: ShuffleDependency[Int, _, _], - metrics: Map[String, SQLMetric], + val metrics: Map[String, SQLMetric], partitionSpecs: Array[ShufflePartitionSpec]) extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index f96c8f16dd..ed2287ffc7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -39,6 +39,7 @@ case class NativeBatchDecoderIterator( decodeTime: SQLMetric, nativeLib: Native, nativeUtil: NativeUtil, + schemaBytes: Array[Byte], tracingEnabled: Boolean) extends Iterator[ColumnarBatch] { @@ -160,6 +161,7 @@ case class NativeBatchDecoderIterator( bytesToRead.toInt, arrayAddrs, schemaAddrs, + schemaBytes, tracingEnabled) }) decodeTime.add(System.nanoTime() - startTime) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..2e195e73eb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,14 +34,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.execution.shuffle.{CometBlockStoreShuffleReader, CometShuffledBatchRDD, CometShuffleExchangeExec, ShuffledRowRDDPartition} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -50,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import com.google.protobuf.CodedOutputStream -import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry} +import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, CometShuffleBlockIterator, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} @@ -553,6 +554,11 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } + // Detect ShuffleScan indices and create factories for direct read + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + val shuffleBlockIteratorFactories = + buildShuffleBlockIteratorFactories(sparkPlans, inputs, shuffleScanIndices) + // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) CometExecRDD( @@ -566,7 +572,8 @@ abstract class CometNativeExec extends CometExec { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } @@ -606,6 +613,108 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. + */ + private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + var scanIndex = 0 + val indices = mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + op.getChildrenList.asScala.foreach(walk) + } + } + walk(plan) + indices.toSet + } + + /** + * Build factory functions that produce CometShuffleBlockIterator for each input index that is a + * ShuffleScan. Maps from input index to a factory that, given TaskContext and Partition, + * creates the iterator. + */ + private def buildShuffleBlockIteratorFactories( + sparkPlans: ArrayBuffer[SparkPlan], + inputs: ArrayBuffer[RDD[ColumnarBatch]], + shuffleScanIndices: Set[Int]) + : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = { + if (shuffleScanIndices.isEmpty) return Map.empty + + val factories = mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] + + shuffleScanIndices.foreach { scanIdx => + if (scanIdx < inputs.length) { + inputs(scanIdx) match { + case rdd: CometShuffledBatchRDD => + val dep = rdd.dependency + val rddMetrics = rdd.metrics + factories(scanIdx) = (context, part) => { + val shufflePart = part.asInstanceOf[ShuffledRowRDDPartition] + val tempMetrics = + context.taskMetrics().createTempShuffleReadMetrics() + val sqlMetricsReporter = + new SQLShuffleReadMetricsReporter(tempMetrics, rddMetrics) + val reader = shufflePart.spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + } + val rawStream = reader.readAsRawStream() + new CometShuffleBlockIterator(rawStream) + } + case _ => // Not a CometShuffledBatchRDD, skip + } + } + } + factories.toMap + } + /** * Find all plan nodes with per-partition planning data in the plan tree. Returns two maps keyed * by a unique identifier: one for common data (shared across partitions) and one for diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 1cf43ea598..11f825e70d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, count, sum} import org.apache.comet.CometConf @@ -437,4 +437,19 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark + .range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } + } }