diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java index 7203d3a4945..19f37d7f7ec 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -43,20 +43,33 @@ public interface AsyncFunction { */ default void finish(final T value, final SingleResultCallback callback) { final AtomicBoolean callbackInvoked = new AtomicBoolean(false); - try { - this.unsafeFinish(value, (v, e) -> { + // The trampoline bounds two sources of stack growth that occur when + // callbacks complete synchronously on the same thread: + // + // Chain unwinding (unsafeFinish nesting): + // finish -> unsafeFinish[C] -> unsafeFinish[B] -> unsafeFinish[A] -> ... + // + // Callback completion (onResult nesting): + // onResult[A] -> onResult[B] -> onResult[C] -> ... + // + // Without the trampoline, a 1000-step chain would produce ~2000 frames. + // With it, re-entrant calls are deferred to the drain loop, keeping depth constant. + AsyncTrampoline.execute(() -> { + try { + this.unsafeFinish(value, (v, e) -> { + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } + AsyncTrampoline.complete(callback, v, e); + }); + } catch (Throwable t) { if (!callbackInvoked.compareAndSet(false, true)) { - throw new AssertionError(String.format("Callback has been already completed. It could happen " - + "if code throws an exception after invoking an async method. Value: %s", v), e); + throw t; + } else { + AsyncTrampoline.complete(callback, null, t); } - callback.onResult(v, e); - }); - } catch (Throwable t) { - if (!callbackInvoked.compareAndSet(false, true)) { - throw t; - } else { - callback.completeExceptionally(t); } - } + }); } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 6dd89e4d9b0..b00b9c19ae7 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -64,21 +64,34 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback */ default void finish(final SingleResultCallback callback) { final AtomicBoolean callbackInvoked = new AtomicBoolean(false); - try { - this.unsafeFinish((v, e) -> { + // The trampoline bounds two sources of stack growth that occur when + // callbacks complete synchronously on the same thread: + // + // Chain unwinding (unsafeFinish nesting): + // finish -> unsafeFinish[C] -> unsafeFinish[B] -> unsafeFinish[A] -> ... + // + // Callback completion (onResult nesting): + // onResult[A] -> onResult[B] -> onResult[C] -> ... + // + // Without the trampoline, a 1000-step chain would produce ~2000 frames. + // With it, re-entrant calls are deferred to the drain loop, keeping depth constant. + AsyncTrampoline.execute(() -> { + try { + this.unsafeFinish((v, e) -> { + if (!callbackInvoked.compareAndSet(false, true)) { + throw new AssertionError(String.format("Callback has been already completed. It could happen " + + "if code throws an exception after invoking an async method. Value: %s", v), e); + } + AsyncTrampoline.complete(callback, v, e); + }); + } catch (Throwable t) { if (!callbackInvoked.compareAndSet(false, true)) { - throw new AssertionError(String.format("Callback has been already completed. It could happen " - + "if code throws an exception after invoking an async method. Value: %s", v), e); + throw t; + } else { + AsyncTrampoline.complete(callback, null, t); } - callback.onResult(v, e); - }); - } catch (Throwable t) { - if (!callbackInvoked.compareAndSet(false, true)) { - throw t; - } else { - callback.completeExceptionally(t); } - } + }); } /** diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java b/driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java new file mode 100644 index 00000000000..6806f04001b --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java @@ -0,0 +1,114 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.assertions.Assertions; +import com.mongodb.lang.Nullable; + +/** + * A trampoline that converts recursive invocations into an iterative loop, + * preventing stack overflow from deep async chains. + * + *

When async operations complete synchronously on the same thread, two types of + * recursion can occur:

+ *
    + *
  1. Chain unwinding: Nested {@code unsafeFinish()} calls when executing + * a long chain (e.g., 1000 {@code thenRun()} steps)
  2. + *
  3. Callback completion: Nested {@code callback.onResult()} calls when + * each step immediately triggers the next
  4. + *
+ * + *

The trampoline intercepts both: instead of executing work immediately (which + * would deepen the stack), it enqueues the work and returns, allowing the stack to + * unwind. A flat loop at the top then processes enqueued work iteratively.

+ * + *

Since async chains are sequential, at most one task is pending at any time. + * The trampoline uses a single slot rather than a queue.

+ * + *

Usage: wrap work with {@link #execute(Runnable)} or {@link #complete(SingleResultCallback, Object, Throwable)}. + * The first call on a thread becomes the "trampoline owner" and runs the drain loop. + * Subsequent (re-entrant) calls on the same thread enqueue their work and return immediately.

+ * + *

This class is not part of the public API and may be removed or changed at any time

+ */ +@NotThreadSafe +public final class AsyncTrampoline { + + private static final ThreadLocal TRAMPOLINE = new ThreadLocal<>(); + + private AsyncTrampoline() { + } + + /** + * Execute work through the trampoline. If no trampoline is active, become the owner + * and drain all enqueued work. If a trampoline is already active, enqueue and return. + */ + public static void execute(final Runnable work) { + Bounce bounce = TRAMPOLINE.get(); + if (bounce != null) { + // Re-entrant, enqueue and return + bounce.enqueue(work); + } else { + // Become the trampoline owner. + bounce = new Bounce(); + TRAMPOLINE.set(bounce); + try { + bounce.enqueue(work); + // drain all work iteratively + while (bounce.hasWork()) { + bounce.runNext(); + } + } finally { + TRAMPOLINE.remove(); + } + } + } + + public static void complete(final SingleResultCallback callback, @Nullable final T result, @Nullable final Throwable t) { + execute(() -> callback.onResult(result, t)); + } + + /** + * A single-slot container for deferred work. + * At most one task is pending at any time in a sequential async chain. + */ + @NotThreadSafe + private static final class Bounce { + @Nullable + private Runnable work; + + void enqueue(final Runnable task) { + if (this.work != null) { + throw new AssertionError("Trampoline slot already occupied. " + + "This indicates a bug: multiple concurrent operations in a sequential async chain."); + } + this.work = task; + } + + boolean hasWork() { + return work != null; + } + + void runNext() { + Runnable task = this.work; + this.work = null; + Assertions.assertNotNull(task); + task.run(); + } + } +} \ No newline at end of file diff --git a/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java b/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java index 11da1c97f75..a934615cec2 100644 --- a/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java +++ b/driver-core/src/main/com/mongodb/internal/async/SingleResultCallback.java @@ -48,7 +48,7 @@ default AsyncCompletionHandler asHandler() { return new AsyncCompletionHandler() { @Override public void completed(@Nullable final T result) { - onResult(result, null); + complete(result); } @Override public void failed(final Throwable t) { @@ -62,14 +62,14 @@ default void complete(final SingleResultCallback callback) { // is not accidentally used when "complete(T)" should have been used // instead, since results are not marked nullable. Assertions.assertTrue(callback == this); - this.onResult(null, null); + AsyncTrampoline.complete(this, null, null); } default void complete(@Nullable final T result) { - this.onResult(result, null); + AsyncTrampoline.complete(this, result, null); } default void completeExceptionally(final Throwable t) { - this.onResult(null, assertNotNull(t)); + AsyncTrampoline.complete(this, null, assertNotNull(t)); } } diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java index a347a2a7e47..164749e8a2a 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java @@ -16,6 +16,7 @@ package com.mongodb.internal.async.function; import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.internal.async.AsyncTrampoline; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; @@ -70,19 +71,19 @@ private class LoopingCallback implements SingleResultCallback { @Override public void onResult(@Nullable final Void result, @Nullable final Throwable t) { if (t != null) { - wrapped.onResult(null, t); + AsyncTrampoline.complete(wrapped, null, t); } else { boolean continueLooping; try { continueLooping = state.advance(); } catch (Throwable e) { - wrapped.onResult(null, e); + AsyncTrampoline.complete(wrapped, null, e); return; } if (continueLooping) { body.run(this); } else { - wrapped.onResult(result, null); + AsyncTrampoline.complete(wrapped, result, null); } } } diff --git a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java index 16f6f2e7086..0834e7b2b02 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java @@ -16,6 +16,7 @@ package com.mongodb.internal.async.function; import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.internal.async.AsyncTrampoline; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; @@ -118,12 +119,12 @@ public void onResult(@Nullable final R result, @Nullable final Throwable t) { try { state.advanceOrThrow(t, onAttemptFailureOperator, retryPredicate); } catch (Throwable failedResult) { - wrapped.onResult(null, failedResult); + AsyncTrampoline.complete(wrapped, null, failedResult); return; } asyncFunction.get(this); } else { - wrapped.onResult(result, null); + AsyncTrampoline.complete(wrapped, result, null); } } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java index 9a9b7552d3e..818923a197f 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java @@ -20,12 +20,14 @@ import com.mongodb.internal.TimeoutSettings; import org.junit.jupiter.api.Test; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertTrue; abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase { private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0)); @@ -990,4 +992,66 @@ void testDerivation() { }).finish(callback); }); } + + @Test + void testStackDepthBounded() { + AtomicInteger maxDepth = new AtomicInteger(0); + AtomicInteger minDepth = new AtomicInteger(Integer.MAX_VALUE); + AtomicInteger maxMongoDepth = new AtomicInteger(0); + AtomicInteger minMongoDepth = new AtomicInteger(Integer.MAX_VALUE); + AtomicInteger stepCount = new AtomicInteger(0); + // Capture one sample of mongodb package frames for printing + String[][] sampleMongoFrames = {null}; + + AsyncRunnable chain = beginAsync(); + for (int i = 0; i < 1000; i++) { + chain = chain.thenRun(c -> { + stepCount.incrementAndGet(); + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + int depth = stack.length; + maxDepth.updateAndGet(current -> Math.max(current, depth)); + minDepth.updateAndGet(current -> Math.min(current, depth)); + int mongoFrames = 0; + for (StackTraceElement frame : stack) { + if (frame.getClassName().startsWith("com.mongodb")) { + mongoFrames++; + } + } + int mf = mongoFrames; + maxMongoDepth.updateAndGet(current -> Math.max(current, mf)); + minMongoDepth.updateAndGet(current -> Math.min(current, mf)); + // Capture first sample + if (sampleMongoFrames[0] == null) { + String[] frames = new String[mf]; + int idx = 0; + for (StackTraceElement frame : stack) { + if (frame.getClassName().startsWith("com.mongodb")) { + frames[idx++] = frame.getClassName() + "." + frame.getMethodName() + + "(" + frame.getFileName() + ":" + frame.getLineNumber() + ")"; + } + } + sampleMongoFrames[0] = frames; + } + c.complete(c); + }); + } + + chain.finish((v, e) -> { + assertTrue(stepCount.get() == 1000, "Expected 1000 steps, got " + stepCount.get()); + int depth = maxDepth.get(); + int mongoDepth = maxMongoDepth.get(); + String summary = "Stack depth: min=" + minDepth.get() + ", max=" + depth + + " | MongoDB frames: min=" + minMongoDepth.get() + ", max=" + mongoDepth; + System.out.println(summary); + if (sampleMongoFrames[0] != null) { + System.out.println("MongoDB stack frames (sample):"); + for (int i = 0; i < sampleMongoFrames[0].length; i++) { + System.out.println(" " + (i + 1) + ". " + sampleMongoFrames[0][i]); + } + } + assertTrue(depth < 200, + "Stack depth too deep (min=" + minDepth.get() + ", max=" + depth + + "). Trampoline may not be working correctly."); + }); + } }