diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java b/common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java
new file mode 100644
index 0000000000..132245d7a6
--- /dev/null
+++ b/common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.udf;
+
+import java.io.IOException;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.arrow.memory.AllocationListener;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.spark.TaskContext;
+import org.apache.spark.comet.CometTaskContextShim;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.TaskCompletionListener;
+
+/**
+ * Provides a per-Spark-task Arrow {@link BufferAllocator} whose allocations are accounted to the
+ * task's {@link TaskMemoryManager} via a {@link MemoryConsumer}.
+ *
+ *
One child allocator is created on the first {@link #acquire(TaskContext)} call within a Spark
+ * task and cached for the lifetime of that task. A task-completion listener closes the child
+ * allocator and removes the cache entry, so allocations and frees stay balanced with the task
+ * lifecycle. Subsequent UDF dispatches within the same task reuse the cached allocator without
+ * additional accounting overhead.
+ *
+ *
The allocator is a child of the project-wide {@code CometArrowAllocator}, with a custom {@link
+ * AllocationListener} attached. The listener forwards pre-allocation to {@code
+ * TaskMemoryManager.acquireExecutionMemory} and on-release to {@code releaseExecutionMemory}. A
+ * partial acquisition is released and the allocation aborted via {@link OutOfMemoryError}; this is
+ * the standard Arrow + Spark integration pattern.
+ */
+public final class CometUdfAllocator {
+
+ private static final Logger LOG = LoggerFactory.getLogger(CometUdfAllocator.class);
+
+ /** Keyed by {@code TaskContext.taskAttemptId()}. */
+ private static final ConcurrentHashMap CACHE = new ConcurrentHashMap<>();
+
+ private CometUdfAllocator() {}
+
+ /**
+ * Returns the per-task child allocator, creating it on first call. Safe to call from multiple
+ * worker threads within the same Spark task.
+ */
+ public static BufferAllocator acquire(TaskContext taskContext) {
+ long key = taskContext.taskAttemptId();
+ return CACHE.computeIfAbsent(key, k -> create(taskContext));
+ }
+
+ /** Visible for testing. */
+ static int cacheSize() {
+ return CACHE.size();
+ }
+
+ private static BufferAllocator create(TaskContext taskContext) {
+ TaskMemoryManager tmm = CometTaskContextShim.taskMemoryManager(taskContext);
+ TaskMemoryConsumer consumer = new TaskMemoryConsumer(tmm);
+ AllocationListener listener = new TaskAllocationListener(tmm, consumer);
+ BufferAllocator child =
+ org.apache.comet.package$.MODULE$.CometArrowAllocator()
+ .newChildAllocator(
+ "comet-udf-task-" + taskContext.taskAttemptId(), listener, 0L, Long.MAX_VALUE);
+ long key = taskContext.taskAttemptId();
+ try {
+ taskContext.addTaskCompletionListener(
+ (TaskCompletionListener) ctx -> closeAndRemove(key, child));
+ } catch (RuntimeException e) {
+ try {
+ child.close();
+ } catch (RuntimeException ignored) {
+ // do not mask the original failure
+ }
+ throw e;
+ }
+ return child;
+ }
+
+ private static void closeAndRemove(long key, BufferAllocator child) {
+ CACHE.remove(key);
+ try {
+ child.close();
+ } catch (RuntimeException e) {
+ // Arrow throws IllegalStateException if any buffers leaked. Log and swallow so we do not
+ // mask other task-completion errors.
+ LOG.warn(
+ "Comet UDF child allocator for taskAttemptId={} closed with leaked allocations", key, e);
+ }
+ }
+
+ /** Delegates Arrow allocation events to a Spark {@link MemoryConsumer}. */
+ private static final class TaskAllocationListener implements AllocationListener {
+ private final TaskMemoryManager tmm;
+ private final TaskMemoryConsumer consumer;
+
+ TaskAllocationListener(TaskMemoryManager tmm, TaskMemoryConsumer consumer) {
+ this.tmm = tmm;
+ this.consumer = consumer;
+ }
+
+ @Override
+ public void onPreAllocation(long size) {
+ long acquired = tmm.acquireExecutionMemory(size, consumer);
+ if (acquired < size) {
+ if (acquired > 0) {
+ tmm.releaseExecutionMemory(acquired, consumer);
+ }
+ throw new OutOfMemoryError(
+ "Comet UDF: failed to acquire " + size + " bytes from Spark TaskMemoryManager");
+ }
+ }
+
+ @Override
+ public void onRelease(long size) {
+ tmm.releaseExecutionMemory(size, consumer);
+ }
+ }
+
+ /** Non-spillable off-heap consumer for Arrow buffers exported via FFI. */
+ private static final class TaskMemoryConsumer extends MemoryConsumer {
+ TaskMemoryConsumer(TaskMemoryManager tmm) {
+ // pageSize = 0: this consumer never uses allocatePage(); only acquire/releaseExecutionMemory.
+ // Matches the existing NativeMemoryConsumer pattern in CometTaskMemoryManager.
+ super(tmm, 0L, MemoryMode.OFF_HEAP);
+ }
+
+ @Override
+ public long spill(long size, MemoryConsumer trigger) throws IOException {
+ // UDF result buffers are pinned for the duration of the call and exported via Arrow FFI;
+ // they cannot be spilled.
+ return 0L;
+ }
+ }
+}
diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
index 5e76819810..6443d5be50 100644
--- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
+++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
@@ -20,6 +20,10 @@
package org.apache.comet.udf;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
@@ -42,6 +46,10 @@ public class CometUdfBridge {
// single shared instance per class is safe across native worker threads.
private static final ConcurrentHashMap INSTANCES = new ConcurrentHashMap<>();
+ private static final Logger LOG = LoggerFactory.getLogger(CometUdfBridge.class);
+
+ private static final AtomicBoolean WARNED_NO_TASK_CONTEXT = new AtomicBoolean(false);
+
/**
* Called from native via JNI.
*
@@ -114,7 +122,7 @@ private static void evaluateInternal(
}
});
- BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
+ BufferAllocator allocator = resolveAllocator();
ValueVector[] inputs = new ValueVector[inputArrayPtrs.length];
ValueVector result = null;
@@ -125,7 +133,7 @@ private static void evaluateInternal(
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
}
- result = udf.evaluate(inputs, numRows);
+ result = udf.evaluate(allocator, inputs, numRows);
if (!(result instanceof FieldVector)) {
throw new RuntimeException(
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
@@ -159,4 +167,18 @@ private static void evaluateInternal(
}
}
}
+
+ private static BufferAllocator resolveAllocator() {
+ TaskContext ctx = TaskContext.get();
+ if (ctx != null) {
+ return CometUdfAllocator.acquire(ctx);
+ }
+ if (WARNED_NO_TASK_CONTEXT.compareAndSet(false, true)) {
+ LOG.warn(
+ "CometUdfBridge invoked with no TaskContext on the calling thread; falling back to the "
+ + "unaccounted root allocator. UDF off-heap memory will not be charged to Spark's "
+ + "task memory manager.");
+ }
+ return org.apache.comet.package$.MODULE$.CometArrowAllocator();
+ }
}
diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
index 5b6652d90a..4408b0e2b8 100644
--- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
+++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala
@@ -19,6 +19,7 @@
package org.apache.comet.udf
+import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.ValueVector
/**
@@ -28,6 +29,9 @@ import org.apache.arrow.vector.ValueVector
* - Vector arguments arrive at the row count of the current batch.
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
* - The returned vector's length must match `numRows`.
+ * - `allocator` is the per-task Arrow allocator backed by Spark's task memory accounting (see
+ * `CometUdfAllocator`). Implementations must allocate any returned vector or temporary
+ * buffers from this allocator so off-heap usage is charged to the executing Spark task.
*
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
* UDFs that always have at least one batch-length input can derive length from the inputs and
@@ -38,5 +42,5 @@ import org.apache.arrow.vector.ValueVector
* per class is cached and shared across native worker threads for the lifetime of the JVM.
*/
trait CometUDF {
- def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
+ def evaluate(allocator: BufferAllocator, inputs: Array[ValueVector], numRows: Int): ValueVector
}
diff --git a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
index 9218fc5e78..6c3854fd62 100644
--- a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
+++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala
@@ -20,22 +20,22 @@
package org.apache.spark.comet
import org.apache.spark.TaskContext
+import org.apache.spark.memory.TaskMemoryManager
/**
- * Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`.
+ * Package-private access shim for Spark APIs that are `protected[spark]` or `private[spark]`.
*
- * Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are
- * reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`.
- * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a
- * Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive
- * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's
- * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the
- * protected methods, and exposes plain public forwarders the bridge (which lives in
- * `org.apache.comet.udf`) can use.
+ * `TaskContext.setTaskContext` / `TaskContext.unset` are `protected[spark]` on the companion;
+ * `TaskContext.taskMemoryManager()` is `private[spark]` on the instance. Code outside the
+ * `org.apache.spark` package tree (e.g. `org.apache.comet.udf`) cannot call them directly. This
+ * shim lives in `org.apache.spark.comet` so it can forward through.
*/
object CometTaskContextShim {
def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext)
def unset(): Unit = TaskContext.unset()
+
+ def taskMemoryManager(taskContext: TaskContext): TaskMemoryManager =
+ taskContext.taskMemoryManager()
}
diff --git a/common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java b/common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java
new file mode 100644
index 0000000000..2938e7d18e
--- /dev/null
+++ b/common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.udf;
+
+import java.util.Collections;
+import java.util.Iterator;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.spark.TaskContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.comet.CometTaskContextShim;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.util.LongAccumulator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Verifies that {@link CometUdfAllocator} both charges the executing Spark task's {@link
+ * TaskMemoryManager} for allocations and tears down the per-task allocator on task completion.
+ * Lives in the {@code common} module to stay on the unshaded side of the Arrow relocation boundary;
+ * the spark module sees {@code BufferAllocator} as the shaded type.
+ */
+public class CometUdfAllocatorTest {
+
+ private static SparkSession spark;
+ private static JavaSparkContext jsc;
+
+ @BeforeClass
+ public static void setUp() {
+ spark =
+ SparkSession.builder()
+ .master("local[2]")
+ .appName("CometUdfAllocatorTest")
+ .config("spark.memory.offHeap.enabled", "true")
+ .config("spark.memory.offHeap.size", "67108864")
+ .getOrCreate();
+ jsc = new JavaSparkContext(spark.sparkContext());
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ if (spark != null) {
+ spark.stop();
+ spark = null;
+ }
+ }
+
+ @Test
+ public void acquireRegistersMemoryConsumerThatChargesTask() {
+ LongAccumulator before = jsc.sc().longAccumulator("before");
+ LongAccumulator during = jsc.sc().longAccumulator("during");
+ LongAccumulator after = jsc.sc().longAccumulator("after");
+
+ jsc.parallelize(Collections.singletonList(0), 1)
+ .foreachPartition(
+ (VoidFunction>)
+ it -> {
+ TaskContext ctx = TaskContext.get();
+ TaskMemoryManager tmm = CometTaskContextShim.taskMemoryManager(ctx);
+ before.add(tmm.getMemoryConsumptionForThisTask());
+
+ BufferAllocator allocator = CometUdfAllocator.acquire(ctx);
+ IntVector vec = new IntVector("test", allocator);
+ try {
+ vec.allocateNew(1024);
+ vec.setValueCount(1024);
+ during.add(tmm.getMemoryConsumptionForThisTask());
+ } finally {
+ vec.close();
+ }
+ after.add(tmm.getMemoryConsumptionForThisTask());
+ });
+
+ assertTrue(
+ "task memory should grow while allocator holds buffers; before="
+ + before.value()
+ + " during="
+ + during.value(),
+ during.value() > before.value());
+ assertEquals(
+ "task memory should be released after vector closes", before.value(), after.value());
+ }
+
+ @Test
+ public void childAllocatorClosedAndUncachedOnTaskCompletion() {
+ LongAccumulator cacheSizeDuring = jsc.sc().longAccumulator("cacheSizeDuring");
+
+ jsc.parallelize(Collections.singletonList(0), 1)
+ .foreachPartition(
+ (VoidFunction>)
+ it -> {
+ TaskContext ctx = TaskContext.get();
+ CometUdfAllocator.acquire(ctx);
+ cacheSizeDuring.add((long) CometUdfAllocator.cacheSize());
+ });
+
+ assertTrue("allocator should be cached during the task", cacheSizeDuring.value() >= 1L);
+ assertEquals(
+ "TaskCompletionListener should remove the entry once the task ends",
+ 0,
+ CometUdfAllocator.cacheSize());
+ }
+}