From be870493341519a7a940ccfab96e4bf80ec06037 Mon Sep 17 00:00:00 2001 From: jencymaryjoseph <35571282+jencymaryjoseph@users.noreply.github.com> Date: Mon, 11 May 2026 15:26:03 -0700 Subject: [PATCH 1/2] Fix check-then-act race condition in parallel subscriber --- ...ignedUrlMultipartDownloaderSubscriber.java | 25 +++++++---- ...dUrlMultipartDownloaderSubscriberTest.java | 42 +++++++++++++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java index 8a7c898b361..6a39ea80040 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java @@ -21,6 +21,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.reactivestreams.Subscriber; @@ -64,7 +65,7 @@ public class ParallelPresignedUrlMultipartDownloaderSubscriber private final AtomicInteger partNumber = new AtomicInteger(0); private final AtomicInteger completedParts = new AtomicInteger(0); - private final AtomicInteger inFlightRequestsNum = new AtomicInteger(0); + private final Semaphore inFlightPermits; private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean(false); private final AtomicBoolean processingPending = new AtomicBoolean(false); private final Map> inFlightRequests = new ConcurrentHashMap<>(); @@ -90,6 +91,7 @@ public ParallelPresignedUrlMultipartDownloaderSubscriber( this.configuredPartSizeInBytes = configuredPartSizeInBytes; this.resultFuture = resultFuture; this.maxInFlightParts = maxInFlightParts; + this.inFlightPermits = new Semaphore(maxInFlightParts); } @Override @@ -132,17 +134,18 @@ private void sendFirstRequest(AsyncResponseTransformer { if (error != null || isCompletedExceptionally.get()) { + inFlightPermits.release(); handlePartError(error, 0); return; } inFlightRequests.remove(0); - inFlightRequestsNum.decrementAndGet(); + inFlightPermits.release(); completedParts.incrementAndGet(); this.eTag = res.eTag(); @@ -188,7 +191,7 @@ private void processRequest(AsyncResponseTransformer= maxInFlightParts) { + if (!inFlightPermits.tryAcquire()) { pendingTransformers.offer(Pair.of(currentPart, transformer)); return; } @@ -200,6 +203,7 @@ private void processRequest(AsyncResponseTransformer transformer, int partIndex) { if (isCompletedExceptionally.get()) { + inFlightPermits.release(); return; } @@ -210,24 +214,25 @@ private void sendPartRequest(AsyncResponseTransformer { if (error != null || isCompletedExceptionally.get()) { + inFlightPermits.release(); handlePartError(error, partIndex); return; } Optional validationError = validatePartResponse(res, partIndex); if (validationError.isPresent()) { + inFlightPermits.release(); handlePartError(validationError.get(), partIndex); return; } log.debug(() -> "Completed part: " + partIndex); inFlightRequests.remove(partIndex); - inFlightRequestsNum.decrementAndGet(); + inFlightPermits.release(); int totalComplete = completedParts.incrementAndGet(); if (totalComplete == totalParts) { @@ -250,17 +255,19 @@ private void processPendingTransformers() { return; } try { - while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts) { + while (!pendingTransformers.isEmpty() && inFlightPermits.tryAcquire()) { Pair> pendingPart = pendingTransformers.poll(); if (pendingPart != null && pendingPart.left() < totalParts) { sendPartRequest(pendingPart.right(), pendingPart.left()); + } else { + inFlightPermits.release(); } } } finally { processingPending.set(false); } - } while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts); + } while (!pendingTransformers.isEmpty() && inFlightPermits.availablePermits() > 0); } private Optional validatePartResponse(GetObjectResponse response, int partIndex) { @@ -336,4 +343,4 @@ public void onError(Throwable t) { public void onComplete() { // Completion is handled by resultFuture } -} +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java index 0062029e1b8..f24b2b3a2d1 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.services.s3.internal.multipart; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.findAll; import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.matching; @@ -27,6 +28,7 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; import java.io.IOException; import java.net.MalformedURLException; import java.net.URI; @@ -35,6 +37,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; +import java.util.List; import java.util.UUID; import java.util.concurrent.CompletionException; import org.junit.jupiter.api.AfterEach; @@ -234,6 +237,45 @@ void onNext_withNullTransformer_shouldThrowNPE() { .isInstanceOf(NullPointerException.class); } + @Test + void multiPartDownload_manyParts_shouldCompleteSuccessfully() throws Exception { + // 13 parts to exceed maxInFlightParts (10) + byte[] data = new byte[208]; // 13 × 16 bytes + Arrays.fill(data, (byte) 'X'); + + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=0-15")) + .willReturn(aResponse().withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 0-15/208") + .withHeader("ETag", "\"etag\"") + .withBody(Arrays.copyOfRange(data, 0, 16)))); + + for (int i = 1; i < 13; i++) { + int start = i * 16; + int end = start + 15; + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=" + start + "-" + end)) + .willReturn(aResponse().withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes " + start + "-" + end + "/208") + .withHeader("ETag", "\"etag\"") + .withBody(Arrays.copyOfRange(data, start, end + 1)))); + } + + tempFile = createTempFileUnchecked(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + + s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join(); + + assertThat(Files.readAllBytes(tempFile)).isEqualTo(data); + verify(13, getRequestedFor(urlEqualTo(PRESIGNED_URL_PATH))); + } + private static Path createTempFile() throws IOException { Path path = Files.createTempFile("parallel-test-" + UUID.randomUUID(), ".tmp"); Files.deleteIfExists(path); From b504628646637d63a00655d317082bd0a972c16e Mon Sep 17 00:00:00 2001 From: jencymaryjoseph <35571282+jencymaryjoseph@users.noreply.github.com> Date: Tue, 12 May 2026 12:10:07 -0700 Subject: [PATCH 2/2] address comments --- ...ignedUrlMultipartDownloaderSubscriber.java | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java index 6a39ea80040..bf1e9ddaf6a 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java @@ -130,22 +130,25 @@ private void sendFirstRequest(AsyncResponseTransformer "Sending first range request with range=" + partRequest.range()); + if (!inFlightPermits.tryAcquire()) { + throw new IllegalStateException("Failed to acquire permit for first request"); + } + CompletableFuture response = s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer); inFlightRequests.put(0, response); - inFlightPermits.tryAcquire(); CompletableFutureUtils.forwardExceptionTo(resultFuture, response); response.whenComplete((res, error) -> { + inFlightRequests.remove(0); + inFlightPermits.release(); + if (error != null || isCompletedExceptionally.get()) { - inFlightPermits.release(); handlePartError(error, 0); return; } - inFlightRequests.remove(0); - inFlightPermits.release(); completedParts.incrementAndGet(); this.eTag = res.eTag(); @@ -217,22 +220,21 @@ private void sendPartRequest(AsyncResponseTransformer { + inFlightRequests.remove(partIndex); + inFlightPermits.release(); + if (error != null || isCompletedExceptionally.get()) { - inFlightPermits.release(); handlePartError(error, partIndex); return; } Optional validationError = validatePartResponse(res, partIndex); if (validationError.isPresent()) { - inFlightPermits.release(); handlePartError(validationError.get(), partIndex); return; } log.debug(() -> "Completed part: " + partIndex); - inFlightRequests.remove(partIndex); - inFlightPermits.release(); int totalComplete = completedParts.incrementAndGet(); if (totalComplete == totalParts) { @@ -250,11 +252,14 @@ private void sendPartRequest(AsyncResponseTransformer> pendingPart = pendingTransformers.poll();