From 827a8ca063d284b254e6c9d1120e6c2a2dd9b70c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:14:25 -0600 Subject: [PATCH 01/33] docs: add design spec for shuffle direct read optimization Adds a design document for bypassing Arrow FFI in the shuffle read path when both the shuffle writer and downstream operator are native. --- .../2026-03-18-shuffle-direct-read-design.md | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md diff --git a/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md b/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md new file mode 100644 index 0000000000..2f002a2d89 --- /dev/null +++ b/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md @@ -0,0 +1,163 @@ +# Shuffle Direct Read: Bypass FFI for Native Shuffle Read Path + +## Problem + +When a native shuffle exchange feeds into a downstream native operator, shuffle data crosses the JVM/native FFI boundary twice: + +1. **Native to JVM**: `decodeShuffleBlock` JNI call decompresses Arrow IPC, creates a `RecordBatch`, and exports it via Arrow C Data Interface (per-column `FFI_ArrowArray` + `FFI_ArrowSchema` allocation, export, and import). +2. **JVM to Native**: `CometBatchIterator` re-exports the `ColumnarBatch` via Arrow C Data Interface back to native, where `ScanExec` imports and copies/unpacks the arrays. + +Each crossing involves per-column schema serialization, struct allocation, and array copying. For queries with many shuffle stages or wide schemas, this overhead is significant. + +## Solution + +Introduce a direct read path where native code consumes compressed shuffle blocks directly, bypassing Arrow FFI entirely. The JVM reads raw bytes from Spark's shuffle infrastructure and hands them to native via a `DirectByteBuffer` (zero-copy pointer access). Native decompresses and decodes in-place, feeding `RecordBatch` directly into the execution plan. + +### Data Flow Comparison + +**Current path (double FFI):** + +``` +Shuffle stream + -> NativeBatchDecoderIterator (JVM) + -> JNI: decodeShuffleBlock + -> FFI export: RecordBatch -> ArrowArray/Schema (native -> JVM) + -> ColumnarBatch on JVM + -> CometBatchIterator + -> FFI export: ColumnarBatch -> ArrowArray/Schema (JVM -> native) + -> ScanExec imports + copies arrays + -> Native operators +``` + +**New path (zero FFI):** + +``` +Shuffle stream + -> CometShuffleBlockIterator (JVM) + -> reads header + compressed body into DirectByteBuffer + -> holds bytes, waits for native pull + +ShuffleScanExec (native, pull-based) + -> JNI callback: iterator.hasNext()/getBuffer() + -> read_ipc_compressed() -> RecordBatch + -> feeds directly into native execution plan +``` + +## Scope + +- Native shuffle (`CometNativeShuffle`) only. JVM columnar shuffle is excluded because its per-batch dictionary encoding decisions can change the schema between batches. +- Both paths (old and new) are retained. A config flag controls which is used. + +## Components + +### New JVM Components + +#### `CometShuffleBlockIterator` (Java) + +A new class that wraps a shuffle `InputStream` and exposes raw compressed blocks for native consumption. Absorbs the header-reading and buffer-management logic from `NativeBatchDecoderIterator`, but does not decode. + +JNI-callable interface: + +- `hasNext() -> int`: Reads the next block's header from the stream. The header is 16 bytes: 8-byte compressed length (includes the 8-byte field count but not itself) + 8-byte field count. The field count from the header is discarded — the schema is determined by the `ShuffleScan` protobuf's `fields` list, which is authoritative. Returns the compressed body length in bytes (i.e., `compressedLength - 8`, which includes the 4-byte codec prefix + compressed IPC data), or -1 for EOF. +- `getBuffer() -> ByteBuffer`: Returns the `DirectByteBuffer` containing the current block's compressed bytes (4-byte codec prefix + compressed IPC data). This buffer is only valid until the next `hasNext()` call — the caller must fully consume it (via `read_ipc_compressed()`, which decompresses into a new allocation) before pulling the next block. + +Uses its own `DirectByteBuffer` instance (not shared with `NativeBatchDecoderIterator`) with the same pooling strategy: initial 128KB, grows as needed, reset on close. + +**Lifecycle**: Implements `Closeable`. `close()` closes the underlying shuffle `InputStream` and resets the buffer. `CometBlockStoreShuffleReader` registers a task completion listener to close it, matching the existing pattern for `NativeBatchDecoderIterator`. + +### New Native Components + +#### `ShuffleScanExec` (Rust) + +Location: `native/core/src/execution/operators/shuffle_scan.rs` + +A new `ExecutionPlan` operator that replaces `ScanExec` at shuffle boundaries. On each `poll_next`: + +1. Calls JNI into `CometShuffleBlockIterator.hasNext()` to get the next block's byte length (or -1 for EOF). +2. Calls `CometShuffleBlockIterator.getBuffer()` to get a `DirectByteBuffer`. +3. Obtains the buffer's raw pointer via `JNIEnv::get_direct_buffer_address()` and creates a slice over it (zero-copy, same pattern as `decodeShuffleBlock`). +4. Calls `read_ipc_compressed()` to decompress and decode into a `RecordBatch`. This allocates new memory for the decompressed data — the `DirectByteBuffer` can be safely reused afterward. +5. Returns the `RecordBatch` directly to the downstream native operator. + +No `FFI_ArrowArray`, `FFI_ArrowSchema`, `ArrowImporter`, or `CometVector` involved. + +Implements `on_close` for cleanup (releasing the JNI `GlobalRef`), matching the `ScanExec` pattern. + +#### `ShuffleScan` Protobuf Message + +Location: `native/proto/src/proto/operator.proto` + +New message alongside existing `Scan`: + +```protobuf +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + string source = 2; // Informational label (e.g., "CometShuffleExchangeExec [id=5]") +} +``` + +The `Operator` message gains a new `shuffle_scan` field in its oneof. + +### Modified JVM Components + +#### `CometExchangeSink` / `CometExecRule` + +The decision to use `ShuffleScan` vs `Scan` is made when `CometNativeExec` is constructed (not during the bottom-up conversion pass). At that point, the operator tree is already converted: `CometExecRule.convertBlock()` wraps a contiguous group of native operators into `CometNativeExec` and serializes the protobuf plan. The children (including `CometSinkPlaceHolder` wrapping shuffle exchanges) are already known. So the check is: when serializing a `CometSinkPlaceHolder` whose `originalPlan` is a `CometShuffleExchangeExec` with `shuffleType == CometNativeShuffle`, and the config flag is enabled, emit `ShuffleScan` instead of `Scan`. + +Conditions for `ShuffleScan`: + +1. Shuffle type is `CometNativeShuffle` +2. The sink is inside a `CometNativeExec` block (always true at serialization time — this is where sinks get serialized) +3. Config `spark.comet.shuffle.directRead.enabled` is true (default: true) + +#### `CometNativeExec` (operators.scala) + +When collecting input RDDs and creating iterators, distinguish the two cases: + +- `ShuffleScan` input: Wrap the shuffle RDD's `Iterator[ColumnarBatch]` stream in `CometShuffleBlockIterator` — but note that `CometShuffleBlockIterator` wraps the raw `InputStream` from shuffle blocks, not decoded `ColumnarBatch`. This means the RDD must provide the raw shuffle `InputStream` rather than going through `NativeBatchDecoderIterator`. The `CometShuffledBatchRDD` / `CometBlockStoreShuffleReader` needs a mode where it yields raw `InputStream` objects per block instead of decoded batches. +- `Scan` input: Wrap in `CometBatchIterator` (existing behavior) + +#### `CometExecIterator` — JNI Input Contract + +Currently `CometExecIterator` wraps all inputs as `CometBatchIterator` and passes them to `Native.createPlan()` as `Array[CometBatchIterator]`. To support `CometShuffleBlockIterator`: + +- Change the JNI parameter from `Array[CometBatchIterator]` to `Array[Object]`. On the native side in `createPlan`, the planner already knows from the protobuf whether each input is a `Scan` or `ShuffleScan`, so it knows which JNI methods to call on each `GlobalRef` — no type checking needed at runtime. +- `CometExecIterator` populates the array with either `CometBatchIterator` or `CometShuffleBlockIterator` based on whether the corresponding leaf in the protobuf plan is `Scan` or `ShuffleScan`. + +### Native Planner Changes + +In `planner.rs`, handle the `ShuffleScan` protobuf variant: + +- Consume an input from `inputs.remove(0)` (same pattern as `Scan`) +- Create `ShuffleScanExec` instead of `ScanExec` +- The `GlobalRef` points to a `CometShuffleBlockIterator` Java object + +## Fallback Behavior + +The new path is used only when all conditions above are met. Otherwise, the existing path is used unchanged. The most common fallback case is a shuffle whose output is consumed by a non-native Spark operator (e.g., `collect()`, or an unsupported operator), where the JVM needs a materialized `ColumnarBatch`. + +## Configuration + +| Config | Default | Description | +|--------|---------|-------------| +| `spark.comet.shuffle.directRead.enabled` | `true` | Use direct native read path for native shuffle when downstream operator is native | + +## Error Handling + +- `ShuffleScanExec` reuses `read_ipc_compressed()`, which handles corrupt data and unsupported codecs. +- JNI errors from `CometShuffleBlockIterator` (stream closed, EOF, I/O errors) propagate through the existing `try_unwrap_or_throw` pattern. +- If the JVM iterator throws, the exception surfaces as a Rust error and propagates through DataFusion's error handling. +- Empty batches (zero rows): `read_ipc_compressed()` calls `reader.next().unwrap()` which panics if the stream contains no batches. The shuffle writer never writes zero-row blocks (guarded by `if batch.num_rows() == 0 { return Ok(0) }` in `ShuffleBlockWriter.write_batch`), so this case does not arise. + +## Metrics + +`ShuffleScanExec` tracks and reports: + +- `decodeTime`: Time spent in `read_ipc_compressed()` (decompression + IPC decode). Same metric as `NativeBatchDecoderIterator` reports today. +- Shuffle read metrics (`recordsRead`, `bytesRead`) continue to be reported by `CometBlockStoreShuffleReader` and the `ShuffleBlockFetcherIterator`, which are upstream of the new code and unchanged. + +## Testing + +- Existing shuffle tests (`CometShuffleSuite`) run with the config defaulting to true, automatically covering the new path. +- Add a test that runs the same queries with the config flag on and off, asserting identical results. +- Add a Rust unit test for `ShuffleScanExec` with pre-built compressed IPC blocks (no JNI), using the `TEST_EXEC_CONTEXT_ID` pattern from `ScanExec` tests. From 3a3edb48a0feded6af9c06648a8bfce9aee24f77 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:27:46 -0600 Subject: [PATCH 02/33] docs: add implementation plan for shuffle direct read --- .../plans/2026-03-18-shuffle-direct-read.md | 1011 +++++++++++++++++ 1 file changed, 1011 insertions(+) create mode 100644 docs/superpowers/plans/2026-03-18-shuffle-direct-read.md diff --git a/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md b/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md new file mode 100644 index 0000000000..647f122cc4 --- /dev/null +++ b/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md @@ -0,0 +1,1011 @@ +# Shuffle Direct Read Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Eliminate double Arrow FFI crossing at shuffle boundaries by having native code consume compressed IPC blocks directly from JVM-provided byte buffers. + +**Architecture:** A new `ShuffleScanExec` Rust operator pulls raw compressed bytes from a JVM `CometShuffleBlockIterator` via JNI, decompresses and decodes them in native code, and feeds `RecordBatch` directly into the execution plan. This bypasses the current path where data is decoded to JVM `ColumnarBatch` (FFI export), then re-exported back to native (FFI import). + +**Tech Stack:** Scala, Java, Rust, Protobuf, JNI, Arrow IPC + +**Spec:** `docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md` + +--- + +### Task 1: Add config flag + +**Files:** +- Modify: `common/src/main/scala/org/apache/comet/CometConf.scala` + +- [ ] **Step 1: Add the config entry** + +Find the existing shuffle config entries (search for `COMET_EXEC_SHUFFLE_ENABLED`) and add nearby: + +```scala +val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.shuffle.directRead.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Only applies to native shuffle (not JVM columnar shuffle). " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `./mvnw compile -DskipTests -pl common` +Expected: BUILD SUCCESS + +- [ ] **Step 3: Commit** + +```bash +git add common/src/main/scala/org/apache/comet/CometConf.scala +git commit -m "feat: add spark.comet.shuffle.directRead.enabled config" +``` + +--- + +### Task 2: Add ShuffleScan protobuf message + +**Files:** +- Modify: `native/proto/src/proto/operator.proto` + +- [ ] **Step 1: Add ShuffleScan message** + +Add after the existing `Scan` message (after line 86): + +```protobuf +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} +``` + +- [ ] **Step 2: Add shuffle_scan to the Operator oneof** + +In the `oneof op_struct` block (lines 38-55), add after `csv_scan = 115`: + +```protobuf + ShuffleScan shuffle_scan = 116; +``` + +- [ ] **Step 3: Rebuild protobuf and verify** + +Run: `make core` +Expected: Successful build with generated protobuf code. + +- [ ] **Step 4: Commit** + +```bash +git add native/proto/src/proto/operator.proto +git commit -m "feat: add ShuffleScan protobuf message" +``` + +--- + +### Task 3: Create CometShuffleBlockIterator (Java) + +**Files:** +- Create: `spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java` + +- [ ] **Step 1: Create the class** + +```java +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads + * the compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() + * and getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_ipc_compressed which allocates new memory for + * the decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = + ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. + * Called by native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + + * 8-byte fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + return -1; + } + throw new EOFException( + "Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + bytesToRead + " exceeds maximum of " + + Integer.MAX_VALUE + ". Try reducing shuffle batch size."); + } + + if (dataBuf.capacity() < bytesToRead) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit((int) bytesToRead); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException( + "Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + currentBlockLength = (int) bytesToRead; + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes + * (4-byte codec prefix + compressed IPC data). + * Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** + * Returns the length of the current block in bytes. + * Called by native code via JNI. + */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { + dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + } + } + } +} +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `./mvnw compile -DskipTests` +Expected: BUILD SUCCESS + +- [ ] **Step 3: Commit** + +```bash +git add spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java +git commit -m "feat: add CometShuffleBlockIterator for raw shuffle block access" +``` + +--- + +### Task 4: Add JNI bridge for CometShuffleBlockIterator (Rust) + +**Files:** +- Create: `native/core/src/jvm_bridge/shuffle_block_iterator.rs` +- Modify: `native/core/src/jvm_bridge/mod.rs` + +- [ ] **Step 1: Create the JNI bridge struct** + +Create `native/core/src/jvm_bridge/shuffle_block_iterator.rs`: + +```rust +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// JNI method IDs for `CometShuffleBlockIterator`. +#[allow(dead_code)] +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} +``` + +- [ ] **Step 2: Register in mod.rs** + +In `native/core/src/jvm_bridge/mod.rs`: + +Add `mod shuffle_block_iterator;` alongside the existing `mod batch_iterator;` (line 174). + +Add `use shuffle_block_iterator::CometShuffleBlockIterator as CometShuffleBlockIteratorBridge;` (to avoid name collision with the operator). + +Add a field to the `JVMClasses` struct (around line 206): +```rust +pub comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge<'a>, +``` + +Initialize it in `JVMClasses::init` alongside the existing `comet_batch_iterator` init (around line 259): +```rust +comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge::new(env).unwrap(), +``` + +- [ ] **Step 3: Add a `jni_call!` compatible accessor** + +Check how `comet_batch_iterator` is called in `scan.rs`. The `jni_call!` macro uses the field name from `JVMClasses`. Ensure `comet_shuffle_block_iterator` follows the same pattern. You may need to add a module in the `jni_bridge` macros — look at how `jni_call!(&mut env, comet_batch_iterator(iter).has_next() -> i32)` is defined and add equivalent patterns for `comet_shuffle_block_iterator`. + +Check `native/core/src/jvm_bridge/` for macro definitions (likely in a separate file or in `mod.rs`) that define the `jni_call!` dispatch for each class. + +- [ ] **Step 4: Verify it compiles** + +Run: `cd native && cargo build` +Expected: Successful build. + +- [ ] **Step 5: Commit** + +```bash +git add native/core/src/jvm_bridge/shuffle_block_iterator.rs +git add native/core/src/jvm_bridge/mod.rs +git commit -m "feat: add JNI bridge for CometShuffleBlockIterator" +``` + +--- + +### Task 5: Create ShuffleScanExec (Rust) + +**Files:** +- Create: `native/core/src/execution/operators/shuffle_scan.rs` +- Modify: `native/core/src/execution/operators/mod.rs` + +**Design decision — pre-pull pattern:** `ShuffleScanExec` MUST use the pre-pull pattern (same as `ScanExec`). The comment at `jni_api.rs:483-488` explains why: JNI calls cannot happen from within `poll_next` on tokio threads. So `ShuffleScanExec` stores a `batch: Arc>>` and `get_next_batch()` is called from `pull_input_batches` before each `poll_next`. + +- [ ] **Step 1: Create shuffle_scan.rs** + +Use `scan.rs` as the template. The key differences: +- `get_next_batch` calls `hasNext()`/`getBuffer()`/`getCurrentBlockLength()` on `CometShuffleBlockIterator` instead of Arrow FFI methods on `CometBatchIterator` +- After getting the `DirectByteBuffer`, call `read_ipc_compressed()` to decode +- No `arrow_ffi_safe` flag, no selection vectors, no `copy_or_unpack_array` +- Track `decode_time` metric + +The core `get_next` method: + +```rust +fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], +) -> Result { + let mut env = JVMClasses::get_env()?; + + // Call hasNext() — returns block length or -1 for EOF + let block_length: i32 = unsafe { + jni_call!(&mut env, comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length < 0 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer + let buffer: JByteBuffer = unsafe { + jni_call!(&mut env, comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }.into(); + + // Get raw pointer to the buffer data + let raw_pointer = env.get_direct_buffer_address(&buffer)?; + let length = block_length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decompress and decode the IPC block + let batch = read_ipc_compressed(slice)?; + + // Convert RecordBatch columns to InputBatch + let arrays: Vec = batch.columns().to_vec(); + let num_rows = batch.num_rows(); + + Ok(InputBatch::new(arrays, Some(num_rows))) +} +``` + +For the `ExecutionPlan` trait implementation, follow `ScanExec` closely: +- `schema()` returns schema built from `data_types` +- `execute()` returns a `ScanStream` (reuse the same stream type from `scan.rs`) +- The `ScanStream` checks `self.batch` mutex on each `poll_next`, takes the batch if available + +- [ ] **Step 2: Register the module** + +In `native/core/src/execution/operators/mod.rs`, add: + +```rust +mod shuffle_scan; +pub use shuffle_scan::ShuffleScanExec; +``` + +- [ ] **Step 3: Verify it compiles** + +Run: `cd native && cargo build` +Expected: Successful build. + +- [ ] **Step 4: Commit** + +```bash +git add native/core/src/execution/operators/shuffle_scan.rs +git add native/core/src/execution/operators/mod.rs +git commit -m "feat: add ShuffleScanExec native operator for direct shuffle read" +``` + +--- + +### Task 6: Wire ShuffleScanExec into the native planner and pre-pull + +**Files:** +- Modify: `native/core/src/execution/planner.rs` +- Modify: `native/core/src/execution/jni_api.rs` + +**Design decision — separate scan vectors:** The planner's `create_plan` currently returns `(Vec, Arc)`. Change the return type to include shuffle scans: `(Vec, Vec, Arc)`. All intermediate operators pass both vectors through. `ExecutionContext` gets a new `shuffle_scans: Vec` field, and `pull_input_batches` iterates both. + +- [ ] **Step 1: Update create_plan return type** + +In `planner.rs`, change the `create_plan` return type (line 915): + +```rust +) -> Result<(Vec, Vec, Arc), ExecutionError> +``` + +Update every match arm that calls `create_plan` recursively or returns results: +- Single-child operators (Filter, Project, Sort, etc.): destructure as `let (scans, shuffle_scans, child) = ...` and pass both through +- Multi-child operators (joins via `parse_join_parameters`): concatenate both scan vectors from left and right children +- `Scan` arm: returns `(vec![scan.clone()], vec![], ...)` +- Add `ShuffleScan` arm (see step 2) + +This is a mechanical change across many match arms. Each `Ok((scans, ...))` becomes `Ok((scans, shuffle_scans, ...))`. + +Also update `parse_join_parameters` return type similarly. + +- [ ] **Step 2: Add ShuffleScan match arm** + +```rust +OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError("No input for shuffle scan".to_string())); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = ShuffleScanExec::new( + self.exec_context_id, + input_source, + &scan.source, + data_types, + )?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(shuffle_scan), vec![])), + )) +} +``` + +- [ ] **Step 3: Update ExecutionContext and pull_input_batches** + +In `jni_api.rs`: + +Add `shuffle_scans: Vec` field to `ExecutionContext` struct (after `scans` on line 153). Initialize as `shuffle_scans: vec![]` in the constructor (line 313). + +Where `create_plan` results are stored (line 542-550): + +```rust +let (scans, shuffle_scans, root_op) = planner.create_plan(...)?; +exec_context.scans = scans; +exec_context.shuffle_scans = shuffle_scans; +``` + +Update `pull_input_batches` (line 490): + +```rust +fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometError> { + exec_context.scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) + }) +} +``` + +Also update the `exec_context.scans.is_empty()` check (line 563) to also check `shuffle_scans`: + +```rust +if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { +``` + +- [ ] **Step 4: Verify it compiles** + +Run: `cd native && cargo build` +Expected: Successful build. + +- [ ] **Step 5: Commit** + +```bash +git add native/core/src/execution/planner.rs +git add native/core/src/execution/jni_api.rs +git commit -m "feat: wire ShuffleScanExec into planner and pre-pull mechanism" +``` + +--- + +### Task 7: Emit ShuffleScan from JVM serde + +**Files:** +- Modify: `spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala` + +The `CometExchangeSink.convert()` receives the outer operator (e.g., `ShuffleQueryStageExec`) not the inner `CometShuffleExchangeExec`. We must unwrap to check `shuffleType`. + +- [ ] **Step 1: Override convert in CometExchangeSink** + +Replace the `CometExchangeSink` object (lines 87-100) with: + +```scala +object CometExchangeSink extends CometSink[SparkPlan] { + + override def isFfiSafe: Boolean = true + + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.exists(_.shuffleType == CometNativeShuffle) + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, "unsupported data types for shuffle direct read") + // Fall back to regular Scan + None + } + } + + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = + CometSinkPlaceHolder(nativeOp, op, op) +} +``` + +Add necessary imports at the top of the file: +```scala +import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.comet.CometConf +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `./mvnw compile -DskipTests` +Expected: BUILD SUCCESS + +- [ ] **Step 3: Commit** + +```bash +git add spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +git commit -m "feat: emit ShuffleScan protobuf for native shuffle with direct read" +``` + +--- + +### Task 8: Wire CometShuffleBlockIterator into JVM execution path + +**Files:** +- Modify: `spark/src/main/scala/org/apache/comet/Native.scala` +- Modify: `spark/src/main/scala/org/apache/comet/CometExecIterator.scala` +- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala` +- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala` +- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala` + +This task connects the JVM plumbing so that `ShuffleScan` inputs get `CometShuffleBlockIterator` (wrapping raw `InputStream`) instead of `CometBatchIterator` (wrapping decoded `ColumnarBatch`). + +**Key insight**: Currently all inputs flow through `RDD[ColumnarBatch]`. For shuffle direct read, we need the raw `InputStream` before decoding. The approach: add a parallel input channel for raw shuffle streams alongside the existing `ColumnarBatch` inputs. + +- [ ] **Step 1: Change Native.scala createPlan signature** + +In `spark/src/main/scala/org/apache/comet/Native.scala` (line 57), change: + +```scala +iterators: Array[CometBatchIterator], +``` +to: +```scala +iterators: Array[Object], +``` + +The JNI side (`jni_api.rs:190`) already uses `JObjectArray`, so no Rust changes needed. + +- [ ] **Step 2: Add shuffle stream inputs to CometExecIterator** + +In `spark/src/main/scala/org/apache/comet/CometExecIterator.scala`, add a parameter for shuffle block iterators that should be used instead of regular batch iterators at specific input positions: + +```scala +class CometExecIterator( + val id: Long, + inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, + protobufQueryPlan: Array[Byte], + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIndex: Int, + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) +``` + +Replace the `cometBatchIterators` construction (lines 81-83): + +```scala +private val nativeIterators: Array[Object] = { + val result = new Array[Object](inputs.size) + inputs.zipWithIndex.foreach { case (iterator, idx) => + result(idx) = shuffleBlockIterators.getOrElse( + idx, + new CometBatchIterator(iterator, nativeUtil)) + } + result +} +``` + +Change `nativeLib.createPlan(id, cometBatchIterators, ...)` (line 109) to use `nativeIterators`. + +In the `close()` method, also close `CometShuffleBlockIterator` instances: +```scala +shuffleBlockIterators.values.foreach { iter => + try { iter.close() } catch { case _: Exception => } +} +``` + +- [ ] **Step 3: Add shuffle stream support to CometExecRDD** + +In `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala`, add a parameter to carry shuffle block iterator factories: + +```scala +private[spark] class CometExecRDD( + sc: SparkContext, + var inputRDDs: Seq[RDD[ColumnarBatch]], + ... + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) +``` + +In the `compute` method (line 112), pass them to `CometExecIterator`: + +```scala +// Create shuffle block iterators for this partition +val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => + idx -> factory(context, partition.inputPartitions(idx)) +} + +val it = new CometExecIterator( + CometExec.newIterId, + inputs, + numOutputCols, + actualPlan, + nativeMetrics, + numPartitions, + partition.index, + broadcastedHadoopConfForEncryption, + encryptedFilePaths, + shuffleBlockIters) +``` + +- [ ] **Step 4: Identify ShuffleScan inputs in operators.scala** + +In `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala`, in `CometNativeExec.doExecuteColumnar` (around line 480): + +After `foreachUntilCometInput(this)(sparkPlans += _)`, determine which inputs correspond to `ShuffleScan` operators. Parse the serialized protobuf plan to find `ShuffleScan` leaf positions: + +```scala +import org.apache.comet.serde.OperatorOuterClass + +// Find which input indices correspond to ShuffleScan operators +val shuffleScanIndices: Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(serializedPlanCopy) + var scanIndex = 0 + val indices = scala.collection.mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + // Recurse into children in order + (0 until op.getChildrenCount).foreach(i => walk(op.getChildren(i))) + } + } + walk(plan) + indices.toSet +} +``` + +Then in the `sparkPlans.zipWithIndex.foreach` loop (line 523), for plans at shuffle scan indices, create a factory that produces `CometShuffleBlockIterator`: + +```scala +val shuffleBlockIteratorFactories = scala.collection.mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] + +sparkPlans.zipWithIndex.foreach { case (plan, idx) => + plan match { + // ... existing cases ... + case _ if shuffleScanIndices.contains(inputIndexForPlan(idx)) => + // Still add the RDD for partition tracking, but also register + // a factory for the raw InputStream + val rdd = plan.executeColumnar() + inputs += rdd + // The factory creates a CometShuffleBlockIterator from the raw shuffle stream + // We need to get the raw InputStream - see Step 5 + shuffleBlockIteratorFactories(inputs.size - 1) = ... + // ... remaining cases ... + } +} +``` + +The tricky part is getting the raw `InputStream` from the shuffle read. See Step 5. + +- [ ] **Step 5: Add raw InputStream mode to CometBlockStoreShuffleReader** + +In `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala`: + +The current `read()` method creates `NativeBatchDecoderIterator` which decodes blocks. For direct read, we need a mode that yields the raw `InputStream` wrapped in `CometShuffleBlockIterator`. + +Add a method: + +```scala +def readRawStreams(): Iterator[CometShuffleBlockIterator] = { + fetchIterator.map { case (_, inputStream) => + new CometShuffleBlockIterator(inputStream) + } +} +``` + +The challenge is that `CometShuffledBatchRDD` calls `reader.read()` which returns `Iterator[Product2[Int, ColumnarBatch]]`. For the direct read path, we need a different RDD that calls `readRawStreams()` instead. + +**Approach**: Create `CometShuffledRawStreamRDD` — a simple RDD that wraps the shuffle reader and yields `CometShuffleBlockIterator` objects per block. Then in `operators.scala`, instead of using the ColumnarBatch RDD, create a `CometShuffledRawStreamRDD` and pass its iterator-producing factory to `CometExecRDD`. + +Alternatively, since `CometShuffleBlockIterator` wraps a single `InputStream` that may contain multiple blocks, and `fetchIterator` yields one `InputStream` per shuffle block, the simplest approach is to **concatenate all InputStreams into one** per partition: + +```scala +def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2) + new SequenceInputStream(java.util.Collections.enumeration( + streams.toList.asJava)) +} +``` + +Then in the factory: `(ctx, part) => new CometShuffleBlockIterator(reader.readAsRawStream())` + +But the reader is created per-partition in `CometShuffledBatchRDD.compute()`. The factory approach means the reader creation must be deferred. + +**Simplest concrete approach**: Instead of a factory, create a new RDD `CometShuffledRawRDD` that returns `Iterator[CometShuffleBlockIterator]`. Pass this as a separate input alongside the regular `ColumnarBatch` inputs: + +```scala +// In CometExecRDD, add: +shuffleRawInputRDDs: Seq[(Int, RDD[CometShuffleBlockIterator])] +``` + +In `compute`, create iterators from these RDDs and pass them to `CometExecIterator` via the `shuffleBlockIterators` map. + +This is the most invasive part of the implementation. The exact approach should be determined by reading the code at implementation time, as there are multiple valid paths. The key constraint: the raw `InputStream` from `fetchIterator` must reach `CometShuffleBlockIterator` without going through `NativeBatchDecoderIterator`. + +- [ ] **Step 6: Verify it compiles** + +Run: `./mvnw compile -DskipTests` +Expected: BUILD SUCCESS + +- [ ] **Step 7: Commit** + +```bash +git add spark/src/main/scala/org/apache/comet/Native.scala +git add spark/src/main/scala/org/apache/comet/CometExecIterator.scala +git add spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +git add spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +git add spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +git commit -m "feat: wire CometShuffleBlockIterator into JVM execution path" +``` + +--- + +### Task 9: End-to-end testing + +**Files:** +- Modify: Appropriate test suite (find the right suite by searching for existing shuffle tests) + +- [ ] **Step 1: Build everything** + +Run: `make` +Expected: Successful build of both native and JVM. + +- [ ] **Step 2: Run existing shuffle tests** + +Run: `./mvnw test -Dsuites="org.apache.comet.exec.CometShuffleSuite"` +Expected: All existing tests pass (they now use the new direct read path by default). + +If tests fail, debug by setting `spark.comet.shuffle.directRead.enabled=false` to confirm the old path still works, then investigate the new path. + +- [ ] **Step 3: Add comparison test** + +Add a test that runs the same queries with direct read enabled and disabled: + +```scala +test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf( + CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark.range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } +} +``` + +- [ ] **Step 4: Add Rust unit test for ShuffleScanExec** + +In `native/core/src/execution/operators/shuffle_scan.rs`, add a `#[cfg(test)]` module: + +```rust +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::io::Cursor; + use std::sync::Arc; + + #[test] + fn test_read_compressed_ipc_block() { + // Create a test RecordBatch + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ).unwrap(); + + // Write it as compressed IPC using ShuffleBlockWriter + let writer = ShuffleBlockWriter::try_new( + &batch.schema(), CompressionCodec::Zstd(1) + ).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = datafusion::physical_plan::metrics::Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back the body (skip the 16-byte header) + let bytes = buf.into_inner(); + let body = &bytes[16..]; // Skip compressed_length(8) + field_count(8) + + // Decode using read_ipc_compressed + let decoded = read_ipc_compressed(body).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + } +} +``` + +- [ ] **Step 5: Run all tests** + +Run: `make test` + +- [ ] **Step 6: Run clippy** + +Run: `cd native && cargo clippy --all-targets --workspace -- -D warnings` +Expected: No warnings. + +- [ ] **Step 7: Format** + +Run: `make format` + +- [ ] **Step 8: Commit** + +```bash +git add -A +git commit -m "test: add shuffle direct read tests" +``` + +--- + +## Implementation Notes + +### Task 8 is the hardest + +The core challenge is routing raw `InputStream` from Spark's shuffle infrastructure through to `CometShuffleBlockIterator` without going through the decode path. The current RDD pipeline (`CometShuffledBatchRDD` → `CometBlockStoreShuffleReader.read()` → `NativeBatchDecoderIterator`) always decodes. You need to intercept before `NativeBatchDecoderIterator` is created. + +The most surgical approach: in `CometBlockStoreShuffleReader`, add a `readRaw()` method that returns the raw `InputStream` (or a `CometShuffleBlockIterator` wrapping it) instead of decoded batches. Then create a parallel RDD (`CometShuffledRawRDD`) that calls `readRaw()` in its `compute` method and pass it through to `CometExecIterator`. + +### Metrics + +`ShuffleScanExec` should track `decode_time` using DataFusion's `Time` metric. Register it in `ShuffleScanExec::new` via `MetricBuilder` following the pattern in `ScanExec`. + +### Order of tasks + +Tasks 1-7 can be done sequentially. Task 8 depends on all previous tasks. Task 9 validates everything. From cb2fe12a887edd863a3b707caa2280500150ee7f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:30:11 -0600 Subject: [PATCH 03/33] feat: add spark.comet.shuffle.directRead.enabled config --- .../src/main/scala/org/apache/comet/CometConf.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 4d2e37924a..ad3774567c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -343,6 +343,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.shuffle.directRead.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Only applies to native shuffle (not JVM columnar shuffle). " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( From 191bbe1a32033ca2664baf274340f2862d9d6597 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:31:29 -0600 Subject: [PATCH 04/33] feat: add ShuffleScan protobuf message --- native/core/src/execution/planner/operator_registry.rs | 1 + native/proto/src/proto/operator.proto | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index b34a80df95..e20624b6c9 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -153,5 +153,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option { OpStruct::Window(_) => Some(OperatorType::Window), OpStruct::Explode(_) => None, // Not yet in OperatorType enum OpStruct::CsvScan(_) => Some(OperatorType::CsvScan), + OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum } } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..344b9f0f21 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -52,6 +52,7 @@ message Operator { ParquetWriter parquet_writer = 113; Explode explode = 114; CsvScan csv_scan = 115; + ShuffleScan shuffle_scan = 116; } } @@ -85,6 +86,12 @@ message Scan { bool arrow_ffi_safe = 3; } +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} + // Common data shared by all partitions in split mode (sent once at planning) message NativeScanCommon { repeated SparkStructField required_schema = 1; From 7ac1d93595607970d038e2858d68f584a3aed3c3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:33:22 -0600 Subject: [PATCH 05/33] feat: add CometShuffleBlockIterator for raw shuffle block access --- .../comet/CometShuffleBlockIterator.java | 141 ++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java new file mode 100644 index 0000000000..5de5e05c4e --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads the + * compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() and + * getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_ipc_compressed which allocates new memory for the + * decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. Called by + * native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + 8-byte + * fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + return -1; + } + throw new EOFException("Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + + bytesToRead + + " exceeds maximum of " + + Integer.MAX_VALUE + + ". Try reducing shuffle batch size."); + } + + if (dataBuf.capacity() < bytesToRead) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit((int) bytesToRead); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + currentBlockLength = (int) bytesToRead; + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes (4-byte codec + * prefix + compressed IPC data). Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** Returns the length of the current block in bytes. Called by native code via JNI. */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { + dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + } + } + } +} From 98bab7348af98282241570056a0a1234b9363da9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:34:38 -0600 Subject: [PATCH 06/33] feat: add JNI bridge for CometShuffleBlockIterator --- native/core/src/jvm_bridge/mod.rs | 5 ++ .../src/jvm_bridge/shuffle_block_iterator.rs | 56 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 native/core/src/jvm_bridge/shuffle_block_iterator.rs diff --git a/native/core/src/jvm_bridge/mod.rs b/native/core/src/jvm_bridge/mod.rs index 00fe7b33c3..85c2ae7577 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/core/src/jvm_bridge/mod.rs @@ -174,11 +174,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod shuffle_block_iterator; use crate::{errors::CometError, JAVA_VM}; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. #[allow(dead_code)] // we need to keep references to Java items to prevent GC @@ -204,6 +206,8 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. + pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -257,6 +261,7 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/core/src/jvm_bridge/shuffle_block_iterator.rs b/native/core/src/jvm_bridge/shuffle_block_iterator.rs new file mode 100644 index 0000000000..02fcf8ca27 --- /dev/null +++ b/native/core/src/jvm_bridge/shuffle_block_iterator.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// A struct that holds all the JNI methods and fields for JVM `CometShuffleBlockIterator` class. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env + .get_method_id(Self::JVM_CLASS, "getBuffer", "()Ljava/nio/ByteBuffer;")?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env + .get_method_id(Self::JVM_CLASS, "getCurrentBlockLength", "()I")?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} From c01cf1d47f1d12d2d1a6222d7468a3be9a176c96 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:58:35 -0600 Subject: [PATCH 07/33] feat: add ShuffleScanExec native operator for direct shuffle read Add a new ShuffleScanExec operator that pulls compressed shuffle blocks from JVM via CometShuffleBlockIterator and decodes them natively using read_ipc_compressed(). Uses the pre-pull pattern (get_next_batch called externally before poll_next) to avoid JNI calls on tokio threads. --- native/core/src/execution/operators/mod.rs | 2 + .../src/execution/operators/shuffle_scan.rs | 348 ++++++++++++++++++ 2 files changed, 350 insertions(+) create mode 100644 native/core/src/execution/operators/shuffle_scan.rs diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..ad3ec3f08b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -34,7 +34,9 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; +mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; +pub use shuffle_scan::ShuffleScanExec; /// Error returned during executing operators. #[derive(thiserror::Error, Debug)] diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs new file mode 100644 index 0000000000..4a8d09111b --- /dev/null +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -0,0 +1,348 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometError, + execution::{ + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, + shuffle::codec::read_ipc_compressed, + }, + jvm_bridge::{jni_call, JVMClasses}, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::{arrow_datafusion_err, Result as DataFusionResult}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::{ + execution::TaskContext, + physical_expr::*, + physical_plan::{ExecutionPlan, *}, +}; +use futures::Stream; +use jni::objects::{GlobalRef, JByteBuffer, JObject}; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use super::scan::InputBatch; + +/// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. +/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed +/// bytes from CometShuffleBlockIterator and decodes them using read_ipc_compressed(). +#[derive(Debug, Clone)] +pub struct ShuffleScanExec { + /// The ID of the execution context that owns this subquery. + pub exec_context_id: i64, + /// The input source: a global reference to a JVM CometShuffleBlockIterator object. + pub input_source: Option>, + /// The data types of columns in the shuffle output. + pub data_types: Vec, + /// Schema of the shuffle output. + pub schema: SchemaRef, + /// The current input batch, populated by get_next_batch() before poll_next(). + pub batch: Arc>>, + /// Cache of plan properties. + cache: PlanProperties, + /// Metrics collector. + metrics: ExecutionPlanMetricsSet, + /// Baseline metrics. + baseline_metrics: BaselineMetrics, + /// Time spent decoding compressed shuffle blocks. + decode_time: Time, +} + +impl ShuffleScanExec { + pub fn new( + exec_context_id: i64, + input_source: Option>, + data_types: Vec, + ) -> Result { + let metrics_set = ExecutionPlanMetricsSet::default(); + let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); + let decode_time = MetricBuilder::new(&metrics_set).subset_time("decode_time", 0); + + let schema = schema_from_data_types(&data_types); + + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(Self { + exec_context_id, + input_source, + data_types, + batch: Arc::new(Mutex::new(None)), + cache, + metrics: metrics_set, + baseline_metrics, + schema, + decode_time, + }) + } + + /// Feeds input batch into this scan. Only used in unit tests. + pub fn set_input_batch(&mut self, input: InputBatch) { + *self.batch.try_lock().unwrap() = Some(input); + } + + /// Pull next input batch from JVM. Called externally before poll_next() + /// because JNI calls cannot happen from within poll_next on tokio threads. + pub fn get_next_batch(&mut self) -> Result<(), CometError> { + if self.input_source.is_none() { + // Unit test mode - no JNI calls needed. + return Ok(()); + } + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + + let mut current_batch = self.batch.try_lock().unwrap(); + if current_batch.is_none() { + let next_batch = Self::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + &self.data_types, + &self.decode_time, + )?; + *current_batch = Some(next_batch); + } + + timer.stop(); + + Ok(()) + } + + /// Invokes JNI calls to get the next compressed shuffle block and decode it. + fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], + decode_time: &Time, + ) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + return Ok(InputBatch::EOF); + } + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null shuffle block iterator object. Plan id: {exec_context_id}" + )))); + } + + let mut env = JVMClasses::get_env()?; + + // has_next() returns block length or -1 if no more blocks + let block_length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length == -1 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer containing the compressed shuffle block + let buffer: JObject = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }; + + // Get the actual block length (may differ from has_next return value) + let length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_current_block_length() -> i32)? + }; + + let byte_buffer = JByteBuffer::from(buffer); + let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let length = length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decode the compressed IPC data + let mut timer = decode_time.timer(); + let batch = read_ipc_compressed(slice)?; + timer.stop(); + + let num_rows = batch.num_rows(); + + // The read_ipc_compressed already produces owned arrays, so we skip the + // header (field count + codec) that was already consumed by read_ipc_compressed. + // Extract column arrays from the RecordBatch. + let columns: Vec = batch.columns().to_vec(); + + debug_assert_eq!( + columns.len(), + data_types.len(), + "Shuffle block column count mismatch: got {} but expected {}", + columns.len(), + data_types.len() + ); + + Ok(InputBatch::new(columns, Some(num_rows))) + } +} + +fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef { + let fields = data_types + .iter() + .enumerate() + .map(|(idx, dt)| Field::new(format!("col_{idx}"), dt.clone(), true)) + .collect::>(); + + Arc::new(Schema::new(fields)) +} + +impl ExecutionPlan for ShuffleScanExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _: Arc, + ) -> datafusion::common::Result { + Ok(Box::pin(ShuffleScanStream::new( + self.clone(), + self.schema(), + partition, + self.baseline_metrics.clone(), + ))) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn name(&self) -> &str { + "ShuffleScanExec" + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +impl DisplayAs for ShuffleScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let fields: Vec = self + .data_types + .iter() + .enumerate() + .map(|(idx, dt)| format!("col_{idx}: {dt}")) + .collect(); + write!(f, "ShuffleScanExec: schema=[{}]", fields.join(", "))?; + } + DisplayFormatType::TreeRender => unimplemented!(), + } + Ok(()) + } +} + +/// An async stream that feeds decoded shuffle batches into the DataFusion plan. +struct ShuffleScanStream { + /// The ShuffleScanExec producing input batches. + shuffle_scan: ShuffleScanExec, + /// Schema of the output. + schema: SchemaRef, + /// Metrics. + baseline_metrics: BaselineMetrics, +} + +impl ShuffleScanStream { + pub fn new( + shuffle_scan: ShuffleScanExec, + schema: SchemaRef, + _partition: usize, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + shuffle_scan, + schema, + baseline_metrics, + } + } +} + +impl Stream for ShuffleScanStream { + type Item = DataFusionResult; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + let mut scan_batch = self.shuffle_scan.batch.try_lock().unwrap(); + + let input_batch = &*scan_batch; + let input_batch = if let Some(batch) = input_batch { + batch + } else { + timer.stop(); + return Poll::Pending; + }; + + let result = match input_batch { + InputBatch::EOF => Poll::Ready(None), + InputBatch::Batch(columns, num_rows) => { + self.baseline_metrics.record_output(*num_rows); + let options = arrow::array::RecordBatchOptions::new() + .with_row_count(Some(*num_rows)); + let maybe_batch = arrow::array::RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + columns.clone(), + &options, + ) + .map_err(|e| arrow_datafusion_err!(e)); + Poll::Ready(Some(maybe_batch)) + } + }; + + *scan_batch = None; + + timer.stop(); + + result + } +} + +impl RecordBatchStream for ShuffleScanStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} From e1c9111203d88214802d82c35d86075f6ef61861 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:06:22 -0600 Subject: [PATCH 08/33] feat: wire ShuffleScanExec into planner and pre-pull mechanism --- native/core/src/execution/jni_api.rs | 14 ++- .../src/execution/operators/projection.rs | 9 +- native/core/src/execution/planner.rs | 114 ++++++++++++++---- .../execution/planner/operator_registry.rs | 11 +- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..d20cf128b5 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -82,7 +82,7 @@ use tokio::sync::mpsc; use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; -use crate::execution::operators::ScanExec; +use crate::execution::operators::{ScanExec, ShuffleScanExec}; use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; @@ -151,6 +151,8 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, + /// The shuffle scan input sources for the DataFusion plan + pub shuffle_scans: Vec, /// The global reference of input sources for the DataFusion plan pub input_sources: Vec>, /// The record batch stream to pull results from @@ -311,6 +313,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( partition_count: partition_count as usize, root_op: None, scans: vec![], + shuffle_scans: vec![], input_sources, stream: None, batch_receiver: None, @@ -491,6 +494,10 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr exec_context.scans.iter_mut().try_for_each(|scan| { scan.get_next_batch()?; Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) }) } @@ -539,7 +546,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) .with_exec_id(exec_context_id); - let (scans, root_op) = planner.create_plan( + let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), exec_context.partition_count, @@ -548,6 +555,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.plan_creation_time += physical_plan_time; exec_context.scans = scans; + exec_context.shuffle_scans = shuffle_scans; if exec_context.explain_native { let formatted_plan_str = @@ -560,7 +568,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // so we should always execute partition 0. let stream = root_op.native_plan.execute(0, task_ctx)?; - if exec_context.scans.is_empty() { + if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { // No JVM data sources — spawn onto tokio so the executor // thread parks in blocking_recv instead of busy-polling. // diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 6ba1bb5d59..4169ed8d40 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,8 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - operators::{ExecutionError, ScanExec}, - planner::{operator_registry::OperatorBuilder, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PlanCreationResult, PhysicalPlanner}, spark_plan::SparkPlan, }, extract_op, @@ -42,12 +41,13 @@ impl OperatorBuilder for ProjectionBuilder { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let project = extract_op!(spark_plan, Projection); let children = &spark_plan.children; assert_eq!(children.len(), 1); - let (scans, child) = planner.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + planner.create_plan(&children[0], inputs, partition_count)?; // Create projection expressions let exprs: Result, _> = project @@ -68,6 +68,7 @@ impl OperatorBuilder for ProjectionBuilder { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), )) } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bd37755922..e19891a0d6 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -27,7 +27,7 @@ use crate::{ errors::ExpressionError, execution::{ expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, @@ -141,6 +141,8 @@ use url::Url; type PhyAggResult = Result, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +pub type PlanCreationResult = + Result<(Vec, Vec, Arc), ExecutionError>; struct JoinParameters { pub left: Arc, @@ -913,7 +915,7 @@ impl PhysicalPlanner { spark_plan: &'a Operator, inputs: &mut Vec>, partition_count: usize, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { // Try to use the modular registry first - this automatically handles any registered operator types if OperatorRegistry::global().can_handle(spark_plan) { return OperatorRegistry::global().create_plan( @@ -929,7 +931,8 @@ impl PhysicalPlanner { match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Filter(filter) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; @@ -940,12 +943,14 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])), )) } OpStruct::HashAgg(agg) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let group_exprs: PhyExprResult = agg .grouping_exprs @@ -996,6 +1001,7 @@ impl PhysicalPlanner { if agg.result_exprs.is_empty() { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), )) } else { @@ -1012,6 +1018,7 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, projection, @@ -1030,7 +1037,8 @@ impl PhysicalPlanner { "Invalid limit/offset combination: [{num}. {offset}]" ))); } - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let limit: Arc = if offset == 0 { Arc::new(LocalLimitExec::new( Arc::clone(&child.native_plan), @@ -1050,12 +1058,14 @@ impl PhysicalPlanner { }; Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])), )) } OpStruct::Sort(sort) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let exprs: Result, ExecutionError> = sort .sort_orders @@ -1079,6 +1089,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, sort_exec, @@ -1115,6 +1126,7 @@ impl PhysicalPlanner { if partition_files.partitioned_file.is_empty() { let empty_exec = Arc::new(EmptyExec::new(required_schema)); return Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, empty_exec, vec![])), )); @@ -1205,6 +1217,7 @@ impl PhysicalPlanner { common.encryption_enabled, )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1243,6 +1256,7 @@ impl PhysicalPlanner { &scan.csv_options.clone().unwrap(), )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1276,6 +1290,7 @@ impl PhysicalPlanner { Ok(( vec![scan.clone()], + vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), )) } @@ -1307,6 +1322,7 @@ impl PhysicalPlanner { )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new( spark_plan.plan_id, @@ -1317,7 +1333,8 @@ impl PhysicalPlanner { } OpStruct::ShuffleWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), @@ -1350,6 +1367,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, shuffle_writer, @@ -1359,7 +1377,8 @@ impl PhysicalPlanner { } OpStruct::ParquetWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let codec = match writer.compression.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), @@ -1396,6 +1415,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, parquet_writer, @@ -1405,7 +1425,8 @@ impl PhysicalPlanner { } OpStruct::Expand(expand) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let mut projections = vec![]; let mut projection = vec![]; @@ -1448,12 +1469,14 @@ impl PhysicalPlanner { let expand = Arc::new(ExpandExec::new(projections, input, schema)); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])), )) } OpStruct::Explode(explode) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; // Create the expression for the array to explode let child_expr = if let Some(child_expr) = &explode.child { @@ -1559,11 +1582,12 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, unnest_exec, vec![child])), )) } OpStruct::SortMergeJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1615,6 +1639,7 @@ impl PhysicalPlanner { )); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, coalesce_batches, @@ -1628,6 +1653,7 @@ impl PhysicalPlanner { } else { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, join, @@ -1640,7 +1666,7 @@ impl PhysicalPlanner { } } OpStruct::HashJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1670,6 +1696,7 @@ impl PhysicalPlanner { if join.build_side == BuildSide::BuildLeft as i32 { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, hash_join, @@ -1688,6 +1715,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, swapped_hash_join, @@ -1698,7 +1726,8 @@ impl PhysicalPlanner { } } OpStruct::Window(wnd) => { - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let input_schema = child.schema(); let sort_exprs: Result, ExecutionError> = wnd .order_by_list @@ -1736,9 +1765,42 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError( + "No input for shuffle scan".to_string(), + )); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = ShuffleScanExec::new( + self.exec_context_id, + input_source, + data_types, + )?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new( + spark_plan.plan_id, + Arc::new(shuffle_scan), + vec![], + )), + )) + } _ => Err(GeneralError(format!( "Unsupported or unregistered operator type: {:?}", spark_plan.op_struct @@ -1756,12 +1818,15 @@ impl PhysicalPlanner { join_type: i32, condition: &Option, partition_count: usize, - ) -> Result<(JoinParameters, Vec), ExecutionError> { + ) -> Result<(JoinParameters, Vec, Vec), ExecutionError> { assert_eq!(children.len(), 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs, partition_count)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs, partition_count)?; + let (mut left_scans, mut left_shuffle_scans, left) = + self.create_plan(&children[0], inputs, partition_count)?; + let (mut right_scans, mut right_shuffle_scans, right) = + self.create_plan(&children[1], inputs, partition_count)?; left_scans.append(&mut right_scans); + left_shuffle_scans.append(&mut right_shuffle_scans); let left_join_exprs: Vec<_> = left_join_keys .iter() @@ -1882,6 +1947,7 @@ impl PhysicalPlanner { join_filter, }, left_scans, + left_shuffle_scans, )) } @@ -3670,7 +3736,7 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3744,7 +3810,7 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3791,7 +3857,7 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3876,7 +3942,7 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3900,7 +3966,7 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); @@ -4014,7 +4080,7 @@ mod tests { })), }; - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); @@ -4140,7 +4206,7 @@ mod tests { }; // Create a physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Start executing the plan in a separate thread @@ -4631,7 +4697,7 @@ mod tests { }; // Create the physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Create test data: Date32 and Int8 columns diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index e20624b6c9..cad5df40c5 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -25,11 +25,8 @@ use std::{ use datafusion_comet_proto::spark_operator::Operator; use jni::objects::GlobalRef; -use super::PhysicalPlanner; -use crate::execution::{ - operators::{ExecutionError, ScanExec}, - spark_plan::SparkPlan, -}; +use super::{PhysicalPlanner, PlanCreationResult}; +use crate::execution::operators::ExecutionError; /// Trait for building physical operators from Spark protobuf operators pub trait OperatorBuilder: Send + Sync { @@ -40,7 +37,7 @@ pub trait OperatorBuilder: Send + Sync { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError>; + ) -> PlanCreationResult; } /// Enum to identify different operator types for registry dispatch @@ -100,7 +97,7 @@ impl OperatorRegistry { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let operator_type = get_operator_type(spark_operator).ok_or_else(|| { ExecutionError::GeneralError(format!( "Unsupported operator type: {:?}", From 9a9812a0bcd877327cf165a269266480c7ee2eb4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:09:31 -0600 Subject: [PATCH 09/33] feat: emit ShuffleScan protobuf for native shuffle with direct read --- .../comet/serde/operator/CometSink.scala | 70 +++++++++++++++++-- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index ca9dbdad7c..dde36d9789 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -22,8 +22,12 @@ package org.apache.comet.serde.operator import scala.jdk.CollectionConverters._ import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.ConfigEntry import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} @@ -86,15 +90,67 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { object CometExchangeSink extends CometSink[SparkPlan] { - /** - * Exchange data is FFI safe because there is no use of mutable buffers involved. - * - * Source of broadcast exchange batches is ArrowStreamReader. - * - * Source of shuffle exchange batches is NativeBatchDecoderIterator. - */ override def isFfiSafe: Boolean = true + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.exists(_.shuffleType == CometNativeShuffle) + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, "unsupported data types for shuffle direct read") + None + } + } + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = CometSinkPlaceHolder(nativeOp, op, op) } From e098cd5df5d5e03b3f50443e22b64ca108089b2f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:17:38 -0600 Subject: [PATCH 10/33] feat: wire CometShuffleBlockIterator into JVM execution path --- .../org/apache/comet/CometExecIterator.scala | 15 +- .../main/scala/org/apache/comet/Native.scala | 2 +- .../apache/spark/sql/comet/CometExecRDD.scala | 24 +++- .../CometBlockStoreShuffleReader.scala | 11 ++ .../apache/spark/sql/comet/operators.scala | 133 +++++++++++++++++- 5 files changed, 171 insertions(+), 14 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 44ebf7e36e..e198ac99ff 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -67,7 +67,8 @@ class CometExecIterator( numParts: Int, partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) extends Iterator[ColumnarBatch] with Logging { @@ -78,8 +79,13 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - private val cometBatchIterators = inputs.map { iterator => - new CometBatchIterator(iterator, nativeUtil) + // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle + // scan indices, CometBatchIterator for regular scan indices. + private val inputIterators: Array[Object] = inputs.zipWithIndex.map { + case (_, idx) if shuffleBlockIterators.contains(idx) => + shuffleBlockIterators(idx).asInstanceOf[Object] + case (iterator, _) => + new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] }.toArray private val plan = { @@ -106,7 +112,7 @@ class CometExecIterator( nativeLib.createPlan( id, - cometBatchIterators, + inputIterators, protobufQueryPlan, protobufSparkConfigs, numParts, @@ -229,6 +235,7 @@ class CometExecIterator( currentBatch = null } nativeUtil.close() + shuffleBlockIterators.values.foreach(_.close()) nativeLib.releasePlan(plan) if (tracingEnabled) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..f6800626d6 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -54,7 +54,7 @@ class Native extends NativeBase { // scalastyle:off @native def createPlan( id: Long, - iterators: Array[CometBatchIterator], + iterators: Array[Object], plan: Array[Byte], configMapProto: Array[Byte], partitionCount: Int, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index ad0c4f2afe..cb8652507f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -64,7 +64,10 @@ private[spark] class CometExecRDD( nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -109,6 +112,12 @@ private[spark] class CometExecRDD( serializedPlan } + // Create shuffle block iterators for indices that have factories + val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => + val inputPart = partition.inputPartitions(idx) + idx -> factory(context, inputPart) + } + val it = new CometExecIterator( CometExec.newIterId, inputs, @@ -118,7 +127,8 @@ private[spark] class CometExecRDD( numPartitions, partition.index, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIters) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -167,7 +177,10 @@ object CometExecRDD { nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = { + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -181,6 +194,7 @@ object CometExecRDD { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index e95eb92d21..647d4a0856 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -21,6 +21,8 @@ package org.apache.spark.sql.comet.execution.shuffle import java.io.InputStream +import scala.jdk.CollectionConverters._ + import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec @@ -153,6 +155,15 @@ class CometBlockStoreShuffleReader[K, C]( } } + /** + * Returns the raw concatenated InputStream of all shuffle blocks, bypassing the decode step. + * Used by ShuffleScan direct read path. + */ + def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2).toList + new java.io.SequenceInputStream(java.util.Collections.enumeration(streams.asJava)) + } + private def fetchContinuousBlocksInBatch: Boolean = { val conf = SparkEnv.get.conf val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..a0cb14bbd0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,14 +34,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.execution.shuffle.{CometBlockStoreShuffleReader, CometShuffledBatchRDD, CometShuffleExchangeExec, ShuffledRowRDDPartition} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -50,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import com.google.protobuf.CodedOutputStream -import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry} +import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, CometShuffleBlockIterator, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} @@ -553,6 +554,11 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } + // Detect ShuffleScan indices and create factories for direct read + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + val shuffleBlockIteratorFactories = + buildShuffleBlockIteratorFactories(sparkPlans, inputs, shuffleScanIndices) + // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) CometExecRDD( @@ -566,7 +572,8 @@ abstract class CometNativeExec extends CometExec { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } @@ -606,6 +613,124 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. + */ + private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + var scanIndex = 0 + val indices = mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + op.getChildrenList.asScala.foreach(walk) + } + } + walk(plan) + indices.toSet + } + + /** + * Build factory functions that produce CometShuffleBlockIterator for each input index that is a + * ShuffleScan. Maps from input index to a factory that, given TaskContext and Partition, + * creates the iterator. + */ + private def buildShuffleBlockIteratorFactories( + sparkPlans: ArrayBuffer[SparkPlan], + inputs: ArrayBuffer[RDD[ColumnarBatch]], + shuffleScanIndices: Set[Int]) + : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = { + if (shuffleScanIndices.isEmpty) return Map.empty + + // Build the mapping from sparkPlans index to inputs index + // (CometNativeExec entries are skipped in inputs) + var inputIdx = 0 + val sparkPlanToInputIdx = mutable.Map.empty[Int, Int] + sparkPlans.zipWithIndex.foreach { case (plan, spIdx) => + plan match { + case _: CometNativeExec => // skipped, no input + case _ => + sparkPlanToInputIdx(spIdx) = inputIdx + inputIdx += 1 + } + } + + val factories = mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] + + shuffleScanIndices.foreach { scanIdx => + if (scanIdx < inputs.length) { + inputs(scanIdx) match { + case rdd: CometShuffledBatchRDD => + val dep = rdd.dependency + factories(scanIdx) = (context, part) => { + val shufflePart = + part + .asInstanceOf[CometExecPartition] + .inputPartitions(scanIdx) + .asInstanceOf[ShuffledRowRDDPartition] + val tempMetrics = + context.taskMetrics().createTempShuffleReadMetrics() + val sqlMetricsReporter = + new SQLShuffleReadMetricsReporter(tempMetrics, Map.empty) + val reader = shufflePart.spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + } + val rawStream = reader.readAsRawStream() + new CometShuffleBlockIterator(rawStream) + } + case _ => // Not a CometShuffledBatchRDD, skip + } + } + } + factories.toMap + } + /** * Find all plan nodes with per-partition planning data in the plan tree. Returns two maps keyed * by a unique identifier: one for common data (shared across partitions) and one for From bf7040ffbf1bf93d0b6672a45506fc8dc7a6756c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:37:44 -0600 Subject: [PATCH 11/33] test: add shuffle direct read tests Fix two bugs discovered during testing: - ClassCastException: factory closure incorrectly cast Partition to CometExecPartition before extracting ShuffledRowRDDPartition; the partition passed to the factory is already the unwrapped partition from the input RDD - NoSuchElementException in SQLShuffleReadMetricsReporter: metrics field in CometShuffledBatchRDD was not exposed as a val, causing Map.empty to be used instead of the real shuffle metrics map Add Scala integration test that runs a repartition+aggregate query with direct read enabled and disabled to verify result parity. Add Rust unit test for read_ipc_compressed codec round-trip. --- .../src/execution/operators/projection.rs | 2 +- .../src/execution/operators/shuffle_scan.rs | 58 ++++++++++++++++++- native/core/src/execution/planner.rs | 26 ++++----- .../src/jvm_bridge/shuffle_block_iterator.rs | 14 +++-- .../shuffle/CometShuffledRowRDD.scala | 2 +- .../apache/spark/sql/comet/operators.scala | 9 +-- .../comet/exec/CometNativeShuffleSuite.scala | 17 +++++- 7 files changed, 100 insertions(+), 28 deletions(-) diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 4169ed8d40..194fa6769a 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,7 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - planner::{operator_registry::OperatorBuilder, PlanCreationResult, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PhysicalPlanner, PlanCreationResult}, spark_plan::SparkPlan, }, extract_op, diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 4a8d09111b..567a6e22f4 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -321,8 +321,8 @@ impl Stream for ShuffleScanStream { InputBatch::EOF => Poll::Ready(None), InputBatch::Batch(columns, num_rows) => { self.baseline_metrics.record_output(*num_rows); - let options = arrow::array::RecordBatchOptions::new() - .with_row_count(Some(*num_rows)); + let options = + arrow::array::RecordBatchOptions::new().with_row_count(Some(*num_rows)); let maybe_batch = arrow::array::RecordBatch::try_new_with_options( Arc::clone(&self.schema), columns.clone(), @@ -346,3 +346,57 @@ impl RecordBatchStream for ShuffleScanStream { Arc::clone(&self.schema) } } + +#[cfg(test)] +mod tests { + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + use crate::execution::shuffle::codec::read_ipc_compressed; + + #[test] + fn test_read_compressed_ipc_block() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + // Write as compressed IPC + let writer = + ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back (skip 16-byte header: 8 compressed_length + 8 field_count) + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_ipc_compressed(body).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify data + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert_eq!(col0.value(1), 2); + assert_eq!(col0.value(2), 3); + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e19891a0d6..b5892d763c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1773,9 +1773,7 @@ impl PhysicalPlanner { let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - return Err(GeneralError( - "No input for shuffle scan".to_string(), - )); + return Err(GeneralError("No input for shuffle scan".to_string())); } let input_source = @@ -1785,11 +1783,8 @@ impl PhysicalPlanner { Some(inputs.remove(0)) }; - let shuffle_scan = ShuffleScanExec::new( - self.exec_context_id, - input_source, - data_types, - )?; + let shuffle_scan = + ShuffleScanExec::new(self.exec_context_id, input_source, data_types)?; Ok(( vec![], @@ -3736,7 +3731,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3810,7 +3806,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3857,7 +3854,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3942,7 +3940,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, _shuffle_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3966,7 +3965,8 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, _shuffle_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = + planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); diff --git a/native/core/src/jvm_bridge/shuffle_block_iterator.rs b/native/core/src/jvm_bridge/shuffle_block_iterator.rs index 02fcf8ca27..c3bb5af5fb 100644 --- a/native/core/src/jvm_bridge/shuffle_block_iterator.rs +++ b/native/core/src/jvm_bridge/shuffle_block_iterator.rs @@ -45,11 +45,17 @@ impl<'a> CometShuffleBlockIterator<'a> { class, method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_get_buffer: env - .get_method_id(Self::JVM_CLASS, "getBuffer", "()Ljava/nio/ByteBuffer;")?, + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, method_get_buffer_ret: ReturnType::Object, - method_get_current_block_length: env - .get_method_id(Self::JVM_CLASS, "getCurrentBlockLength", "()I")?, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), }) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index ba6fc588e2..6594982c85 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ class CometShuffledBatchRDD( var dependency: ShuffleDependency[Int, _, _], - metrics: Map[String, SQLMetric], + val metrics: Map[String, SQLMetric], partitionSpecs: Array[ShufflePartitionSpec]) extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a0cb14bbd0..9edaf447c5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -667,16 +667,13 @@ abstract class CometNativeExec extends CometExec { inputs(scanIdx) match { case rdd: CometShuffledBatchRDD => val dep = rdd.dependency + val rddMetrics = rdd.metrics factories(scanIdx) = (context, part) => { - val shufflePart = - part - .asInstanceOf[CometExecPartition] - .inputPartitions(scanIdx) - .asInstanceOf[ShuffledRowRDDPartition] + val shufflePart = part.asInstanceOf[ShuffledRowRDDPartition] val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() val sqlMetricsReporter = - new SQLShuffleReadMetricsReporter(tempMetrics, Map.empty) + new SQLShuffleReadMetricsReporter(tempMetrics, rddMetrics) val reader = shufflePart.spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => SparkEnv.get.shuffleManager diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 1cf43ea598..11f825e70d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, count, sum} import org.apache.comet.CometConf @@ -437,4 +437,19 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark + .range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } + } } From c91d1b9af8826315e1d419859f85eef8c8a54356 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:45:08 -0600 Subject: [PATCH 12/33] refactor: simplify shuffle direct read code - Remove redundant getCurrentBlockLength() JNI call (reuse hasNext() return value) - Make readAsRawStream() lazy instead of materializing all streams to a List - Remove pointless DirectByteBuffer re-allocation in close() - Remove dead sparkPlanToInputIdx map --- native/core/src/execution/operators/shuffle_scan.rs | 10 ++-------- .../org/apache/comet/CometShuffleBlockIterator.java | 3 --- .../shuffle/CometBlockStoreShuffleReader.scala | 7 +++++-- .../org/apache/spark/sql/comet/operators.scala | 13 ------------- 4 files changed, 7 insertions(+), 26 deletions(-) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 567a6e22f4..80c72a6d31 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -152,7 +152,7 @@ impl ShuffleScanExec { let mut env = JVMClasses::get_env()?; - // has_next() returns block length or -1 if no more blocks + // has_next() reads the next block and returns its length, or -1 if EOF let block_length: i32 = unsafe { jni_call!(&mut env, comet_shuffle_block_iterator(iter).has_next() -> i32)? @@ -168,15 +168,9 @@ impl ShuffleScanExec { comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? }; - // Get the actual block length (may differ from has_next return value) - let length: i32 = unsafe { - jni_call!(&mut env, - comet_shuffle_block_iterator(iter).get_current_block_length() -> i32)? - }; - let byte_buffer = JByteBuffer::from(buffer); let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; - let length = length as usize; + let length = block_length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; // Decode the compressed IPC data diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java index 5de5e05c4e..f9abef1c36 100644 --- a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -133,9 +133,6 @@ public void close() throws IOException { if (!closed) { closed = true; inputStream.close(); - if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { - dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - } } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index 647d4a0856..ecf37efe2a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -160,8 +160,11 @@ class CometBlockStoreShuffleReader[K, C]( * Used by ShuffleScan direct read path. */ def readAsRawStream(): InputStream = { - val streams = fetchIterator.map(_._2).toList - new java.io.SequenceInputStream(java.util.Collections.enumeration(streams.asJava)) + val streams = fetchIterator.map(_._2) + new java.io.SequenceInputStream(new java.util.Enumeration[InputStream] { + override def hasMoreElements: Boolean = streams.hasNext + override def nextElement(): InputStream = streams.next() + }) } private def fetchContinuousBlocksInBatch: Boolean = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 9edaf447c5..2e195e73eb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -647,19 +647,6 @@ abstract class CometNativeExec extends CometExec { : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = { if (shuffleScanIndices.isEmpty) return Map.empty - // Build the mapping from sparkPlans index to inputs index - // (CometNativeExec entries are skipped in inputs) - var inputIdx = 0 - val sparkPlanToInputIdx = mutable.Map.empty[Int, Int] - sparkPlans.zipWithIndex.foreach { case (plan, spIdx) => - plan match { - case _: CometNativeExec => // skipped, no input - case _ => - sparkPlanToInputIdx(spIdx) = inputIdx - inputIdx += 1 - } - } - val factories = mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] shuffleScanIndices.foreach { scanIdx => From b41889dcd6f466b35769692b71517ed1f56f5779 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 17:03:26 -0600 Subject: [PATCH 13/33] style: remove unused import --- .../comet/execution/shuffle/CometBlockStoreShuffleReader.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index ecf37efe2a..14e656f038 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -21,8 +21,6 @@ package org.apache.spark.sql.comet.execution.shuffle import java.io.InputStream -import scala.jdk.CollectionConverters._ - import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec From 6e24a270f400588d1035e307ad50fac4a85164f4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 18:21:20 -0600 Subject: [PATCH 14/33] remove design doc --- .../2026-03-18-shuffle-direct-read-design.md | 163 ------------------ 1 file changed, 163 deletions(-) delete mode 100644 docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md diff --git a/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md b/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md deleted file mode 100644 index 2f002a2d89..0000000000 --- a/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md +++ /dev/null @@ -1,163 +0,0 @@ -# Shuffle Direct Read: Bypass FFI for Native Shuffle Read Path - -## Problem - -When a native shuffle exchange feeds into a downstream native operator, shuffle data crosses the JVM/native FFI boundary twice: - -1. **Native to JVM**: `decodeShuffleBlock` JNI call decompresses Arrow IPC, creates a `RecordBatch`, and exports it via Arrow C Data Interface (per-column `FFI_ArrowArray` + `FFI_ArrowSchema` allocation, export, and import). -2. **JVM to Native**: `CometBatchIterator` re-exports the `ColumnarBatch` via Arrow C Data Interface back to native, where `ScanExec` imports and copies/unpacks the arrays. - -Each crossing involves per-column schema serialization, struct allocation, and array copying. For queries with many shuffle stages or wide schemas, this overhead is significant. - -## Solution - -Introduce a direct read path where native code consumes compressed shuffle blocks directly, bypassing Arrow FFI entirely. The JVM reads raw bytes from Spark's shuffle infrastructure and hands them to native via a `DirectByteBuffer` (zero-copy pointer access). Native decompresses and decodes in-place, feeding `RecordBatch` directly into the execution plan. - -### Data Flow Comparison - -**Current path (double FFI):** - -``` -Shuffle stream - -> NativeBatchDecoderIterator (JVM) - -> JNI: decodeShuffleBlock - -> FFI export: RecordBatch -> ArrowArray/Schema (native -> JVM) - -> ColumnarBatch on JVM - -> CometBatchIterator - -> FFI export: ColumnarBatch -> ArrowArray/Schema (JVM -> native) - -> ScanExec imports + copies arrays - -> Native operators -``` - -**New path (zero FFI):** - -``` -Shuffle stream - -> CometShuffleBlockIterator (JVM) - -> reads header + compressed body into DirectByteBuffer - -> holds bytes, waits for native pull - -ShuffleScanExec (native, pull-based) - -> JNI callback: iterator.hasNext()/getBuffer() - -> read_ipc_compressed() -> RecordBatch - -> feeds directly into native execution plan -``` - -## Scope - -- Native shuffle (`CometNativeShuffle`) only. JVM columnar shuffle is excluded because its per-batch dictionary encoding decisions can change the schema between batches. -- Both paths (old and new) are retained. A config flag controls which is used. - -## Components - -### New JVM Components - -#### `CometShuffleBlockIterator` (Java) - -A new class that wraps a shuffle `InputStream` and exposes raw compressed blocks for native consumption. Absorbs the header-reading and buffer-management logic from `NativeBatchDecoderIterator`, but does not decode. - -JNI-callable interface: - -- `hasNext() -> int`: Reads the next block's header from the stream. The header is 16 bytes: 8-byte compressed length (includes the 8-byte field count but not itself) + 8-byte field count. The field count from the header is discarded — the schema is determined by the `ShuffleScan` protobuf's `fields` list, which is authoritative. Returns the compressed body length in bytes (i.e., `compressedLength - 8`, which includes the 4-byte codec prefix + compressed IPC data), or -1 for EOF. -- `getBuffer() -> ByteBuffer`: Returns the `DirectByteBuffer` containing the current block's compressed bytes (4-byte codec prefix + compressed IPC data). This buffer is only valid until the next `hasNext()` call — the caller must fully consume it (via `read_ipc_compressed()`, which decompresses into a new allocation) before pulling the next block. - -Uses its own `DirectByteBuffer` instance (not shared with `NativeBatchDecoderIterator`) with the same pooling strategy: initial 128KB, grows as needed, reset on close. - -**Lifecycle**: Implements `Closeable`. `close()` closes the underlying shuffle `InputStream` and resets the buffer. `CometBlockStoreShuffleReader` registers a task completion listener to close it, matching the existing pattern for `NativeBatchDecoderIterator`. - -### New Native Components - -#### `ShuffleScanExec` (Rust) - -Location: `native/core/src/execution/operators/shuffle_scan.rs` - -A new `ExecutionPlan` operator that replaces `ScanExec` at shuffle boundaries. On each `poll_next`: - -1. Calls JNI into `CometShuffleBlockIterator.hasNext()` to get the next block's byte length (or -1 for EOF). -2. Calls `CometShuffleBlockIterator.getBuffer()` to get a `DirectByteBuffer`. -3. Obtains the buffer's raw pointer via `JNIEnv::get_direct_buffer_address()` and creates a slice over it (zero-copy, same pattern as `decodeShuffleBlock`). -4. Calls `read_ipc_compressed()` to decompress and decode into a `RecordBatch`. This allocates new memory for the decompressed data — the `DirectByteBuffer` can be safely reused afterward. -5. Returns the `RecordBatch` directly to the downstream native operator. - -No `FFI_ArrowArray`, `FFI_ArrowSchema`, `ArrowImporter`, or `CometVector` involved. - -Implements `on_close` for cleanup (releasing the JNI `GlobalRef`), matching the `ScanExec` pattern. - -#### `ShuffleScan` Protobuf Message - -Location: `native/proto/src/proto/operator.proto` - -New message alongside existing `Scan`: - -```protobuf -message ShuffleScan { - repeated spark.spark_expression.DataType fields = 1; - string source = 2; // Informational label (e.g., "CometShuffleExchangeExec [id=5]") -} -``` - -The `Operator` message gains a new `shuffle_scan` field in its oneof. - -### Modified JVM Components - -#### `CometExchangeSink` / `CometExecRule` - -The decision to use `ShuffleScan` vs `Scan` is made when `CometNativeExec` is constructed (not during the bottom-up conversion pass). At that point, the operator tree is already converted: `CometExecRule.convertBlock()` wraps a contiguous group of native operators into `CometNativeExec` and serializes the protobuf plan. The children (including `CometSinkPlaceHolder` wrapping shuffle exchanges) are already known. So the check is: when serializing a `CometSinkPlaceHolder` whose `originalPlan` is a `CometShuffleExchangeExec` with `shuffleType == CometNativeShuffle`, and the config flag is enabled, emit `ShuffleScan` instead of `Scan`. - -Conditions for `ShuffleScan`: - -1. Shuffle type is `CometNativeShuffle` -2. The sink is inside a `CometNativeExec` block (always true at serialization time — this is where sinks get serialized) -3. Config `spark.comet.shuffle.directRead.enabled` is true (default: true) - -#### `CometNativeExec` (operators.scala) - -When collecting input RDDs and creating iterators, distinguish the two cases: - -- `ShuffleScan` input: Wrap the shuffle RDD's `Iterator[ColumnarBatch]` stream in `CometShuffleBlockIterator` — but note that `CometShuffleBlockIterator` wraps the raw `InputStream` from shuffle blocks, not decoded `ColumnarBatch`. This means the RDD must provide the raw shuffle `InputStream` rather than going through `NativeBatchDecoderIterator`. The `CometShuffledBatchRDD` / `CometBlockStoreShuffleReader` needs a mode where it yields raw `InputStream` objects per block instead of decoded batches. -- `Scan` input: Wrap in `CometBatchIterator` (existing behavior) - -#### `CometExecIterator` — JNI Input Contract - -Currently `CometExecIterator` wraps all inputs as `CometBatchIterator` and passes them to `Native.createPlan()` as `Array[CometBatchIterator]`. To support `CometShuffleBlockIterator`: - -- Change the JNI parameter from `Array[CometBatchIterator]` to `Array[Object]`. On the native side in `createPlan`, the planner already knows from the protobuf whether each input is a `Scan` or `ShuffleScan`, so it knows which JNI methods to call on each `GlobalRef` — no type checking needed at runtime. -- `CometExecIterator` populates the array with either `CometBatchIterator` or `CometShuffleBlockIterator` based on whether the corresponding leaf in the protobuf plan is `Scan` or `ShuffleScan`. - -### Native Planner Changes - -In `planner.rs`, handle the `ShuffleScan` protobuf variant: - -- Consume an input from `inputs.remove(0)` (same pattern as `Scan`) -- Create `ShuffleScanExec` instead of `ScanExec` -- The `GlobalRef` points to a `CometShuffleBlockIterator` Java object - -## Fallback Behavior - -The new path is used only when all conditions above are met. Otherwise, the existing path is used unchanged. The most common fallback case is a shuffle whose output is consumed by a non-native Spark operator (e.g., `collect()`, or an unsupported operator), where the JVM needs a materialized `ColumnarBatch`. - -## Configuration - -| Config | Default | Description | -|--------|---------|-------------| -| `spark.comet.shuffle.directRead.enabled` | `true` | Use direct native read path for native shuffle when downstream operator is native | - -## Error Handling - -- `ShuffleScanExec` reuses `read_ipc_compressed()`, which handles corrupt data and unsupported codecs. -- JNI errors from `CometShuffleBlockIterator` (stream closed, EOF, I/O errors) propagate through the existing `try_unwrap_or_throw` pattern. -- If the JVM iterator throws, the exception surfaces as a Rust error and propagates through DataFusion's error handling. -- Empty batches (zero rows): `read_ipc_compressed()` calls `reader.next().unwrap()` which panics if the stream contains no batches. The shuffle writer never writes zero-row blocks (guarded by `if batch.num_rows() == 0 { return Ok(0) }` in `ShuffleBlockWriter.write_batch`), so this case does not arise. - -## Metrics - -`ShuffleScanExec` tracks and reports: - -- `decodeTime`: Time spent in `read_ipc_compressed()` (decompression + IPC decode). Same metric as `NativeBatchDecoderIterator` reports today. -- Shuffle read metrics (`recordsRead`, `bytesRead`) continue to be reported by `CometBlockStoreShuffleReader` and the `ShuffleBlockFetcherIterator`, which are upstream of the new code and unchanged. - -## Testing - -- Existing shuffle tests (`CometShuffleSuite`) run with the config defaulting to true, automatically covering the new path. -- Add a test that runs the same queries with the config flag on and off, asserting identical results. -- Add a Rust unit test for `ShuffleScanExec` with pre-built compressed IPC blocks (no JNI), using the `TEST_EXEC_CONTEXT_ID` pattern from `ScanExec` tests. From 33c2f11a00e69b41133edb8e1c7b9c495fd607ab Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 18:23:24 -0600 Subject: [PATCH 15/33] Remove doc --- .../plans/2026-03-18-shuffle-direct-read.md | 1011 ----------------- 1 file changed, 1011 deletions(-) delete mode 100644 docs/superpowers/plans/2026-03-18-shuffle-direct-read.md diff --git a/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md b/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md deleted file mode 100644 index 647f122cc4..0000000000 --- a/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md +++ /dev/null @@ -1,1011 +0,0 @@ -# Shuffle Direct Read Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Eliminate double Arrow FFI crossing at shuffle boundaries by having native code consume compressed IPC blocks directly from JVM-provided byte buffers. - -**Architecture:** A new `ShuffleScanExec` Rust operator pulls raw compressed bytes from a JVM `CometShuffleBlockIterator` via JNI, decompresses and decodes them in native code, and feeds `RecordBatch` directly into the execution plan. This bypasses the current path where data is decoded to JVM `ColumnarBatch` (FFI export), then re-exported back to native (FFI import). - -**Tech Stack:** Scala, Java, Rust, Protobuf, JNI, Arrow IPC - -**Spec:** `docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md` - ---- - -### Task 1: Add config flag - -**Files:** -- Modify: `common/src/main/scala/org/apache/comet/CometConf.scala` - -- [ ] **Step 1: Add the config entry** - -Find the existing shuffle config entries (search for `COMET_EXEC_SHUFFLE_ENABLED`) and add nearby: - -```scala -val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = - conf("spark.comet.shuffle.directRead.enabled") - .category(CATEGORY_EXEC) - .doc( - "When enabled, native operators that consume shuffle output will read " + - "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + - "Only applies to native shuffle (not JVM columnar shuffle). " + - "Requires spark.comet.exec.shuffle.enabled to be true.") - .booleanConf - .createWithDefault(true) -``` - -- [ ] **Step 2: Verify it compiles** - -Run: `./mvnw compile -DskipTests -pl common` -Expected: BUILD SUCCESS - -- [ ] **Step 3: Commit** - -```bash -git add common/src/main/scala/org/apache/comet/CometConf.scala -git commit -m "feat: add spark.comet.shuffle.directRead.enabled config" -``` - ---- - -### Task 2: Add ShuffleScan protobuf message - -**Files:** -- Modify: `native/proto/src/proto/operator.proto` - -- [ ] **Step 1: Add ShuffleScan message** - -Add after the existing `Scan` message (after line 86): - -```protobuf -message ShuffleScan { - repeated spark.spark_expression.DataType fields = 1; - // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") - string source = 2; -} -``` - -- [ ] **Step 2: Add shuffle_scan to the Operator oneof** - -In the `oneof op_struct` block (lines 38-55), add after `csv_scan = 115`: - -```protobuf - ShuffleScan shuffle_scan = 116; -``` - -- [ ] **Step 3: Rebuild protobuf and verify** - -Run: `make core` -Expected: Successful build with generated protobuf code. - -- [ ] **Step 4: Commit** - -```bash -git add native/proto/src/proto/operator.proto -git commit -m "feat: add ShuffleScan protobuf message" -``` - ---- - -### Task 3: Create CometShuffleBlockIterator (Java) - -**Files:** -- Create: `spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java` - -- [ ] **Step 1: Create the class** - -```java -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet; - -import java.io.Closeable; -import java.io.EOFException; -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; - -/** - * Provides raw compressed shuffle blocks to native code via JNI. - * - *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads - * the compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() - * and getBuffer(). - * - *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. - * Native code must fully consume it (via read_ipc_compressed which allocates new memory for - * the decompressed data) before pulling the next block. - */ -public class CometShuffleBlockIterator implements Closeable { - - private static final int INITIAL_BUFFER_SIZE = 128 * 1024; - - private final ReadableByteChannel channel; - private final InputStream inputStream; - private final ByteBuffer headerBuf = - ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); - private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - private boolean closed = false; - private int currentBlockLength = 0; - - public CometShuffleBlockIterator(InputStream in) { - this.inputStream = in; - this.channel = Channels.newChannel(in); - } - - /** - * Reads the next block header and loads the compressed body into the internal buffer. - * Called by native code via JNI. - * - *

