Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,35 @@ default AsyncRunnable thenRunRetryingWhile(
});
}

/**
* This method is equivalent to a while loop, where the condition is checked before each iteration.
* If the condition returns {@code false} on the first check, the body is never executed.
*
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop
* @return the composition of this and the looping branch
* @see AsyncCallbackLoop
*/
default AsyncRunnable thenRunWhileLoop(final BooleanSupplier whileCheck, final AsyncRunnable loopBodyRunnable) {
return thenRun(finalCallback -> {
LoopState loopState = new LoopState();
new AsyncCallbackLoop(loopState, iterationCallback -> {

if (loopState.breakAndCompleteIf(() -> !whileCheck.getAsBoolean(), iterationCallback)) {
return;
}
loopBodyRunnable.finish((result, t) -> {
if (t != null) {
iterationCallback.completeExceptionally(t);
return;
}
iterationCallback.complete(iterationCallback);
});

}).run(finalCallback);
});
}

/**
* This method is equivalent to a do-while loop, where the loop body is executed first and
* then the condition is checked to determine whether the loop should continue.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal.async;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.lang.Nullable;

/**
* A trampoline that converts recursive callback invocations into an iterative loop,
* preventing stack overflow in async loops.
*
* <p>When async loop iterations complete synchronously on the same thread, callback
* recursion occurs: each iteration's {@code callback.onResult()} immediately triggers
* the next iteration, causing unbounded stack growth. For example, a 1000-iteration
* loop would create > 1000 stack frames and cause {@code StackOverflowError}.</p>
*
* <p>The trampoline intercepts this recursion: instead of executing the next iteration
* immediately (which would deepen the stack), it enqueues the continuation and returns, allowing
* the stack to unwind. A flat loop at the top then processes enqueued continuation iteratively,
* maintaining constant stack depth regardless of iteration count.</p>
*
* <p>Since async chains are sequential, at most one task is pending at any time.
* The trampoline uses a single slot rather than a queue.</p>
*
* The first call on a thread becomes the "trampoline owner" and runs the drain loop.
* Subsequent (re-entrant) calls on the same thread enqueue their continuation and return immediately.</p>
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
@NotThreadSafe
public final class AsyncTrampoline {

private static final ThreadLocal<ContinuationHolder> TRAMPOLINE = new ThreadLocal<>();

private AsyncTrampoline() {}

/**
* Execute continuation through the trampoline. If no trampoline is active, become the owner
* and drain all enqueued continuations. If a trampoline is already active, enqueue and return.
*/
public static void run(final Runnable continuation) {
ContinuationHolder continuationHolder = TRAMPOLINE.get();
if (continuationHolder != null) {
continuationHolder.enqueue(continuation);
} else {
continuationHolder = new ContinuationHolder();
TRAMPOLINE.set(continuationHolder);
try {
continuation.run();
while (continuationHolder.continuation != null) {
Runnable continuationToRun = continuationHolder.continuation;
continuationHolder.continuation = null;
continuationToRun.run();
}
} finally {
TRAMPOLINE.remove();
}
}
}

