diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java index 8a7c898b361..bf1e9ddaf6a 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriber.java @@ -21,6 +21,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.reactivestreams.Subscriber; @@ -64,7 +65,7 @@ public class ParallelPresignedUrlMultipartDownloaderSubscriber private final AtomicInteger partNumber = new AtomicInteger(0); private final AtomicInteger completedParts = new AtomicInteger(0); - private final AtomicInteger inFlightRequestsNum = new AtomicInteger(0); + private final Semaphore inFlightPermits; private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean(false); private final AtomicBoolean processingPending = new AtomicBoolean(false); private final Map> inFlightRequests = new ConcurrentHashMap<>(); @@ -90,6 +91,7 @@ public ParallelPresignedUrlMultipartDownloaderSubscriber( this.configuredPartSizeInBytes = configuredPartSizeInBytes; this.resultFuture = resultFuture; this.maxInFlightParts = maxInFlightParts; + this.inFlightPermits = new Semaphore(maxInFlightParts); } @Override @@ -128,21 +130,25 @@ private void sendFirstRequest(AsyncResponseTransformer "Sending first range request with range=" + partRequest.range()); + if (!inFlightPermits.tryAcquire()) { + throw new IllegalStateException("Failed to acquire permit for first request"); + } + CompletableFuture response = s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer); inFlightRequests.put(0, response); - inFlightRequestsNum.incrementAndGet(); CompletableFutureUtils.forwardExceptionTo(resultFuture, response); response.whenComplete((res, error) -> { + inFlightRequests.remove(0); + inFlightPermits.release(); + if (error != null || isCompletedExceptionally.get()) { handlePartError(error, 0); return; } - inFlightRequests.remove(0); - inFlightRequestsNum.decrementAndGet(); completedParts.incrementAndGet(); this.eTag = res.eTag(); @@ -188,7 +194,7 @@ private void processRequest(AsyncResponseTransformer= maxInFlightParts) { + if (!inFlightPermits.tryAcquire()) { pendingTransformers.offer(Pair.of(currentPart, transformer)); return; } @@ -200,6 +206,7 @@ private void processRequest(AsyncResponseTransformer transformer, int partIndex) { if (isCompletedExceptionally.get()) { + inFlightPermits.release(); return; } @@ -210,10 +217,12 @@ private void sendPartRequest(AsyncResponseTransformer { + inFlightRequests.remove(partIndex); + inFlightPermits.release(); + if (error != null || isCompletedExceptionally.get()) { handlePartError(error, partIndex); return; @@ -226,8 +235,6 @@ private void sendPartRequest(AsyncResponseTransformer "Completed part: " + partIndex); - inFlightRequests.remove(partIndex); - inFlightRequestsNum.decrementAndGet(); int totalComplete = completedParts.incrementAndGet(); if (totalComplete == totalParts) { @@ -245,22 +252,27 @@ private void sendPartRequest(AsyncResponseTransformer> pendingPart = pendingTransformers.poll(); if (pendingPart != null && pendingPart.left() < totalParts) { sendPartRequest(pendingPart.right(), pendingPart.left()); + } else { + inFlightPermits.release(); } } } finally { processingPending.set(false); } - } while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts); + } while (!pendingTransformers.isEmpty() && inFlightPermits.availablePermits() > 0); } private Optional validatePartResponse(GetObjectResponse response, int partIndex) { @@ -336,4 +348,4 @@ public void onError(Throwable t) { public void onComplete() { // Completion is handled by resultFuture } -} +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java index 0062029e1b8..f24b2b3a2d1 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelPresignedUrlMultipartDownloaderSubscriberTest.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.services.s3.internal.multipart; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.findAll; import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.matching; @@ -27,6 +28,7 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; import java.io.IOException; import java.net.MalformedURLException; import java.net.URI; @@ -35,6 +37,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; +import java.util.List; import java.util.UUID; import java.util.concurrent.CompletionException; import org.junit.jupiter.api.AfterEach; @@ -234,6 +237,45 @@ void onNext_withNullTransformer_shouldThrowNPE() { .isInstanceOf(NullPointerException.class); } + @Test + void multiPartDownload_manyParts_shouldCompleteSuccessfully() throws Exception { + // 13 parts to exceed maxInFlightParts (10) + byte[] data = new byte[208]; // 13 × 16 bytes + Arrays.fill(data, (byte) 'X'); + + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=0-15")) + .willReturn(aResponse().withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes 0-15/208") + .withHeader("ETag", "\"etag\"") + .withBody(Arrays.copyOfRange(data, 0, 16)))); + + for (int i = 1; i < 13; i++) { + int start = i * 16; + int end = start + 15; + stubFor(get(urlEqualTo(PRESIGNED_URL_PATH)) + .withHeader("Range", matching("bytes=" + start + "-" + end)) + .willReturn(aResponse().withStatus(206) + .withHeader("Content-Length", "16") + .withHeader("Content-Range", "bytes " + start + "-" + end + "/208") + .withHeader("ETag", "\"etag\"") + .withBody(Arrays.copyOfRange(data, start, end + 1)))); + } + + tempFile = createTempFileUnchecked(); + PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder() + .presignedUrl(presignedUrl) + .build(); + + s3AsyncClient.presignedUrlExtension() + .getObject(request, AsyncResponseTransformer.toFile(tempFile)) + .join(); + + assertThat(Files.readAllBytes(tempFile)).isEqualTo(data); + verify(13, getRequestedFor(urlEqualTo(PRESIGNED_URL_PATH))); + } + private static Path createTempFile() throws IOException { Path path = Files.createTempFile("parallel-test-" + UUID.randomUUID(), ".tmp"); Files.deleteIfExists(path);