diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index e404e2b8152..ecc83e7e005 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -16,6 +16,7 @@ package com.mongodb.internal.async; +import com.mongodb.assertions.Assertions; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.async.function.AsyncCallbackLoop; import com.mongodb.internal.async.function.LoopState; @@ -206,6 +207,15 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu }; } + /** + * @param condition The condition to check before each iteration + * @param body The body to run on each iteration + * @return the composition of this runnable and the loop, a runnable + */ + default AsyncRunnable loopWhile(final BooleanSupplier condition, final AsyncRunnable body) { + throw Assertions.fail("Not implemented"); + } + /** * @param supplier The supplier to supply using after this runnable * @return the composition of this runnable and the supplier, a supplier diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java index a347a2a7e47..4ca496794af 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java @@ -41,50 +41,160 @@ public final class AsyncCallbackLoop implements AsyncCallbackRunnable { private final LoopState state; private final AsyncCallbackRunnable body; + private final ThreadLocal sameThreadDetector; /** * @param state The {@link LoopState} to be deemed as initial for the purpose of the new {@link AsyncCallbackLoop}. * @param body The body of the loop. */ public AsyncCallbackLoop(final LoopState state, final AsyncCallbackRunnable body) { - this.state = state; this.body = body; + this.state = state; + sameThreadDetector = ThreadLocal.withInitial(() -> SameThreadDetectionStatus.NEGATIVE); } @Override public void run(final SingleResultCallback callback) { - body.run(new LoopingCallback(callback)); + run(false, callback); } /** - * This callback is allowed to be completed more than once. + * Initiates a new iteration of the loop by invoking + * {@link #body}{@code .}{@link AsyncCallbackRunnable#run(SingleResultCallback) run}. + * The initiated iteration may be executed either synchronously or asynchronously with the method that initiated it: + *
    + *
  • synchronous execution—completion of the initiated iteration is guaranteed to happen-before the method completion; + *
      + *
    • Note that the formulations + *
        + *
      1. "completion of the initiated iteration is guaranteed to happen-before the method completion"
      2. + *
      3. "completion of the initiated iteration happens-before the method completion"
      4. + *
      + * are different: the former is about the program while the latter is about the execution, and follows from the former. + * For us the former is useful. + *
    • + *
    + *
  • + *
  • asynchronous execution—the aforementioned guarantee does not exist. + *
      + *
    • Note that the formulations + *
        + *
      1. "the aforementioned guarantee does not exist"
      2. + *
      3. "the aforementioned relation does not exist"
      4. + *
      + * are different: the former is about the program while the latter is about the execution, and follows from the former. + * For us the former is useful. + *
    • + *
    + *
  • + *
+ * + *

If another iteration is needed, it is initiated from the callback passed to + * {@link #body}{@code .}{@link AsyncCallbackRunnable#run(SingleResultCallback) run} + * by invoking {@link #run(boolean, SingleResultCallback)}. + * Completing the initiated iteration is {@linkplain SingleResultCallback#onResult(Object, Throwable) invoking} the callback. + * Thus, it is guaranteed that all iterations are executed sequentially with each other + * (that is, completion of one iteration happens-before initiation of the next one) + * regardless of them being executed synchronously or asynchronously with the method that initiated them. + * + *

Initiating any but the {@linkplain LoopState#isFirstIteration() first} iteration is done using trampolining, + * which allows us to do it iteratively rather than recursively, if iterations are executed synchronously, + * and ensures stack usage does not increase with the number of iterations. + * + * @return {@code true} iff it is known that another iteration must be initiated. + * This information is used only for trampolining, and is available only if the iteration executed synchronously. + * + *

It is impossible to detect whether an iteration is executed synchronously. + * It is, however, possible to detect whether an iteration is executed in the same thread as the method that initiated it, + * and we use this as a proxy indicator of synchronous execution. Unfortunately, this means we do not support / behave incorrectly + * if an iteration is executed synchronously but in a thread different from the one in which the method that + * initiated the iteration was invoked. + * + *

The above limitation should not be a problem in practice: + *

    + *
  • the only way to execute an iteration synchronously but in a different thread is to block the thread that + * initiated the iteration by waiting for completion of the iteration by that other thread;
  • + *
  • blocking a thread is forbidden in asynchronous code, and we do not do it;
  • + *
  • therefore, we would not have an iteration that is executed synchronously but in a different thread.
  • + *
*/ - @NotThreadSafe - private class LoopingCallback implements SingleResultCallback { - private final SingleResultCallback wrapped; - - LoopingCallback(final SingleResultCallback callback) { - wrapped = callback; - } - - @Override - public void onResult(@Nullable final Void result, @Nullable final Throwable t) { - if (t != null) { - wrapped.onResult(null, t); - } else { - boolean continueLooping; - try { - continueLooping = state.advance(); - } catch (Throwable e) { - wrapped.onResult(null, e); + boolean run(final boolean trampolining, final SingleResultCallback afterLoopCallback) { + // The `trampoliningResult` variable must be used only if the initiated iteration is executed synchronously with + // the current method, which must be detected separately. + // + // It may be tempting to detect whether the iteration was executed synchronously by reading from the variable + // and observing a write that is part of the callback execution. However, if the iteration is executed asynchronously with + // the current method, then the aforementioned conflicting write and read actions are not ordered by + // the happens-before relation, the execution contains a data race and the read is allowed to observe the write. + // If such observation happens when the iteration is executed asynchronously, then we have a false positive. + // Furthermore, depending on the nature of the value read, it may not be trustworthy. + // + // Making `trampoliningResult` a `volatile`, or even making it an `AtomicReference`/`AtomicInteger` and calling `compareAndSet` + // does not resolve the issue: it gets rid of the data race, but still leave us with a race condition + // that allows for false positives. + boolean[] trampoliningResult = {false}; + sameThreadDetector.set(SameThreadDetectionStatus.PROBING); + body.run((r, t) -> { + if (completeIfNeeded(afterLoopCallback, r, t)) { + // Bounce if we are trampolining and the iteration was executed synchronously, + // trampolining completes and so is the loop; + // otherwise, the loop simply completes. + return; + } + if (trampolining) { + boolean sameThread = sameThreadDetector.get().equals(SameThreadDetectionStatus.PROBING); + if (sameThread) { + // Bounce if we are trampolining and the iteration was executed synchronously; + // otherwise proceed to initiate trampolining. + sameThreadDetector.set(SameThreadDetectionStatus.POSITIVE); + trampoliningResult[0] = true; return; - } - if (continueLooping) { - body.run(this); } else { - wrapped.onResult(result, null); + sameThreadDetector.remove(); } } + // initiate trampolining + boolean anotherIterationNeeded; + do { + anotherIterationNeeded = run(true, afterLoopCallback); + } while (anotherIterationNeeded); + }); + try { + return sameThreadDetector.get().equals(SameThreadDetectionStatus.POSITIVE) && trampoliningResult[0]; + } finally { + sameThreadDetector.remove(); + } + } + + /** + * @return {@code true} iff the {@code afterLoopCallback} was + * {@linkplain SingleResultCallback#onResult(Object, Throwable) completed}. + */ + private boolean completeIfNeeded(final SingleResultCallback afterLoopCallback, + @Nullable final Void result, @Nullable final Throwable t) { + if (t != null) { + afterLoopCallback.onResult(null, t); + return true; + } else { + boolean anotherIterationNeeded; + try { + anotherIterationNeeded = state.advance(); + } catch (Throwable e) { + afterLoopCallback.onResult(null, e); + return true; + } + if (anotherIterationNeeded) { + return false; + } else { + afterLoopCallback.onResult(result, null); + return true; + } } } + + private enum SameThreadDetectionStatus { + NEGATIVE, + PROBING, + POSITIVE + } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java index 9a9b7552d3e..f5669d2a5b8 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java @@ -18,14 +18,17 @@ import com.mongodb.MongoException; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertTrue; abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase { private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0)); @@ -724,7 +727,85 @@ void testTryCatchTestAndRethrow() { } @Test + @Disabled("Tests AsyncRunnable.loopWhile, but we agreed to improve and test AsyncRunnable.thenRunDoWhileLoop") + void testWhile() { + assertBehavesSameVariations(10, // TODO check expected variations + () -> { + int i = 0; + while (i < 3 && plainTest(i)) { + i++; + sync(i); + } + }, + (callback) -> { + final int[] i = new int[1]; + beginAsync().loopWhile(() -> i[0] < 3 && plainTest(i[0]), (c2) -> { + i[0]++; + async(i[0], c2); + }).finish(callback); + }); + } + + @Test + @Disabled("Tests AsyncRunnable.loopWhile, but we agreed to improve and test AsyncRunnable.thenRunDoWhileLoop") + void testWhile2() { + assertBehavesSameVariations(14, // TODO check expected variations + () -> { + int i = 0; + while (i < 3 && plainTest(i)) { + i++; + sync(i); + } + sync(i + 100); + }, + (callback) -> { + final int[] i = new int[1]; + beginAsync().thenRun(c -> { + beginAsync().loopWhile(() -> i[0] < 3 && plainTest(i[0]), (c2) -> { + i[0]++; + async(i[0], c2); + }).finish(c); + }).thenRun(c -> { + async(i[0] + 100, c); + }).finish(callback); + }); + } + + @Test + @Disabled("Tests AsyncRunnable.loopWhile, but we agreed to improve and test AsyncRunnable.thenRunDoWhileLoop") void testRetryLoop() { + assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1, + () -> { + while (true) { + try { + sync(plainTest(0) ? 1 : 2); + } catch (RuntimeException e) { + if (e.getMessage().equals("exception-1")) { + continue; + } + throw e; + } + break; + } + }, + (callback) -> { + final boolean[] shouldContinue = new boolean[]{true}; + beginAsync().loopWhile(() -> shouldContinue[0], (c) -> { + beginAsync().thenRun(c2 -> { + async(plainTest(0) ? 1 : 2, c2); + }).thenRun(c2 -> { + shouldContinue[0] = false; + c2.complete(c2); + }).onErrorIf(e -> e.getMessage().equals("exception-1"), (e, c2) -> { + c2.complete(c2); + }).finish(c); + }).finish(callback); + }); + } + + @Test + void testThenRunRetryingWhile() { + for (int i = 0; i < 1000; i++) { assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1, () -> { while (true) { @@ -746,10 +827,11 @@ void testRetryLoop() { e -> e.getMessage().equals("exception-1") ).finish(callback); }); - } + }} @Test void testDoWhileLoop() { + for (int i = 0; i < 1000; i++) { assertBehavesSameVariations(67, () -> { do { @@ -766,6 +848,25 @@ void testDoWhileLoop() { () -> plainTest(2) ).finish(finalCallback); }); + }} + + @Test + void testDoWhileLoop2() { + assertBehavesSameVariations(8, + () -> { + int i = 0; + do { + i++; + sync(i); + } while (i < 3 && plainTest(i)); + }, + (callback) -> { + final int[] i = new int[1]; + beginAsync().thenRunDoWhileLoop((c) -> { + i[0]++; + async(i[0], c); + }, () -> i[0] < 3 && plainTest(i[0])).finish(callback); + }); } @Test @@ -990,4 +1091,67 @@ void testDerivation() { }).finish(callback); }); } + + @Test + @Disabled("Tests AsyncRunnable.thenRun/finish, but we agreed to improve and test AsyncRunnable.thenRunDoWhileLoop") + void testStackDepthBounded() { + AtomicInteger maxDepth = new AtomicInteger(0); + AtomicInteger minDepth = new AtomicInteger(Integer.MAX_VALUE); + AtomicInteger maxMongoDepth = new AtomicInteger(0); + AtomicInteger minMongoDepth = new AtomicInteger(Integer.MAX_VALUE); + AtomicInteger stepCount = new AtomicInteger(0); + // Capture one sample of mongodb package frames for printing + String[][] sampleMongoFrames = {null}; + + AsyncRunnable chain = beginAsync(); + for (int i = 0; i < 1000; i++) { + chain = chain.thenRun(c -> { + stepCount.incrementAndGet(); + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + int depth = stack.length; + maxDepth.updateAndGet(current -> Math.max(current, depth)); + minDepth.updateAndGet(current -> Math.min(current, depth)); + int mongoFrames = 0; + for (StackTraceElement frame : stack) { + if (frame.getClassName().startsWith("com.mongodb")) { + mongoFrames++; + } + } + int mf = mongoFrames; + maxMongoDepth.updateAndGet(current -> Math.max(current, mf)); + minMongoDepth.updateAndGet(current -> Math.min(current, mf)); + // Capture first sample + if (sampleMongoFrames[0] == null) { + String[] frames = new String[mf]; + int idx = 0; + for (StackTraceElement frame : stack) { + if (frame.getClassName().startsWith("com.mongodb")) { + frames[idx++] = frame.getClassName() + "." + frame.getMethodName() + + "(" + frame.getFileName() + ":" + frame.getLineNumber() + ")"; + } + } + sampleMongoFrames[0] = frames; + } + c.complete(c); + }); + } + + chain.finish((v, e) -> { + assertTrue(stepCount.get() == 1000, "Expected 1000 steps, got " + stepCount.get()); + int depth = maxDepth.get(); + int mongoDepth = maxMongoDepth.get(); + String summary = "Stack depth: min=" + minDepth.get() + ", max=" + depth + + " | MongoDB frames: min=" + minMongoDepth.get() + ", max=" + mongoDepth; + System.out.printf(summary + "%n"); + if (sampleMongoFrames[0] != null) { + System.out.printf("MongoDB stack frames (sample):%n"); + for (int i = 0; i < sampleMongoFrames[0].length; i++) { + System.out.printf(" " + (i + 1) + ". " + sampleMongoFrames[0][i] + "%n"); + } + } + assertTrue(depth < 200, + "Stack depth too deep (min=" + minDepth.get() + ", max=" + depth + + "). Trampoline may not be working correctly."); + }); + } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncLoopTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncLoopTest.java new file mode 100644 index 00000000000..f8ea79e0e09 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncLoopTest.java @@ -0,0 +1,394 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.async; + +import com.mongodb.internal.async.function.AsyncCallbackLoop; +import com.mongodb.internal.async.function.LoopState; +import com.mongodb.internal.time.StartTime; +import com.mongodb.lang.Nullable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class AsyncLoopTest { + private static final int MAX_STACK_DEPTH = 500; + + @ParameterizedTest + @CsvSource({ + "10" + }) + void testDemo(final int iterations) throws Exception { + System.err.printf("baselineStackDepth=%d%n%n", Thread.currentThread().getStackTrace().length); + CompletableFuture join = new CompletableFuture<>(); + LoopState loopState = new LoopState(); + new AsyncCallbackLoop(loopState, c -> { + int iteration = loopState.iteration(); + System.err.printf("iteration=%d, callStackDepth=%d%n", iteration, Thread.currentThread().getStackTrace().length); + if (!loopState.breakAndCompleteIf(() -> iteration == (iterations - 1), c)) { + c.complete(c); + } + }).run((r, t) -> { + System.err.printf("test callback completed callStackDepth=%d, r=%s, t=%s%n", + Thread.currentThread().getStackTrace().length, r, exceptionToString(t)); + complete(join, r, t); + }); + join.get(); + System.err.printf("%nDONE%n%n"); + } + + private enum IterationExecutionType { + SYNC_SAME_THREAD, + SYNC_DIFFERENT_THREAD, + ASYNC, + MIXED_SYNC_SAME_THREAD_AND_ASYNC + } + + private enum Verbocity { + VERBOSE, + COMPACT; + + /** + * Every {@value}s message is printed. + */ + private static final int COMPACTNESS = 50_000; + } + + private enum ThreadManagement { + NEW_THREAD_PER_TASK, + REUSE_THREADS + } + + @ParameterizedTest() + @CsvSource({ + "250_000, 0, SYNC_SAME_THREAD, 0, COMPACT, 0, REUSE_THREADS", + "250_000, 0, ASYNC, 0, COMPACT, 0, NEW_THREAD_PER_TASK", + "250_000, 0, ASYNC, 0, COMPACT, 1, REUSE_THREADS", + "250_000, 0, ASYNC, 0, COMPACT, 2, REUSE_THREADS", + "250_000, 0, MIXED_SYNC_SAME_THREAD_AND_ASYNC, 0, COMPACT, 0, NEW_THREAD_PER_TASK", + "250_000, 0, MIXED_SYNC_SAME_THREAD_AND_ASYNC, 0, COMPACT, 1, REUSE_THREADS", + "4, 0, ASYNC, 4, VERBOSE, 1, REUSE_THREADS", + "4, 4, ASYNC, 0, VERBOSE, 1, REUSE_THREADS", + "250_000, 0, SYNC_DIFFERENT_THREAD, 0, COMPACT, 0, NEW_THREAD_PER_TASK", + "250_000, 0, SYNC_DIFFERENT_THREAD, 0, COMPACT, 1, REUSE_THREADS", + }) + void thenRunDoWhileLoopTest( + final int counterInitialValue, + final int blockSyncPartOfIterationTotalSeconds, + final IterationExecutionType executionType, + final int delayAsyncExecutionTotalSeconds, + final Verbocity verbocity, + final int executorSize, + final ThreadManagement threadManagement) throws Exception { + Duration blockSyncPartOfIterationTotalDuration = Duration.ofSeconds(blockSyncPartOfIterationTotalSeconds); + if (executionType.equals(IterationExecutionType.SYNC_DIFFERENT_THREAD)) { + com.mongodb.assertions.Assertions.assertTrue( + (executorSize > 0 && threadManagement.equals(ThreadManagement.REUSE_THREADS)) + || (executorSize == 0 && threadManagement.equals(ThreadManagement.NEW_THREAD_PER_TASK))); + } + if (executionType.equals(IterationExecutionType.SYNC_SAME_THREAD)) { + com.mongodb.assertions.Assertions.assertTrue(executorSize == 0); + com.mongodb.assertions.Assertions.assertTrue(threadManagement.equals(ThreadManagement.REUSE_THREADS)); + } + if (!executionType.equals(IterationExecutionType.ASYNC)) { + com.mongodb.assertions.Assertions.assertTrue(delayAsyncExecutionTotalSeconds == 0); + } + if (threadManagement.equals(ThreadManagement.NEW_THREAD_PER_TASK)) { + com.mongodb.assertions.Assertions.assertTrue(executorSize == 0); + } + Duration delayAsyncExecutionTotalDuration = Duration.ofSeconds(delayAsyncExecutionTotalSeconds); + ScheduledExecutor executor = executorSize == 0 ? null : new ScheduledExecutor(executorSize, threadManagement); + try { + System.err.printf("baselineStackDepth=%d%n%n", Thread.currentThread().getStackTrace().length); + StartTime start = StartTime.now(); + CompletableFuture join = new CompletableFuture<>(); + asyncLoop(new Counter(counterInitialValue, verbocity), + blockSyncPartOfIterationTotalDuration, executionType, delayAsyncExecutionTotalDuration, verbocity, executor, + (r, t) -> { + int stackDepth = Thread.currentThread().getStackTrace().length; + System.err.printf("test callback completed callStackDepth=%s, r=%s, t=%s%n", + stackDepth, r, exceptionToString(t)); + assertTrue(stackDepth <= MAX_STACK_DEPTH); + complete(join, r, t); + }); + System.err.printf("\tasyncLoop method completed in %s%n", start.elapsed()); + join.get(); + System.err.printf("%nDONE%n%n"); + } finally { + if (executor != null) { + executor.shutdownNow(); + com.mongodb.assertions.Assertions.assertTrue(executor.awaitTermination(1, TimeUnit.MINUTES)); + } + } + } + + private static void asyncLoop( + final Counter counter, + final Duration blockSyncPartOfIterationTotalDuration, + final IterationExecutionType executionType, + final Duration delayAsyncExecutionTotalDuration, + final Verbocity verbocity, + @Nullable + final ScheduledExecutor executor, + final SingleResultCallback callback) { + beginAsync().thenRunDoWhileLoop(c -> { + sleep(blockSyncPartOfIterationTotalDuration.dividedBy(counter.initial())); + StartTime start = StartTime.now(); + asyncPartOfIteration(counter, executionType, delayAsyncExecutionTotalDuration, verbocity, executor, c); + if (verbocity.equals(Verbocity.VERBOSE)) { + System.err.printf("\tasyncPartOfIteration method completed in %s%n", start.elapsed()); + } + }, () -> !counter.done()).finish(callback); + } + + private static void asyncPartOfIteration( + final Counter counter, + final IterationExecutionType executionType, + final Duration delayAsyncExecutionTotalDuration, + final Verbocity verbocity, + @Nullable + final ScheduledExecutor executor, + final SingleResultCallback callback) { + Runnable asyncPartOfIteration = () -> { + counter.countDown(); + StartTime start = StartTime.now(); + callback.complete(callback); + if (verbocity.equals(Verbocity.VERBOSE)) { + System.err.printf("\tasyncPartOfIteration callback.complete method completed in %s%n", start.elapsed()); + } + }; + switch (executionType) { + case SYNC_SAME_THREAD: { + asyncPartOfIteration.run(); + break; + } + case SYNC_DIFFERENT_THREAD: { + if (executor == null) { + Thread thread = new Thread(asyncPartOfIteration); + thread.start(); + join(thread); + } else { + join(executor.submit(asyncPartOfIteration)); + } + break; + } + case ASYNC: { + if (executor == null) { + Thread thread = new Thread(() -> { + sleep(delayAsyncExecutionTotalDuration.dividedBy(counter.initial())); + asyncPartOfIteration.run(); + }); + thread.start(); + } else { + com.mongodb.assertions.Assertions.assertNotNull(executor).schedule(asyncPartOfIteration, + delayAsyncExecutionTotalDuration.dividedBy(counter.initial()).toNanos(), TimeUnit.NANOSECONDS); + } + break; + } + case MIXED_SYNC_SAME_THREAD_AND_ASYNC: { + if (ThreadLocalRandom.current().nextBoolean()) { + asyncPartOfIteration.run(); + } else { + if (executor == null) { + Thread thread = new Thread(() -> { + sleep(delayAsyncExecutionTotalDuration.dividedBy(counter.initial())); + asyncPartOfIteration.run(); + }); + thread.start(); + } else { + com.mongodb.assertions.Assertions.assertNotNull(executor).schedule(asyncPartOfIteration, + delayAsyncExecutionTotalDuration.dividedBy(counter.initial()).toNanos(), TimeUnit.NANOSECONDS); + } + } + break; + } + default: { + com.mongodb.assertions.Assertions.fail(executionType.toString()); + } + } + } + + private static final class Counter { + private final int initial; + private int current; + private boolean doneReturnedTrue; + private final Verbocity verbocity; + + Counter(final int initial, final Verbocity verbocity) { + this.initial = initial; + this.current = initial; + this.doneReturnedTrue = false; + this.verbocity = verbocity; + } + + int initial() { + return initial; + } + + void countDown() { + com.mongodb.assertions.Assertions.assertTrue(current > 0); + int previous = current; + int decremented = --current; + if (verbocity.equals(Verbocity.VERBOSE) || decremented % Verbocity.COMPACTNESS == 0) { + int stackDepth = Thread.currentThread().getStackTrace().length; + assertTrue(stackDepth <= MAX_STACK_DEPTH); + System.err.printf("counted %d->%d tid=%d callStackDepth=%d %n", + previous, decremented, Thread.currentThread().getId(), stackDepth); + } + } + + boolean done() { + if (current == 0) { + com.mongodb.assertions.Assertions.assertFalse(doneReturnedTrue); + int stackDepth = Thread.currentThread().getStackTrace().length; + assertTrue(stackDepth <= MAX_STACK_DEPTH); + System.err.printf("counting done callStackDepth=%d %n", stackDepth); + doneReturnedTrue = true; + return true; + } + return false; + } + } + + private static String exceptionToString(@Nullable final Throwable t) { + if (t == null) { + return Objects.toString(null); + } + try (StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw)) { +// t.printStackTrace(pw); + pw.println(t); + pw.flush(); + return sw.toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void complete(final CompletableFuture future, @Nullable final T result, @Nullable final Throwable t) { + if (t != null) { + future.completeExceptionally(t); + } else { + future.complete(result); + } + } + + private static void join(final Thread thread) { + try { + thread.join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + private static void join(final Future future) { + try { + future.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + + private static void sleep(final Duration duration) { + if (duration.isZero()) { + return; + } + try { + long durationNsPart = duration.getNano(); + long durationMsPartFromNsPart = TimeUnit.MILLISECONDS.convert(duration.getNano(), TimeUnit.NANOSECONDS); + long sleepMs = TimeUnit.MILLISECONDS.convert(duration.getSeconds(), TimeUnit.SECONDS) + durationMsPartFromNsPart; + int sleepNs = Math.toIntExact(durationNsPart - TimeUnit.NANOSECONDS.convert(durationMsPartFromNsPart, TimeUnit.MILLISECONDS)); + Thread.sleep(sleepMs, sleepNs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + /** + * This {@link ScheduledThreadPoolExecutor} propagates exceptions that caused termination of a task execution, + * causing the thread that executed the task to be terminated. + */ + private static final class ScheduledExecutor extends ScheduledThreadPoolExecutor { + ScheduledExecutor(final int size, final ThreadManagement threadManagement) { + super(size, r -> { + Thread thread = new Thread(() -> { + r.run(); + if (threadManagement.equals(ThreadManagement.NEW_THREAD_PER_TASK)) { + terminateCurrentThread(); + } + }); + thread.setUncaughtExceptionHandler((t, e) -> { + if (e instanceof ThreadTerminationException) { + return; + } + t.getThreadGroup().uncaughtException(t, e); + }); + return thread; + }); + } + + private static void terminateCurrentThread() { + throw ThreadTerminationException.INSTANCE; + } + + @Override + protected void afterExecute(final Runnable r, final Throwable t) { + if (t instanceof ThreadTerminationException) { + throw (ThreadTerminationException) t; + } else if (r instanceof Future) { + Future future = (Future) r; + if (future.isDone()) { + try { + future.get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof ThreadTerminationException) { + throw (ThreadTerminationException) cause; + } + } catch (Throwable e) { + // do nothing, we are not swallowing `e`, btw + } + } + } + } + + private static final class ThreadTerminationException extends RuntimeException { + static final ThreadTerminationException INSTANCE = new ThreadTerminationException(); + + private ThreadTerminationException() { + super(null, null, false, false); + } + } + } +}