/**
* A single-slot container for continuation.
* At most one continuation is pending at any time in a sequential async chain.
*/
@NotThreadSafe
private static final class ContinuationHolder {
@Nullable
private Runnable continuation;

void enqueue(final Runnable continuation) {
if (this.continuation != null) {
throw new AssertionError("Trampoline slot already occupied");
}
this.continuation = continuation;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.mongodb.internal.async.function;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.internal.async.AsyncTrampoline;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.Nullable;

Expand Down Expand Up @@ -62,9 +63,11 @@ public void run(final SingleResultCallback<Void> callback) {
@NotThreadSafe
private class LoopingCallback implements SingleResultCallback<Void> {
private final SingleResultCallback<Void> wrapped;
private final Runnable nextIteration;

LoopingCallback(final SingleResultCallback<Void> callback) {
wrapped = callback;
nextIteration = () -> AsyncCallbackLoop.this.body.run(this);
}

@Override
Expand All @@ -80,7 +83,7 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
return;
}
if (continueLooping) {
body.run(this);
AsyncTrampoline.run(nextIteration);
} else {
wrapped.onResult(result, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static org.junit.jupiter.api.Assertions.assertEquals;

abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
Expand Down Expand Up @@ -723,6 +724,120 @@ void testTryCatchTestAndRethrow() {
});
}

@Test
void testWhile() {
// last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
assertBehavesSameVariations(10,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(callback);
});
}

@Test
void testWhileWithThenRun() {
// while: last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
// trailing sync: 1(exception) + 1(success) = 2
// 6(while exception) + 4(while success) * 2(trailing sync) = 14
assertBehavesSameVariations(14,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
sync(counter + 1);
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRun(c -> {
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(c);
}).thenRun(c -> {
async(counter.get() + 1, c);
}).finish(callback);
});
}

@Test
void testNestedWhileLoops() {
// inner while: 4 success + 6 exception = 10
// last inner iteration: 3 < 3 = 1
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 1(transition to next iteration) = 12
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 12(transition to next iteration) = 56
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 56(transition to next iteration) = 232
assertBehavesSameVariations(232,
() -> {
int outer = 0;
while (outer < 3 && plainTest(outer)) {
int inner = 0;
while (inner < 3 && plainTest(inner)) {
sync(outer + inner);
inner++;
}
outer++;
}
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> outer.get() < 3 && plainTest(outer.get()), c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(
() -> inner.get() < 3 && plainTest(inner.get()),
c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}).finish(callback);
});
}

@Test
void testWhileLoopStackConstant() {
int depthWith100 = maxStackDepthForIterations(100);
int depthWith10000 = maxStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000, "Stack depth should be constant regardless of iteration count (trampoline)");
}

private int maxStackDepthForIterations(final int iterations) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this method below all other test cases
so that we preserve the order

  1. tests first
  2. utility + private method second

Copy link
Copy Markdown
Member Author

@vbabanin vbabanin Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that separating utilities from tests is useful in general. However, these methods each have exactly one caller (the test directly above). I kept them inline for locality so the reader sees the test and its implementation detail together without scrolling. The codebase doesn't enforce a strict ordering convention - CrudProseTest.java#L357 and several other test classes use the same inline approach.

MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < iterations, c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}).finish((v, t) -> {});

assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testRetryLoop() {
assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1,
Expand Down Expand Up @@ -768,6 +883,65 @@ void testDoWhileLoop() {
});
}

@Test
void testNestedDoWhileLoops() {
// inner do-while: 3 success + 5 exception = 8
// last outer iteration: 3 < 3 = 1
// 5(inner exception) + 3(inner success) * 1(transition to next iteration) = 8
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 8(transition to next iteration)) = 35
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 35(transition to next iteration)) = 116
assertBehavesSameVariations(116,
() -> {
int outer = 0;
do {
int inner = 0;
do {
sync(outer + inner);
inner++;
} while (inner < 3 && plainTest(inner));
outer++;
} while (outer < 3 && plainTest(outer));
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}, () -> inner.get() < 3 && plainTest(inner.get())
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}, () -> outer.get() < 3 && plainTest(outer.get())).finish(callback);
});
}

@Test
void testDoWhileLoopStackConstant() {
int depthWith100 = maxDoWhileStackDepthForIterations(100);
int depthWith10000 = maxDoWhileStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000,
"Stack depth should be constant regardless of iteration count");
}

private int maxDoWhileStackDepthForIterations(final int iterations) {
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}, () -> counter.get() < iterations).finish((v, t) -> {});
assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testFinallyWithPlainInsideTry() {
// (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import static java.lang.String.format;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -272,14 +273,16 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
}

assertTrue(wasCalledFuture.isDone(), "callback should have been called");
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());
assertEquals(expectedException == null, actualException.get() == null,
"both or neither should have produced an exception");
format("both or neither should have produced an exception. Expected exception: %s, actual exception: %s",
expectedException,
actualException.get()));
if (expectedException != null) {
assertEquals(expectedException.getMessage(), actualException.get().getMessage());
assertEquals(expectedException.getClass(), actualException.get().getClass());
}
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());

listener.clear();
}
Expand Down