Header format: 8-byte compressedLength (includes field count but not itself) + - * 8-byte fieldCount (discarded, schema comes from protobuf). - * - * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF - */ - public int hasNext() throws IOException { - if (closed) { - return -1; - } - - // Read 16-byte header - headerBuf.clear(); - while (headerBuf.hasRemaining()) { - int bytesRead = channel.read(headerBuf); - if (bytesRead < 0) { - if (headerBuf.position() == 0) { - return -1; - } - throw new EOFException( - "Data corrupt: unexpected EOF while reading batch header"); - } - } - headerBuf.flip(); - long compressedLength = headerBuf.getLong(); - // Field count discarded - schema determined by ShuffleScan protobuf fields - headerBuf.getLong(); - - long bytesToRead = compressedLength - 8; - if (bytesToRead > Integer.MAX_VALUE) { - throw new IllegalStateException( - "Native shuffle block size of " + bytesToRead + " exceeds maximum of " - + Integer.MAX_VALUE + ". Try reducing shuffle batch size."); - } - - if (dataBuf.capacity() < bytesToRead) { - int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); - dataBuf = ByteBuffer.allocateDirect(newCapacity); - } - - dataBuf.clear(); - dataBuf.limit((int) bytesToRead); - while (dataBuf.hasRemaining()) { - int bytesRead = channel.read(dataBuf); - if (bytesRead < 0) { - throw new EOFException( - "Data corrupt: unexpected EOF while reading compressed batch"); - } - } - // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, - // not the buffer's position/limit. No flip needed. - - currentBlockLength = (int) bytesToRead; - return currentBlockLength; - } - - /** - * Returns the DirectByteBuffer containing the current block's compressed bytes - * (4-byte codec prefix + compressed IPC data). - * Called by native code via JNI. - */ - public ByteBuffer getBuffer() { - return dataBuf; - } - - /** - * Returns the length of the current block in bytes. - * Called by native code via JNI. - */ - public int getCurrentBlockLength() { - return currentBlockLength; - } - - @Override - public void close() throws IOException { - if (!closed) { - closed = true; - inputStream.close(); - if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { - dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - } - } - } -} -``` - -- [ ] **Step 2: Verify it compiles** - -Run: `./mvnw compile -DskipTests` -Expected: BUILD SUCCESS - -- [ ] **Step 3: Commit** - -```bash -git add spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java -git commit -m "feat: add CometShuffleBlockIterator for raw shuffle block access" -``` - ---- - -### Task 4: Add JNI bridge for CometShuffleBlockIterator (Rust) - -**Files:** -- Create: `native/core/src/jvm_bridge/shuffle_block_iterator.rs` -- Modify: `native/core/src/jvm_bridge/mod.rs` - -- [ ] **Step 1: Create the JNI bridge struct** - -Create `native/core/src/jvm_bridge/shuffle_block_iterator.rs`: - -```rust -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use jni::signature::Primitive; -use jni::{ - errors::Result as JniResult, - objects::{JClass, JMethodID}, - signature::ReturnType, - JNIEnv, -}; - -/// JNI method IDs for `CometShuffleBlockIterator`. -#[allow(dead_code)] -pub struct CometShuffleBlockIterator<'a> { - pub class: JClass<'a>, - pub method_has_next: JMethodID, - pub method_has_next_ret: ReturnType, - pub method_get_buffer: JMethodID, - pub method_get_buffer_ret: ReturnType, - pub method_get_current_block_length: JMethodID, - pub method_get_current_block_length_ret: ReturnType, -} - -impl<'a> CometShuffleBlockIterator<'a> { - pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; - - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; - - Ok(CometShuffleBlockIterator { - class, - method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, - method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_get_buffer: env.get_method_id( - Self::JVM_CLASS, - "getBuffer", - "()Ljava/nio/ByteBuffer;", - )?, - method_get_buffer_ret: ReturnType::Object, - method_get_current_block_length: env.get_method_id( - Self::JVM_CLASS, - "getCurrentBlockLength", - "()I", - )?, - method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), - }) - } -} -``` - -- [ ] **Step 2: Register in mod.rs** - -In `native/core/src/jvm_bridge/mod.rs`: - -Add `mod shuffle_block_iterator;` alongside the existing `mod batch_iterator;` (line 174). - -Add `use shuffle_block_iterator::CometShuffleBlockIterator as CometShuffleBlockIteratorBridge;` (to avoid name collision with the operator). - -Add a field to the `JVMClasses` struct (around line 206): -```rust -pub comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge<'a>, -``` - -Initialize it in `JVMClasses::init` alongside the existing `comet_batch_iterator` init (around line 259): -```rust -comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge::new(env).unwrap(), -``` - -- [ ] **Step 3: Add a `jni_call!` compatible accessor** - -Check how `comet_batch_iterator` is called in `scan.rs`. The `jni_call!` macro uses the field name from `JVMClasses`. Ensure `comet_shuffle_block_iterator` follows the same pattern. You may need to add a module in the `jni_bridge` macros — look at how `jni_call!(&mut env, comet_batch_iterator(iter).has_next() -> i32)` is defined and add equivalent patterns for `comet_shuffle_block_iterator`. - -Check `native/core/src/jvm_bridge/` for macro definitions (likely in a separate file or in `mod.rs`) that define the `jni_call!` dispatch for each class. - -- [ ] **Step 4: Verify it compiles** - -Run: `cd native && cargo build` -Expected: Successful build. - -- [ ] **Step 5: Commit** - -```bash -git add native/core/src/jvm_bridge/shuffle_block_iterator.rs -git add native/core/src/jvm_bridge/mod.rs -git commit -m "feat: add JNI bridge for CometShuffleBlockIterator" -``` - ---- - -### Task 5: Create ShuffleScanExec (Rust) - -**Files:** -- Create: `native/core/src/execution/operators/shuffle_scan.rs` -- Modify: `native/core/src/execution/operators/mod.rs` - -**Design decision — pre-pull pattern:** `ShuffleScanExec` MUST use the pre-pull pattern (same as `ScanExec`). The comment at `jni_api.rs:483-488` explains why: JNI calls cannot happen from within `poll_next` on tokio threads. So `ShuffleScanExec` stores a `batch: Arc>>` and `get_next_batch()` is called from `pull_input_batches` before each `poll_next`. - -- [ ] **Step 1: Create shuffle_scan.rs** - -Use `scan.rs` as the template. The key differences: -- `get_next_batch` calls `hasNext()`/`getBuffer()`/`getCurrentBlockLength()` on `CometShuffleBlockIterator` instead of Arrow FFI methods on `CometBatchIterator` -- After getting the `DirectByteBuffer`, call `read_ipc_compressed()` to decode -- No `arrow_ffi_safe` flag, no selection vectors, no `copy_or_unpack_array` -- Track `decode_time` metric - -The core `get_next` method: - -```rust -fn get_next( - exec_context_id: i64, - iter: &JObject, - data_types: &[DataType], -) -> Result { - let mut env = JVMClasses::get_env()?; - - // Call hasNext() — returns block length or -1 for EOF - let block_length: i32 = unsafe { - jni_call!(&mut env, comet_shuffle_block_iterator(iter).has_next() -> i32)? - }; - - if block_length < 0 { - return Ok(InputBatch::EOF); - } - - // Get the DirectByteBuffer - let buffer: JByteBuffer = unsafe { - jni_call!(&mut env, comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? - }.into(); - - // Get raw pointer to the buffer data - let raw_pointer = env.get_direct_buffer_address(&buffer)?; - let length = block_length as usize; - let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - - // Decompress and decode the IPC block - let batch = read_ipc_compressed(slice)?; - - // Convert RecordBatch columns to InputBatch - let arrays: Vec = batch.columns().to_vec(); - let num_rows = batch.num_rows(); - - Ok(InputBatch::new(arrays, Some(num_rows))) -} -``` - -For the `ExecutionPlan` trait implementation, follow `ScanExec` closely: -- `schema()` returns schema built from `data_types` -- `execute()` returns a `ScanStream` (reuse the same stream type from `scan.rs`) -- The `ScanStream` checks `self.batch` mutex on each `poll_next`, takes the batch if available - -- [ ] **Step 2: Register the module** - -In `native/core/src/execution/operators/mod.rs`, add: - -```rust -mod shuffle_scan; -pub use shuffle_scan::ShuffleScanExec; -``` - -- [ ] **Step 3: Verify it compiles** - -Run: `cd native && cargo build` -Expected: Successful build. - -- [ ] **Step 4: Commit** - -```bash -git add native/core/src/execution/operators/shuffle_scan.rs -git add native/core/src/execution/operators/mod.rs -git commit -m "feat: add ShuffleScanExec native operator for direct shuffle read" -``` - ---- - -### Task 6: Wire ShuffleScanExec into the native planner and pre-pull - -**Files:** -- Modify: `native/core/src/execution/planner.rs` -- Modify: `native/core/src/execution/jni_api.rs` - -**Design decision — separate scan vectors:** The planner's `create_plan` currently returns `(Vec, Arc)`. Change the return type to include shuffle scans: `(Vec, Vec, Arc)`. All intermediate operators pass both vectors through. `ExecutionContext` gets a new `shuffle_scans: Vec` field, and `pull_input_batches` iterates both. - -- [ ] **Step 1: Update create_plan return type** - -In `planner.rs`, change the `create_plan` return type (line 915): - -```rust -) -> Result<(Vec, Vec, Arc), ExecutionError> -``` - -Update every match arm that calls `create_plan` recursively or returns results: -- Single-child operators (Filter, Project, Sort, etc.): destructure as `let (scans, shuffle_scans, child) = ...` and pass both through -- Multi-child operators (joins via `parse_join_parameters`): concatenate both scan vectors from left and right children -- `Scan` arm: returns `(vec![scan.clone()], vec![], ...)` -- Add `ShuffleScan` arm (see step 2) - -This is a mechanical change across many match arms. Each `Ok((scans, ...))` becomes `Ok((scans, shuffle_scans, ...))`. - -Also update `parse_join_parameters` return type similarly. - -- [ ] **Step 2: Add ShuffleScan match arm** - -```rust -OpStruct::ShuffleScan(scan) => { - let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); - - if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - return Err(GeneralError("No input for shuffle scan".to_string())); - } - - let input_source = - if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - None - } else { - Some(inputs.remove(0)) - }; - - let shuffle_scan = ShuffleScanExec::new( - self.exec_context_id, - input_source, - &scan.source, - data_types, - )?; - - Ok(( - vec![], - vec![shuffle_scan.clone()], - Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(shuffle_scan), vec![])), - )) -} -``` - -- [ ] **Step 3: Update ExecutionContext and pull_input_batches** - -In `jni_api.rs`: - -Add `shuffle_scans: Vec` field to `ExecutionContext` struct (after `scans` on line 153). Initialize as `shuffle_scans: vec![]` in the constructor (line 313). - -Where `create_plan` results are stored (line 542-550): - -```rust -let (scans, shuffle_scans, root_op) = planner.create_plan(...)?; -exec_context.scans = scans; -exec_context.shuffle_scans = shuffle_scans; -``` - -Update `pull_input_batches` (line 490): - -```rust -fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometError> { - exec_context.scans.iter_mut().try_for_each(|scan| { - scan.get_next_batch()?; - Ok::<(), CometError>(()) - })?; - exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { - scan.get_next_batch()?; - Ok::<(), CometError>(()) - }) -} -``` - -Also update the `exec_context.scans.is_empty()` check (line 563) to also check `shuffle_scans`: - -```rust -if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { -``` - -- [ ] **Step 4: Verify it compiles** - -Run: `cd native && cargo build` -Expected: Successful build. - -- [ ] **Step 5: Commit** - -```bash -git add native/core/src/execution/planner.rs -git add native/core/src/execution/jni_api.rs -git commit -m "feat: wire ShuffleScanExec into planner and pre-pull mechanism" -``` - ---- - -### Task 7: Emit ShuffleScan from JVM serde - -**Files:** -- Modify: `spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala` - -The `CometExchangeSink.convert()` receives the outer operator (e.g., `ShuffleQueryStageExec`) not the inner `CometShuffleExchangeExec`. We must unwrap to check `shuffleType`. - -- [ ] **Step 1: Override convert in CometExchangeSink** - -Replace the `CometExchangeSink` object (lines 87-100) with: - -```scala -object CometExchangeSink extends CometSink[SparkPlan] { - - override def isFfiSafe: Boolean = true - - override def convert( - op: SparkPlan, - builder: Operator.Builder, - childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { - if (shouldUseShuffleScan(op)) { - convertToShuffleScan(op, builder) - } else { - super.convert(op, builder, childOp: _*) - } - } - - private def shouldUseShuffleScan(op: SparkPlan): Boolean = { - if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false - - // Extract the CometShuffleExchangeExec from the wrapper - val shuffleExec = op match { - case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) - case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => - Some(s) - case s: CometShuffleExchangeExec => Some(s) - case _ => None - } - - shuffleExec.exists(_.shuffleType == CometNativeShuffle) - } - - private def convertToShuffleScan( - op: SparkPlan, - builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { - val supportedTypes = - op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) - - if (!supportedTypes) { - withInfo(op, "Unsupported data type for shuffle direct read") - return None - } - - val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() - val source = op.simpleStringWithNodeId() - if (source.isEmpty) { - scanBuilder.setSource(op.getClass.getSimpleName) - } else { - scanBuilder.setSource(source) - } - - val scanTypes = op.output.flatMap { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == op.output.length) { - scanBuilder.addAllFields(scanTypes.asJava) - builder.clearChildren() - Some(builder.setShuffleScan(scanBuilder).build()) - } else { - withInfo(op, "unsupported data types for shuffle direct read") - // Fall back to regular Scan - None - } - } - - override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = - CometSinkPlaceHolder(nativeOp, op, op) -} -``` - -Add necessary imports at the top of the file: -```scala -import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec -import org.apache.comet.CometConf -``` - -- [ ] **Step 2: Verify it compiles** - -Run: `./mvnw compile -DskipTests` -Expected: BUILD SUCCESS - -- [ ] **Step 3: Commit** - -```bash -git add spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala -git commit -m "feat: emit ShuffleScan protobuf for native shuffle with direct read" -``` - ---- - -### Task 8: Wire CometShuffleBlockIterator into JVM execution path - -**Files:** -- Modify: `spark/src/main/scala/org/apache/comet/Native.scala` -- Modify: `spark/src/main/scala/org/apache/comet/CometExecIterator.scala` -- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala` -- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala` -- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala` - -This task connects the JVM plumbing so that `ShuffleScan` inputs get `CometShuffleBlockIterator` (wrapping raw `InputStream`) instead of `CometBatchIterator` (wrapping decoded `ColumnarBatch`). - -**Key insight**: Currently all inputs flow through `RDD[ColumnarBatch]`. For shuffle direct read, we need the raw `InputStream` before decoding. The approach: add a parallel input channel for raw shuffle streams alongside the existing `ColumnarBatch` inputs. - -- [ ] **Step 1: Change Native.scala createPlan signature** - -In `spark/src/main/scala/org/apache/comet/Native.scala` (line 57), change: - -```scala -iterators: Array[CometBatchIterator], -``` -to: -```scala -iterators: Array[Object], -``` - -The JNI side (`jni_api.rs:190`) already uses `JObjectArray`, so no Rust changes needed. - -- [ ] **Step 2: Add shuffle stream inputs to CometExecIterator** - -In `spark/src/main/scala/org/apache/comet/CometExecIterator.scala`, add a parameter for shuffle block iterators that should be used instead of regular batch iterators at specific input positions: - -```scala -class CometExecIterator( - val id: Long, - inputs: Seq[Iterator[ColumnarBatch]], - numOutputCols: Int, - protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode, - numParts: Int, - partitionIndex: Int, - broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty, - shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) -``` - -Replace the `cometBatchIterators` construction (lines 81-83): - -```scala -private val nativeIterators: Array[Object] = { - val result = new Array[Object](inputs.size) - inputs.zipWithIndex.foreach { case (iterator, idx) => - result(idx) = shuffleBlockIterators.getOrElse( - idx, - new CometBatchIterator(iterator, nativeUtil)) - } - result -} -``` - -Change `nativeLib.createPlan(id, cometBatchIterators, ...)` (line 109) to use `nativeIterators`. - -In the `close()` method, also close `CometShuffleBlockIterator` instances: -```scala -shuffleBlockIterators.values.foreach { iter => - try { iter.close() } catch { case _: Exception => } -} -``` - -- [ ] **Step 3: Add shuffle stream support to CometExecRDD** - -In `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala`, add a parameter to carry shuffle block iterator factories: - -```scala -private[spark] class CometExecRDD( - sc: SparkContext, - var inputRDDs: Seq[RDD[ColumnarBatch]], - ... - encryptedFilePaths: Seq[String] = Seq.empty, - shuffleBlockIteratorFactories: Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) -``` - -In the `compute` method (line 112), pass them to `CometExecIterator`: - -```scala -// Create shuffle block iterators for this partition -val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => - idx -> factory(context, partition.inputPartitions(idx)) -} - -val it = new CometExecIterator( - CometExec.newIterId, - inputs, - numOutputCols, - actualPlan, - nativeMetrics, - numPartitions, - partition.index, - broadcastedHadoopConfForEncryption, - encryptedFilePaths, - shuffleBlockIters) -``` - -- [ ] **Step 4: Identify ShuffleScan inputs in operators.scala** - -In `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala`, in `CometNativeExec.doExecuteColumnar` (around line 480): - -After `foreachUntilCometInput(this)(sparkPlans += _)`, determine which inputs correspond to `ShuffleScan` operators. Parse the serialized protobuf plan to find `ShuffleScan` leaf positions: - -```scala -import org.apache.comet.serde.OperatorOuterClass - -// Find which input indices correspond to ShuffleScan operators -val shuffleScanIndices: Set[Int] = { - val plan = OperatorOuterClass.Operator.parseFrom(serializedPlanCopy) - var scanIndex = 0 - val indices = scala.collection.mutable.Set.empty[Int] - def walk(op: OperatorOuterClass.Operator): Unit = { - if (op.hasShuffleScan) { - indices += scanIndex - scanIndex += 1 - } else if (op.hasScan) { - scanIndex += 1 - } else { - // Recurse into children in order - (0 until op.getChildrenCount).foreach(i => walk(op.getChildren(i))) - } - } - walk(plan) - indices.toSet -} -``` - -Then in the `sparkPlans.zipWithIndex.foreach` loop (line 523), for plans at shuffle scan indices, create a factory that produces `CometShuffleBlockIterator`: - -```scala -val shuffleBlockIteratorFactories = scala.collection.mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] - -sparkPlans.zipWithIndex.foreach { case (plan, idx) => - plan match { - // ... existing cases ... - case _ if shuffleScanIndices.contains(inputIndexForPlan(idx)) => - // Still add the RDD for partition tracking, but also register - // a factory for the raw InputStream - val rdd = plan.executeColumnar() - inputs += rdd - // The factory creates a CometShuffleBlockIterator from the raw shuffle stream - // We need to get the raw InputStream - see Step 5 - shuffleBlockIteratorFactories(inputs.size - 1) = ... - // ... remaining cases ... - } -} -``` - -The tricky part is getting the raw `InputStream` from the shuffle read. See Step 5. - -- [ ] **Step 5: Add raw InputStream mode to CometBlockStoreShuffleReader** - -In `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala`: - -The current `read()` method creates `NativeBatchDecoderIterator` which decodes blocks. For direct read, we need a mode that yields the raw `InputStream` wrapped in `CometShuffleBlockIterator`. - -Add a method: - -```scala -def readRawStreams(): Iterator[CometShuffleBlockIterator] = { - fetchIterator.map { case (_, inputStream) => - new CometShuffleBlockIterator(inputStream) - } -} -``` - -The challenge is that `CometShuffledBatchRDD` calls `reader.read()` which returns `Iterator[Product2[Int, ColumnarBatch]]`. For the direct read path, we need a different RDD that calls `readRawStreams()` instead. - -**Approach**: Create `CometShuffledRawStreamRDD` — a simple RDD that wraps the shuffle reader and yields `CometShuffleBlockIterator` objects per block. Then in `operators.scala`, instead of using the ColumnarBatch RDD, create a `CometShuffledRawStreamRDD` and pass its iterator-producing factory to `CometExecRDD`. - -Alternatively, since `CometShuffleBlockIterator` wraps a single `InputStream` that may contain multiple blocks, and `fetchIterator` yields one `InputStream` per shuffle block, the simplest approach is to **concatenate all InputStreams into one** per partition: - -```scala -def readAsRawStream(): InputStream = { - val streams = fetchIterator.map(_._2) - new SequenceInputStream(java.util.Collections.enumeration( - streams.toList.asJava)) -} -``` - -Then in the factory: `(ctx, part) => new CometShuffleBlockIterator(reader.readAsRawStream())` - -But the reader is created per-partition in `CometShuffledBatchRDD.compute()`. The factory approach means the reader creation must be deferred. - -**Simplest concrete approach**: Instead of a factory, create a new RDD `CometShuffledRawRDD` that returns `Iterator[CometShuffleBlockIterator]`. Pass this as a separate input alongside the regular `ColumnarBatch` inputs: - -```scala -// In CometExecRDD, add: -shuffleRawInputRDDs: Seq[(Int, RDD[CometShuffleBlockIterator])] -``` - -In `compute`, create iterators from these RDDs and pass them to `CometExecIterator` via the `shuffleBlockIterators` map. - -This is the most invasive part of the implementation. The exact approach should be determined by reading the code at implementation time, as there are multiple valid paths. The key constraint: the raw `InputStream` from `fetchIterator` must reach `CometShuffleBlockIterator` without going through `NativeBatchDecoderIterator`. - -- [ ] **Step 6: Verify it compiles** - -Run: `./mvnw compile -DskipTests` -Expected: BUILD SUCCESS - -- [ ] **Step 7: Commit** - -```bash -git add spark/src/main/scala/org/apache/comet/Native.scala -git add spark/src/main/scala/org/apache/comet/CometExecIterator.scala -git add spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala -git add spark/src/main/scala/org/apache/spark/sql/comet/operators.scala -git add spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala -git commit -m "feat: wire CometShuffleBlockIterator into JVM execution path" -``` - ---- - -### Task 9: End-to-end testing - -**Files:** -- Modify: Appropriate test suite (find the right suite by searching for existing shuffle tests) - -- [ ] **Step 1: Build everything** - -Run: `make` -Expected: Successful build of both native and JVM. - -- [ ] **Step 2: Run existing shuffle tests** - -Run: `./mvnw test -Dsuites="org.apache.comet.exec.CometShuffleSuite"` -Expected: All existing tests pass (they now use the new direct read path by default). - -If tests fail, debug by setting `spark.comet.shuffle.directRead.enabled=false` to confirm the old path still works, then investigate the new path. - -- [ ] **Step 3: Add comparison test** - -Add a test that runs the same queries with direct read enabled and disabled: - -```scala -test("shuffle direct read produces same results as FFI path") { - Seq(true, false).foreach { directRead => - withSQLConf( - CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { - val df = spark.range(1000) - .selectExpr("id", "id % 10 as key", "cast(id as string) as value") - .repartition(4, col("key")) - .groupBy("key") - .agg(sum("id").as("total"), count("value").as("cnt")) - .orderBy("key") - checkSparkAnswer(df) - } - } -} -``` - -- [ ] **Step 4: Add Rust unit test for ShuffleScanExec** - -In `native/core/src/execution/operators/shuffle_scan.rs`, add a `#[cfg(test)]` module: - -```rust -#[cfg(test)] -mod tests { - use super::*; - use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; - use arrow::array::{Int32Array, StringArray}; - use arrow::datatypes::{Field, Schema}; - use arrow::record_batch::RecordBatch; - use std::io::Cursor; - use std::sync::Arc; - - #[test] - fn test_read_compressed_ipc_block() { - // Create a test RecordBatch - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec!["a", "b", "c"])), - ], - ).unwrap(); - - // Write it as compressed IPC using ShuffleBlockWriter - let writer = ShuffleBlockWriter::try_new( - &batch.schema(), CompressionCodec::Zstd(1) - ).unwrap(); - let mut buf = Cursor::new(Vec::new()); - let ipc_time = datafusion::physical_plan::metrics::Time::new(); - writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); - - // Read back the body (skip the 16-byte header) - let bytes = buf.into_inner(); - let body = &bytes[16..]; // Skip compressed_length(8) + field_count(8) - - // Decode using read_ipc_compressed - let decoded = read_ipc_compressed(body).unwrap(); - assert_eq!(decoded.num_rows(), 3); - assert_eq!(decoded.num_columns(), 2); - } -} -``` - -- [ ] **Step 5: Run all tests** - -Run: `make test` - -- [ ] **Step 6: Run clippy** - -Run: `cd native && cargo clippy --all-targets --workspace -- -D warnings` -Expected: No warnings. - -- [ ] **Step 7: Format** - -Run: `make format` - -- [ ] **Step 8: Commit** - -```bash -git add -A -git commit -m "test: add shuffle direct read tests" -``` - ---- - -## Implementation Notes - -### Task 8 is the hardest - -The core challenge is routing raw `InputStream` from Spark's shuffle infrastructure through to `CometShuffleBlockIterator` without going through the decode path. The current RDD pipeline (`CometShuffledBatchRDD` → `CometBlockStoreShuffleReader.read()` → `NativeBatchDecoderIterator`) always decodes. You need to intercept before `NativeBatchDecoderIterator` is created. - -The most surgical approach: in `CometBlockStoreShuffleReader`, add a `readRaw()` method that returns the raw `InputStream` (or a `CometShuffleBlockIterator` wrapping it) instead of decoded batches. Then create a parallel RDD (`CometShuffledRawRDD`) that calls `readRaw()` in its `compute` method and pass it through to `CometExecIterator`. - -### Metrics - -`ShuffleScanExec` should track `decode_time` using DataFusion's `Time` metric. Register it in `ShuffleScanExec::new` via `MetricBuilder` following the pattern in `ScanExec`. - -### Order of tasks - -Tasks 1-7 can be done sequentially. Task 8 depends on all previous tasks. Task 9 validates everything. From 19cb04b0d34af995046942005111873e757b7518 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 19:11:43 -0600 Subject: [PATCH 16/33] test: skip miri-incompatible zstd FFI test Skip test_read_compressed_ipc_block under Miri since it calls foreign zstd functions that Miri cannot execute. --- native/core/src/execution/operators/shuffle_scan.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 80c72a6d31..163fc9992a 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -354,6 +354,7 @@ mod tests { use crate::execution::shuffle::codec::read_ipc_compressed; #[test] + #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) fn test_read_compressed_ipc_block() { let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), From b67cd4058c0a4d465026a92e49ea6b597cd6208b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 09:40:56 -0600 Subject: [PATCH 17/33] feat: replace Arrow IPC with raw buffer format in shuffle write/read --- .../src/execution/operators/shuffle_scan.rs | 8 +- native/core/src/execution/shuffle/codec.rs | 361 ++++++++++++++++-- native/core/src/execution/shuffle/mod.rs | 2 +- 3 files changed, 338 insertions(+), 33 deletions(-) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 163fc9992a..b638567166 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -351,11 +351,11 @@ mod tests { use std::io::Cursor; use std::sync::Arc; - use crate::execution::shuffle::codec::read_ipc_compressed; + use crate::execution::shuffle::codec::read_shuffle_block; #[test] #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) - fn test_read_compressed_ipc_block() { + fn test_read_compressed_block() { let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, true), @@ -369,7 +369,7 @@ mod tests { ) .unwrap(); - // Write as compressed IPC + // Write as compressed raw batch let writer = ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); let mut buf = Cursor::new(Vec::new()); @@ -380,7 +380,7 @@ mod tests { let bytes = buf.into_inner(); let body = &bytes[16..]; - let decoded = read_ipc_compressed(body).unwrap(); + let decoded = read_shuffle_block(body, &schema).unwrap(); assert_eq!(decoded.num_rows(), 3); assert_eq!(decoded.num_columns(), 2); diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 33e6989d4c..003f82db23 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -16,17 +16,18 @@ // under the License. use crate::errors::{CometError, CometResult}; -use arrow::array::RecordBatch; -use arrow::datatypes::Schema; -use arrow::ipc::reader::StreamReader; -use arrow::ipc::writer::StreamWriter; +use arrow::array::{make_array, Array, ArrayRef, RecordBatch}; +use arrow::buffer::Buffer; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Schema}; use bytes::Buf; use crc32fast::Hasher; use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::metrics::Time; use simd_adler32::Adler32; -use std::io::{Cursor, Seek, SeekFrom, Write}; +use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::sync::Arc; #[derive(Debug, Clone)] pub enum CompressionCodec { @@ -42,6 +43,73 @@ pub struct ShuffleBlockWriter { header_bytes: Vec, } +/// Recursively writes raw Arrow ArrayData buffers to the given writer. +fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> Result<()> { + debug_assert_eq!(data.offset(), 0, "shuffle arrays must have offset 0"); + + // Write null_count + let null_count = data.null_count() as u32; + writer.write_all(&null_count.to_le_bytes())?; + + // Write validity bitmap + if null_count > 0 { + if let Some(bitmap) = data.nulls() { + let bitmap_bytes = bitmap.buffer().as_slice(); + let len = bitmap_bytes.len() as u32; + writer.write_all(&len.to_le_bytes())?; + writer.write_all(bitmap_bytes)?; + } else { + writer.write_all(&0u32.to_le_bytes())?; + } + } else { + writer.write_all(&0u32.to_le_bytes())?; + } + + // Write buffers + let num_buffers = data.buffers().len() as u32; + writer.write_all(&num_buffers.to_le_bytes())?; + for buffer in data.buffers() { + let len: u32 = buffer.len().try_into().map_err(|_| { + DataFusionError::Execution(format!( + "Buffer length {} exceeds u32::MAX", + buffer.len() + )) + })?; + writer.write_all(&len.to_le_bytes())?; + writer.write_all(buffer.as_slice())?; + } + + // Write children + let num_children = data.child_data().len() as u32; + writer.write_all(&num_children.to_le_bytes())?; + for child in data.child_data() { + let child_num_rows = child.len() as u32; + writer.write_all(&child_num_rows.to_le_bytes())?; + write_array_data(child, writer)?; + } + + Ok(()) +} + +/// Writes a RecordBatch in raw buffer format. Dictionary arrays are cast to their value type. +fn write_raw_batch(batch: &RecordBatch, writer: &mut W) -> Result<()> { + let num_rows = batch.num_rows() as u32; + writer.write_all(&num_rows.to_le_bytes())?; + + for col in batch.columns() { + // Cast dictionary arrays to their value type + let col = match col.data_type() { + DataType::Dictionary(_, value_type) => { + cast(col.as_ref(), value_type.as_ref())? + } + _ => Arc::clone(col), + }; + write_array_data(&col.to_data(), writer)?; + } + + Ok(()) +} + impl ShuffleBlockWriter { pub fn try_new(schema: &Schema, codec: CompressionCodec) -> Result { let header_bytes = Vec::with_capacity(20); @@ -71,7 +139,7 @@ impl ShuffleBlockWriter { }) } - /// Writes given record batch as Arrow IPC bytes into given writer. + /// Writes given record batch in raw buffer format into given writer. /// Returns number of bytes written. pub fn write_batch( &self, @@ -91,55 +159,46 @@ impl ShuffleBlockWriter { let output = match &self.codec { CompressionCodec::None => { - let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - arrow_writer.into_inner()? + write_raw_batch(batch, output)?; + output } CompressionCodec::Lz4Frame => { let mut wtr = lz4_flex::frame::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; + write_raw_batch(batch, &mut wtr)?; wtr.finish().map_err(|e| { DataFusionError::Execution(format!("lz4 compression error: {e}")) })? } CompressionCodec::Zstd(level) => { - let encoder = zstd::Encoder::new(output, *level)?; - let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - let zstd_encoder = arrow_writer.into_inner()?; - zstd_encoder.finish()? + let mut encoder = zstd::Encoder::new(output, *level)?; + write_raw_batch(batch, &mut encoder)?; + encoder.finish()? } CompressionCodec::Snappy => { let mut wtr = snap::write::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; + write_raw_batch(batch, &mut wtr)?; wtr.into_inner().map_err(|e| { DataFusionError::Execution(format!("snappy compression error: {e}")) })? } }; - // fill ipc length + // fill block length let end_pos = output.stream_position()?; - let ipc_length = end_pos - start_pos - 8; + let block_length = end_pos - start_pos - 8; let max_size = i32::MAX as u64; - if ipc_length > max_size { + if block_length > max_size { return Err(DataFusionError::Execution(format!( - "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ + "Shuffle block size {block_length} exceeds maximum size of {max_size}. \ Try reducing batch size or increasing compression level" ))); } - // fill ipc length + // fill block length output.seek(SeekFrom::Start(start_pos))?; - output.write_all(&ipc_length.to_le_bytes())?; + output.write_all(&block_length.to_le_bytes())?; output.seek(SeekFrom::Start(end_pos))?; timer.stop(); @@ -148,7 +207,135 @@ impl ShuffleBlockWriter { } } +// --------------------------------------------------------------------------- +// Read-side helpers +// --------------------------------------------------------------------------- + +fn read_u32(cursor: &mut &[u8]) -> Result { + if cursor.len() < 4 { + return Err(DataFusionError::Execution( + "unexpected end of shuffle block data".to_string(), + )); + } + let (bytes, rest) = cursor.split_at(4); + *cursor = rest; + Ok(u32::from_le_bytes(bytes.try_into().unwrap())) +} + +fn read_bytes<'a>(cursor: &mut &'a [u8], len: usize) -> Result<&'a [u8]> { + if cursor.len() < len { + return Err(DataFusionError::Execution( + "unexpected end of shuffle block data".to_string(), + )); + } + let (bytes, rest) = cursor.split_at(len); + *cursor = rest; + Ok(bytes) +} + +/// Returns child data types for nested Arrow types. +fn get_child_types(data_type: &DataType) -> Vec { + match data_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => { + vec![field.data_type().clone()] + } + DataType::Map(field, _) => { + // Map's single child is a struct with key/value fields + vec![field.data_type().clone()] + } + DataType::Struct(fields) => fields.iter().map(|f| f.data_type().clone()).collect(), + _ => vec![], + } +} + +/// Reconstructs ArrayData from raw buffer format (reverse of write_array_data). +fn read_array_data( + cursor: &mut &[u8], + data_type: &DataType, + num_rows: usize, +) -> Result { + let null_count = read_u32(cursor)? as usize; + + // Read validity bitmap + let bitmap_len = read_u32(cursor)? as usize; + let null_buffer = if bitmap_len > 0 { + let bytes = read_bytes(cursor, bitmap_len)?; + Some(Buffer::from(bytes)) + } else { + None + }; + + // Read buffers + let num_buffers = read_u32(cursor)? as usize; + let mut buffers = Vec::with_capacity(num_buffers); + for _ in 0..num_buffers { + let buf_len = read_u32(cursor)? as usize; + let bytes = read_bytes(cursor, buf_len)?; + buffers.push(Buffer::from(bytes)); + } + + // Read children + let num_children = read_u32(cursor)? as usize; + let child_types = get_child_types(data_type); + let mut child_data = Vec::with_capacity(num_children); + for i in 0..num_children { + let child_num_rows = read_u32(cursor)? as usize; + let child_type = child_types.get(i).ok_or_else(|| { + DataFusionError::Execution(format!( + "unexpected child index {i} for data type {data_type:?}" + )) + })?; + child_data.push(read_array_data(cursor, child_type, child_num_rows)?); + } + + // Build ArrayData without validation (data came from our own writer) + let mut builder = arrow::array::ArrayData::builder(data_type.clone()) + .len(num_rows) + .null_count(null_count); + + if let Some(nb) = null_buffer { + builder = builder.null_bit_buffer(Some(nb)); + } + + for buf in buffers { + builder = builder.add_buffer(buf); + } + + for child in child_data { + builder = builder.add_child_data(child); + } + + // SAFETY: data was written by write_array_data from valid Arrow arrays + Ok(unsafe { builder.build_unchecked() }) +} + +/// Read a raw batch from decompressed bytes, given the expected schema. +fn read_raw_batch(bytes: &[u8], schema: &Schema) -> Result { + let mut cursor = bytes; + + let num_rows = read_u32(&mut cursor)? as usize; + + let mut columns: Vec = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + let array_data = read_array_data(&mut cursor, field.data_type(), num_rows)?; + columns.push(make_array(array_data)); + } + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + Ok(batch) +} + +/// Reads and decompresses a shuffle block. The `bytes` slice starts at the codec tag +/// (after the 8-byte length and 8-byte field_count header that the JVM reads). +/// +/// This is kept temporarily for backward compatibility while callers are migrated. pub fn read_ipc_compressed(bytes: &[u8]) -> Result { + // NOTE: This function cannot work with the new raw format because it needs a schema. + // It is kept as a stub that delegates to the old IPC path for backward compatibility. + // Callers should migrate to read_shuffle_block(). + use arrow::ipc::reader::StreamReader; match &bytes[0..4] { b"SNAP" => { let decoder = snap::read::FrameDecoder::new(&bytes[4..]); @@ -179,6 +366,40 @@ pub fn read_ipc_compressed(bytes: &[u8]) -> Result { } } +/// Reads and decompresses a shuffle block written in raw buffer format. +/// The `bytes` slice starts at the codec tag (after the 8-byte length and +/// 8-byte field_count header that the JVM reads). +pub fn read_shuffle_block(bytes: &[u8], schema: &Schema) -> Result { + match &bytes[0..4] { + b"SNAP" => { + let decoder = snap::read::FrameDecoder::new(&bytes[4..]); + let decompressed = read_all(decoder)?; + read_raw_batch(&decompressed, schema) + } + b"LZ4_" => { + let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); + let decompressed = read_all(decoder)?; + read_raw_batch(&decompressed, schema) + } + b"ZSTD" => { + let decoder = zstd::Decoder::new(&bytes[4..])?; + let decompressed = read_all(decoder)?; + read_raw_batch(&decompressed, schema) + } + b"NONE" => read_raw_batch(&bytes[4..], schema), + other => Err(DataFusionError::Execution(format!( + "Failed to decode batch: invalid compression codec: {other:?}" + ))), + } +} + +/// Read all bytes from a reader into a Vec. +fn read_all(mut reader: R) -> Result> { + let mut buf = Vec::new(); + reader.read_to_end(&mut buf)?; + Ok(buf) +} + /// Checksum algorithms for writing IPC bytes. #[derive(Clone)] pub(crate) enum Checksum { @@ -237,3 +458,87 @@ impl Checksum { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, Int32Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + fn make_test_batch() -> (Arc, RecordBatch) { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])), + ], + ) + .unwrap(); + (schema, batch) + } + + fn roundtrip(codec: CompressionCodec) { + let (schema, batch) = make_test_batch(); + let writer = ShuffleBlockWriter::try_new(&schema, codec).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + let bytes = buf.into_inner(); + // Skip 16-byte header: 8 compressed_length + 8 field_count + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &schema).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify Int32 column (nullable) + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert!(col0.is_null(1)); + assert_eq!(col0.value(2), 3); + + // Verify Float64 column (non-nullable) + let col1 = decoded + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col1.value(0), 1.0); + assert_eq!(col1.value(1), 2.0); + assert_eq!(col1.value(2), 3.0); + } + + #[test] + fn test_raw_roundtrip_primitives_none() { + roundtrip(CompressionCodec::None); + } + + #[test] + fn test_raw_roundtrip_primitives_lz4() { + roundtrip(CompressionCodec::Lz4Frame); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_raw_roundtrip_primitives_zstd() { + roundtrip(CompressionCodec::Zstd(1)); + } + + #[test] + fn test_raw_roundtrip_primitives_snappy() { + roundtrip(CompressionCodec::Snappy); + } +} diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index 6018cff50f..dd3c9bf849 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -23,6 +23,6 @@ mod shuffle_writer; pub mod spark_unsafe; mod writers; -pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter}; +pub use codec::{read_ipc_compressed, read_shuffle_block, CompressionCodec, ShuffleBlockWriter}; pub use comet_partitioning::CometPartitioning; pub use shuffle_writer::ShuffleWriterExec; From 0ccbe153115e45e8c919e1568c363523d933d552 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 09:44:55 -0600 Subject: [PATCH 18/33] feat: replace Arrow IPC with raw buffer format in shuffle write/read Replace Arrow IPC StreamWriter/StreamReader with a lightweight raw buffer format that writes Arrow ArrayData buffers directly. The new format has minimal per-block overhead (~16 bytes per column vs ~200-800 bytes for IPC schema flatbuffers). The outer block header (compressed_length + field_count) is unchanged for JVM compatibility. Key changes: - write_array_data: recursively serializes ArrayData (validity + buffers + children) - read_array_data: reconstructs ArrayData from raw buffers using known schema - Dictionary arrays are cast to value type before writing - read_shuffle_block replaces read_ipc_compressed (takes schema parameter) - read_ipc_compressed retained temporarily for callers not yet migrated --- native/core/src/execution/shuffle/codec.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 003f82db23..2157342a5c 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -19,7 +19,8 @@ use crate::errors::{CometError, CometResult}; use arrow::array::{make_array, Array, ArrayRef, RecordBatch}; use arrow::buffer::Buffer; use arrow::compute::cast; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::DataType; +use arrow::datatypes::Schema; use bytes::Buf; use crc32fast::Hasher; use datafusion::common::DataFusionError; From e62943c56ea2877eb7f1387ba48434893e5b53bd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 09:46:47 -0600 Subject: [PATCH 19/33] test: add roundtrip tests for raw shuffle format covering all data types --- native/core/src/execution/shuffle/codec.rs | 289 ++++++++++++++++++++- 1 file changed, 287 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 2157342a5c..fcfdf3ad51 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -463,8 +463,8 @@ impl Checksum { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, Int32Array}; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::array::*; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion::physical_plan::metrics::Time; use std::io::Cursor; @@ -542,4 +542,289 @@ mod tests { fn test_raw_roundtrip_primitives_snappy() { roundtrip(CompressionCodec::Snappy); } + + /// Generic roundtrip helper: writes a batch with ShuffleBlockWriter, + /// reads it back with read_shuffle_block, and asserts equality for all + /// four compression codecs. + fn roundtrip_test(schema: Arc, batch: &RecordBatch) { + let codecs = vec![ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + ]; + for codec in codecs { + let writer = ShuffleBlockWriter::try_new(&schema, codec.clone()).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(batch, &mut buf, &ipc_time).unwrap(); + + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &schema).unwrap(); + assert_eq!(decoded.num_rows(), batch.num_rows()); + assert_eq!(decoded.num_columns(), batch.num_columns()); + for i in 0..batch.num_columns() { + assert_eq!( + batch.column(i).as_ref(), + decoded.column(i).as_ref(), + "column {i} mismatch with codec {:?}", + codec + ); + } + } + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_string_and_binary() { + let schema = Arc::new(Schema::new(vec![ + Field::new("s", DataType::Utf8, true), + Field::new("b", DataType::Binary, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![ + Some("hello"), + None, + Some("world"), + Some(""), + ])), + Arc::new(BinaryArray::from(vec![ + Some(b"abc" as &[u8]), + Some(b"\x00\x01\x02"), + None, + Some(b""), + ])), + ], + ) + .unwrap(); + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_boolean_and_null() { + let schema = Arc::new(Schema::new(vec![ + Field::new("bool", DataType::Boolean, true), + Field::new("n", DataType::Null, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(vec![ + Some(true), + None, + Some(false), + Some(true), + ])), + Arc::new(NullArray::new(4)), + ], + ) + .unwrap(); + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_decimal_date_timestamp() { + let schema = Arc::new(Schema::new(vec![ + Field::new("dec", DataType::Decimal128(18, 3), true), + Field::new("date", DataType::Date32, true), + Field::new( + "ts", + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), + true, + ), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new( + Decimal128Array::from(vec![Some(12345_i128), None, Some(-99999)]) + .with_precision_and_scale(18, 3) + .unwrap(), + ), + Arc::new(Date32Array::from(vec![Some(18000), None, Some(19000)])), + Arc::new(TimestampMicrosecondArray::from(vec![ + Some(1_000_000), + None, + Some(2_000_000), + ])), + ], + ) + .unwrap(); + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_nested_types() { + // List with nulls at the list level + let list_field = Field::new_list("l", Field::new("item", DataType::Int32, true), true); + + // Struct with nulls + let struct_field = Field::new( + "st", + DataType::Struct( + vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Utf8, true), + ] + .into(), + ), + true, + ); + + // Map + let map_field = Field::new( + "m", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Int32, true), + ] + .into(), + ), + false, + )), + false, + ), + true, + ); + + let schema = Arc::new(Schema::new(vec![list_field, struct_field, map_field])); + + // Build List + let list_arr = { + let mut builder = ListBuilder::new(Int32Builder::new()); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.append(false); // null list + builder.values().append_value(3); + builder.append(true); + builder.finish() + }; + + // Build Struct + let struct_arr = StructArray::from(vec![ + ( + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![Some(10), None, Some(30)])) as ArrayRef, + ), + ( + Arc::new(Field::new("y", DataType::Utf8, true)), + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])) as ArrayRef, + ), + ]); + // Apply null at row 1 + let struct_arr = StructArray::try_new( + struct_arr.fields().clone(), + struct_arr.columns().to_vec(), + Some(arrow::buffer::NullBuffer::from(vec![true, false, true])), + ) + .unwrap(); + + // Build Map + let map_arr = { + let key_builder = StringBuilder::new(); + let value_builder = Int32Builder::new(); + let mut builder = MapBuilder::new(None, key_builder, value_builder); + builder.keys().append_value("k1"); + builder.values().append_value(100); + builder.append(true).unwrap(); + builder.append(false).unwrap(); // null map entry + builder.keys().append_value("k2"); + builder.values().append_value(200); + builder.append(true).unwrap(); + builder.finish() + }; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(list_arr), + Arc::new(struct_arr), + Arc::new(map_arr), + ], + ) + .unwrap(); + + roundtrip_test(schema, &batch); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_dictionary_cast() { + // Dictionary should be cast to plain Utf8 on write + let dict_schema = Arc::new(Schema::new(vec![Field::new( + "d", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + + let keys = Int32Array::from(vec![Some(0), Some(1), None, Some(0)]); + let values = StringArray::from(vec!["foo", "bar"]); + let dict_arr = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let batch = + RecordBatch::try_new(Arc::clone(&dict_schema), vec![Arc::new(dict_arr)]).unwrap(); + + // The writer casts dict to plain string, so the read schema must be Utf8 + let read_schema = Arc::new(Schema::new(vec![Field::new("d", DataType::Utf8, true)])); + + let codecs = vec![ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + ]; + for codec in codecs { + let writer = ShuffleBlockWriter::try_new(&dict_schema, codec.clone()).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_shuffle_block(body, &read_schema).unwrap(); + assert_eq!(decoded.num_rows(), 4); + + // Result should be a plain StringArray, not a DictionaryArray + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .expect("expected plain StringArray after dict cast"); + assert_eq!(col.value(0), "foo"); + assert_eq!(col.value(1), "bar"); + assert!(col.is_null(2)); + assert_eq!(col.value(3), "foo"); + } + } + + #[test] + fn test_empty_batch_returns_zero() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(Vec::::new()))], + ) + .unwrap(); + assert_eq!(batch.num_rows(), 0); + + let writer = ShuffleBlockWriter::try_new(&schema, CompressionCodec::None).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + let bytes_written = writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + assert_eq!(bytes_written, 0); + } } From ae4363d6f0b87bb76283f84f38841dcc5395d4ca Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 09:50:18 -0600 Subject: [PATCH 20/33] refactor: update all Rust callers from read_ipc_compressed to read_shuffle_block --- native/core/src/execution/jni_api.rs | 3 +- .../src/execution/operators/shuffle_scan.rs | 12 +++---- native/core/src/execution/shuffle/mod.rs | 2 +- .../src/execution/shuffle/shuffle_writer.rs | 35 ++++++++++--------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index d20cf128b5..3914901229 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -83,7 +83,8 @@ use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; use crate::execution::operators::{ScanExec, ShuffleScanExec}; -use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; +use crate::execution::shuffle::codec::read_ipc_compressed; +use crate::execution::shuffle::CompressionCodec; use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index b638567166..3f903193bd 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -19,7 +19,7 @@ use crate::{ errors::CometError, execution::{ operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, - shuffle::codec::read_ipc_compressed, + shuffle::codec::read_shuffle_block, }, jvm_bridge::{jni_call, JVMClasses}, }; @@ -48,7 +48,7 @@ use super::scan::InputBatch; /// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. /// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed -/// bytes from CometShuffleBlockIterator and decodes them using read_ipc_compressed(). +/// bytes from CometShuffleBlockIterator and decodes them using read_shuffle_block(). #[derive(Debug, Clone)] pub struct ShuffleScanExec { /// The ID of the execution context that owns this subquery. @@ -124,6 +124,7 @@ impl ShuffleScanExec { self.input_source.as_ref().unwrap().as_obj(), &self.data_types, &self.decode_time, + &self.schema, )?; *current_batch = Some(next_batch); } @@ -139,6 +140,7 @@ impl ShuffleScanExec { iter: &JObject, data_types: &[DataType], decode_time: &Time, + schema: &Schema, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { return Ok(InputBatch::EOF); @@ -173,15 +175,13 @@ impl ShuffleScanExec { let length = block_length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - // Decode the compressed IPC data + // Decode the compressed shuffle block let mut timer = decode_time.timer(); - let batch = read_ipc_compressed(slice)?; + let batch = read_shuffle_block(slice, schema)?; timer.stop(); let num_rows = batch.num_rows(); - // The read_ipc_compressed already produces owned arrays, so we skip the - // header (field count + codec) that was already consumed by read_ipc_compressed. // Extract column arrays from the RecordBatch. let columns: Vec = batch.columns().to_vec(); diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index dd3c9bf849..19e23bb72c 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -23,6 +23,6 @@ mod shuffle_writer; pub mod spark_unsafe; mod writers; -pub use codec::{read_ipc_compressed, read_shuffle_block, CompressionCodec, ShuffleBlockWriter}; +pub use codec::{read_shuffle_block, CompressionCodec, ShuffleBlockWriter}; pub use comet_partitioning::CometPartitioning; pub use shuffle_writer::ShuffleWriterExec; diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index fe1bf0fccf..7aefb3e403 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -265,7 +265,7 @@ async fn external_shuffle( #[cfg(test)] mod test { use super::*; - use crate::execution::shuffle::{read_ipc_compressed, ShuffleBlockWriter}; + use crate::execution::shuffle::{read_shuffle_block, ShuffleBlockWriter}; use arrow::array::{Array, StringArray, StringBuilder}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -285,8 +285,9 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn roundtrip_ipc() { + fn roundtrip_raw() { let batch = create_batch(8192); + let schema = batch.schema(); for codec in &[ CompressionCodec::None, CompressionCodec::Zstd(1), @@ -296,14 +297,14 @@ mod test { let mut output = vec![]; let mut cursor = Cursor::new(&mut output); let writer = - ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone()).unwrap(); + ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); let length = writer .write_batch(&batch, &mut cursor, &Time::default()) .unwrap(); assert_eq!(length, output.len()); - let ipc_without_length_prefix = &output[16..]; - let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + let block_without_length_prefix = &output[16..]; + let batch2 = read_shuffle_block(block_without_length_prefix, &schema).unwrap(); assert_eq!(batch, batch2); } } @@ -651,7 +652,7 @@ mod test { buf_writer.flush(&encode_time, &write_time).unwrap(); } - // Coalesced output should be smaller due to fewer IPC schema blocks + // Coalesced output should be smaller due to fewer block headers assert!( coalesced_output.len() < uncoalesced_output.len(), "Coalesced output ({} bytes) should be smaller than uncoalesced ({} bytes)", @@ -659,9 +660,9 @@ mod test { uncoalesced_output.len() ); - // Verify both roundtrip correctly by reading all IPC blocks - let coalesced_rows = read_all_ipc_blocks(&coalesced_output); - let uncoalesced_rows = read_all_ipc_blocks(&uncoalesced_output); + // Verify both roundtrip correctly by reading all shuffle blocks + let coalesced_rows = read_all_shuffle_blocks(&coalesced_output, &schema); + let uncoalesced_rows = read_all_shuffle_blocks(&uncoalesced_output, &schema); assert_eq!( coalesced_rows, 5000, "Coalesced should contain all 5000 rows" @@ -672,22 +673,22 @@ mod test { ); } - /// Read all IPC blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, + /// Read all shuffle blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, /// returning the total number of rows. - fn read_all_ipc_blocks(data: &[u8]) -> usize { + fn read_all_shuffle_blocks(data: &[u8], schema: &Schema) -> usize { let mut offset = 0; let mut total_rows = 0; while offset < data.len() { - // First 8 bytes are the IPC length (little-endian u64) - let ipc_length = + // First 8 bytes are the block length (little-endian u64) + let block_length = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; // Skip the 8-byte length prefix; the next 8 bytes are field_count + codec header let block_start = offset + 8; - let block_end = block_start + ipc_length; - // read_ipc_compressed expects data starting after the 16-byte header + let block_end = block_start + block_length; + // read_shuffle_block expects data starting after the 16-byte header // (i.e., after length + field_count), at the codec tag - let ipc_data = &data[block_start + 8..block_end]; - let batch = read_ipc_compressed(ipc_data).unwrap(); + let block_data = &data[block_start + 8..block_end]; + let batch = read_shuffle_block(block_data, schema).unwrap(); total_rows += batch.num_rows(); offset = block_end; } From c12ae25aa8fa23b0de21d49397a971a9fe55c540 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 10:08:18 -0600 Subject: [PATCH 21/33] feat: pass schema to decodeShuffleBlock for raw shuffle format --- native/core/src/execution/jni_api.rs | 24 ++++++++++-- native/core/src/execution/shuffle/codec.rs | 39 ------------------- .../main/scala/org/apache/comet/Native.scala | 1 + .../CometBlockStoreShuffleReader.scala | 14 +++++++ .../shuffle/NativeBatchDecoderIterator.scala | 2 + 5 files changed, 37 insertions(+), 43 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 3914901229..f50f290795 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -28,7 +28,7 @@ use crate::{ }; use arrow::array::{Array, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; -use arrow::datatypes::DataType as ArrowDataType; +use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use datafusion::common::{DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::MemoryPool; @@ -39,7 +39,8 @@ use datafusion::{ physical_plan::{display::DisplayableExecutionPlan, SendableRecordBatchStream}, prelude::{SessionConfig, SessionContext}, }; -use datafusion_comet_proto::spark_operator::Operator; +use datafusion_comet_proto::spark_operator::{self, Operator}; +use prost::Message; use datafusion_spark::function::bitwise::bit_count::SparkBitCount; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::bitwise::bitwise_not::SparkBitwiseNot; @@ -83,7 +84,7 @@ use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; use crate::execution::operators::{ScanExec, ShuffleScanExec}; -use crate::execution::shuffle::codec::read_ipc_compressed; +use crate::execution::shuffle::read_shuffle_block; use crate::execution::shuffle::CompressionCodec; use crate::execution::spark_plan::SparkPlan; @@ -887,14 +888,29 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( length: jint, array_addrs: JLongArray, schema_addrs: JLongArray, + schema_bytes: JByteArray, tracing_enabled: jboolean, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { + // Parse the schema from protobuf bytes + let schema_vec = env.convert_byte_array(&schema_bytes)?; + let shuffle_scan = + spark_operator::ShuffleScan::decode(schema_vec.as_slice()).map_err(|e| { + CometError::Internal(format!("Failed to parse shuffle schema: {e}")) + })?; + let fields: Vec = shuffle_scan + .fields + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("c{i}"), to_arrow_datatype(dt), true)) + .collect(); + let schema = Schema::new(fields); + with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || { let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - let batch = read_ipc_compressed(slice)?; + let batch = read_shuffle_block(slice, &schema)?; prepare_output(&mut env, array_addrs, schema_addrs, batch, false) }) }) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index fcfdf3ad51..61b65a12cd 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -328,45 +328,6 @@ fn read_raw_batch(bytes: &[u8], schema: &Schema) -> Result { Ok(batch) } -/// Reads and decompresses a shuffle block. The `bytes` slice starts at the codec tag -/// (after the 8-byte length and 8-byte field_count header that the JVM reads). -/// -/// This is kept temporarily for backward compatibility while callers are migrated. -pub fn read_ipc_compressed(bytes: &[u8]) -> Result { - // NOTE: This function cannot work with the new raw format because it needs a schema. - // It is kept as a stub that delegates to the old IPC path for backward compatibility. - // Callers should migrate to read_shuffle_block(). - use arrow::ipc::reader::StreamReader; - match &bytes[0..4] { - b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"NONE" => { - let mut reader = - unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - other => Err(DataFusionError::Execution(format!( - "Failed to decode batch: invalid compression codec: {other:?}" - ))), - } -} - /// Reads and decompresses a shuffle block written in raw buffer format. /// The `bytes` slice starts at the codec tag (after the 8-byte length and /// 8-byte field_count header that the JVM reads). diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index f6800626d6..3571d9f6cd 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -177,6 +177,7 @@ class Native extends NativeBase { length: Int, arrayAddrs: Array[Long], schemaAddrs: Array[Long], + schemaBytes: Array[Byte], tracingEnabled: Boolean): Long /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index 14e656f038..dfec6e4474 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -36,6 +36,8 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator import org.apache.spark.util.CompletionIterator import org.apache.comet.{CometConf, Native} +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.QueryPlanSerde.serializeDataType import org.apache.comet.vector.NativeUtil /** @@ -86,6 +88,17 @@ class CometBlockStoreShuffleReader[K, C]( fetchContinuousBlocksInBatch).toCompletionIterator } + /** Serialize the output schema as a ShuffleScan protobuf message. */ + private lazy val schemaBytes: Array[Byte] = { + import scala.jdk.CollectionConverters._ + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val scanTypes = dep.outputAttributes.flatMap { attr => + serializeDataType(attr.dataType) + } + scanBuilder.addAllFields(scanTypes.asJava) + scanBuilder.build().toByteArray + } + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { var currentReadIterator: NativeBatchDecoderIterator = null @@ -114,6 +127,7 @@ class CometBlockStoreShuffleReader[K, C]( dep.decodeTime, nativeLib, nativeUtil, + schemaBytes, tracingEnabled) currentReadIterator }) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index f96c8f16dd..ed2287ffc7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -39,6 +39,7 @@ case class NativeBatchDecoderIterator( decodeTime: SQLMetric, nativeLib: Native, nativeUtil: NativeUtil, + schemaBytes: Array[Byte], tracingEnabled: Boolean) extends Iterator[ColumnarBatch] { @@ -160,6 +161,7 @@ case class NativeBatchDecoderIterator( bytesToRead.toInt, arrayAddrs, schemaAddrs, + schemaBytes, tracingEnabled) }) decodeTime.add(System.nanoTime() - startTime) From a7e9659a978f13841b12ddf50ec8144aaeb255df Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 10:28:34 -0600 Subject: [PATCH 22/33] docs: update comments to reflect raw buffer shuffle format --- native/core/src/execution/shuffle/codec.rs | 2 +- native/core/src/execution/shuffle/metrics.rs | 2 +- .../core/src/execution/shuffle/partitioners/multi_partition.rs | 2 +- native/core/src/execution/shuffle/shuffle_writer.rs | 2 +- native/core/src/execution/shuffle/writers/buf_batch_writer.rs | 2 +- .../main/java/org/apache/comet/CometShuffleBlockIterator.java | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 61b65a12cd..f01494d6f3 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -362,7 +362,7 @@ fn read_all(mut reader: R) -> Result> { Ok(buf) } -/// Checksum algorithms for writing IPC bytes. +/// Checksum algorithms for writing shuffle bytes. #[derive(Clone)] pub(crate) enum Checksum { /// CRC32 checksum algorithm. diff --git a/native/core/src/execution/shuffle/metrics.rs b/native/core/src/execution/shuffle/metrics.rs index 33b51c3cd8..6c768bf92f 100644 --- a/native/core/src/execution/shuffle/metrics.rs +++ b/native/core/src/execution/shuffle/metrics.rs @@ -26,7 +26,7 @@ pub(super) struct ShufflePartitionerMetrics { /// Time to perform repartitioning pub(super) repart_time: Time, - /// Time encoding batches to IPC format + /// Time encoding batches to shuffle format pub(super) encode_time: Time, /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics. diff --git a/native/core/src/execution/shuffle/partitioners/multi_partition.rs b/native/core/src/execution/shuffle/partitioners/multi_partition.rs index 9c366ad462..4da855b605 100644 --- a/native/core/src/execution/shuffle/partitioners/multi_partition.rs +++ b/native/core/src/execution/shuffle/partitioners/multi_partition.rs @@ -555,7 +555,7 @@ impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { .await } - /// Writes buffered shuffled record batches into Arrow IPC bytes. + /// Writes buffered shuffled record batches to the output shuffle file. fn shuffle_write(&mut self) -> datafusion::common::Result<()> { with_trace("shuffle_write", self.tracing_enabled, || { let start_time = Instant::now(); diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index 7aefb3e403..0ae6a996ed 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -588,7 +588,7 @@ mod test { } /// Test that batch coalescing in BufBatchWriter reduces output size by - /// writing fewer, larger IPC blocks instead of many small ones. + /// writing fewer, larger blocks instead of many small ones. #[test] #[cfg_attr(miri, ignore)] fn test_batch_coalescing_reduces_size() { diff --git a/native/core/src/execution/shuffle/writers/buf_batch_writer.rs b/native/core/src/execution/shuffle/writers/buf_batch_writer.rs index 8d056d7bb0..afbb3f09ed 100644 --- a/native/core/src/execution/shuffle/writers/buf_batch_writer.rs +++ b/native/core/src/execution/shuffle/writers/buf_batch_writer.rs @@ -27,7 +27,7 @@ use std::io::{Cursor, Seek, SeekFrom, Write}; /// Once the buffer exceeds the max size, the buffer will be flushed to the writer. /// /// Small batches are coalesced using Arrow's [`BatchCoalescer`] before serialization, -/// producing exactly `batch_size`-row output batches to reduce per-block IPC schema overhead. +/// producing exactly `batch_size`-row output batches to reduce per-block serialization overhead. /// The coalescer is lazily initialized on the first write. pub(crate) struct BufBatchWriter, W: Write> { shuffle_block_writer: S, diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java index f9abef1c36..02526a6e63 100644 --- a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -36,7 +36,7 @@ * getBuffer(). * *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. - * Native code must fully consume it (via read_ipc_compressed which allocates new memory for the + * Native code must fully consume it (via read_shuffle_block which allocates new memory for the * decompressed data) before pulling the next block. */ public class CometShuffleBlockIterator implements Closeable { From b1ccfd6332642efb67c493adb8a38c307a3816ce Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 10:42:13 -0600 Subject: [PATCH 23/33] format --- native/core/src/execution/jni_api.rs | 8 +++---- native/core/src/execution/shuffle/codec.rs | 22 +++++-------------- .../src/execution/shuffle/shuffle_writer.rs | 3 +-- 3 files changed, 9 insertions(+), 24 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index f50f290795..1f35a3bca4 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -40,7 +40,6 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; use datafusion_comet_proto::spark_operator::{self, Operator}; -use prost::Message; use datafusion_spark::function::bitwise::bit_count::SparkBitCount; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; use datafusion_spark::function::bitwise::bitwise_not::SparkBitwiseNot; @@ -73,6 +72,7 @@ use jni::{ sys::{jboolean, jdouble, jint, jlong}, JNIEnv, }; +use prost::Message; use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; @@ -894,10 +894,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( try_unwrap_or_throw(&e, |mut env| { // Parse the schema from protobuf bytes let schema_vec = env.convert_byte_array(&schema_bytes)?; - let shuffle_scan = - spark_operator::ShuffleScan::decode(schema_vec.as_slice()).map_err(|e| { - CometError::Internal(format!("Failed to parse shuffle schema: {e}")) - })?; + let shuffle_scan = spark_operator::ShuffleScan::decode(schema_vec.as_slice()) + .map_err(|e| CometError::Internal(format!("Failed to parse shuffle schema: {e}")))?; let fields: Vec = shuffle_scan .fields .iter() diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index f01494d6f3..8995168b0d 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -71,10 +71,7 @@ fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> writer.write_all(&num_buffers.to_le_bytes())?; for buffer in data.buffers() { let len: u32 = buffer.len().try_into().map_err(|_| { - DataFusionError::Execution(format!( - "Buffer length {} exceeds u32::MAX", - buffer.len() - )) + DataFusionError::Execution(format!("Buffer length {} exceeds u32::MAX", buffer.len())) })?; writer.write_all(&len.to_le_bytes())?; writer.write_all(buffer.as_slice())?; @@ -100,9 +97,7 @@ fn write_raw_batch(batch: &RecordBatch, writer: &mut W) -> Result<()> for col in batch.columns() { // Cast dictionary arrays to their value type let col = match col.data_type() { - DataType::Dictionary(_, value_type) => { - cast(col.as_ref(), value_type.as_ref())? - } + DataType::Dictionary(_, value_type) => cast(col.as_ref(), value_type.as_ref())?, _ => Arc::clone(col), }; write_array_data(&col.to_data(), writer)?; @@ -237,9 +232,7 @@ fn read_bytes<'a>(cursor: &mut &'a [u8], len: usize) -> Result<&'a [u8]> { /// Returns child data types for nested Arrow types. fn get_child_types(data_type: &DataType) -> Vec { match data_type { - DataType::List(field) - | DataType::LargeList(field) - | DataType::FixedSizeList(field, _) => { + DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) => { vec![field.data_type().clone()] } DataType::Map(field, _) => { @@ -709,11 +702,7 @@ mod tests { let batch = RecordBatch::try_new( Arc::clone(&schema), - vec![ - Arc::new(list_arr), - Arc::new(struct_arr), - Arc::new(map_arr), - ], + vec![Arc::new(list_arr), Arc::new(struct_arr), Arc::new(map_arr)], ) .unwrap(); @@ -732,8 +721,7 @@ mod tests { let keys = Int32Array::from(vec![Some(0), Some(1), None, Some(0)]); let values = StringArray::from(vec!["foo", "bar"]); - let dict_arr = - DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let dict_arr = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); let batch = RecordBatch::try_new(Arc::clone(&dict_schema), vec![Arc::new(dict_arr)]).unwrap(); diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index 0ae6a996ed..97b497b181 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -296,8 +296,7 @@ mod test { ] { let mut output = vec![]; let mut cursor = Cursor::new(&mut output); - let writer = - ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); + let writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); let length = writer .write_batch(&batch, &mut cursor, &Time::default()) .unwrap(); From a80220971d70d8b31fada0563f5b6e661e9495d5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 10:46:41 -0600 Subject: [PATCH 24/33] fix: handle zero-column batches in raw shuffle format Use RecordBatch::try_new_with_options with explicit row_count instead of try_new so that zero-column batches (produced by Spark when query results are unused) do not fail with "must either specify a row count or at least one column". --- native/core/src/execution/shuffle/codec.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 8995168b0d..56c7354ada 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -17,6 +17,7 @@ use crate::errors::{CometError, CometResult}; use arrow::array::{make_array, Array, ArrayRef, RecordBatch}; +use arrow::record_batch::RecordBatchOptions; use arrow::buffer::Buffer; use arrow::compute::cast; use arrow::datatypes::DataType; @@ -317,7 +318,10 @@ fn read_raw_batch(bytes: &[u8], schema: &Schema) -> Result { columns.push(make_array(array_data)); } - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let options = + RecordBatchOptions::new().with_row_count(Some(num_rows)); + let batch = + RecordBatch::try_new_with_options(Arc::new(schema.clone()), columns, &options)?; Ok(batch) } From ebd3c615a75bad6e515d32d0f1813004e13e9dec Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 10:47:08 -0600 Subject: [PATCH 25/33] format --- native/core/src/execution/shuffle/codec.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 56c7354ada..04f30ff7fa 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -17,11 +17,11 @@ use crate::errors::{CometError, CometResult}; use arrow::array::{make_array, Array, ArrayRef, RecordBatch}; -use arrow::record_batch::RecordBatchOptions; use arrow::buffer::Buffer; use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatchOptions; use bytes::Buf; use crc32fast::Hasher; use datafusion::common::DataFusionError; @@ -318,10 +318,8 @@ fn read_raw_batch(bytes: &[u8], schema: &Schema) -> Result { columns.push(make_array(array_data)); } - let options = - RecordBatchOptions::new().with_row_count(Some(num_rows)); - let batch = - RecordBatch::try_new_with_options(Arc::new(schema.clone()), columns, &options)?; + let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); + let batch = RecordBatch::try_new_with_options(Arc::new(schema.clone()), columns, &options)?; Ok(batch) } From 38fe95c462707fb37647fda70a6e901beae62ce1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 11:30:37 -0600 Subject: [PATCH 26/33] fix: pass outputAttributes in CometColumnarShuffle dependency CometColumnarShuffle was not setting outputAttributes on the CometShuffleDependency, leaving it as Seq.empty. This caused the shuffle reader to pass an empty schema to the native decodeShuffleBlock, resulting in "Output column count mismatch: expected N, got 0" errors. --- .../sql/comet/execution/shuffle/CometShuffleExchangeExec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index d65a6b21f4..02d707b5a4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -832,7 +832,8 @@ object CometShuffleExchangeExec shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), shuffleType = CometColumnarShuffle, schema = Some(fromAttributes(outputAttributes)), - decodeTime = writeMetrics("decode_time")) + decodeTime = writeMetrics("decode_time"), + outputAttributes = outputAttributes) dependency } From 517620b143058d85816f366e70e838030a3a876c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 11:57:00 -0600 Subject: [PATCH 27/33] fix: handle non-zero null buffer offset in raw shuffle format The null bitmap in Arrow arrays can have a non-zero bit offset even when ArrayData.offset() is 0 (e.g. after RecordBatch::slice). The raw shuffle writer was copying the bitmap bytes verbatim, but the reader assumes bits start at offset 0. This caused shifted null bitmaps, corrupting data during shuffle and producing wrong query results (e.g. TPC-DS q6 counts off by 1). Fix by detecting non-zero bitmap offsets and emitting a re-aligned copy. Add a roundtrip test with sliced batches to cover this case. --- native/core/src/execution/shuffle/codec.rs | 69 ++++++++++++++++++++-- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 04f30ff7fa..c0e6ff4ad3 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -31,6 +31,7 @@ use simd_adler32::Adler32; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; use std::sync::Arc; + #[derive(Debug, Clone)] pub enum CompressionCodec { None, @@ -46,6 +47,10 @@ pub struct ShuffleBlockWriter { } /// Recursively writes raw Arrow ArrayData buffers to the given writer. +/// +/// The null buffer may have a non-zero bit offset (e.g. from `RecordBatch::slice`), +/// even when `data.offset() == 0`. We emit a zero-offset copy of the bitmap so +/// the reader can consume it without tracking offsets. fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> Result<()> { debug_assert_eq!(data.offset(), 0, "shuffle arrays must have offset 0"); @@ -53,13 +58,28 @@ fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> let null_count = data.null_count() as u32; writer.write_all(&null_count.to_le_bytes())?; - // Write validity bitmap + // Write validity bitmap (always emitted at bit-offset 0) if null_count > 0 { if let Some(bitmap) = data.nulls() { - let bitmap_bytes = bitmap.buffer().as_slice(); - let len = bitmap_bytes.len() as u32; - writer.write_all(&len.to_le_bytes())?; - writer.write_all(bitmap_bytes)?; + if bitmap.offset() == 0 { + // Fast path: bitmap is already aligned, write raw bytes + let bitmap_bytes = bitmap.buffer().as_slice(); + let len = bitmap_bytes.len() as u32; + writer.write_all(&len.to_le_bytes())?; + writer.write_all(bitmap_bytes)?; + } else { + // Bitmap has a non-zero bit offset — produce a zero-offset copy + let num_bits = bitmap.len(); + let num_bytes = num_bits.div_ceil(8); + writer.write_all(&(num_bytes as u32).to_le_bytes())?; + let mut buf = vec![0u8; num_bytes]; + for i in 0..num_bits { + if bitmap.is_valid(i) { + buf[i / 8] |= 1 << (i % 8); + } + } + writer.write_all(&buf)?; + } } else { writer.write_all(&0u32.to_le_bytes())?; } @@ -762,6 +782,45 @@ mod tests { } } + #[test] + #[cfg_attr(miri, ignore)] + fn test_roundtrip_sliced_batch() { + // Test that arrays with non-zero offsets (from slicing) roundtrip correctly. + // This is important because the shuffle writer uses debug_assert for offset==0, + // but in release builds sliced arrays could silently produce wrong results. + let schema = Arc::new(Schema::new(vec![ + Field::new("i", DataType::Int32, true), + Field::new("s", DataType::Utf8, true), + ])); + let full_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + None, + Some(6), + ])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("bb"), + None, + Some("dddd"), + Some("eeeee"), + None, + ])), + ], + ) + .unwrap(); + + // Slice the batch to get arrays with non-zero offset + let sliced = full_batch.slice(2, 3); // rows: [Some(3), Some(4), None] and [None, Some("dddd"), Some("eeeee")] + assert_eq!(sliced.num_rows(), 3); + roundtrip_test(schema, &sliced); + } + #[test] fn test_empty_batch_returns_zero() { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); From 61c4bdf4b98b56c8366b928870f8121826b37308 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 12:17:52 -0600 Subject: [PATCH 28/33] format --- native/core/src/execution/shuffle/codec.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index c0e6ff4ad3..35a16483eb 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -31,7 +31,6 @@ use simd_adler32::Adler32; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; use std::sync::Arc; - #[derive(Debug, Clone)] pub enum CompressionCodec { None, From 3744fbee1b1346e7180f4b4eb073139defdba8f6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 12:49:56 -0600 Subject: [PATCH 29/33] fix: normalize arrays to zero offset before writing raw shuffle format Arrays from RecordBatch::slice can have non-zero offsets in both the ArrayData and the null bitmap. The raw shuffle format writes buffers verbatim assuming offset 0, causing data corruption when offsets are present. Use take() to produce zero-offset copies when needed, similar to prepare_output in jni_api.rs. This fixes TPC-DS q64 failures where the debug_assert fired and data mismatch errors from shifted null bitmaps. --- native/core/src/execution/shuffle/codec.rs | 61 +++++++++++----------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 35a16483eb..4f56580fe0 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -16,9 +16,9 @@ // under the License. use crate::errors::{CometError, CometResult}; -use arrow::array::{make_array, Array, ArrayRef, RecordBatch}; +use arrow::array::{make_array, Array, ArrayRef, RecordBatch, UInt32Array}; use arrow::buffer::Buffer; -use arrow::compute::cast; +use arrow::compute::{cast, take}; use arrow::datatypes::DataType; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatchOptions; @@ -46,10 +46,7 @@ pub struct ShuffleBlockWriter { } /// Recursively writes raw Arrow ArrayData buffers to the given writer. -/// -/// The null buffer may have a non-zero bit offset (e.g. from `RecordBatch::slice`), -/// even when `data.offset() == 0`. We emit a zero-offset copy of the bitmap so -/// the reader can consume it without tracking offsets. +/// Arrays must be normalized to zero offset before calling this function. fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> Result<()> { debug_assert_eq!(data.offset(), 0, "shuffle arrays must have offset 0"); @@ -57,28 +54,14 @@ fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> let null_count = data.null_count() as u32; writer.write_all(&null_count.to_le_bytes())?; - // Write validity bitmap (always emitted at bit-offset 0) + // Write validity bitmap if null_count > 0 { if let Some(bitmap) = data.nulls() { - if bitmap.offset() == 0 { - // Fast path: bitmap is already aligned, write raw bytes - let bitmap_bytes = bitmap.buffer().as_slice(); - let len = bitmap_bytes.len() as u32; - writer.write_all(&len.to_le_bytes())?; - writer.write_all(bitmap_bytes)?; - } else { - // Bitmap has a non-zero bit offset — produce a zero-offset copy - let num_bits = bitmap.len(); - let num_bytes = num_bits.div_ceil(8); - writer.write_all(&(num_bytes as u32).to_le_bytes())?; - let mut buf = vec![0u8; num_bytes]; - for i in 0..num_bits { - if bitmap.is_valid(i) { - buf[i / 8] |= 1 << (i % 8); - } - } - writer.write_all(&buf)?; - } + debug_assert_eq!(bitmap.offset(), 0, "null bitmap must have offset 0"); + let bitmap_bytes = bitmap.buffer().as_slice(); + let len = bitmap_bytes.len() as u32; + writer.write_all(&len.to_le_bytes())?; + writer.write_all(bitmap_bytes)?; } else { writer.write_all(&0u32.to_le_bytes())?; } @@ -109,17 +92,33 @@ fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> Ok(()) } +/// Ensures an array has zero offset in both its data and null buffer by +/// producing a physical copy when necessary. This is required because our +/// raw buffer format writes buffers verbatim and assumes offset 0. +fn normalize_array(col: &ArrayRef) -> Result { + // Cast dictionary arrays to their value type + let col = match col.data_type() { + DataType::Dictionary(_, value_type) => cast(col.as_ref(), value_type.as_ref())?, + _ => Arc::clone(col), + }; + + let needs_copy = col.offset() != 0 || col.nulls().is_some_and(|nulls| nulls.offset() != 0); + if needs_copy { + let indices = UInt32Array::from_iter_values(0..col.len() as u32); + Ok(take(col.as_ref(), &indices, None)?) + } else { + Ok(col) + } +} + /// Writes a RecordBatch in raw buffer format. Dictionary arrays are cast to their value type. +/// Arrays with non-zero offsets (e.g. from slicing) are copied to ensure offset 0. fn write_raw_batch(batch: &RecordBatch, writer: &mut W) -> Result<()> { let num_rows = batch.num_rows() as u32; writer.write_all(&num_rows.to_le_bytes())?; for col in batch.columns() { - // Cast dictionary arrays to their value type - let col = match col.data_type() { - DataType::Dictionary(_, value_type) => cast(col.as_ref(), value_type.as_ref())?, - _ => Arc::clone(col), - }; + let col = normalize_array(col)?; write_array_data(&col.to_data(), writer)?; } From 24528a8f1203a6fed291ee66b026b6dbedda0eda Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 13:30:21 -0600 Subject: [PATCH 30/33] perf: optimize shuffle codec hot paths - Replace take() with MutableArrayData::extend in normalize_array for zero-offset copies. MutableArrayData does a direct memcpy instead of building an index array and doing per-element lookups. - Cache parsed shuffle schema in thread-local storage. The schema bytes are identical for every decodeShuffleBlock call within a shuffle reader, avoiding repeated protobuf decode and Field/Schema allocation on every batch. - Change read_shuffle_block to accept Arc instead of &Schema, eliminating a full Schema clone (all field names and types) on every batch decode. --- native/core/src/execution/jni_api.rs | 49 ++++++++++++++----- .../src/execution/operators/shuffle_scan.rs | 2 +- native/core/src/execution/shuffle/codec.rs | 18 ++++--- .../src/execution/shuffle/shuffle_writer.rs | 2 +- 4 files changed, 51 insertions(+), 20 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1f35a3bca4..b866d98a9a 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -877,6 +877,43 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( }) } +// Thread-local cache for the parsed shuffle schema. The schema bytes are +// identical for every call within a given shuffle reader, so we avoid +// re-parsing protobuf and re-allocating Field/Schema objects on each batch. +thread_local! { + static CACHED_SHUFFLE_SCHEMA: std::cell::RefCell, Arc)>> = + const { std::cell::RefCell::new(None) }; +} + +/// Parse shuffle schema from protobuf bytes, using a thread-local cache. +fn get_or_parse_shuffle_schema( + env: &mut JNIEnv, + schema_bytes: &JByteArray, +) -> CometResult> { + let schema_vec = env.convert_byte_array(schema_bytes)?; + + CACHED_SHUFFLE_SCHEMA.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some((ref cached_bytes, ref cached_schema)) = *cache { + if cached_bytes == &schema_vec { + return Ok(Arc::clone(cached_schema)); + } + } + + let shuffle_scan = spark_operator::ShuffleScan::decode(schema_vec.as_slice()) + .map_err(|e| CometError::Internal(format!("Failed to parse shuffle schema: {e}")))?; + let fields: Vec = shuffle_scan + .fields + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("c{i}"), to_arrow_datatype(dt), true)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + *cache = Some((schema_vec, Arc::clone(&schema))); + Ok(schema) + }) +} + #[no_mangle] /// Used by Comet native shuffle reader /// # Safety @@ -892,17 +929,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( tracing_enabled: jboolean, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { - // Parse the schema from protobuf bytes - let schema_vec = env.convert_byte_array(&schema_bytes)?; - let shuffle_scan = spark_operator::ShuffleScan::decode(schema_vec.as_slice()) - .map_err(|e| CometError::Internal(format!("Failed to parse shuffle schema: {e}")))?; - let fields: Vec = shuffle_scan - .fields - .iter() - .enumerate() - .map(|(i, dt)| Field::new(format!("c{i}"), to_arrow_datatype(dt), true)) - .collect(); - let schema = Schema::new(fields); + let schema = get_or_parse_shuffle_schema(&mut env, &schema_bytes)?; with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || { let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 3f903193bd..9f4102e601 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -140,7 +140,7 @@ impl ShuffleScanExec { iter: &JObject, data_types: &[DataType], decode_time: &Time, - schema: &Schema, + schema: &Arc, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { return Ok(InputBatch::EOF); diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 4f56580fe0..59ecf65ada 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -16,9 +16,9 @@ // under the License. use crate::errors::{CometError, CometResult}; -use arrow::array::{make_array, Array, ArrayRef, RecordBatch, UInt32Array}; +use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, RecordBatch}; use arrow::buffer::Buffer; -use arrow::compute::{cast, take}; +use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatchOptions; @@ -104,8 +104,12 @@ fn normalize_array(col: &ArrayRef) -> Result { let needs_copy = col.offset() != 0 || col.nulls().is_some_and(|nulls| nulls.offset() != 0); if needs_copy { - let indices = UInt32Array::from_iter_values(0..col.len() as u32); - Ok(take(col.as_ref(), &indices, None)?) + // Use MutableArrayData::extend for a direct memcpy rather than + // take() which builds an index array and does per-element lookups. + let data = col.to_data(); + let mut mutable = MutableArrayData::new(vec![&data], false, col.len()); + mutable.extend(0, 0, col.len()); + Ok(make_array(mutable.freeze())) } else { Ok(col) } @@ -325,7 +329,7 @@ fn read_array_data( } /// Read a raw batch from decompressed bytes, given the expected schema. -fn read_raw_batch(bytes: &[u8], schema: &Schema) -> Result { +fn read_raw_batch(bytes: &[u8], schema: &Arc) -> Result { let mut cursor = bytes; let num_rows = read_u32(&mut cursor)? as usize; @@ -337,14 +341,14 @@ fn read_raw_batch(bytes: &[u8], schema: &Schema) -> Result { } let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); - let batch = RecordBatch::try_new_with_options(Arc::new(schema.clone()), columns, &options)?; + let batch = RecordBatch::try_new_with_options(Arc::clone(schema), columns, &options)?; Ok(batch) } /// Reads and decompresses a shuffle block written in raw buffer format. /// The `bytes` slice starts at the codec tag (after the 8-byte length and /// 8-byte field_count header that the JVM reads). -pub fn read_shuffle_block(bytes: &[u8], schema: &Schema) -> Result { +pub fn read_shuffle_block(bytes: &[u8], schema: &Arc) -> Result { match &bytes[0..4] { b"SNAP" => { let decoder = snap::read::FrameDecoder::new(&bytes[4..]); diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index 97b497b181..2b1cba7049 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -674,7 +674,7 @@ mod test { /// Read all shuffle blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, /// returning the total number of rows. - fn read_all_shuffle_blocks(data: &[u8], schema: &Schema) -> usize { + fn read_all_shuffle_blocks(data: &[u8], schema: &Arc) -> usize { let mut offset = 0; let mut total_rows = 0; while offset < data.len() { From 36800650911af6e21e96dd4acbf4be1d3225a389 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 13:45:37 -0600 Subject: [PATCH 31/33] perf: preserve dictionary encoding in raw shuffle format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, dictionary arrays were cast to their value types before writing to the shuffle format. This expanded the data (e.g. 100 unique strings × 8192 rows became 8192 full string copies) and reduced compression effectiveness. Now dictionary arrays are written natively with a per-column tag byte indicating the encoding (plain=0, dictionary=1) and key type. The reader reconstructs DictionaryArrays from these tags. For the ShuffleScanExec path, dictionary arrays flow directly to DataFusion which handles them natively. For the JNI decodeShuffleBlock path, dictionary columns are cast to value types after decode since the JVM expects plain Arrow types. The cast cost is the same but the serialized data is much smaller, saving IO and compression time. --- native/core/src/execution/jni_api.rs | 44 +++++++- native/core/src/execution/shuffle/codec.rs | 116 +++++++++++++++++---- 2 files changed, 137 insertions(+), 23 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index b866d98a9a..c9db6cc8aa 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -26,7 +26,7 @@ use crate::{ }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; -use arrow::array::{Array, RecordBatch, UInt32Array}; +use arrow::array::{Array, ArrayRef, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use datafusion::common::{DataFusionError, Result as DataFusionResult, ScalarValue}; @@ -877,6 +877,42 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( }) } +/// Casts any dictionary-encoded columns in a RecordBatch to their value types. +/// Used by the JNI decode path where the JVM expects plain Arrow types. +fn unpack_dictionary_columns(batch: RecordBatch) -> DataFusionResult { + let mut needs_cast = false; + for col in batch.columns() { + if matches!(col.data_type(), ArrowDataType::Dictionary(_, _)) { + needs_cast = true; + break; + } + } + if !needs_cast { + return Ok(batch); + } + + let mut new_columns: Vec = Vec::with_capacity(batch.num_columns()); + let mut new_fields: Vec = Vec::with_capacity(batch.num_columns()); + for (col, field) in batch.columns().iter().zip(batch.schema().fields()) { + match col.data_type() { + ArrowDataType::Dictionary(_, value_type) => { + new_columns.push(arrow::compute::cast(col.as_ref(), value_type)?); + new_fields.push(Arc::new(arrow::datatypes::Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ))); + } + _ => { + new_columns.push(Arc::clone(col)); + new_fields.push(Arc::clone(field)); + } + } + } + let schema = Arc::new(Schema::new(new_fields)); + Ok(RecordBatch::try_new(schema, new_columns)?) +} + // Thread-local cache for the parsed shuffle schema. The schema bytes are // identical for every call within a given shuffle reader, so we avoid // re-parsing protobuf and re-allocating Field/Schema objects on each batch. @@ -936,6 +972,12 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; let batch = read_shuffle_block(slice, &schema)?; + + // The raw shuffle format may preserve dictionary-encoded columns. + // The JVM side expects plain (non-dictionary) types, so cast any + // dictionary columns to their value types before FFI export. + let batch = unpack_dictionary_columns(batch)?; + prepare_output(&mut env, array_addrs, schema_addrs, batch, false) }) }) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 59ecf65ada..44196c87eb 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -18,7 +18,6 @@ use crate::errors::{CometError, CometResult}; use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, RecordBatch}; use arrow::buffer::Buffer; -use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatchOptions; @@ -96,12 +95,6 @@ fn write_array_data(data: &arrow::array::ArrayData, writer: &mut W) -> /// producing a physical copy when necessary. This is required because our /// raw buffer format writes buffers verbatim and assumes offset 0. fn normalize_array(col: &ArrayRef) -> Result { - // Cast dictionary arrays to their value type - let col = match col.data_type() { - DataType::Dictionary(_, value_type) => cast(col.as_ref(), value_type.as_ref())?, - _ => Arc::clone(col), - }; - let needs_copy = col.offset() != 0 || col.nulls().is_some_and(|nulls| nulls.offset() != 0); if needs_copy { // Use MutableArrayData::extend for a direct memcpy rather than @@ -111,18 +104,66 @@ fn normalize_array(col: &ArrayRef) -> Result { mutable.extend(0, 0, col.len()); Ok(make_array(mutable.freeze())) } else { - Ok(col) + Ok(Arc::clone(col)) + } +} + +// Column encoding tags for the raw shuffle format. +const COL_TAG_PLAIN: u8 = 0; +const COL_TAG_DICTIONARY: u8 = 1; + +/// Encode a dictionary key DataType as a single byte. +fn encode_key_type(dt: &DataType) -> Result { + match dt { + DataType::Int8 => Ok(0), + DataType::Int16 => Ok(1), + DataType::Int32 => Ok(2), + DataType::Int64 => Ok(3), + DataType::UInt8 => Ok(4), + DataType::UInt16 => Ok(5), + DataType::UInt32 => Ok(6), + DataType::UInt64 => Ok(7), + _ => Err(DataFusionError::Execution(format!( + "unsupported dictionary key type: {dt:?}" + ))), + } +} + +/// Decode a dictionary key DataType from the byte written by encode_key_type. +fn decode_key_type(tag: u8) -> Result { + match tag { + 0 => Ok(DataType::Int8), + 1 => Ok(DataType::Int16), + 2 => Ok(DataType::Int32), + 3 => Ok(DataType::Int64), + 4 => Ok(DataType::UInt8), + 5 => Ok(DataType::UInt16), + 6 => Ok(DataType::UInt32), + 7 => Ok(DataType::UInt64), + _ => Err(DataFusionError::Execution(format!( + "unknown dictionary key type tag: {tag}" + ))), } } -/// Writes a RecordBatch in raw buffer format. Dictionary arrays are cast to their value type. -/// Arrays with non-zero offsets (e.g. from slicing) are copied to ensure offset 0. +/// Writes a RecordBatch in raw buffer format. Dictionary arrays are preserved +/// (not cast to value type) so they compress better. Each column is prefixed +/// with a tag byte indicating plain (0) or dictionary (1) encoding. fn write_raw_batch(batch: &RecordBatch, writer: &mut W) -> Result<()> { let num_rows = batch.num_rows() as u32; writer.write_all(&num_rows.to_le_bytes())?; for col in batch.columns() { let col = normalize_array(col)?; + match col.data_type() { + DataType::Dictionary(key_type, _) => { + writer.write_all(&[COL_TAG_DICTIONARY])?; + writer.write_all(&[encode_key_type(key_type)?])?; + } + _ => { + writer.write_all(&[COL_TAG_PLAIN])?; + } + } write_array_data(&col.to_data(), writer)?; } @@ -263,6 +304,7 @@ fn get_child_types(data_type: &DataType) -> Vec { vec![field.data_type().clone()] } DataType::Struct(fields) => fields.iter().map(|f| f.data_type().clone()).collect(), + DataType::Dictionary(_, value_type) => vec![value_type.as_ref().clone()], _ => vec![], } } @@ -329,19 +371,45 @@ fn read_array_data( } /// Read a raw batch from decompressed bytes, given the expected schema. +/// Columns that were written as dictionary-encoded are reconstructed as +/// DictionaryArray. The returned batch schema may differ from the input +/// schema (Dictionary vs plain types) — callers that need plain types +/// should cast dictionary columns afterward. fn read_raw_batch(bytes: &[u8], schema: &Arc) -> Result { let mut cursor = bytes; let num_rows = read_u32(&mut cursor)? as usize; let mut columns: Vec = Vec::with_capacity(schema.fields().len()); + let mut fields: Vec = Vec::with_capacity(schema.fields().len()); for field in schema.fields() { - let array_data = read_array_data(&mut cursor, field.data_type(), num_rows)?; + // Read per-column tag + let tag = read_bytes(&mut cursor, 1)?[0]; + let data_type = match tag { + COL_TAG_PLAIN => field.data_type().clone(), + COL_TAG_DICTIONARY => { + let key_tag = read_bytes(&mut cursor, 1)?[0]; + let key_type = decode_key_type(key_tag)?; + DataType::Dictionary(Box::new(key_type), Box::new(field.data_type().clone())) + } + _ => { + return Err(DataFusionError::Execution(format!( + "unknown column tag: {tag}" + ))); + } + }; + let array_data = read_array_data(&mut cursor, &data_type, num_rows)?; columns.push(make_array(array_data)); + fields.push(Arc::new(arrow::datatypes::Field::new( + field.name(), + data_type, + field.is_nullable(), + ))); } + let actual_schema = Arc::new(Schema::new(fields)); let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); - let batch = RecordBatch::try_new_with_options(Arc::clone(schema), columns, &options)?; + let batch = RecordBatch::try_new_with_options(actual_schema, columns, &options)?; Ok(batch) } @@ -735,8 +803,8 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] - fn test_roundtrip_dictionary_cast() { - // Dictionary should be cast to plain Utf8 on write + fn test_roundtrip_dictionary_preserved() { + // Dictionary should be preserved through write/read let dict_schema = Arc::new(Schema::new(vec![Field::new( "d", DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), @@ -750,7 +818,8 @@ mod tests { let batch = RecordBatch::try_new(Arc::clone(&dict_schema), vec![Arc::new(dict_arr)]).unwrap(); - // The writer casts dict to plain string, so the read schema must be Utf8 + // The read schema uses the value type (Utf8) since that's what Spark knows about. + // The reader reconstructs a Dictionary type from the per-column tag. let read_schema = Arc::new(Schema::new(vec![Field::new("d", DataType::Utf8, true)])); let codecs = vec![ @@ -771,16 +840,19 @@ mod tests { let decoded = read_shuffle_block(body, &read_schema).unwrap(); assert_eq!(decoded.num_rows(), 4); - // Result should be a plain StringArray, not a DictionaryArray + // Result should be a DictionaryArray (preserved, not cast) let col = decoded .column(0) .as_any() - .downcast_ref::() - .expect("expected plain StringArray after dict cast"); - assert_eq!(col.value(0), "foo"); - assert_eq!(col.value(1), "bar"); - assert!(col.is_null(2)); - assert_eq!(col.value(3), "foo"); + .downcast_ref::>() + .expect("expected DictionaryArray to be preserved"); + // Verify values by casting to string for comparison + let cast_col = arrow::compute::cast(col, &DataType::Utf8).expect("cast dict to utf8"); + let str_col = cast_col.as_any().downcast_ref::().unwrap(); + assert_eq!(str_col.value(0), "foo"); + assert_eq!(str_col.value(1), "bar"); + assert!(str_col.is_null(2)); + assert_eq!(str_col.value(3), "foo"); } } From 172c41bc40a21fba8e9004c49b5766c336ed1596 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 13:58:59 -0600 Subject: [PATCH 32/33] perf: pre-allocate decompression buffer and avoid per-batch encoder Two medium-priority optimizations: 1. Write uncompressed size as a u32 header before compressed data. The reader uses this to pre-allocate the decompression buffer to exact size, eliminating ~18 reallocations (doubling strategy) per 256KB block. For Zstd, use bulk decompress which is also faster than streaming. 2. Serialize raw batch to intermediate buffer first, then compress in one shot. This avoids creating a streaming compression encoder per batch (Zstd allocates ~128KB internal state per encoder). For Zstd, use bulk::compress which reuses internal context. Also batches many small write_all calls into a single buffer, reducing overhead through compression codec state machines. --- native/core/src/execution/shuffle/codec.rs | 65 +++++++++++++--------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 44196c87eb..7ff5c581ac 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -201,6 +201,11 @@ impl ShuffleBlockWriter { /// Writes given record batch in raw buffer format into given writer. /// Returns number of bytes written. + /// + /// The batch is first serialized to an intermediate buffer, then compressed + /// in one shot. This avoids creating a streaming compression encoder per + /// batch (expensive for Zstd ~128KB state) and gives us the uncompressed + /// size for a pre-allocation hint on the read side. pub fn write_batch( &self, batch: &RecordBatch, @@ -217,31 +222,35 @@ impl ShuffleBlockWriter { // write header output.write_all(&self.header_bytes)?; - let output = match &self.codec { + // Serialize raw batch to intermediate buffer + let mut raw_buf = Vec::new(); + write_raw_batch(batch, &mut raw_buf)?; + + // Write uncompressed size hint (u32), then compressed data + let uncompressed_len = raw_buf.len() as u32; + output.write_all(&uncompressed_len.to_le_bytes())?; + + match &self.codec { CompressionCodec::None => { - write_raw_batch(batch, output)?; - output + output.write_all(&raw_buf)?; } CompressionCodec::Lz4Frame => { - let mut wtr = lz4_flex::frame::FrameEncoder::new(output); - write_raw_batch(batch, &mut wtr)?; + let mut wtr = lz4_flex::frame::FrameEncoder::new(output.by_ref()); + wtr.write_all(&raw_buf)?; wtr.finish().map_err(|e| { DataFusionError::Execution(format!("lz4 compression error: {e}")) - })? + })?; } - CompressionCodec::Zstd(level) => { - let mut encoder = zstd::Encoder::new(output, *level)?; - write_raw_batch(batch, &mut encoder)?; - encoder.finish()? + let compressed = zstd::bulk::compress(&raw_buf, *level)?; + output.write_all(&compressed)?; } - CompressionCodec::Snappy => { - let mut wtr = snap::write::FrameEncoder::new(output); - write_raw_batch(batch, &mut wtr)?; + let mut wtr = snap::write::FrameEncoder::new(output.by_ref()); + wtr.write_all(&raw_buf)?; wtr.into_inner().map_err(|e| { DataFusionError::Execution(format!("snappy compression error: {e}")) - })? + })?; } }; @@ -416,33 +425,39 @@ fn read_raw_batch(bytes: &[u8], schema: &Arc) -> Result { /// Reads and decompresses a shuffle block written in raw buffer format. /// The `bytes` slice starts at the codec tag (after the 8-byte length and /// 8-byte field_count header that the JVM reads). +/// +/// Format: `[codec_tag: 4 bytes][uncompressed_len: u32][compressed_data...]` pub fn read_shuffle_block(bytes: &[u8], schema: &Arc) -> Result { - match &bytes[0..4] { + let codec_tag = &bytes[0..4]; + // Read uncompressed size hint for pre-allocation + let uncompressed_len = u32::from_le_bytes(bytes[4..8].try_into().unwrap()) as usize; + let data = &bytes[8..]; + + match codec_tag { b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let decompressed = read_all(decoder)?; + let decoder = snap::read::FrameDecoder::new(data); + let decompressed = read_all_with_capacity(decoder, uncompressed_len)?; read_raw_batch(&decompressed, schema) } b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let decompressed = read_all(decoder)?; + let decoder = lz4_flex::frame::FrameDecoder::new(data); + let decompressed = read_all_with_capacity(decoder, uncompressed_len)?; read_raw_batch(&decompressed, schema) } b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let decompressed = read_all(decoder)?; + let decompressed = zstd::bulk::decompress(data, uncompressed_len)?; read_raw_batch(&decompressed, schema) } - b"NONE" => read_raw_batch(&bytes[4..], schema), + b"NONE" => read_raw_batch(data, schema), other => Err(DataFusionError::Execution(format!( "Failed to decode batch: invalid compression codec: {other:?}" ))), } } -/// Read all bytes from a reader into a Vec. -fn read_all(mut reader: R) -> Result> { - let mut buf = Vec::new(); +/// Read all bytes from a reader into a pre-allocated Vec. +fn read_all_with_capacity(mut reader: R, capacity: usize) -> Result> { + let mut buf = Vec::with_capacity(capacity); reader.read_to_end(&mut buf)?; Ok(buf) } From 9af4c40360c381d274a6b33665cc4af12a02566b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Mar 2026 14:04:52 -0600 Subject: [PATCH 33/33] perf: use bulk LZ4 block compression instead of streaming frame LZ4 is the default shuffle compression codec. Switch from lz4_flex FrameEncoder/FrameDecoder (streaming, per-batch encoder allocation) to lz4_flex::compress/decompress (block-level, no encoder state). Combined with the uncompressed size header, the decompressor allocates exactly once to the right size. --- native/core/src/execution/shuffle/codec.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 7ff5c581ac..52cc320e6c 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -235,11 +235,8 @@ impl ShuffleBlockWriter { output.write_all(&raw_buf)?; } CompressionCodec::Lz4Frame => { - let mut wtr = lz4_flex::frame::FrameEncoder::new(output.by_ref()); - wtr.write_all(&raw_buf)?; - wtr.finish().map_err(|e| { - DataFusionError::Execution(format!("lz4 compression error: {e}")) - })?; + let compressed = lz4_flex::compress(&raw_buf); + output.write_all(&compressed)?; } CompressionCodec::Zstd(level) => { let compressed = zstd::bulk::compress(&raw_buf, *level)?; @@ -440,8 +437,8 @@ pub fn read_shuffle_block(bytes: &[u8], schema: &Arc) -> Result { - let decoder = lz4_flex::frame::FrameDecoder::new(data); - let decompressed = read_all_with_capacity(decoder, uncompressed_len)?; + let decompressed = lz4_flex::decompress(data, uncompressed_len) + .map_err(|e| DataFusionError::Execution(format!("lz4 decompression error: {e}")))?; read_raw_batch(&decompressed, schema) } b"ZSTD" => {