diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 4d2e37924a..bfe90181ff 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -343,6 +343,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directRead.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Applies to both native shuffle and JVM columnar shuffle. " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( diff --git a/native/Cargo.lock b/native/Cargo.lock index 5f99c614b3..5b3f7e885e 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1885,7 +1885,7 @@ dependencies = [ [[package]] name = "datafusion-comet-common" -version = "0.14.0" +version = "0.15.0" dependencies = [ "arrow", "datafusion", @@ -1911,7 +1911,7 @@ dependencies = [ [[package]] name = "datafusion-comet-jni-bridge" -version = "0.14.0" +version = "0.15.0" dependencies = [ "arrow", "assertables", diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 59ac674431..e0a395ebbf 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -82,7 +82,7 @@ use tokio::sync::mpsc; use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; -use crate::execution::operators::ScanExec; +use crate::execution::operators::{ScanExec, ShuffleScanExec}; use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; @@ -151,6 +151,8 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, + /// The shuffle scan input sources for the DataFusion plan + pub shuffle_scans: Vec, /// The global reference of input sources for the DataFusion plan pub input_sources: Vec>, /// The record batch stream to pull results from @@ -311,6 +313,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( partition_count: partition_count as usize, root_op: None, scans: vec![], + shuffle_scans: vec![], input_sources, stream: None, batch_receiver: None, @@ -491,6 +494,10 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr exec_context.scans.iter_mut().try_for_each(|scan| { scan.get_next_batch()?; Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) }) } @@ -539,7 +546,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) .with_exec_id(exec_context_id); - let (scans, root_op) = planner.create_plan( + let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), exec_context.partition_count, @@ -548,6 +555,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.plan_creation_time += physical_plan_time; exec_context.scans = scans; + exec_context.shuffle_scans = shuffle_scans; if exec_context.explain_native { let formatted_plan_str = @@ -560,7 +568,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // so we should always execute partition 0. let stream = root_op.native_plan.execute(0, task_ctx)?; - if exec_context.scans.is_empty() { + if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { // No JVM data sources — spawn onto tokio so the executor // thread parks in blocking_recv instead of busy-polling. // diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 3c3814a2b5..4b2c06575d 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -32,4 +32,6 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; +mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; +pub use shuffle_scan::ShuffleScanExec; diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 6ba1bb5d59..194fa6769a 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,8 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - operators::{ExecutionError, ScanExec}, - planner::{operator_registry::OperatorBuilder, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PhysicalPlanner, PlanCreationResult}, spark_plan::SparkPlan, }, extract_op, @@ -42,12 +41,13 @@ impl OperatorBuilder for ProjectionBuilder { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let project = extract_op!(spark_plan, Projection); let children = &spark_plan.children; assert_eq!(children.len(), 1); - let (scans, child) = planner.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + planner.create_plan(&children[0], inputs, partition_count)?; // Create projection expressions let exprs: Result, _> = project @@ -68,6 +68,7 @@ impl OperatorBuilder for ProjectionBuilder { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), )) } diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs new file mode 100644 index 0000000000..824965d489 --- /dev/null +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -0,0 +1,508 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometError, + execution::{ + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, + shuffle::codec::read_ipc_compressed, + }, + jvm_bridge::{jni_call, JVMClasses}, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::{arrow_datafusion_err, Result as DataFusionResult}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::{ + execution::TaskContext, + physical_expr::*, + physical_plan::{ExecutionPlan, *}, +}; +use futures::Stream; +use jni::objects::{GlobalRef, JByteBuffer, JObject}; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use super::scan::InputBatch; + +/// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. +/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed +/// bytes from CometShuffleBlockIterator and decodes them using read_ipc_compressed(). +#[derive(Debug, Clone)] +pub struct ShuffleScanExec { + /// The ID of the execution context that owns this subquery. + pub exec_context_id: i64, + /// The input source: a global reference to a JVM CometShuffleBlockIterator object. + pub input_source: Option>, + /// The data types of columns in the shuffle output. + pub data_types: Vec, + /// Schema of the shuffle output. + pub schema: SchemaRef, + /// The current input batch, populated by get_next_batch() before poll_next(). + pub batch: Arc>>, + /// Cache of plan properties. + cache: PlanProperties, + /// Metrics collector. + metrics: ExecutionPlanMetricsSet, + /// Baseline metrics. + baseline_metrics: BaselineMetrics, + /// Time spent decoding compressed shuffle blocks. + decode_time: Time, +} + +impl ShuffleScanExec { + pub fn new( + exec_context_id: i64, + input_source: Option>, + data_types: Vec, + ) -> Result { + let metrics_set = ExecutionPlanMetricsSet::default(); + let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); + let decode_time = MetricBuilder::new(&metrics_set).subset_time("decode_time", 0); + + let schema = schema_from_data_types(&data_types); + + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(Self { + exec_context_id, + input_source, + data_types, + batch: Arc::new(Mutex::new(None)), + cache, + metrics: metrics_set, + baseline_metrics, + schema, + decode_time, + }) + } + + /// Feeds input batch into this scan. Only used in unit tests. + pub fn set_input_batch(&mut self, input: InputBatch) { + *self.batch.try_lock().unwrap() = Some(input); + } + + /// Pull next input batch from JVM. Called externally before poll_next() + /// because JNI calls cannot happen from within poll_next on tokio threads. + pub fn get_next_batch(&mut self) -> Result<(), CometError> { + if self.input_source.is_none() { + // Unit test mode - no JNI calls needed. + return Ok(()); + } + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + + let mut current_batch = self.batch.try_lock().unwrap(); + if current_batch.is_none() { + let next_batch = Self::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + &self.data_types, + &self.decode_time, + )?; + *current_batch = Some(next_batch); + } + + timer.stop(); + + Ok(()) + } + + /// Invokes JNI calls to get the next compressed shuffle block and decode it. + fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], + decode_time: &Time, + ) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + return Ok(InputBatch::EOF); + } + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null shuffle block iterator object. Plan id: {exec_context_id}" + )))); + } + + let mut env = JVMClasses::get_env()?; + + // has_next() reads the next block and returns its length, or -1 if EOF + let block_length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length == -1 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer containing the compressed shuffle block + let buffer: JObject = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }; + + let byte_buffer = JByteBuffer::from(buffer); + let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let length = block_length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decode the compressed IPC data + let mut timer = decode_time.timer(); + let batch = read_ipc_compressed(slice)?; + timer.stop(); + + let num_rows = batch.num_rows(); + + // Extract column arrays, unpacking any dictionary-encoded columns. + // Native shuffle may dictionary-encode string/binary columns for efficiency, + // but downstream DataFusion operators expect the value types declared in the + // schema (e.g. Utf8, not Dictionary). + let columns: Vec = batch + .columns() + .iter() + .map(|col| unpack_dictionary(col)) + .collect(); + + debug_assert_eq!( + columns.len(), + data_types.len(), + "Shuffle block column count mismatch: got {} but expected {}", + columns.len(), + data_types.len() + ); + + Ok(InputBatch::new(columns, Some(num_rows))) + } +} + +/// If `array` is dictionary-encoded, cast it to the value type. Otherwise return as-is. +fn unpack_dictionary(array: &ArrayRef) -> ArrayRef { + if let DataType::Dictionary(_, value_type) = array.data_type() { + arrow::compute::cast(array, value_type.as_ref()).expect("failed to unpack dictionary array") + } else { + Arc::clone(array) + } +} + +fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef { + let fields = data_types + .iter() + .enumerate() + .map(|(idx, dt)| Field::new(format!("col_{idx}"), dt.clone(), true)) + .collect::>(); + + Arc::new(Schema::new(fields)) +} + +impl ExecutionPlan for ShuffleScanExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _: Arc, + ) -> datafusion::common::Result { + Ok(Box::pin(ShuffleScanStream::new( + self.clone(), + partition, + self.baseline_metrics.clone(), + ))) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn name(&self) -> &str { + "ShuffleScanExec" + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +impl DisplayAs for ShuffleScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let fields: Vec = self + .data_types + .iter() + .enumerate() + .map(|(idx, dt)| format!("col_{idx}: {dt}")) + .collect(); + write!(f, "ShuffleScanExec: schema=[{}]", fields.join(", "))?; + } + DisplayFormatType::TreeRender => unimplemented!(), + } + Ok(()) + } +} + +/// An async stream that feeds decoded shuffle batches into the DataFusion plan. +struct ShuffleScanStream { + /// The ShuffleScanExec producing input batches. + shuffle_scan: ShuffleScanExec, + /// Metrics. + baseline_metrics: BaselineMetrics, +} + +impl ShuffleScanStream { + pub fn new( + shuffle_scan: ShuffleScanExec, + _partition: usize, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + shuffle_scan, + baseline_metrics, + } + } +} + +impl Stream for ShuffleScanStream { + type Item = DataFusionResult; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + let mut scan_batch = self.shuffle_scan.batch.try_lock().unwrap(); + + let input_batch = &*scan_batch; + let input_batch = if let Some(batch) = input_batch { + batch + } else { + timer.stop(); + return Poll::Pending; + }; + + let result = match input_batch { + InputBatch::EOF => Poll::Ready(None), + InputBatch::Batch(columns, num_rows) => { + self.baseline_metrics.record_output(*num_rows); + let options = + arrow::array::RecordBatchOptions::new().with_row_count(Some(*num_rows)); + let maybe_batch = arrow::array::RecordBatch::try_new_with_options( + self.shuffle_scan.schema(), + columns.clone(), + &options, + ) + .map_err(|e| arrow_datafusion_err!(e)); + Poll::Ready(Some(maybe_batch)) + } + }; + + *scan_batch = None; + + timer.stop(); + + result + } +} + +impl RecordBatchStream for ShuffleScanStream { + fn schema(&self) -> SchemaRef { + self.shuffle_scan.schema() + } +} + +#[cfg(test)] +mod tests { + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + use crate::execution::shuffle::codec::read_ipc_compressed; + + #[test] + #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) + fn test_read_compressed_ipc_block() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + // Write as compressed IPC + let writer = + ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back (skip 16-byte header: 8 compressed_length + 8 field_count) + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_ipc_compressed(body).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify data + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert_eq!(col0.value(1), 2); + assert_eq!(col0.value(2), 3); + } + + /// Tests that ShuffleScanExec correctly unpacks dictionary-encoded columns. + /// Native shuffle may dictionary-encode string/binary columns, but the schema + /// declares value types (e.g. Utf8). Without unpacking, RecordBatch creation + /// fails with a schema mismatch. + #[test] + #[cfg_attr(miri, ignore)] + fn test_dictionary_encoded_shuffle_block_is_unpacked() { + use super::*; + use arrow::array::StringDictionaryBuilder; + use arrow::datatypes::Int32Type; + use datafusion::physical_plan::ExecutionPlan; + use futures::StreamExt; + + // Build a batch with a dictionary-encoded string column (simulating what + // the native shuffle writer produces for string columns). + let mut dict_builder = StringDictionaryBuilder::::new(); + dict_builder.append_value("hello"); + dict_builder.append_value("world"); + dict_builder.append_value("hello"); // repeated value, good for dictionary + let dict_array = dict_builder.finish(); + + // The IPC schema includes the dictionary type + let dict_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "name", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + let dict_batch = RecordBatch::try_new( + Arc::clone(&dict_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(dict_array), + ], + ) + .unwrap(); + + // Write as compressed IPC (preserves dictionary encoding) + let writer = + ShuffleBlockWriter::try_new(&dict_batch.schema(), CompressionCodec::Zstd(1)).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer + .write_batch(&dict_batch, &mut buf, &ipc_time) + .unwrap(); + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + // Confirm that read_ipc_compressed returns dictionary-encoded arrays + let decoded = read_ipc_compressed(body).unwrap(); + assert!( + matches!(decoded.column(1).data_type(), DataType::Dictionary(_, _)), + "Expected dictionary-encoded column from IPC, got {:?}", + decoded.column(1).data_type() + ); + + // Create ShuffleScanExec with value types (Utf8, not Dictionary) — this is + // what the protobuf schema provides. + let mut scan = ShuffleScanExec::new( + super::super::super::planner::TEST_EXEC_CONTEXT_ID, + None, + vec![DataType::Int32, DataType::Utf8], + ) + .unwrap(); + + // Feed the decoded batch through unpack_dictionary (simulating get_next) + let columns: Vec = decoded + .columns() + .iter() + .map(|col| super::unpack_dictionary(col)) + .collect(); + let input = InputBatch::new(columns, Some(decoded.num_rows())); + scan.set_input_batch(input); + + // Execute and verify the output RecordBatch has the expected schema + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let ctx = Arc::new(TaskContext::default()); + let mut stream = scan.execute(0, ctx).unwrap(); + let result_batch = stream.next().await.unwrap().unwrap(); + + // Schema should have Utf8, not Dictionary + assert_eq!( + *result_batch.schema().field(1).data_type(), + DataType::Utf8, + "Expected Utf8 after dictionary unpacking" + ); + + // Verify data integrity + let col1 = result_batch + .column(1) + .as_any() + .downcast_ref::() + .expect("Column should be StringArray after unpacking"); + assert_eq!(col1.value(0), "hello"); + assert_eq!(col1.value(1), "world"); + assert_eq!(col1.value(2), "hello"); + }); + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e730fd0c89..5af31fcc22 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -25,7 +25,7 @@ use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; use crate::execution::{ expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, @@ -138,6 +138,8 @@ use url::Url; type PhyAggResult = Result, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +pub type PlanCreationResult = + Result<(Vec, Vec, Arc), ExecutionError>; struct JoinParameters { pub left: Arc, @@ -910,7 +912,7 @@ impl PhysicalPlanner { spark_plan: &'a Operator, inputs: &mut Vec>, partition_count: usize, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { // Try to use the modular registry first - this automatically handles any registered operator types if OperatorRegistry::global().can_handle(spark_plan) { return OperatorRegistry::global().create_plan( @@ -926,7 +928,8 @@ impl PhysicalPlanner { match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Filter(filter) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; @@ -937,12 +940,14 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])), )) } OpStruct::HashAgg(agg) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let group_exprs: PhyExprResult = agg .grouping_exprs @@ -993,6 +998,7 @@ impl PhysicalPlanner { if agg.result_exprs.is_empty() { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), )) } else { @@ -1009,6 +1015,7 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, projection, @@ -1027,7 +1034,8 @@ impl PhysicalPlanner { "Invalid limit/offset combination: [{num}. {offset}]" ))); } - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let limit: Arc = if offset == 0 { Arc::new(LocalLimitExec::new( Arc::clone(&child.native_plan), @@ -1047,12 +1055,14 @@ impl PhysicalPlanner { }; Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])), )) } OpStruct::Sort(sort) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let exprs: Result, ExecutionError> = sort .sort_orders @@ -1076,6 +1086,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, sort_exec, @@ -1112,6 +1123,7 @@ impl PhysicalPlanner { if partition_files.partitioned_file.is_empty() { let empty_exec = Arc::new(EmptyExec::new(required_schema)); return Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, empty_exec, vec![])), )); @@ -1202,6 +1214,7 @@ impl PhysicalPlanner { common.encryption_enabled, )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1240,6 +1253,7 @@ impl PhysicalPlanner { &scan.csv_options.clone().unwrap(), )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1273,6 +1287,7 @@ impl PhysicalPlanner { Ok(( vec![scan.clone()], + vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), )) } @@ -1304,6 +1319,7 @@ impl PhysicalPlanner { )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new( spark_plan.plan_id, @@ -1314,7 +1330,8 @@ impl PhysicalPlanner { } OpStruct::ShuffleWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), @@ -1347,6 +1364,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, shuffle_writer, @@ -1356,7 +1374,8 @@ impl PhysicalPlanner { } OpStruct::ParquetWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let codec = match writer.compression.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), @@ -1393,6 +1412,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, parquet_writer, @@ -1402,7 +1422,8 @@ impl PhysicalPlanner { } OpStruct::Expand(expand) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let mut projections = vec![]; let mut projection = vec![]; @@ -1445,12 +1466,14 @@ impl PhysicalPlanner { let expand = Arc::new(ExpandExec::new(projections, input, schema)); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])), )) } OpStruct::Explode(explode) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; // Create the expression for the array to explode let child_expr = if let Some(child_expr) = &explode.child { @@ -1556,11 +1579,12 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, unnest_exec, vec![child])), )) } OpStruct::SortMergeJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1612,6 +1636,7 @@ impl PhysicalPlanner { )); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, coalesce_batches, @@ -1625,6 +1650,7 @@ impl PhysicalPlanner { } else { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, join, @@ -1637,7 +1663,7 @@ impl PhysicalPlanner { } } OpStruct::HashJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1667,6 +1693,7 @@ impl PhysicalPlanner { if join.build_side == BuildSide::BuildLeft as i32 { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, hash_join, @@ -1685,6 +1712,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, swapped_hash_join, @@ -1695,7 +1723,8 @@ impl PhysicalPlanner { } } OpStruct::Window(wnd) => { - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let input_schema = child.schema(); let sort_exprs: Result, ExecutionError> = wnd .order_by_list @@ -1733,9 +1762,37 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError("No input for shuffle scan".to_string())); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = + ShuffleScanExec::new(self.exec_context_id, input_source, data_types)?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new( + spark_plan.plan_id, + Arc::new(shuffle_scan), + vec![], + )), + )) + } _ => Err(GeneralError(format!( "Unsupported or unregistered operator type: {:?}", spark_plan.op_struct @@ -1753,12 +1810,15 @@ impl PhysicalPlanner { join_type: i32, condition: &Option, partition_count: usize, - ) -> Result<(JoinParameters, Vec), ExecutionError> { + ) -> Result<(JoinParameters, Vec, Vec), ExecutionError> { assert_eq!(children.len(), 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs, partition_count)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs, partition_count)?; + let (mut left_scans, mut left_shuffle_scans, left) = + self.create_plan(&children[0], inputs, partition_count)?; + let (mut right_scans, mut right_shuffle_scans, right) = + self.create_plan(&children[1], inputs, partition_count)?; left_scans.append(&mut right_scans); + left_shuffle_scans.append(&mut right_shuffle_scans); let left_join_exprs: Vec<_> = left_join_keys .iter() @@ -1879,6 +1939,7 @@ impl PhysicalPlanner { join_filter, }, left_scans, + left_shuffle_scans, )) } @@ -3649,7 +3710,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3723,7 +3785,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3770,7 +3833,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3855,7 +3919,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3879,7 +3944,8 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = + planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); @@ -3993,7 +4059,7 @@ mod tests { })), }; - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); @@ -4119,7 +4185,7 @@ mod tests { }; // Create a physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Start executing the plan in a separate thread @@ -4610,7 +4676,7 @@ mod tests { }; // Create the physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Create test data: Date32 and Int8 columns diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index b34a80df95..cad5df40c5 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -25,11 +25,8 @@ use std::{ use datafusion_comet_proto::spark_operator::Operator; use jni::objects::GlobalRef; -use super::PhysicalPlanner; -use crate::execution::{ - operators::{ExecutionError, ScanExec}, - spark_plan::SparkPlan, -}; +use super::{PhysicalPlanner, PlanCreationResult}; +use crate::execution::operators::ExecutionError; /// Trait for building physical operators from Spark protobuf operators pub trait OperatorBuilder: Send + Sync { @@ -40,7 +37,7 @@ pub trait OperatorBuilder: Send + Sync { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError>; + ) -> PlanCreationResult; } /// Enum to identify different operator types for registry dispatch @@ -100,7 +97,7 @@ impl OperatorRegistry { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let operator_type = get_operator_type(spark_operator).ok_or_else(|| { ExecutionError::GeneralError(format!( "Unsupported operator type: {:?}", @@ -153,5 +150,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option { OpStruct::Window(_) => Some(OperatorType::Window), OpStruct::Explode(_) => None, // Not yet in OperatorType enum OpStruct::CsvScan(_) => Some(OperatorType::CsvScan), + OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum } } diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 456fbdf688..a2e25c3e2f 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -181,10 +181,12 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod shuffle_block_iterator; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. #[allow(dead_code)] // we need to keep references to Java items to prevent GC @@ -210,6 +212,8 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. + pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -263,6 +267,7 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/jni-bridge/src/shuffle_block_iterator.rs b/native/jni-bridge/src/shuffle_block_iterator.rs new file mode 100644 index 0000000000..c3bb5af5fb --- /dev/null +++ b/native/jni-bridge/src/shuffle_block_iterator.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// A struct that holds all the JNI methods and fields for JVM `CometShuffleBlockIterator` class. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..344b9f0f21 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -52,6 +52,7 @@ message Operator { ParquetWriter parquet_writer = 113; Explode explode = 114; CsvScan csv_scan = 115; + ShuffleScan shuffle_scan = 116; } } @@ -85,6 +86,12 @@ message Scan { bool arrow_ffi_safe = 3; } +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} + // Common data shared by all partitions in split mode (sent once at planning) message NativeScanCommon { repeated SparkStructField required_schema = 1; diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java new file mode 100644 index 0000000000..9f72b20f51 --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads the + * compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() and + * getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_ipc_compressed which allocates new memory for the + * decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. Called by + * native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + 8-byte + * fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header: clear() resets position=0, limit=capacity, + // preparing the buffer for channel.read() to fill it + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + close(); + return -1; + } + throw new EOFException("Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + // Subtract 8 because compressedLength includes the 8-byte field count we already read + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + + bytesToRead + + " exceeds maximum of " + + Integer.MAX_VALUE + + ". Try reducing spark.comet.columnar.shuffle.batch.size."); + } + + currentBlockLength = (int) bytesToRead; + + if (dataBuf.capacity() < currentBlockLength) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit(currentBlockLength); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes (4-byte codec + * prefix + compressed IPC data). Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** Returns the length of the current block in bytes. Called by native code via JNI. */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 44ebf7e36e..e198ac99ff 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -67,7 +67,8 @@ class CometExecIterator( numParts: Int, partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) extends Iterator[ColumnarBatch] with Logging { @@ -78,8 +79,13 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - private val cometBatchIterators = inputs.map { iterator => - new CometBatchIterator(iterator, nativeUtil) + // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle + // scan indices, CometBatchIterator for regular scan indices. + private val inputIterators: Array[Object] = inputs.zipWithIndex.map { + case (_, idx) if shuffleBlockIterators.contains(idx) => + shuffleBlockIterators(idx).asInstanceOf[Object] + case (iterator, _) => + new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] }.toArray private val plan = { @@ -106,7 +112,7 @@ class CometExecIterator( nativeLib.createPlan( id, - cometBatchIterators, + inputIterators, protobufQueryPlan, protobufSparkConfigs, numParts, @@ -229,6 +235,7 @@ class CometExecIterator( currentBatch = null } nativeUtil.close() + shuffleBlockIterators.values.foreach(_.close()) nativeLib.releasePlan(plan) if (tracingEnabled) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..f6800626d6 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -54,7 +54,7 @@ class Native extends NativeBase { // scalastyle:off @native def createPlan( id: Long, - iterators: Array[CometBatchIterator], + iterators: Array[Object], plan: Array[Byte], configMapProto: Array[Byte], partitionCount: Int, diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index ca9dbdad7c..71faab2a4d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -22,8 +22,12 @@ package org.apache.comet.serde.operator import scala.jdk.CollectionConverters._ import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.ConfigEntry import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} @@ -86,15 +90,67 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { object CometExchangeSink extends CometSink[SparkPlan] { - /** - * Exchange data is FFI safe because there is no use of mutable buffers involved. - * - * Source of broadcast exchange batches is ArrowStreamReader. - * - * Source of shuffle exchange batches is NativeBatchDecoderIterator. - */ override def isFfiSafe: Boolean = true + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.isDefined + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, s"unsupported data types in ${op.nodeName} for shuffle direct read") + None + } + } + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = CometSinkPlaceHolder(nativeOp, op, op) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index ad0c4f2afe..c5014818c4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -22,6 +22,7 @@ package org.apache.spark.sql.comet import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.comet.execution.shuffle.CometShuffledBatchRDD import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration @@ -64,7 +65,8 @@ private[spark] class CometExecRDD( nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleScanIndices: Set[Int] = Set.empty) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -109,6 +111,15 @@ private[spark] class CometExecRDD( serializedPlan } + // Create shuffle block iterators for inputs that are CometShuffledBatchRDD + val shuffleBlockIters = shuffleScanIndices.flatMap { idx => + inputRDDs(idx) match { + case rdd: CometShuffledBatchRDD => + Some(idx -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(idx), context)) + case _ => None + } + }.toMap + val it = new CometExecIterator( CometExec.newIterId, inputs, @@ -118,7 +129,8 @@ private[spark] class CometExecRDD( numPartitions, partition.index, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIters) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -167,7 +179,8 @@ object CometExecRDD { nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = { + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleScanIndices: Set[Int] = Set.empty): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -181,6 +194,7 @@ object CometExecRDD { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleScanIndices) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index e95eb92d21..14e656f038 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -153,6 +153,18 @@ class CometBlockStoreShuffleReader[K, C]( } } + /** + * Returns the raw concatenated InputStream of all shuffle blocks, bypassing the decode step. + * Used by ShuffleScan direct read path. + */ + def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2) + new java.io.SequenceInputStream(new java.util.Enumeration[InputStream] { + override def hasMoreElements: Boolean = streams.hasNext + override def nextElement(): InputStream = streams.next() + }) + } + private def fetchContinuousBlocksInBatch: Boolean = { val conf = SparkEnv.get.conf val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index ba6fc588e2..7604910b06 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.comet.CometShuffleBlockIterator + /** * Different from [[org.apache.spark.sql.execution.ShuffledRowRDD]], this RDD is specialized for * reading shuffled data through [[CometBlockStoreShuffleReader]]. The shuffled data is read in an @@ -34,7 +36,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ class CometShuffledBatchRDD( var dependency: ShuffleDependency[Int, _, _], - metrics: Map[String, SQLMetric], + val metrics: Map[String, SQLMetric], partitionSpecs: Array[ShufflePartitionSpec]) extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { @@ -90,12 +92,14 @@ class CometShuffledBatchRDD( } } - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + private def createReader( + split: Partition, + context: TaskContext): CometBlockStoreShuffleReader[_, _] = { val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) - val reader = split.asInstanceOf[ShuffledRowRDDPartition].spec match { + split.asInstanceOf[ShuffledRowRDDPartition].spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => SparkEnv.get.shuffleManager .getReader( @@ -142,7 +146,21 @@ class CometShuffledBatchRDD( sqlMetricsReporter) .asInstanceOf[CometBlockStoreShuffleReader[_, _]] } + } + /** + * Creates a CometShuffleBlockIterator that provides raw compressed shuffle blocks for direct + * consumption by native code, bypassing Arrow FFI. + */ + def computeAsShuffleBlockIterator( + split: Partition, + context: TaskContext): CometShuffleBlockIterator = { + val reader = createReader(split, context) + new CometShuffleBlockIterator(reader.readAsRawStream()) + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val reader = createReader(split, context) // TODO: Reads IPC by native code reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..2965e46988 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -553,6 +553,9 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } + // Detect ShuffleScan indices for direct read in CometExecRDD + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) CometExecRDD( @@ -566,7 +569,8 @@ abstract class CometNativeExec extends CometExec { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleScanIndices) } } @@ -606,6 +610,28 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. + */ + private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + var scanIndex = 0 + val indices = mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + op.getChildrenList.asScala.foreach(walk) + } + } + walk(plan) + indices.toSet + } + /** * Find all plan nodes with per-partition planning data in the plan tree. Returns two maps keyed * by a unique identifier: one for common data (shared across partitions) and one for diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 1cf43ea598..efb5fbca8a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, count, sum} import org.apache.comet.CometConf @@ -437,4 +437,41 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark + .range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } + } + + test("shuffle direct read with multiple shuffles in plan") { + Seq(true, false).foreach { directRead => + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + // Join two shuffled datasets to produce a plan with multiple shuffle reads + val left = spark + .range(100) + .selectExpr("id as l_id", "id % 10 as key") + .repartition(4, col("key")) + val right = spark + .range(100) + .selectExpr("id as r_id", "id % 10 as key") + .repartition(4, col("key")) + val df = left + .join(right, "key") + .groupBy("key") + .agg(count("l_id").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } + } }