From 81ced72252d6c19fc08309e8ca5c1feb9c05246c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 May 2026 08:10:58 -0600 Subject: [PATCH 1/6] feat(udf): expose TaskMemoryManager forwarder on CometTaskContextShim --- .../spark/comet/CometTaskContextShim.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala index 9218fc5e78..6c3854fd62 100644 --- a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala +++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala @@ -20,22 +20,22 @@ package org.apache.spark.comet import org.apache.spark.TaskContext +import org.apache.spark.memory.TaskMemoryManager /** - * Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`. + * Package-private access shim for Spark APIs that are `protected[spark]` or `private[spark]`. * - * Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are - * reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`. - * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a - * Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive - * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's - * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the - * protected methods, and exposes plain public forwarders the bridge (which lives in - * `org.apache.comet.udf`) can use. + * `TaskContext.setTaskContext` / `TaskContext.unset` are `protected[spark]` on the companion; + * `TaskContext.taskMemoryManager()` is `private[spark]` on the instance. Code outside the + * `org.apache.spark` package tree (e.g. `org.apache.comet.udf`) cannot call them directly. This + * shim lives in `org.apache.spark.comet` so it can forward through. */ object CometTaskContextShim { def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext) def unset(): Unit = TaskContext.unset() + + def taskMemoryManager(taskContext: TaskContext): TaskMemoryManager = + taskContext.taskMemoryManager() } From f634c1cdca36c46da933f077334e97371ec75402 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 May 2026 08:15:36 -0600 Subject: [PATCH 2/6] feat(udf): add CometUdfAllocator for per-task Arrow accounting --- .../apache/comet/udf/CometUdfAllocator.java | 154 ++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java b/common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java new file mode 100644 index 0000000000..132245d7a6 --- /dev/null +++ b/common/src/main/java/org/apache/comet/udf/CometUdfAllocator.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.arrow.memory.AllocationListener; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.spark.TaskContext; +import org.apache.spark.comet.CometTaskContextShim; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.TaskCompletionListener; + +/** + * Provides a per-Spark-task Arrow {@link BufferAllocator} whose allocations are accounted to the + * task's {@link TaskMemoryManager} via a {@link MemoryConsumer}. + * + *

One child allocator is created on the first {@link #acquire(TaskContext)} call within a Spark + * task and cached for the lifetime of that task. A task-completion listener closes the child + * allocator and removes the cache entry, so allocations and frees stay balanced with the task + * lifecycle. Subsequent UDF dispatches within the same task reuse the cached allocator without + * additional accounting overhead. + * + *

The allocator is a child of the project-wide {@code CometArrowAllocator}, with a custom {@link + * AllocationListener} attached. The listener forwards pre-allocation to {@code + * TaskMemoryManager.acquireExecutionMemory} and on-release to {@code releaseExecutionMemory}. A + * partial acquisition is released and the allocation aborted via {@link OutOfMemoryError}; this is + * the standard Arrow + Spark integration pattern. + */ +public final class CometUdfAllocator { + + private static final Logger LOG = LoggerFactory.getLogger(CometUdfAllocator.class); + + /** Keyed by {@code TaskContext.taskAttemptId()}. */ + private static final ConcurrentHashMap CACHE = new ConcurrentHashMap<>(); + + private CometUdfAllocator() {} + + /** + * Returns the per-task child allocator, creating it on first call. Safe to call from multiple + * worker threads within the same Spark task. + */ + public static BufferAllocator acquire(TaskContext taskContext) { + long key = taskContext.taskAttemptId(); + return CACHE.computeIfAbsent(key, k -> create(taskContext)); + } + + /** Visible for testing. */ + static int cacheSize() { + return CACHE.size(); + } + + private static BufferAllocator create(TaskContext taskContext) { + TaskMemoryManager tmm = CometTaskContextShim.taskMemoryManager(taskContext); + TaskMemoryConsumer consumer = new TaskMemoryConsumer(tmm); + AllocationListener listener = new TaskAllocationListener(tmm, consumer); + BufferAllocator child = + org.apache.comet.package$.MODULE$.CometArrowAllocator() + .newChildAllocator( + "comet-udf-task-" + taskContext.taskAttemptId(), listener, 0L, Long.MAX_VALUE); + long key = taskContext.taskAttemptId(); + try { + taskContext.addTaskCompletionListener( + (TaskCompletionListener) ctx -> closeAndRemove(key, child)); + } catch (RuntimeException e) { + try { + child.close(); + } catch (RuntimeException ignored) { + // do not mask the original failure + } + throw e; + } + return child; + } + + private static void closeAndRemove(long key, BufferAllocator child) { + CACHE.remove(key); + try { + child.close(); + } catch (RuntimeException e) { + // Arrow throws IllegalStateException if any buffers leaked. Log and swallow so we do not + // mask other task-completion errors. + LOG.warn( + "Comet UDF child allocator for taskAttemptId={} closed with leaked allocations", key, e); + } + } + + /** Delegates Arrow allocation events to a Spark {@link MemoryConsumer}. */ + private static final class TaskAllocationListener implements AllocationListener { + private final TaskMemoryManager tmm; + private final TaskMemoryConsumer consumer; + + TaskAllocationListener(TaskMemoryManager tmm, TaskMemoryConsumer consumer) { + this.tmm = tmm; + this.consumer = consumer; + } + + @Override + public void onPreAllocation(long size) { + long acquired = tmm.acquireExecutionMemory(size, consumer); + if (acquired < size) { + if (acquired > 0) { + tmm.releaseExecutionMemory(acquired, consumer); + } + throw new OutOfMemoryError( + "Comet UDF: failed to acquire " + size + " bytes from Spark TaskMemoryManager"); + } + } + + @Override + public void onRelease(long size) { + tmm.releaseExecutionMemory(size, consumer); + } + } + + /** Non-spillable off-heap consumer for Arrow buffers exported via FFI. */ + private static final class TaskMemoryConsumer extends MemoryConsumer { + TaskMemoryConsumer(TaskMemoryManager tmm) { + // pageSize = 0: this consumer never uses allocatePage(); only acquire/releaseExecutionMemory. + // Matches the existing NativeMemoryConsumer pattern in CometTaskMemoryManager. + super(tmm, 0L, MemoryMode.OFF_HEAP); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + // UDF result buffers are pinned for the duration of the call and exported via Arrow FFI; + // they cannot be spilled. + return 0L; + } + } +} From 90988c4fd276899e2189163162e70e386a3b4d55 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 May 2026 08:22:34 -0600 Subject: [PATCH 3/6] feat(udf): route bridge allocations through per-task allocator Add BufferAllocator as first parameter to CometUDF.evaluate so implementations receive the per-task allocator (CometUdfAllocator) and can charge off-heap usage to Spark's task memory manager. CometUdfBridge.evaluateInternal now calls resolveAllocator() which uses CometUdfAllocator.acquire(ctx) when a TaskContext is available, falling back to the root allocator with a one-time warning otherwise. --- .../org/apache/comet/udf/CometUdfBridge.java | 26 +++++++++++++++++-- .../scala/org/apache/comet/udf/CometUDF.scala | 6 ++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index 5e76819810..6443d5be50 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -20,6 +20,10 @@ package org.apache.comet.udf; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; @@ -42,6 +46,10 @@ public class CometUdfBridge { // single shared instance per class is safe across native worker threads. private static final ConcurrentHashMap INSTANCES = new ConcurrentHashMap<>(); + private static final Logger LOG = LoggerFactory.getLogger(CometUdfBridge.class); + + private static final AtomicBoolean WARNED_NO_TASK_CONTEXT = new AtomicBoolean(false); + /** * Called from native via JNI. * @@ -114,7 +122,7 @@ private static void evaluateInternal( } }); - BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); + BufferAllocator allocator = resolveAllocator(); ValueVector[] inputs = new ValueVector[inputArrayPtrs.length]; ValueVector result = null; @@ -125,7 +133,7 @@ private static void evaluateInternal( inputs[i] = Data.importVector(allocator, inArr, inSch, null); } - result = udf.evaluate(inputs, numRows); + result = udf.evaluate(allocator, inputs, numRows); if (!(result instanceof FieldVector)) { throw new RuntimeException( "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); @@ -159,4 +167,18 @@ private static void evaluateInternal( } } } + + private static BufferAllocator resolveAllocator() { + TaskContext ctx = TaskContext.get(); + if (ctx != null) { + return CometUdfAllocator.acquire(ctx); + } + if (WARNED_NO_TASK_CONTEXT.compareAndSet(false, true)) { + LOG.warn( + "CometUdfBridge invoked with no TaskContext on the calling thread; falling back to the " + + "unaccounted root allocator. UDF off-heap memory will not be charged to Spark's " + + "task memory manager."); + } + return org.apache.comet.package$.MODULE$.CometArrowAllocator(); + } } diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index 5b6652d90a..4408b0e2b8 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -19,6 +19,7 @@ package org.apache.comet.udf +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.ValueVector /** @@ -28,6 +29,9 @@ import org.apache.arrow.vector.ValueVector * - Vector arguments arrive at the row count of the current batch. * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. * - The returned vector's length must match `numRows`. + * - `allocator` is the per-task Arrow allocator backed by Spark's task memory accounting (see + * `CometUdfAllocator`). Implementations must allocate any returned vector or temporary + * buffers from this allocator so off-heap usage is charged to the executing Spark task. * * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count. * UDFs that always have at least one batch-length input can derive length from the inputs and @@ -38,5 +42,5 @@ import org.apache.arrow.vector.ValueVector * per class is cached and shared across native worker threads for the lifetime of the JVM. */ trait CometUDF { - def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector + def evaluate(allocator: BufferAllocator, inputs: Array[ValueVector], numRows: Int): ValueVector } From 432566741e31c85fe1726c2f1814acc58c46522f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 May 2026 08:28:40 -0600 Subject: [PATCH 4/6] test(udf): cover per-task allocator accounting and cleanup --- .../comet/udf/CometUdfAllocatorSuite.scala | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala diff --git a/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala b/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala new file mode 100644 index 0000000000..cc91a977ce --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.IntVector +import org.apache.spark.TaskContext +import org.apache.spark.comet.CometTaskContextShim +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.LongAccumulator + +class CometUdfAllocatorSuite extends AnyFunSuite with BeforeAndAfterAll { + + private var spark: SparkSession = _ + + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession + .builder() + .master("local[2]") + .appName("CometUdfAllocatorSuite") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "64m") + .getOrCreate() + } + + override def afterAll(): Unit = { + if (spark != null) { + spark.stop() + spark = null + } + super.afterAll() + } + + test("acquire registers a MemoryConsumer that charges the task") { + val sc = spark.sparkContext + val taskMemBefore: LongAccumulator = sc.longAccumulator("taskMemBefore") + val taskMemDuring: LongAccumulator = sc.longAccumulator("taskMemDuring") + val taskMemAfter: LongAccumulator = sc.longAccumulator("taskMemAfter") + + sc.parallelize(Seq(0), 1).foreachPartition { _: Iterator[Int] => + val ctx = TaskContext.get() + // TaskContext.taskMemoryManager() is private[spark]; use the shim from a non-spark package. + val tmm = CometTaskContextShim.taskMemoryManager(ctx) + taskMemBefore.add(tmm.getMemoryConsumptionForThisTask) + + val allocator: BufferAllocator = CometUdfAllocator.acquire(ctx) + val vec = new IntVector("test", allocator) + try { + vec.allocateNew(1024) + vec.setValueCount(1024) + taskMemDuring.add(tmm.getMemoryConsumptionForThisTask) + } finally { + vec.close() + } + taskMemAfter.add(tmm.getMemoryConsumptionForThisTask) + } + + assert( + taskMemDuring.value > taskMemBefore.value, + s"task memory should grow while UDF allocator holds buffers; " + + s"before=${taskMemBefore.value} during=${taskMemDuring.value}") + assert( + taskMemAfter.value == taskMemBefore.value, + s"task memory should be released after the allocation closes; " + + s"before=${taskMemBefore.value} after=${taskMemAfter.value}") + } + + test("child allocator is closed and uncached on task completion") { + val sc = spark.sparkContext + val cacheSizeDuring: LongAccumulator = sc.longAccumulator("cacheSizeDuring") + + sc.parallelize(Seq(0), 1).foreachPartition { _: Iterator[Int] => + val ctx = TaskContext.get() + val _ = CometUdfAllocator.acquire(ctx) + cacheSizeDuring.add(CometUdfAllocator.cacheSize().toLong) + } + + assert(cacheSizeDuring.value >= 1L, "allocator should be cached during the task") + assert( + CometUdfAllocator.cacheSize() == 0, + "TaskCompletionListener should remove the entry once the task ends") + } +} From 6d14adb4ece31202844be5c0f9c96ed01b33933f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 May 2026 08:36:03 -0600 Subject: [PATCH 5/6] style(udf): drop redundant s interpolators in allocator test assertions --- .../scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala b/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala index cc91a977ce..508b6985c4 100644 --- a/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala +++ b/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala @@ -78,11 +78,11 @@ class CometUdfAllocatorSuite extends AnyFunSuite with BeforeAndAfterAll { assert( taskMemDuring.value > taskMemBefore.value, - s"task memory should grow while UDF allocator holds buffers; " + + "task memory should grow while UDF allocator holds buffers; " + s"before=${taskMemBefore.value} during=${taskMemDuring.value}") assert( taskMemAfter.value == taskMemBefore.value, - s"task memory should be released after the allocation closes; " + + "task memory should be released after the allocation closes; " + s"before=${taskMemBefore.value} after=${taskMemAfter.value}") } From 80476e3e5104a73fee225be6308e5229b472f314 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 14 May 2026 09:58:58 -0600 Subject: [PATCH 6/6] test(udf): move per-task allocator suite into common module to avoid shading boundary --- .../comet/udf/CometUdfAllocatorTest.java | 128 ++++++++++++++++++ .../comet/udf/CometUdfAllocatorSuite.scala | 104 -------------- 2 files changed, 128 insertions(+), 104 deletions(-) create mode 100644 common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java delete mode 100644 spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala diff --git a/common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java b/common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java new file mode 100644 index 0000000000..2938e7d18e --- /dev/null +++ b/common/src/test/java/org/apache/comet/udf/CometUdfAllocatorTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import java.util.Collections; +import java.util.Iterator; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.comet.CometTaskContextShim; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.util.LongAccumulator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Verifies that {@link CometUdfAllocator} both charges the executing Spark task's {@link + * TaskMemoryManager} for allocations and tears down the per-task allocator on task completion. + * Lives in the {@code common} module to stay on the unshaded side of the Arrow relocation boundary; + * the spark module sees {@code BufferAllocator} as the shaded type. + */ +public class CometUdfAllocatorTest { + + private static SparkSession spark; + private static JavaSparkContext jsc; + + @BeforeClass + public static void setUp() { + spark = + SparkSession.builder() + .master("local[2]") + .appName("CometUdfAllocatorTest") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "67108864") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + } + + @AfterClass + public static void tearDown() { + if (spark != null) { + spark.stop(); + spark = null; + } + } + + @Test + public void acquireRegistersMemoryConsumerThatChargesTask() { + LongAccumulator before = jsc.sc().longAccumulator("before"); + LongAccumulator during = jsc.sc().longAccumulator("during"); + LongAccumulator after = jsc.sc().longAccumulator("after"); + + jsc.parallelize(Collections.singletonList(0), 1) + .foreachPartition( + (VoidFunction>) + it -> { + TaskContext ctx = TaskContext.get(); + TaskMemoryManager tmm = CometTaskContextShim.taskMemoryManager(ctx); + before.add(tmm.getMemoryConsumptionForThisTask()); + + BufferAllocator allocator = CometUdfAllocator.acquire(ctx); + IntVector vec = new IntVector("test", allocator); + try { + vec.allocateNew(1024); + vec.setValueCount(1024); + during.add(tmm.getMemoryConsumptionForThisTask()); + } finally { + vec.close(); + } + after.add(tmm.getMemoryConsumptionForThisTask()); + }); + + assertTrue( + "task memory should grow while allocator holds buffers; before=" + + before.value() + + " during=" + + during.value(), + during.value() > before.value()); + assertEquals( + "task memory should be released after vector closes", before.value(), after.value()); + } + + @Test + public void childAllocatorClosedAndUncachedOnTaskCompletion() { + LongAccumulator cacheSizeDuring = jsc.sc().longAccumulator("cacheSizeDuring"); + + jsc.parallelize(Collections.singletonList(0), 1) + .foreachPartition( + (VoidFunction>) + it -> { + TaskContext ctx = TaskContext.get(); + CometUdfAllocator.acquire(ctx); + cacheSizeDuring.add((long) CometUdfAllocator.cacheSize()); + }); + + assertTrue("allocator should be cached during the task", cacheSizeDuring.value() >= 1L); + assertEquals( + "TaskCompletionListener should remove the entry once the task ends", + 0, + CometUdfAllocator.cacheSize()); + } +} diff --git a/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala b/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala deleted file mode 100644 index 508b6985c4..0000000000 --- a/spark/src/test/scala/org/apache/comet/udf/CometUdfAllocatorSuite.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.udf - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.funsuite.AnyFunSuite - -import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.IntVector -import org.apache.spark.TaskContext -import org.apache.spark.comet.CometTaskContextShim -import org.apache.spark.sql.SparkSession -import org.apache.spark.util.LongAccumulator - -class CometUdfAllocatorSuite extends AnyFunSuite with BeforeAndAfterAll { - - private var spark: SparkSession = _ - - override def beforeAll(): Unit = { - super.beforeAll() - spark = SparkSession - .builder() - .master("local[2]") - .appName("CometUdfAllocatorSuite") - .config("spark.memory.offHeap.enabled", "true") - .config("spark.memory.offHeap.size", "64m") - .getOrCreate() - } - - override def afterAll(): Unit = { - if (spark != null) { - spark.stop() - spark = null - } - super.afterAll() - } - - test("acquire registers a MemoryConsumer that charges the task") { - val sc = spark.sparkContext - val taskMemBefore: LongAccumulator = sc.longAccumulator("taskMemBefore") - val taskMemDuring: LongAccumulator = sc.longAccumulator("taskMemDuring") - val taskMemAfter: LongAccumulator = sc.longAccumulator("taskMemAfter") - - sc.parallelize(Seq(0), 1).foreachPartition { _: Iterator[Int] => - val ctx = TaskContext.get() - // TaskContext.taskMemoryManager() is private[spark]; use the shim from a non-spark package. - val tmm = CometTaskContextShim.taskMemoryManager(ctx) - taskMemBefore.add(tmm.getMemoryConsumptionForThisTask) - - val allocator: BufferAllocator = CometUdfAllocator.acquire(ctx) - val vec = new IntVector("test", allocator) - try { - vec.allocateNew(1024) - vec.setValueCount(1024) - taskMemDuring.add(tmm.getMemoryConsumptionForThisTask) - } finally { - vec.close() - } - taskMemAfter.add(tmm.getMemoryConsumptionForThisTask) - } - - assert( - taskMemDuring.value > taskMemBefore.value, - "task memory should grow while UDF allocator holds buffers; " + - s"before=${taskMemBefore.value} during=${taskMemDuring.value}") - assert( - taskMemAfter.value == taskMemBefore.value, - "task memory should be released after the allocation closes; " + - s"before=${taskMemBefore.value} after=${taskMemAfter.value}") - } - - test("child allocator is closed and uncached on task completion") { - val sc = spark.sparkContext - val cacheSizeDuring: LongAccumulator = sc.longAccumulator("cacheSizeDuring") - - sc.parallelize(Seq(0), 1).foreachPartition { _: Iterator[Int] => - val ctx = TaskContext.get() - val _ = CometUdfAllocator.acquire(ctx) - cacheSizeDuring.add(CometUdfAllocator.cacheSize().toLong) - } - - assert(cacheSizeDuring.value >= 1L, "allocator should be cached during the task") - assert( - CometUdfAllocator.cacheSize() == 0, - "TaskCompletionListener should remove the entry once the task ends") - } -}