Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,33 @@ public interface AsyncFunction<T, R> {
*/
default void finish(final T value, final SingleResultCallback<R> callback) {
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish(value, (v, e) -> {
// The trampoline bounds two sources of stack growth that occur when
// callbacks complete synchronously on the same thread:
//
// Chain unwinding (unsafeFinish nesting):
// finish -> unsafeFinish[C] -> unsafeFinish[B] -> unsafeFinish[A] -> ...
//
// Callback completion (onResult nesting):
// onResult[A] -> onResult[B] -> onResult[C] -> ...
//
// Without the trampoline, a 1000-step chain would produce ~2000 frames.
// With it, re-entrant calls are deferred to the drain loop, keeping depth constant.
AsyncTrampoline.execute(() -> {
try {
this.unsafeFinish(value, (v, e) -> {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
AsyncTrampoline.complete(callback, v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
throw t;
} else {
AsyncTrampoline.complete(callback, null, t);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
}
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,34 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
*/
default void finish(final SingleResultCallback<T> callback) {
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish((v, e) -> {
// The trampoline bounds two sources of stack growth that occur when
// callbacks complete synchronously on the same thread:
//
// Chain unwinding (unsafeFinish nesting):
// finish -> unsafeFinish[C] -> unsafeFinish[B] -> unsafeFinish[A] -> ...
//
// Callback completion (onResult nesting):
// onResult[A] -> onResult[B] -> onResult[C] -> ...
//
// Without the trampoline, a 1000-step chain would produce ~2000 frames.
// With it, re-entrant calls are deferred to the drain loop, keeping depth constant.
AsyncTrampoline.execute(() -> {
try {
this.unsafeFinish((v, e) -> {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
AsyncTrampoline.complete(callback, v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
throw t;
} else {
AsyncTrampoline.complete(callback, null, t);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
}
}
});
}

/**
Expand Down
114 changes: 114 additions & 0 deletions driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal.async;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.assertions.Assertions;
import com.mongodb.lang.Nullable;

/**
* A trampoline that converts recursive invocations into an iterative loop,
* preventing stack overflow from deep async chains.
*
* <p>When async operations complete synchronously on the same thread, two types of
* recursion can occur:</p>
* <ol>
* <li><b>Chain unwinding</b>: Nested {@code unsafeFinish()} calls when executing
* a long chain (e.g., 1000 {@code thenRun()} steps)</li>
* <li><b>Callback completion</b>: Nested {@code callback.onResult()} calls when
* each step immediately triggers the next</li>
* </ol>
*
* <p>The trampoline intercepts both: instead of executing work immediately (which
* would deepen the stack), it enqueues the work and returns, allowing the stack to
* unwind. A flat loop at the top then processes enqueued work iteratively.</p>
*
* <p>Since async chains are sequential, at most one task is pending at any time.
* The trampoline uses a single slot rather than a queue.</p>
*
* <p>Usage: wrap work with {@link #execute(Runnable)} or {@link #complete(SingleResultCallback, Object, Throwable)}.
* The first call on a thread becomes the "trampoline owner" and runs the drain loop.
* Subsequent (re-entrant) calls on the same thread enqueue their work and return immediately.</p>
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
@NotThreadSafe
public final class AsyncTrampoline {

private static final ThreadLocal<Bounce> TRAMPOLINE = new ThreadLocal<>();

private AsyncTrampoline() {
}

/**
* Execute work through the trampoline. If no trampoline is active, become the owner
* and drain all enqueued work. If a trampoline is already active, enqueue and return.
*/
public static void execute(final Runnable work) {
Bounce bounce = TRAMPOLINE.get();
if (bounce != null) {
// Re-entrant, enqueue and return
bounce.enqueue(work);
} else {
// Become the trampoline owner.
bounce = new Bounce();
TRAMPOLINE.set(bounce);
try {
bounce.enqueue(work);
// drain all work iteratively
while (bounce.hasWork()) {
bounce.runNext();
}
} finally {
TRAMPOLINE.remove();
}
}
}

public static <T> void complete(final SingleResultCallback<T> callback, @Nullable final T result, @Nullable final Throwable t) {
execute(() -> callback.onResult(result, t));
}

/**
* A single-slot container for deferred work.
* At most one task is pending at any time in a sequential async chain.
*/
@NotThreadSafe
private static final class Bounce {
@Nullable
private Runnable work;

void enqueue(final Runnable task) {
if (this.work != null) {
throw new AssertionError("Trampoline slot already occupied. "
+ "This indicates a bug: multiple concurrent operations in a sequential async chain.");
}
this.work = task;
}

boolean hasWork() {
return work != null;
}

void runNext() {
Runnable task = this.work;
this.work = null;
Assertions.assertNotNull(task);
task.run();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ default AsyncCompletionHandler<T> asHandler() {
return new AsyncCompletionHandler<T>() {
@Override
public void completed(@Nullable final T result) {
onResult(result, null);
complete(result);
}
@Override
public void failed(final Throwable t) {
Expand All @@ -62,14 +62,14 @@ default void complete(final SingleResultCallback<Void> callback) {
// is not accidentally used when "complete(T)" should have been used
// instead, since results are not marked nullable.
Assertions.assertTrue(callback == this);
this.onResult(null, null);
AsyncTrampoline.complete(this, null, null);
}

default void complete(@Nullable final T result) {
this.onResult(result, null);
AsyncTrampoline.complete(this, result, null);
}

default void completeExceptionally(final Throwable t) {
this.onResult(null, assertNotNull(t));
AsyncTrampoline.complete(this, null, assertNotNull(t));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.mongodb.internal.async.function;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.internal.async.AsyncTrampoline;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.Nullable;

Expand Down Expand Up @@ -70,19 +71,19 @@ private class LoopingCallback implements SingleResultCallback<Void> {
@Override
public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
if (t != null) {
wrapped.onResult(null, t);
AsyncTrampoline.complete(wrapped, null, t);
} else {
boolean continueLooping;
try {
continueLooping = state.advance();
} catch (Throwable e) {
wrapped.onResult(null, e);
AsyncTrampoline.complete(wrapped, null, e);
return;
}
if (continueLooping) {
body.run(this);
} else {
wrapped.onResult(result, null);
AsyncTrampoline.complete(wrapped, result, null);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.mongodb.internal.async.function;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.internal.async.AsyncTrampoline;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
Expand Down Expand Up @@ -118,12 +119,12 @@ public void onResult(@Nullable final R result, @Nullable final Throwable t) {
try {
state.advanceOrThrow(t, onAttemptFailureOperator, retryPredicate);
} catch (Throwable failedResult) {
wrapped.onResult(null, failedResult);
AsyncTrampoline.complete(wrapped, null, failedResult);
return;
}
asyncFunction.get(this);
} else {
wrapped.onResult(result, null);
AsyncTrampoline.complete(wrapped, result, null);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import com.mongodb.internal.TimeoutSettings;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static org.junit.jupiter.api.Assertions.assertTrue;

abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
Expand Down Expand Up @@ -990,4 +992,66 @@ void testDerivation() {
}).finish(callback);
});
}

@Test
void testStackDepthBounded() {
AtomicInteger maxDepth = new AtomicInteger(0);
AtomicInteger minDepth = new AtomicInteger(Integer.MAX_VALUE);
AtomicInteger maxMongoDepth = new AtomicInteger(0);
AtomicInteger minMongoDepth = new AtomicInteger(Integer.MAX_VALUE);
AtomicInteger stepCount = new AtomicInteger(0);
// Capture one sample of mongodb package frames for printing
String[][] sampleMongoFrames = {null};

AsyncRunnable chain = beginAsync();
for (int i = 0; i < 1000; i++) {
chain = chain.thenRun(c -> {
stepCount.incrementAndGet();
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
int depth = stack.length;
maxDepth.updateAndGet(current -> Math.max(current, depth));
minDepth.updateAndGet(current -> Math.min(current, depth));
int mongoFrames = 0;
for (StackTraceElement frame : stack) {
if (frame.getClassName().startsWith("com.mongodb")) {
mongoFrames++;
}
}
int mf = mongoFrames;
maxMongoDepth.updateAndGet(current -> Math.max(current, mf));
minMongoDepth.updateAndGet(current -> Math.min(current, mf));
// Capture first sample
if (sampleMongoFrames[0] == null) {
String[] frames = new String[mf];
int idx = 0;
for (StackTraceElement frame : stack) {
if (frame.getClassName().startsWith("com.mongodb")) {
frames[idx++] = frame.getClassName() + "." + frame.getMethodName()
+ "(" + frame.getFileName() + ":" + frame.getLineNumber() + ")";
}
}
sampleMongoFrames[0] = frames;
}
c.complete(c);
});
}

chain.finish((v, e) -> {
assertTrue(stepCount.get() == 1000, "Expected 1000 steps, got " + stepCount.get());
int depth = maxDepth.get();
int mongoDepth = maxMongoDepth.get();
String summary = "Stack depth: min=" + minDepth.get() + ", max=" + depth
+ " | MongoDB frames: min=" + minMongoDepth.get() + ", max=" + mongoDepth;
System.out.println(summary);
if (sampleMongoFrames[0] != null) {
System.out.println("MongoDB stack frames (sample):");
for (int i = 0; i < sampleMongoFrames[0].length; i++) {
System.out.println(" " + (i + 1) + ". " + sampleMongoFrames[0][i]);
}
}
assertTrue(depth < 200,
"Stack depth too deep (min=" + minDepth.get() + ", max=" + depth
+ "). Trampoline may not be working correctly.");
});
}
}