diff --git a/core/src/main/scala/sttp/client4/compression/Compressor.scala b/core/src/main/scala/sttp/client4/compression/Compressor.scala index a5e5b59c87..6fcc9991dd 100644 --- a/core/src/main/scala/sttp/client4/compression/Compressor.scala +++ b/core/src/main/scala/sttp/client4/compression/Compressor.scala @@ -1,7 +1,6 @@ package sttp.client4.compression import sttp.client4._ -import java.nio.ByteBuffer /** Allows compressing bodies, using the supported encoding. The compressed bodies might use `R` capabilities (e.g. * streaming). @@ -53,12 +52,4 @@ object Compressor extends CompressorExtensions { private[compression] def streamsNotSupported: Nothing = throw new IllegalArgumentException("Streams are not supported") - private[compression] def byteBufferToArray(inputBuffer: ByteBuffer): Array[Byte] = - if (inputBuffer.hasArray()) { - inputBuffer.array() - } else { - val inputBytes = new Array[Byte](inputBuffer.remaining()) - inputBuffer.get(inputBytes) - inputBytes - } } diff --git a/core/src/main/scala/sttp/client4/internal/DigestAuthenticator.scala b/core/src/main/scala/sttp/client4/internal/DigestAuthenticator.scala index 96d19bb164..07540fd653 100644 --- a/core/src/main/scala/sttp/client4/internal/DigestAuthenticator.scala +++ b/core/src/main/scala/sttp/client4/internal/DigestAuthenticator.scala @@ -165,7 +165,7 @@ private[client4] class DigestAuthenticator private ( brb match { case StringBody(s, e, _) => s.getBytes(Charset.forName(e)) case ByteArrayBody(b, _) => b - case ByteBufferBody(b, _) => b.array() + case ByteBufferBody(b, _) => byteBufferToArray(b) case InputStreamBody(b, _) => toByteArray(b) case _: FileBody => throw new IllegalStateException("Qop auth-int cannot be used with a file body") } diff --git a/core/src/main/scala/sttp/client4/internal/package.scala b/core/src/main/scala/sttp/client4/internal/package.scala index b00751cc53..3f16ac23a2 100644 --- a/core/src/main/scala/sttp/client4/internal/package.scala +++ b/core/src/main/scala/sttp/client4/internal/package.scala @@ -40,13 +40,27 @@ package object internal { private[client4] def emptyInputStream(): InputStream = new ByteArrayInputStream(Array[Byte]()) + /** Returns the bytes between the buffer's current `position` and `limit` as an `Array[Byte]`. Safe for direct, + * read-only, partially-consumed, and sliced buffers. The source buffer's `position` and `limit` are preserved. For + * a fresh, full heap buffer the backing array is returned directly (no copy); otherwise a new array is allocated. + */ + private[client4] def byteBufferToArray(b: ByteBuffer): Array[Byte] = + if (b.hasArray && b.arrayOffset() == 0 && b.position() == 0 && b.remaining() == b.array().length) + b.array() + else { + val out = new Array[Byte](b.remaining()) + b.duplicate().get(out) + out + } + private[client4] def enqueueBytes( queue: Queue[Array[Byte]], bytes: ByteBuffer - ): Queue[Array[Byte]] = queue.enqueue(bytes.array()) + ): Queue[Array[Byte]] = queue.enqueue(byteBufferToArray(bytes)) private[client4] def concatBytes(queue: Queue[Array[Byte]]): Array[Byte] = { val size = queue.map(_.length).sum + // Heap buffer of exact size, fully filled and never read partially: array() is safe here. val bytes = ByteBuffer.allocate(size) queue.foreach(bytes.put) bytes.array() @@ -81,6 +95,9 @@ package object internal { private[client4] val IOBufferSize = 1024 implicit class RichByteBuffer(byteBuffer: ByteBuffer) { + /** Reads the buffer's remaining bytes into a fresh array, advancing the buffer's `position` to `limit`. Use + * [[byteBufferToArray]] instead when the buffer's position should be preserved. + */ def safeRead(): Array[Byte] = { val array = new Array[Byte](byteBuffer.remaining()) byteBuffer.get(array) diff --git a/core/src/main/scala/sttp/client4/testing/package.scala b/core/src/main/scala/sttp/client4/testing/package.scala index 08a6d2f9db..0595178b8e 100644 --- a/core/src/main/scala/sttp/client4/testing/package.scala +++ b/core/src/main/scala/sttp/client4/testing/package.scala @@ -1,6 +1,6 @@ package sttp.client4 -import sttp.client4.internal.toByteArray +import sttp.client4.internal.{byteBufferToArray, toByteArray} package object testing { implicit class RichTestingRequest[T, R](r: GenericRequest[T, R]) { @@ -13,7 +13,7 @@ package object testing { case NoBody => "" case StringBody(s, _, _) => s case ByteArrayBody(b, _) => new String(b) - case ByteBufferBody(b, _) => new String(b.array()) + case ByteBufferBody(b, _) => new String(byteBufferToArray(b)) case InputStreamBody(b, _) => new String(toByteArray(b)) case FileBody(f, _) => f.readAsString() case StreamBody(_) => @@ -30,7 +30,7 @@ package object testing { case NoBody => Array.emptyByteArray case StringBody(s, encoding, _) => s.getBytes(encoding) case ByteArrayBody(b, _) => b - case ByteBufferBody(b, _) => b.array() + case ByteBufferBody(b, _) => byteBufferToArray(b) case InputStreamBody(b, _) => toByteArray(b) case FileBody(f, _) => f.readAsByteArray() case StreamBody(_) => diff --git a/core/src/main/scalajs/sttp/client4/fetch/AbstractFetchBackend.scala b/core/src/main/scalajs/sttp/client4/fetch/AbstractFetchBackend.scala index b4fe9c2465..37720a111b 100644 --- a/core/src/main/scalajs/sttp/client4/fetch/AbstractFetchBackend.scala +++ b/core/src/main/scalajs/sttp/client4/fetch/AbstractFetchBackend.scala @@ -244,13 +244,6 @@ abstract class AbstractFetchBackend[F[_], S <: Streams[S]]( f.toDomFile } - // https://stackoverflow.com/questions/679298/gets-byte-array-from-a-bytebuffer-in-java - private def byteBufferToArray(bb: ByteBuffer): Array[Byte] = { - val b = new Array[Byte](bb.remaining()) - bb.get(b) - b - } - private def sendWebSocket[T](request: GenericRequest[T, R]): F[Response[T]] = { val queue = new JSSimpleQueue[F, WebSocketEvent] val ws = new JSWebSocket(request.uri.toString) diff --git a/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala b/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala index c9b9d351fa..371dd0ff28 100644 --- a/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala +++ b/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala @@ -4,6 +4,7 @@ import sttp.client4._ import sttp.model.Encodings import Compressor._ +import sttp.client4.internal.byteBufferToArray import java.util.zip.Deflater import java.util.zip.DeflaterInputStream import java.io.ByteArrayOutputStream diff --git a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala index 69ea75d57b..c4405f8e12 100644 --- a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala +++ b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala @@ -5,6 +5,7 @@ import sttp.client4._ import sttp.client4.compression.Compressor import sttp.client4.httpclient.BodyProgressCallback import sttp.client4.internal.SttpToJavaConverters.toJavaSupplier +import sttp.client4.internal.byteBufferToArray import sttp.model.HeaderNames import sttp.monad.MonadError import sttp.monad.syntax._ @@ -31,8 +32,7 @@ private[client4] trait BodyToHttpClient[F[_], S, R] { case StringBody(b, _, _) => BodyPublishers.ofString(b).unit case ByteArrayBody(b, _) => BodyPublishers.ofByteArray(b).unit case ByteBufferBody(b, _) => - if (b.hasArray) BodyPublishers.ofByteArray(b.array(), 0, b.limit()).unit - else { val a = new Array[Byte](b.remaining()); b.get(a); BodyPublishers.ofByteArray(a).unit } + BodyPublishers.ofByteArray(byteBufferToArray(b)).unit case InputStreamBody(b, _) => BodyPublishers.ofInputStream(toJavaSupplier(() => b)).unit case FileBody(f, _) => BodyPublishers.ofFile(f.toFile.toPath).unit case StreamBody(s) => streamToPublisher(s.asInstanceOf[streams.BinaryStream]) diff --git a/core/src/main/scalajvm/sttp/client4/internal/httpclient/MultipartBodyBuilder.scala b/core/src/main/scalajvm/sttp/client4/internal/httpclient/MultipartBodyBuilder.scala index b35540dae8..8cdfdbf321 100644 --- a/core/src/main/scalajvm/sttp/client4/internal/httpclient/MultipartBodyBuilder.scala +++ b/core/src/main/scalajvm/sttp/client4/internal/httpclient/MultipartBodyBuilder.scala @@ -5,6 +5,7 @@ import sttp.client4.internal.httpclient import sttp.model.Part import sttp.model.Header import sttp.model.HeaderNames +import sttp.client4.internal.byteBufferToArray import sttp.client4.internal.throwNestedMultipartNotAllowed import sttp.client4.internal.Utf8 import sttp.monad.MonadError @@ -14,7 +15,6 @@ import java.net.http.HttpRequest import java.util.function.Supplier import java.io.File import java.io.InputStream -import java.nio.Buffer import java.nio.ByteBuffer import java.io.ByteArrayInputStream import java.util.UUID @@ -44,10 +44,7 @@ trait NonStreamMultipartBodyBuilder[BinaryStream, F[_]] extends MultipartBodyBui case ByteArrayBody(b, _) => multipartBuilder.addPart(p.name, supplier(new ByteArrayInputStream(b)), partHeaders) case ByteBufferBody(b, _) => - if ((b: Buffer).isReadOnly) - multipartBuilder.addPart(p.name, supplier(new ByteBufferBackedInputStream(b)), partHeaders) - else - multipartBuilder.addPart(p.name, supplier(new ByteArrayInputStream(b.array())), partHeaders) + multipartBuilder.addPart(p.name, supplier(new ByteArrayInputStream(byteBufferToArray(b))), partHeaders) case InputStreamBody(b, _) => multipartBuilder.addPart(p.name, supplier(b), partHeaders) case StreamBody(_) => throw new IllegalArgumentException("Multipart streaming bodies are not supported with this backend") @@ -96,11 +93,10 @@ trait StreamMultipartBodyBuilder[BinaryStream, F[_]] extends MultipartBodyBuilde case ByteArrayBody(b, _) => concatBytesToStream(accumulatedStream, encodeBytes(b, partHeaders, boundary)) case ByteBufferBody(b, _) => - if ((b: Buffer).isReadOnly) { - val buffer = new ByteBufferBackedInputStream(b) - concatStreams(accumulatedStream, encodeStream(inputStreamToStream(buffer), partHeaders, boundary)) - } else - concatBytesToStream(accumulatedStream, encodeBytes(b.array(), partHeaders, boundary)) + concatStreams( + accumulatedStream, + encodeStream(inputStreamToStream(new ByteBufferBackedInputStream(b.duplicate())), partHeaders, boundary) + ) case InputStreamBody(b, _) => concatStreams(accumulatedStream, encodeStream(inputStreamToStream(b), partHeaders, boundary)) case StreamBody(s) => @@ -144,7 +140,6 @@ trait StreamMultipartBodyBuilder[BinaryStream, F[_]] extends MultipartBodyBuilde } } -// https://stackoverflow.com/a/6603018/362531 private[httpclient] class ByteBufferBackedInputStream(buf: ByteBuffer) extends InputStream { override def read: Int = { if (!buf.hasRemaining) return -1 diff --git a/core/src/test/scala/sttp/client4/internal/ByteBufferToArrayTest.scala b/core/src/test/scala/sttp/client4/internal/ByteBufferToArrayTest.scala new file mode 100644 index 0000000000..43b7d9f0ca --- /dev/null +++ b/core/src/test/scala/sttp/client4/internal/ByteBufferToArrayTest.scala @@ -0,0 +1,59 @@ +package sttp.client4.internal + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.ByteBuffer + +class ByteBufferToArrayTest extends AnyFlatSpec with Matchers { + + it should "extract bytes from a direct buffer" in { + val direct = ByteBuffer.allocateDirect(5) + direct.put("ABCDE".getBytes) + direct.flip() + byteBufferToArray(direct) shouldBe "ABCDE".getBytes + } + + it should "extract bytes from a read-only buffer" in { + val readOnly = ByteBuffer.wrap("ABCDE".getBytes).asReadOnlyBuffer() + byteBufferToArray(readOnly) shouldBe "ABCDE".getBytes + } + + it should "return only the remaining slice when position > 0 and limit < capacity" in { + val partial = ByteBuffer.wrap("ABCDE".getBytes) + partial.position(2) + partial.limit(4) + byteBufferToArray(partial) shouldBe "CD".getBytes + } + + it should "respect arrayOffset for sliced heap buffers" in { + val original = ByteBuffer.wrap("ABCDE".getBytes) + original.position(2) + val sliced = original.slice() + sliced.arrayOffset() should be > 0 + byteBufferToArray(sliced) shouldBe "CDE".getBytes + } + + it should "return the backing array directly for a fresh full heap buffer" in { + val data = "ABCDE".getBytes + val full = ByteBuffer.wrap(data) + byteBufferToArray(full) should be theSameInstanceAs data + } + + it should "return a fresh array that does not alias storage for a partial buffer" in { + val partial = ByteBuffer.wrap("ABCDE".getBytes) + partial.position(1) + val out = byteBufferToArray(partial) + out(0) = 'Z'.toByte + partial.get(1) shouldBe 'B'.toByte + } + + it should "not mutate the source buffer's position or limit" in { + val partial = ByteBuffer.wrap("ABCDE".getBytes) + partial.position(1) + partial.limit(4) + val _ = byteBufferToArray(partial) + partial.position() shouldBe 1 + partial.limit() shouldBe 4 + } +} diff --git a/core/src/test/scala/sttp/client4/internal/DigestAuthenticatorTest.scala b/core/src/test/scala/sttp/client4/internal/DigestAuthenticatorTest.scala index aac521f212..33175f4553 100644 --- a/core/src/test/scala/sttp/client4/internal/DigestAuthenticatorTest.scala +++ b/core/src/test/scala/sttp/client4/internal/DigestAuthenticatorTest.scala @@ -5,6 +5,7 @@ import sttp.client4.internal.DigestAuthenticator.DigestAuthData import sttp.client4._ import sttp.model.{Header, HeaderNames, StatusCode} +import java.nio.ByteBuffer import scala.util.{Failure, Try} import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers @@ -155,6 +156,40 @@ class DigestAuthenticatorTest extends AnyFreeSpec with Matchers with OptionValue header.value.value should fullyMatch regex """Digest username="admin", realm="myrealm", uri="/", nonce="BBBBBB", qop=auth, response="[0-9a-f]+", cnonce="[0-9a-f]+", nc=000000\d\d, algorithm=MD5""" } + "auth-int" - { + val payload = "hello".getBytes("UTF-8") + val fixedCnonce = () => "0123456789abcdef" + val authIntChallenge = """Digest realm="myrealm", nonce="BBBBBB", algorithm=MD5, qop="auth-int"""" + + def authHeaderValueFor(setBody: Request[Either[String, String]] => Request[Either[String, String]]): String = { + val baseRequest = basicRequest.get(uri"http://google.com/").auth.digest("admin", "password") + val r = responseWithAuthenticateHeader(authIntChallenge) + DigestAuthenticator(DigestAuthData("admin", "password"), fixedCnonce) + .authenticate(setBody(baseRequest), r) + .value + .value + } + + "compute the digest for a direct ByteBufferBody (previously threw UnsupportedOperationException)" in { + val direct = ByteBuffer.allocateDirect(payload.length) + direct.put(payload) + direct.flip() + val viaDirectBuffer = authHeaderValueFor(_.body(direct)) + val viaArray = authHeaderValueFor(_.body(payload)) + viaDirectBuffer shouldBe viaArray + } + + "compute the digest over the partial buffer's remaining slice, not the full backing array" in { + val backing = ("XX".getBytes ++ payload ++ "YY".getBytes) + val partial = ByteBuffer.wrap(backing) + partial.position(2) + partial.limit(2 + payload.length) + val viaPartialBuffer = authHeaderValueFor(_.body(partial)) + val viaArray = authHeaderValueFor(_.body(payload)) + viaPartialBuffer shouldBe viaArray + } + } + "proxy" - { "should work" in { val request = basicRequest diff --git a/core/src/test/scala/sttp/client4/internal/EnqueueBytesTest.scala b/core/src/test/scala/sttp/client4/internal/EnqueueBytesTest.scala new file mode 100644 index 0000000000..506a6574b2 --- /dev/null +++ b/core/src/test/scala/sttp/client4/internal/EnqueueBytesTest.scala @@ -0,0 +1,26 @@ +package sttp.client4.internal + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.ByteBuffer +import scala.collection.immutable.Queue + +class EnqueueBytesTest extends AnyFlatSpec with Matchers { + + it should "append the bytes of a direct buffer" in { + val direct = ByteBuffer.allocateDirect(5) + direct.put("hello".getBytes) + direct.flip() + val result = enqueueBytes(Queue.empty[Array[Byte]], direct) + result.head shouldBe "hello".getBytes + } + + it should "append only the remaining slice of a partial heap buffer" in { + val partial = ByteBuffer.wrap("ABCDE".getBytes) + partial.position(2) + partial.limit(4) + val result = enqueueBytes(Queue.empty[Array[Byte]], partial) + result.head shouldBe "CD".getBytes + } +} diff --git a/core/src/test/scala/sttp/client4/testing/ForceBodyAsTest.scala b/core/src/test/scala/sttp/client4/testing/ForceBodyAsTest.scala new file mode 100644 index 0000000000..8e1a1cde00 --- /dev/null +++ b/core/src/test/scala/sttp/client4/testing/ForceBodyAsTest.scala @@ -0,0 +1,46 @@ +package sttp.client4.testing + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import sttp.client4._ + +import java.nio.ByteBuffer + +class ForceBodyAsTest extends AnyFlatSpec with Matchers { + + private val payload = "hello".getBytes("UTF-8") + + it should "forceBodyAsString: extract a direct ByteBuffer body" in { + val direct = ByteBuffer.allocateDirect(payload.length) + direct.put(payload) + direct.flip() + val request = basicRequest.post(uri"http://example.com/").body(direct) + request.forceBodyAsString shouldBe "hello" + } + + it should "forceBodyAsByteArray: extract a direct ByteBuffer body" in { + val direct = ByteBuffer.allocateDirect(payload.length) + direct.put(payload) + direct.flip() + val request = basicRequest.post(uri"http://example.com/").body(direct) + request.forceBodyAsByteArray shouldBe payload + } + + it should "forceBodyAsString: return only the remaining slice of a partial heap buffer" in { + val backing = "XX".getBytes ++ payload ++ "YY".getBytes + val partial = ByteBuffer.wrap(backing) + partial.position(2) + partial.limit(2 + payload.length) + val request = basicRequest.post(uri"http://example.com/").body(partial) + request.forceBodyAsString shouldBe "hello" + } + + it should "forceBodyAsByteArray: return only the remaining slice of a partial heap buffer" in { + val backing = "XX".getBytes ++ payload ++ "YY".getBytes + val partial = ByteBuffer.wrap(backing) + partial.position(2) + partial.limit(2 + payload.length) + val request = basicRequest.post(uri"http://example.com/").body(partial) + request.forceBodyAsByteArray shouldBe payload + } +} diff --git a/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala b/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala index 8a031ab1f6..0fb54fad6e 100644 --- a/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala +++ b/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala @@ -15,7 +15,7 @@ import com.twitter.io.Buf.{ByteArray, ByteBuffer} import com.twitter.util import com.twitter.util.{Duration, Future => TFuture} import sttp.capabilities.Effect -import sttp.client4.internal.{BodyFromResponseAs, FileHelpers, SttpFile, Utf8} +import sttp.client4.internal.{BodyFromResponseAs, FileHelpers, SttpFile, Utf8, byteBufferToArray} import sttp.client4.testing.BackendStub import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException} import sttp.client4.{wrappers, _} @@ -126,7 +126,7 @@ class FinagleBackend(client: Option[Client] = None) extends Backend[TFuture] { case StringBody(s, e, _) if e.equalsIgnoreCase(Utf8) => s case StringBody(s, e, _) => Source.fromBytes(s.getBytes(e)).mkString case ByteArrayBody(b, _) => Source.fromBytes(b).mkString - case ByteBufferBody(b, _) => Source.fromBytes(b.array()).mkString + case ByteBufferBody(b, _) => Source.fromBytes(byteBufferToArray(b)).mkString case InputStreamBody(is, _) => Source.fromInputStream(is).mkString case FileBody(f, _) => Source.fromFile(f.toFile).mkString case StreamBody(_) => throw new IllegalArgumentException("Streaming is not supported")