diff --git a/README.md b/README.md index 27466f91a..0298620f2 100644 --- a/README.md +++ b/README.md @@ -412,6 +412,23 @@ Client client = Client .build(); ``` +To customize the default JDK HTTP client without replacing the SDK implementation, provide +your own `java.net.http.HttpClient` to `JdkA2AHttpClient`: + +```java +HttpClient jdkHttpClient = HttpClient.newBuilder() + .connectTimeout(Duration.ofSeconds(5)) + .followRedirects(HttpClient.Redirect.NORMAL) + .version(HttpClient.Version.HTTP_2) + .build(); + +Client client = Client + .builder(agentCard) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig( + new JdkA2AHttpClient(jdkHttpClient))) + .build(); +``` + ##### gRPC Transport Configuration For the gRPC transport, you must configure a channel factory: diff --git a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java index c04596360..e41f36a1c 100644 --- a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java +++ b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java @@ -1,5 +1,6 @@ package io.a2a.client.http; +import static io.a2a.util.Assert.checkNotNullParam; import static java.net.HttpURLConnection.HTTP_FORBIDDEN; import static java.net.HttpURLConnection.HTTP_MULT_CHOICE; import static java.net.HttpURLConnection.HTTP_OK; @@ -61,10 +62,20 @@ public class JdkA2AHttpClient implements A2AHttpClient { * */ public JdkA2AHttpClient() { - httpClient = HttpClient.newBuilder() + this(HttpClient.newBuilder() .version(HttpClient.Version.HTTP_2) .followRedirects(HttpClient.Redirect.NORMAL) - .build(); + .build()); + } + + /** + * Creates a new JDK-based HTTP client using a caller-provided JDK {@link HttpClient}. + * + * @param httpClient the JDK HTTP client to delegate requests to + * @throws IllegalArgumentException if {@code httpClient} is {@code null} + */ + public JdkA2AHttpClient(HttpClient httpClient) { + this.httpClient = checkNotNullParam("httpClient", httpClient); } @Override diff --git a/http-client/src/test/java/io/a2a/client/http/JdkA2AHttpClientTest.java b/http-client/src/test/java/io/a2a/client/http/JdkA2AHttpClientTest.java new file mode 100644 index 000000000..48e10a86c --- /dev/null +++ b/http-client/src/test/java/io/a2a/client/http/JdkA2AHttpClientTest.java @@ -0,0 +1,91 @@ +package io.a2a.client.http; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.integration.ClientAndServer; + +import java.io.IOException; +import java.net.Proxy; +import java.net.ProxySelector; +import java.net.SocketAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +public class JdkA2AHttpClientTest { + + private ClientAndServer server; + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(); + } + } + + @Test + public void testDefaultConstructorCreatesUsableClient() throws Exception { + server = ClientAndServer.startClientAndServer(0); + server.when(request().withMethod("GET").withPath("/default")) + .respond(response().withStatusCode(200).withBody("ok")); + + JdkA2AHttpClient client = new JdkA2AHttpClient(); + + A2AHttpResponse response = client.createGet() + .url("http://localhost:" + server.getLocalPort() + "/default") + .get(); + + assertEquals(200, response.status()); + assertEquals("ok", response.body()); + } + + @Test + public void testConstructorUsesProvidedHttpClient() throws Exception { + server = ClientAndServer.startClientAndServer(0); + server.when(request().withMethod("GET").withPath("/custom")) + .respond(response().withStatusCode(200).withBody("ok")); + + TrackingProxySelector proxySelector = new TrackingProxySelector(); + HttpClient providedClient = HttpClient.newBuilder() + .proxy(proxySelector) + .build(); + + JdkA2AHttpClient client = new JdkA2AHttpClient(providedClient); + + A2AHttpResponse response = client.createGet() + .url("http://localhost:" + server.getLocalPort() + "/custom") + .get(); + + assertEquals(200, response.status()); + assertEquals("ok", response.body()); + assertEquals(1, proxySelector.selectCount.get(), + "Provided HttpClient should be used for request execution"); + } + + @Test + public void testConstructorRejectsNullHttpClient() { + assertThrows(IllegalArgumentException.class, () -> new JdkA2AHttpClient(null), "foo"); + } + + private static final class TrackingProxySelector extends ProxySelector { + private final AtomicInteger selectCount = new AtomicInteger(); + + @Override + public List select(URI uri) { + selectCount.incrementAndGet(); + return List.of(Proxy.NO_PROXY); + } + + @Override + public void connectFailed(URI uri, SocketAddress sa, IOException ioe) { + throw new AssertionError("Proxy connection should not fail in this test", ioe); + } + } +}