diff --git a/README.md b/README.md index 52bdea52..2316b226 100644 --- a/README.md +++ b/README.md @@ -349,13 +349,13 @@ Different transport protocols can be configured with specific settings using spe ##### JSON-RPC Transport Configuration -For the JSON-RPC transport, to use the default `JdkA2AHttpClient`, provide a `JSONRPCTransportConfig` created with its default constructor. +For the JSON-RPC transport, to use the default `JdkHttpClient`, provide a `JSONRPCTransportConfig` created with its default constructor. To use a custom HTTP client implementation, simply create a `JSONRPCTransportConfig` as follows: ```java -// Create a custom HTTP client -A2AHttpClient customHttpClient = ... +// Create a custom HTTP client builder +HttpClientBuilder httpClientBuilder = ... // Configure the client settings ClientConfig clientConfig = new ClientConfig.Builder() @@ -365,7 +365,7 @@ ClientConfig clientConfig = new ClientConfig.Builder() Client client = Client .builder(agentCard) .clientConfig(clientConfig) - .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(customHttpClient)) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(httpClientBuilder)) .build(); ``` @@ -396,13 +396,13 @@ Client client = Client ##### HTTP+JSON/REST Transport Configuration -For the HTTP+JSON/REST transport, if you'd like to use the default `JdkA2AHttpClient`, provide a `RestTransportConfig` created with its default constructor. +For the HTTP+JSON/REST transport, if you'd like to use the default `JdkHttpClient`, provide a `RestTransportConfig` created with its default constructor. To use a custom HTTP client implementation, simply create a `RestTransportConfig` as follows: ```java // Create a custom HTTP client -A2AHttpClient customHttpClient = ... +HttpClientBuilder httpClientBuilder = ... // Configure the client settings ClientConfig clientConfig = new ClientConfig.Builder() @@ -412,7 +412,7 @@ ClientConfig clientConfig = new ClientConfig.Builder() Client client = Client .builder(agentCard) .clientConfig(clientConfig) - .withTransport(RestTransport.class, new RestTransportConfig(customHttpClient)) + .withTransport(RestTransport.class, new RestTransportConfig(httpClientBuilder)) .build(); ``` diff --git a/client/base/src/main/java/io/a2a/A2A.java b/client/base/src/main/java/io/a2a/A2A.java index 063527c2..158daac1 100644 --- a/client/base/src/main/java/io/a2a/A2A.java +++ b/client/base/src/main/java/io/a2a/A2A.java @@ -3,11 +3,9 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.UUID; import io.a2a.client.http.A2ACardResolver; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; @@ -139,20 +137,7 @@ private static Message toMessage(List> parts, Message.Role role, String * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public static AgentCard getAgentCard(String agentUrl) throws A2AClientError, A2AClientJSONError { - return getAgentCard(new JdkA2AHttpClient(), agentUrl); - } - - /** - * Get the agent card for an A2A agent. - * - * @param httpClient the http client to use - * @param agentUrl the base URL for the agent whose agent card we want to retrieve - * @return the agent card - * @throws A2AClientError If an HTTP error occurs fetching the card - * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema - */ - public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) throws A2AClientError, A2AClientJSONError { - return getAgentCard(httpClient, agentUrl, null, null); + return getAgentCard(HttpClient.createHttpClient(agentUrl), null, null); } /** @@ -160,30 +145,29 @@ public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl) * * @param agentUrl the base URL for the agent whose agent card we want to retrieve * @param relativeCardPath optional path to the agent card endpoint relative to the base - * agent URL, defaults to ".well-known/agent-card.json" + * agent URL, defaults to "/.well-known/agent-card.json" * @param authHeaders the HTTP authentication headers to use * @return the agent card * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public static AgentCard getAgentCard(String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { - return getAgentCard(new JdkA2AHttpClient(), agentUrl, relativeCardPath, authHeaders); + return getAgentCard(HttpClient.createHttpClient(agentUrl), relativeCardPath, authHeaders); } /** * Get the agent card for an A2A agent. * * @param httpClient the http client to use - * @param agentUrl the base URL for the agent whose agent card we want to retrieve * @param relativeCardPath optional path to the agent card endpoint relative to the base - * agent URL, defaults to ".well-known/agent-card.json" + * agent URL, defaults to "/.well-known/agent-card.json" * @param authHeaders the HTTP authentication headers to use * @return the agent card * @throws A2AClientError If an HTTP error occurs fetching the card * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ - public static AgentCard getAgentCard(A2AHttpClient httpClient, String agentUrl, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { - A2ACardResolver resolver = new A2ACardResolver(httpClient, agentUrl, relativeCardPath, authHeaders); + public static AgentCard getAgentCard(HttpClient httpClient, String relativeCardPath, Map authHeaders) throws A2AClientError, A2AClientJSONError { + A2ACardResolver resolver = new A2ACardResolver(httpClient, relativeCardPath, authHeaders); return resolver.getAgentCard(); } } diff --git a/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java b/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java index 1c7ed38a..b8f849cb 100644 --- a/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java +++ b/client/base/src/test/java/io/a2a/client/ClientBuilderTest.java @@ -1,7 +1,8 @@ package io.a2a.client; import io.a2a.client.config.ClientConfig; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; +import io.a2a.client.http.jdk.JdkHttpClientBuilder; import io.a2a.client.transport.grpc.GrpcTransport; import io.a2a.client.transport.grpc.GrpcTransportConfigBuilder; import io.a2a.client.transport.jsonrpc.JSONRPCTransport; @@ -71,13 +72,40 @@ public void shouldNotFindConfigurationTransport() throws A2AClientException { } @Test - public void shouldCreateJSONRPCClient() throws A2AClientException { + public void shouldNotCreateJSONRPCClient_nullHttpClientFactory() throws A2AClientException { + Assertions.assertThrows(IllegalArgumentException.class, + () -> { + Client + .builder(card) + .clientConfig(new ClientConfig.Builder().setUseClientPreference(true).build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() + .addInterceptor(null) + .httpClientBuilder(null)) + .build(); + }); + } + + @Test + public void shouldCreateJSONRPCClient_defaultHttpClientFactory() throws A2AClientException { + Client client = Client + .builder(card) + .clientConfig(new ClientConfig.Builder().setUseClientPreference(true).build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() + .addInterceptor(null) + .httpClientBuilder(HttpClientBuilder.DEFAULT_FACTORY)) + .build(); + + Assertions.assertNotNull(client); + } + + @Test + public void shouldCreateJSONRPCClient_withHttpClientFactory() throws A2AClientException { Client client = Client .builder(card) .clientConfig(new ClientConfig.Builder().setUseClientPreference(true).build()) .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() .addInterceptor(null) - .httpClient(null)) + .httpClientBuilder(new JdkHttpClientBuilder())) .build(); Assertions.assertNotNull(client); @@ -88,7 +116,7 @@ public void shouldCreateClient_differentConfigurations() throws A2AClientExcepti Client client = Client .builder(card) .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder()) - .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(new JdkA2AHttpClient())) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig()) .build(); Assertions.assertNotNull(client); diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java index 2023339d..d1943f27 100644 --- a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java @@ -11,8 +11,8 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; -import io.a2a.client.transport.spi.ClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; import io.a2a.client.transport.spi.interceptors.auth.AuthInterceptor; @@ -50,7 +50,7 @@ import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; -public class GrpcTransport implements ClientTransport { +public class GrpcTransport extends AbstractClientTransport { private static final Metadata.Key AUTHORIZATION_METADATA_KEY = Metadata.Key.of( AuthInterceptor.AUTHORIZATION, @@ -60,7 +60,6 @@ public class GrpcTransport implements ClientTransport { Metadata.ASCII_STRING_MARSHALLER); private final A2AServiceBlockingV2Stub blockingStub; private final A2AServiceStub asyncStub; - private final List interceptors; private AgentCard agentCard; public GrpcTransport(Channel channel, AgentCard agentCard) { @@ -68,11 +67,11 @@ public GrpcTransport(Channel channel, AgentCard agentCard) { } public GrpcTransport(Channel channel, AgentCard agentCard, List interceptors) { + super(interceptors); checkNotNullParam("channel", channel); this.asyncStub = A2AServiceGrpc.newStub(channel); this.blockingStub = A2AServiceGrpc.newBlockingV2Stub(channel); this.agentCard = agentCard; - this.interceptors = interceptors; } @Override @@ -365,17 +364,4 @@ private String getTaskPushNotificationConfigName(String taskId, String pushNotif return name.toString(); } - private PayloadAndHeaders applyInterceptors(String methodName, Object payload, - AgentCard agentCard, ClientCallContext clientCallContext) { - PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, - clientCallContext != null ? clientCallContext.getHeaders() : null); - if (interceptors != null && ! interceptors.isEmpty()) { - for (ClientCallInterceptor interceptor : interceptors) { - payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), - payloadAndHeaders.getHeaders(), agentCard, clientCallContext); - } - } - return payloadAndHeaders; - } - } \ No newline at end of file diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java index 8464911f..a1ff52bb 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransport.java @@ -3,21 +3,23 @@ import static io.a2a.util.Assert.checkNotNullParam; import java.io.IOException; +import java.net.URI; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; import java.util.function.Consumer; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import io.a2a.client.http.A2ACardResolver; +import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; -import io.a2a.client.http.JdkA2AHttpClient; -import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -59,8 +61,9 @@ import java.util.concurrent.atomic.AtomicReference; import io.a2a.util.Utils; +import org.jspecify.annotations.Nullable; -public class JSONRPCTransport implements ClientTransport { +public class JSONRPCTransport extends AbstractClientTransport { private static final TypeReference SEND_MESSAGE_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_TASK_RESPONSE_REFERENCE = new TypeReference<>() {}; @@ -71,9 +74,8 @@ public class JSONRPCTransport implements ClientTransport { private static final TypeReference DELETE_TASK_PUSH_NOTIFICATION_CONFIG_RESPONSE_REFERENCE = new TypeReference<>() {}; private static final TypeReference GET_AUTHENTICATED_EXTENDED_CARD_RESPONSE_REFERENCE = new TypeReference<>() {}; - private final A2AHttpClient httpClient; - private final String agentUrl; - private final List interceptors; + private final HttpClient httpClient; + private final String agentPath; private AgentCard agentCard; private boolean needsExtendedCard = false; @@ -81,21 +83,26 @@ public JSONRPCTransport(String agentUrl) { this(null, null, agentUrl, null); } - public JSONRPCTransport(AgentCard agentCard) { - this(null, agentCard, agentCard.url(), null); - } - - public JSONRPCTransport(A2AHttpClient httpClient, AgentCard agentCard, - String agentUrl, List interceptors) { - this.httpClient = httpClient == null ? new JdkA2AHttpClient() : httpClient; + public JSONRPCTransport(@Nullable HttpClient httpClient, @Nullable AgentCard agentCard, + String agentUrl, @Nullable List interceptors) { + super(interceptors); + this.httpClient = httpClient == null ? HttpClient.createHttpClient(agentUrl) : httpClient; this.agentCard = agentCard; - this.agentUrl = agentUrl; - this.interceptors = interceptors; + + String sAgentPath = URI.create(agentUrl).getPath(); + + // Strip the last slash if one is provided + if (sAgentPath.endsWith("/")) { + this.agentPath = sAgentPath.substring(0, sAgentPath.length() - 1); + } else { + this.agentPath = sAgentPath; + } + this.needsExtendedCard = agentCard == null || agentCard.supportsAuthenticatedExtendedCard(); } @Override - public EventKind sendMessage(MessageSendParams request, ClientCallContext context) throws A2AClientException { + public EventKind sendMessage(MessageSendParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SendMessageRequest sendMessageRequest = new SendMessageRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -103,8 +110,7 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex .params(request) .build(); // id will be randomly generated - PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendMessageRequest.METHOD, sendMessageRequest, - agentCard, context); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendMessageRequest.METHOD, sendMessageRequest, agentCard, context); try { String httpResponseBody = sendPostRequest(payloadAndHeaders); @@ -119,7 +125,7 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex @Override public void sendMessageStreaming(MessageSendParams request, Consumer eventConsumer, - Consumer errorConsumer, ClientCallContext context) throws A2AClientException { + Consumer errorConsumer, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); checkNotNullParam("eventConsumer", eventConsumer); SendStreamingMessageRequest sendStreamingMessageRequest = new SendStreamingMessageRequest.Builder() @@ -128,29 +134,33 @@ public void sendMessageStreaming(MessageSendParams request, Consumer> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); SSEEventListener sseEventListener = new SSEEventListener(eventConsumer, errorConsumer); try { - A2AHttpClient.PostBuilder builder = createPostBuilder(payloadAndHeaders); - ref.set(builder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders).asSSE(); + ref.set(builder.send() + .whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Send streaming message request timed out: " + e, e); } } @Override - public Task getTask(TaskQueryParams request, ClientCallContext context) throws A2AClientException { + public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskRequest getTaskRequest = new GetTaskRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -158,8 +168,7 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A .params(request) .build(); // id will be randomly generated - PayloadAndHeaders payloadAndHeaders = applyInterceptors(GetTaskRequest.METHOD, getTaskRequest, - agentCard, context); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(GetTaskRequest.METHOD, getTaskRequest, agentCard, context); try { String httpResponseBody = sendPostRequest(payloadAndHeaders); @@ -173,7 +182,7 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A } @Override - public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A2AClientException { + public Task cancelTask(TaskIdParams request, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); CancelTaskRequest cancelTaskRequest = new CancelTaskRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -181,8 +190,7 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A .params(request) .build(); // id will be randomly generated - PayloadAndHeaders payloadAndHeaders = applyInterceptors(CancelTaskRequest.METHOD, cancelTaskRequest, - agentCard, context); + PayloadAndHeaders payloadAndHeaders = applyInterceptors(CancelTaskRequest.METHOD, cancelTaskRequest, agentCard, context); try { String httpResponseBody = sendPostRequest(payloadAndHeaders); @@ -197,7 +205,7 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A @Override public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushNotificationConfig request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); SetTaskPushNotificationConfigRequest setTaskPushNotificationRequest = new SetTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -222,7 +230,7 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN @Override public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPushNotificationConfigParams request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); GetTaskPushNotificationConfigRequest getTaskPushNotificationRequest = new GetTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -248,7 +256,7 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu @Override public List listTaskPushNotificationConfigurations( ListTaskPushNotificationConfigParams request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); ListTaskPushNotificationConfigRequest listTaskPushNotificationRequest = new ListTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -273,7 +281,7 @@ public List listTaskPushNotificationConfigurations( @Override public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationConfigParams request, - ClientCallContext context) throws A2AClientException { + @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); DeleteTaskPushNotificationConfigRequest deleteTaskPushNotificationRequest = new DeleteTaskPushNotificationConfigRequest.Builder() .jsonrpc(JSONRPCMessage.JSONRPC_VERSION) @@ -296,7 +304,7 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC @Override public void resubscribe(TaskIdParams request, Consumer eventConsumer, - Consumer errorConsumer, ClientCallContext context) throws A2AClientException { + Consumer errorConsumer, @Nullable ClientCallContext context) throws A2AClientException { checkNotNullParam("request", request); checkNotNullParam("eventConsumer", eventConsumer); checkNotNullParam("errorConsumer", errorConsumer); @@ -309,30 +317,33 @@ public void resubscribe(TaskIdParams request, Consumer event PayloadAndHeaders payloadAndHeaders = applyInterceptors(TaskResubscriptionRequest.METHOD, taskResubscriptionRequest, agentCard, context); - AtomicReference> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); SSEEventListener sseEventListener = new SSEEventListener(eventConsumer, errorConsumer); try { - A2AHttpClient.PostBuilder builder = createPostBuilder(payloadAndHeaders); - ref.set(builder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders).asSSE(); + ref.set(builder.send().whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send task resubscription request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Task resubscription request timed out: " + e, e); } } @Override - public AgentCard getAgentCard(ClientCallContext context) throws A2AClientException { - A2ACardResolver resolver; + public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2AClientException { try { if (agentCard == null) { - resolver = new A2ACardResolver(httpClient, agentUrl, null, getHttpHeaders(context)); + A2ACardResolver resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); agentCard = resolver.getAgentCard(); needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); } @@ -368,30 +379,25 @@ public void close() { // no-op } - private PayloadAndHeaders applyInterceptors(String methodName, Object payload, - AgentCard agentCard, ClientCallContext clientCallContext) { - PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, getHttpHeaders(clientCallContext)); - if (interceptors != null && ! interceptors.isEmpty()) { - for (ClientCallInterceptor interceptor : interceptors) { - payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), - payloadAndHeaders.getHeaders(), agentCard, clientCallContext); + private String sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { + HttpClient.PostRequestBuilder builder = createPostBuilder(payloadAndHeaders); + try { + HttpResponse response = builder.send().get(); + if (!response.success()) { + throw new IOException("Request failed " + response.statusCode()); } - } - return payloadAndHeaders; - } + return response.body(); - private String sendPostRequest(PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { - A2AHttpClient.PostBuilder builder = createPostBuilder(payloadAndHeaders); - A2AHttpResponse response = builder.post(); - if (!response.success()) { - throw new IOException("Request failed " + response.status()); + } catch (ExecutionException e) { + if (e.getCause() instanceof IOException) { + throw (IOException) e.getCause(); + } + throw new IOException("Failed to send request", e.getCause()); } - return response.body(); } - private A2AHttpClient.PostBuilder createPostBuilder(PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException { - A2AHttpClient.PostBuilder postBuilder = httpClient.createPost() - .url(agentUrl) + private HttpClient.PostRequestBuilder createPostBuilder(PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException { + HttpClient.PostRequestBuilder postBuilder = httpClient.post(agentPath) .addHeader("Content-Type", "application/json") .body(Utils.OBJECT_MAPPER.writeValueAsString(payloadAndHeaders.getPayload())); @@ -414,7 +420,7 @@ private > T unmarshalResponse(String response, Type return value; } - private Map getHttpHeaders(ClientCallContext context) { + private Map getHttpHeaders(@Nullable ClientCallContext context) { return context != null ? context.getHeaders() : null; } } \ No newline at end of file diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java index efd3bbdf..2cdc4183 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfig.java @@ -1,21 +1,24 @@ package io.a2a.client.transport.jsonrpc; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfig; -import io.a2a.client.http.A2AHttpClient; +import io.a2a.util.Assert; +import org.jspecify.annotations.Nullable; public class JSONRPCTransportConfig extends ClientTransportConfig { - private final A2AHttpClient httpClient; + private final HttpClientBuilder httpClientBuilder; - public JSONRPCTransportConfig() { - this.httpClient = null; + public JSONRPCTransportConfig(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; } - public JSONRPCTransportConfig(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public JSONRPCTransportConfig() { + this.httpClientBuilder = HttpClientBuilder.DEFAULT_FACTORY; } - public A2AHttpClient getHttpClient() { - return httpClient; + public HttpClientBuilder getHttpClientBuilder() { + return this.httpClientBuilder; } } \ No newline at end of file diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java index 64153620..ed1956e3 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportConfigBuilder.java @@ -1,27 +1,23 @@ package io.a2a.client.transport.jsonrpc; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfigBuilder; +import io.a2a.util.Assert; public class JSONRPCTransportConfigBuilder extends ClientTransportConfigBuilder { - private A2AHttpClient httpClient; + private HttpClientBuilder httpClientBuilder = HttpClientBuilder.DEFAULT_FACTORY; - public JSONRPCTransportConfigBuilder httpClient(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public JSONRPCTransportConfigBuilder httpClientBuilder(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; return this; } @Override public JSONRPCTransportConfig build() { - // No HTTP client provided, fallback to the default one (JDK-based implementation) - if (httpClient == null) { - httpClient = new JdkA2AHttpClient(); - } - - JSONRPCTransportConfig config = new JSONRPCTransportConfig(httpClient); + JSONRPCTransportConfig config = new JSONRPCTransportConfig(httpClientBuilder); config.setInterceptors(this.interceptors); return config; } diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java index 97c22866..66de8dcf 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/JSONRPCTransportProvider.java @@ -1,6 +1,7 @@ package io.a2a.client.transport.jsonrpc; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportProvider; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -9,12 +10,20 @@ public class JSONRPCTransportProvider implements ClientTransportProvider { @Override - public JSONRPCTransport create(JSONRPCTransportConfig clientTransportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { - if (clientTransportConfig == null) { - clientTransportConfig = new JSONRPCTransportConfig(new JdkA2AHttpClient()); + public JSONRPCTransport create(JSONRPCTransportConfig transportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { + if (transportConfig == null) { + transportConfig = new JSONRPCTransportConfig(); } - return new JSONRPCTransport(clientTransportConfig.getHttpClient(), agentCard, agentUrl, clientTransportConfig.getInterceptors()); + HttpClientBuilder httpClientBuilder = transportConfig.getHttpClientBuilder(); + + try { + final HttpClient httpClient = httpClientBuilder.create(agentUrl); + + return new JSONRPCTransport(httpClient, agentCard, agentUrl, transportConfig.getInterceptors()); + } catch (Exception ex) { + throw new A2AClientException("Failed to create JSONRPC transport", ex); + } } @Override diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java index 99ca546c..af88c732 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java @@ -2,6 +2,9 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; import io.a2a.spec.JSONRPCError; import io.a2a.spec.StreamingEventKind; import io.a2a.spec.TaskStatusUpdateEvent; @@ -23,22 +26,28 @@ public SSEEventListener(Consumer eventHandler, this.errorHandler = errorHandler; } - public void onMessage(String message, Future completableFuture) { - try { - handleMessage(OBJECT_MAPPER.readTree(message),completableFuture); - } catch (JsonProcessingException e) { - log.warning("Failed to parse JSON message: " + message); + public void onMessage(Event event, Future completableFuture) { + log.fine("Streaming message received: " + event); + + if (event instanceof DataEvent) { + try { + handleMessage(OBJECT_MAPPER.readTree(((DataEvent) event).getData()), completableFuture); + } catch (JsonProcessingException e) { + log.warning("Failed to parse JSON message: " + ((DataEvent) event).getData()); + } } } - public void onError(Throwable throwable, Future future) { + public void onError(Throwable throwable, Future future) { if (errorHandler != null) { errorHandler.accept(throwable); } - future.cancel(true); // close SSE channel + if (future != null) { + future.cancel(true); // close SSE channel + } } - private void handleMessage(JsonNode jsonNode, Future future) { + private void handleMessage(JsonNode jsonNode, Future future) { try { if (jsonNode.has("error")) { JSONRPCError error = OBJECT_MAPPER.treeToValue(jsonNode.get("error"), JSONRPCError.class); diff --git a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java index 8c4c1495..0acfcea0 100644 --- a/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java +++ b/client/transport/jsonrpc/src/test/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListenerTest.java @@ -13,6 +13,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.DataEvent; import io.a2a.client.transport.jsonrpc.JsonStreamingMessages; import io.a2a.spec.Artifact; import io.a2a.spec.JSONRPCError; @@ -43,7 +45,7 @@ public void testOnEventWithTaskResult() throws Exception { JsonStreamingMessages.STREAMING_TASK_EVENT.indexOf("{")); // Call the onEvent method directly - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -68,7 +70,7 @@ public void testOnEventWithMessageResult() throws Exception { JsonStreamingMessages.STREAMING_MESSAGE_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -96,7 +98,7 @@ public void testOnEventWithTaskStatusUpdateEventEvent() throws Exception { JsonStreamingMessages.STREAMING_STATUS_UPDATE_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -122,7 +124,7 @@ public void testOnEventWithTaskArtifactUpdateEventEvent() throws Exception { JsonStreamingMessages.STREAMING_ARTIFACT_UPDATE_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -154,7 +156,7 @@ public void testOnEventWithError() throws Exception { JsonStreamingMessages.STREAMING_ERROR_EVENT.indexOf("{")); // Call onEvent method - listener.onMessage(eventData, null); + listener.onMessage(new DataEvent(null, eventData, null), null); // Verify the error was processed correctly assertNotNull(receivedError.get()); @@ -217,7 +219,7 @@ public void testOnEventWithFinalTaskStatusUpdateEventEventCancels() throws Excep // Call onEvent method CancelCapturingFuture future = new CancelCapturingFuture(); - listener.onMessage(eventData, future); + listener.onMessage(new DataEvent(null, eventData, null), future); // Verify the event was processed correctly assertNotNull(receivedEvent.get()); @@ -232,7 +234,7 @@ public void testOnEventWithFinalTaskStatusUpdateEventEventCancels() throws Excep } - private static class CancelCapturingFuture implements Future { + private static class CancelCapturingFuture implements Future { private boolean cancelHandlerCalled; public CancelCapturingFuture() { @@ -255,12 +257,12 @@ public boolean isDone() { } @Override - public Void get() throws InterruptedException, ExecutionException { + public HttpResponse get() throws InterruptedException, ExecutionException { return null; } @Override - public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + public HttpResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { return null; } } diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java index 965cc296..85bf962b 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestErrorMapper.java @@ -4,7 +4,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.HttpResponse; import io.a2a.spec.A2AClientException; import io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError; import io.a2a.spec.ContentTypeNotSupportedError; @@ -28,8 +28,8 @@ public class RestErrorMapper { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().registerModule(new JavaTimeModule()); - public static A2AClientException mapRestError(A2AHttpResponse response) { - return RestErrorMapper.mapRestError(response.body(), response.status()); + public static A2AClientException mapRestError(HttpResponse response) { + return RestErrorMapper.mapRestError(response.body(), response.statusCode()); } public static A2AClientException mapRestError(String body, int code) { diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java index f659589b..912c0082 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java @@ -7,11 +7,10 @@ import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.util.JsonFormat; import io.a2a.client.http.A2ACardResolver; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.client.transport.rest.sse.RestSSEEventListener; -import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.client.transport.spi.AbstractClientTransport; import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; @@ -38,8 +37,11 @@ import io.a2a.spec.SetTaskPushNotificationConfigRequest; import io.a2a.util.Utils; import java.io.IOException; +import java.net.URI; import java.util.Collections; import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; import java.util.logging.Logger; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -47,25 +49,31 @@ import java.util.function.Consumer; import org.jspecify.annotations.Nullable; -public class RestTransport implements ClientTransport { +public class RestTransport extends AbstractClientTransport { private static final Logger log = Logger.getLogger(RestTransport.class.getName()); - private final A2AHttpClient httpClient; - private final String agentUrl; - private @Nullable final List interceptors; - private AgentCard agentCard; + private final HttpClient httpClient; + private final String agentPath; + private @Nullable AgentCard agentCard; private boolean needsExtendedCard = false; - public RestTransport(AgentCard agentCard) { - this(null, agentCard, agentCard.url(), null); + public RestTransport(String agentUrl) { + this(null, null, agentUrl, null); } - public RestTransport(@Nullable A2AHttpClient httpClient, AgentCard agentCard, + public RestTransport(@Nullable HttpClient httpClient, @Nullable AgentCard agentCard, String agentUrl, @Nullable List interceptors) { - this.httpClient = httpClient == null ? new JdkA2AHttpClient() : httpClient; + super(interceptors); + this.httpClient = httpClient == null ? HttpClient.createHttpClient(agentUrl) : httpClient; this.agentCard = agentCard; - this.agentUrl = agentUrl.endsWith("/") ? agentUrl.substring(0, agentUrl.length() - 1) : agentUrl; - this.interceptors = interceptors; + String sAgentPath = URI.create(agentUrl).getPath(); + + // Strip the last slash if one is provided + if (sAgentPath.endsWith("/")) { + this.agentPath = sAgentPath.substring(0, sAgentPath.length() - 1); + } else { + this.agentPath = sAgentPath; + } } @Override @@ -74,7 +82,7 @@ public EventKind sendMessage(MessageSendParams messageSendParams, @Nullable Clie io.a2a.grpc.SendMessageRequest.Builder builder = io.a2a.grpc.SendMessageRequest.newBuilder(ProtoUtils.ToProto.sendMessageRequest(messageSendParams)); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, builder, agentCard, context); try { - String httpResponseBody = sendPostRequest(agentUrl + "/v1/message:send", payloadAndHeaders); + String httpResponseBody = sendPostRequest("/v1/message:send", payloadAndHeaders); io.a2a.grpc.SendMessageResponse.Builder responseBuilder = io.a2a.grpc.SendMessageResponse.newBuilder(); JsonFormat.parser().merge(httpResponseBody, responseBuilder); if (responseBuilder.hasMsg()) { @@ -86,7 +94,7 @@ public EventKind sendMessage(MessageSendParams messageSendParams, @Nullable Clie throw new A2AClientException("Failed to send message, wrong response:" + httpResponseBody); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to send message: " + e, e); } } @@ -99,20 +107,24 @@ public void sendMessageStreaming(MessageSendParams messageSendParams, Consumer> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); try { - A2AHttpClient.PostBuilder postBuilder = createPostBuilder(agentUrl + "/v1/message:stream", payloadAndHeaders); - ref.set(postBuilder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + HttpClient.PostRequestBuilder postBuilder = createPostBuilder("/v1/message:stream", payloadAndHeaders).asSSE(); + ref.set(postBuilder.send().whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Send streaming message request timed out: " + e, e); } } @@ -124,19 +136,20 @@ public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, builder, agentCard, context); try { - String url; + String path; if (taskQueryParams.historyLength() != null) { - url = agentUrl + String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); + path = String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength()); } else { - url = agentUrl + String.format("/v1/tasks/%1s", taskQueryParams.id()); + path = String.format("/v1/tasks/%1s", taskQueryParams.id()); } - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -146,7 +159,7 @@ public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext return ProtoUtils.FromProto.task(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to get task: " + e, e); } } @@ -159,13 +172,13 @@ public Task cancelTask(TaskIdParams taskIdParams, @Nullable ClientCallContext co PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, builder, agentCard, context); try { - String httpResponseBody = sendPostRequest(agentUrl + String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders); + String httpResponseBody = sendPostRequest(String.format("/v1/tasks/%1s:cancel", taskIdParams.id()), payloadAndHeaders); io.a2a.grpc.Task.Builder responseBuilder = io.a2a.grpc.Task.newBuilder(); JsonFormat.parser().merge(httpResponseBody, responseBuilder); return ProtoUtils.FromProto.task(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to cancel task: " + e, e); } } @@ -181,13 +194,13 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN } PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String httpResponseBody = sendPostRequest(agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders); + String httpResponseBody = sendPostRequest(String.format("/v1/tasks/%1s/pushNotificationConfigs", request.taskId()), payloadAndHeaders); io.a2a.grpc.TaskPushNotificationConfig.Builder responseBuilder = io.a2a.grpc.TaskPushNotificationConfig.newBuilder(); JsonFormat.parser().merge(httpResponseBody, responseBuilder); return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to set task push notification config: " + e, e); } } @@ -200,14 +213,17 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -217,7 +233,7 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(GetTaskPu return ProtoUtils.FromProto.taskPushNotificationConfig(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to get push notifications: " + e, e); } } @@ -230,14 +246,16 @@ public List listTaskPushNotificationConfigurations(L PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs", request.id()); + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -247,7 +265,7 @@ public List listTaskPushNotificationConfigurations(L return ProtoUtils.FromProto.listTaskPushNotificationConfigParams(responseBuilder); } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to list push notifications: " + e, e); } } @@ -259,20 +277,22 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD, builder, agentCard, context); try { - String url = agentUrl + String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); - A2AHttpClient.DeleteBuilder deleteBuilder = httpClient.createDelete().url(url); + String path = String.format("/v1/tasks/%1s/pushNotificationConfigs/%2s", request.id(), request.pushNotificationConfigId()); + HttpClient.DeleteRequestBuilder deleteBuilder = httpClient.delete(agentPath + path); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { deleteBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = deleteBuilder.delete(); + CompletableFuture responseFut = deleteBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } } catch (A2AClientException e) { throw e; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to delete push notification config: " + e, e); } } @@ -285,21 +305,25 @@ public void resubscribe(TaskIdParams request, Consumer event builder.setName("tasks/" + request.id()); PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.TaskResubscriptionRequest.METHOD, builder, agentCard, context); - AtomicReference> ref = new AtomicReference<>(); + AtomicReference> ref = new AtomicReference<>(); RestSSEEventListener sseEventListener = new RestSSEEventListener(eventConsumer, errorConsumer); try { - String url = agentUrl + String.format("/v1/tasks/%1s:subscribe", request.id()); - A2AHttpClient.PostBuilder postBuilder = createPostBuilder(url, payloadAndHeaders); - ref.set(postBuilder.postAsyncSSE( - msg -> sseEventListener.onMessage(msg, ref.get()), - throwable -> sseEventListener.onError(throwable, ref.get()), - () -> { - // We don't need to do anything special on completion - })); + String path = String.format("/v1/tasks/%1s:subscribe", request.id()); + HttpClient.PostRequestBuilder postBuilder = createPostBuilder(path, payloadAndHeaders).asSSE(); + ref.set(postBuilder.send().whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (httpResponse != null) { + httpResponse.bodyAsSse( + msg -> sseEventListener.onMessage(msg, ref.get()), + cause -> sseEventListener.onError(cause, ref.get())); + } else { + errorConsumer.accept(throwable); + } + } + })); } catch (IOException e) { throw new A2AClientException("Failed to send streaming message request: " + e, e); - } catch (InterruptedException e) { - throw new A2AClientException("Send streaming message request timed out: " + e, e); } } @@ -308,7 +332,7 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli A2ACardResolver resolver; try { if (agentCard == null) { - resolver = new A2ACardResolver(httpClient, agentUrl, null, getHttpHeaders(context)); + resolver = new A2ACardResolver(httpClient, agentPath, getHttpHeaders(context)); agentCard = resolver.getAgentCard(); needsExtendedCard = agentCard.supportsAuthenticatedExtendedCard(); } @@ -317,14 +341,16 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli } PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, null, agentCard, context); - String url = agentUrl + String.format("/v1/card"); - A2AHttpClient.GetBuilder getBuilder = httpClient.createGet().url(url); + + HttpClient.GetRequestBuilder getBuilder = httpClient.get(agentPath + "/v1/card"); if (payloadAndHeaders.getHeaders() != null) { for (Map.Entry entry : payloadAndHeaders.getHeaders().entrySet()) { getBuilder.addHeader(entry.getKey(), entry.getValue()); } } - A2AHttpResponse response = getBuilder.get(); + CompletableFuture responseFut = getBuilder.send(); + HttpResponse response = responseFut.get(); + if (!response.success()) { throw RestErrorMapper.mapRestError(response); } @@ -332,7 +358,7 @@ public AgentCard getAgentCard(@Nullable ClientCallContext context) throws A2ACli agentCard = Utils.OBJECT_MAPPER.readValue(httpResponseBody, AgentCard.class); needsExtendedCard = false; return agentCard; - } catch (IOException | InterruptedException e) { + } catch (IOException | InterruptedException | ExecutionException e) { throw new A2AClientException("Failed to get authenticated extended agent card: " + e, e); } catch (A2AClientError e) { throw new A2AClientException("Failed to get agent card: " + e, e); @@ -344,21 +370,11 @@ public void close() { // no-op } - private PayloadAndHeaders applyInterceptors(String methodName, @Nullable MessageOrBuilder payload, - AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { - PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, getHttpHeaders(clientCallContext)); - if (interceptors != null && !interceptors.isEmpty()) { - for (ClientCallInterceptor interceptor : interceptors) { - payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), - payloadAndHeaders.getHeaders(), agentCard, clientCallContext); - } - } - return payloadAndHeaders; - } + private String sendPostRequest(String path, PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException, ExecutionException { + HttpClient.PostRequestBuilder builder = createPostBuilder(path, payloadAndHeaders); + CompletableFuture responseFut = builder.send(); - private String sendPostRequest(String url, PayloadAndHeaders payloadAndHeaders) throws IOException, InterruptedException { - A2AHttpClient.PostBuilder builder = createPostBuilder(url, payloadAndHeaders); - A2AHttpResponse response = builder.post(); + HttpResponse response = responseFut.get(); if (!response.success()) { log.fine("Error on POST processing " + JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); throw RestErrorMapper.mapRestError(response); @@ -366,10 +382,9 @@ private String sendPostRequest(String url, PayloadAndHeaders payloadAndHeaders) return response.body(); } - private A2AHttpClient.PostBuilder createPostBuilder(String url, PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException, InvalidProtocolBufferException { + private HttpClient.PostRequestBuilder createPostBuilder(String path, PayloadAndHeaders payloadAndHeaders) throws JsonProcessingException, InvalidProtocolBufferException { log.fine(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); - A2AHttpClient.PostBuilder postBuilder = httpClient.createPost() - .url(url) + HttpClient.PostRequestBuilder postBuilder = httpClient.post(agentPath + path) .addHeader("Content-Type", "application/json") .body(JsonFormat.printer().print((MessageOrBuilder) payloadAndHeaders.getPayload())); diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java index d097b010..21b694ce 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfig.java @@ -1,22 +1,23 @@ package io.a2a.client.transport.rest; -import io.a2a.client.http.A2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfig; -import org.jspecify.annotations.Nullable; +import io.a2a.util.Assert; public class RestTransportConfig extends ClientTransportConfig { - private final @Nullable A2AHttpClient httpClient; + private final HttpClientBuilder httpClientBuilder; - public RestTransportConfig() { - this.httpClient = null; + public RestTransportConfig(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; } - public RestTransportConfig(A2AHttpClient httpClient) { - this.httpClient = httpClient; + public RestTransportConfig() { + this.httpClientBuilder = HttpClientBuilder.DEFAULT_FACTORY; } - public @Nullable A2AHttpClient getHttpClient() { - return httpClient; + public HttpClientBuilder getHttpClientBuilder() { + return httpClientBuilder; } } \ No newline at end of file diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java index 68150f18..edcbcd1c 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportConfigBuilder.java @@ -1,27 +1,24 @@ package io.a2a.client.transport.rest; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportConfigBuilder; -import org.jspecify.annotations.Nullable; + +import io.a2a.util.Assert; public class RestTransportConfigBuilder extends ClientTransportConfigBuilder { - private @Nullable A2AHttpClient httpClient; + private HttpClientBuilder httpClientBuilder = io.a2a.client.http.HttpClientBuilder.DEFAULT_FACTORY; + + public RestTransportConfigBuilder httpClientBuilder(HttpClientBuilder httpClientBuilder) { + Assert.checkNotNullParam("httpClientBuilder", httpClientBuilder); + this.httpClientBuilder = httpClientBuilder; - public RestTransportConfigBuilder httpClient(A2AHttpClient httpClient) { - this.httpClient = httpClient; return this; } @Override public RestTransportConfig build() { - // No HTTP client provided, fallback to the default one (JDK-based implementation) - if (httpClient == null) { - httpClient = new JdkA2AHttpClient(); - } - - RestTransportConfig config = new RestTransportConfig(httpClient); + RestTransportConfig config = new RestTransportConfig(this.httpClientBuilder); config.setInterceptors(this.interceptors); return config; } diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java index 99d15596..cd03086c 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransportProvider.java @@ -1,6 +1,7 @@ package io.a2a.client.transport.rest; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; import io.a2a.client.transport.spi.ClientTransportProvider; import io.a2a.spec.A2AClientException; import io.a2a.spec.AgentCard; @@ -14,12 +15,20 @@ public String getTransportProtocol() { } @Override - public RestTransport create(RestTransportConfig clientTransportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { - RestTransportConfig transportConfig = clientTransportConfig; - if (transportConfig == null) { - transportConfig = new RestTransportConfig(new JdkA2AHttpClient()); + public RestTransport create(RestTransportConfig transportConfig, AgentCard agentCard, String agentUrl) throws A2AClientException { + if (transportConfig == null) { + transportConfig = new RestTransportConfig(); + } + + HttpClientBuilder httpClientBuilder = transportConfig.getHttpClientBuilder(); + + try { + final HttpClient httpClient = httpClientBuilder.create(agentUrl); + + return new RestTransport(httpClient, agentCard, agentUrl, transportConfig.getInterceptors()); + } catch (Exception ex) { + throw new A2AClientException("Failed to create REST transport", ex); } - return new RestTransport(clientTransportConfig.getHttpClient(), agentCard, agentUrl, transportConfig.getInterceptors()); } @Override diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java index d0b130ee..2afd586e 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java @@ -1,20 +1,19 @@ package io.a2a.client.transport.rest.sse; -import static io.a2a.grpc.StreamResponse.PayloadCase.ARTIFACT_UPDATE; -import static io.a2a.grpc.StreamResponse.PayloadCase.MSG; -import static io.a2a.grpc.StreamResponse.PayloadCase.STATUS_UPDATE; -import static io.a2a.grpc.StreamResponse.PayloadCase.TASK; - import java.util.concurrent.Future; import java.util.function.Consumer; import java.util.logging.Logger; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; import io.a2a.client.transport.rest.RestErrorMapper; import io.a2a.grpc.StreamResponse; import io.a2a.grpc.utils.ProtoUtils; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; public class RestSSEEventListener { @@ -29,18 +28,21 @@ public RestSSEEventListener(Consumer eventHandler, this.errorHandler = errorHandler; } - public void onMessage(String message, @Nullable Future completableFuture) { - try { - log.fine("Streaming message received: " + message); - io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); - JsonFormat.parser().merge(message, builder); - handleMessage(builder.build()); - } catch (InvalidProtocolBufferException e) { - errorHandler.accept(RestErrorMapper.mapRestError(message, 500)); + public void onMessage(Event event, @Nullable Future completableFuture) { + log.fine("Streaming message received: " + event); + + if (event instanceof DataEvent) { + try { + io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); + JsonFormat.parser().merge(((DataEvent) event).getData(), builder); + handleMessage(builder.build(), completableFuture); + } catch (InvalidProtocolBufferException e) { + errorHandler.accept(RestErrorMapper.mapRestError(((DataEvent) event).getData(), 500)); + } } } - public void onError(Throwable throwable, @Nullable Future future) { + public void onError(Throwable throwable, @Nullable Future future) { if (errorHandler != null) { errorHandler.accept(throwable); } @@ -49,15 +51,19 @@ public void onError(Throwable throwable, @Nullable Future future) { } } - private void handleMessage(StreamResponse response) { + private void handleMessage(StreamResponse response, @Nullable Future future) { StreamingEventKind event; switch (response.getPayloadCase()) { case MSG -> event = ProtoUtils.FromProto.message(response.getMsg()); case TASK -> event = ProtoUtils.FromProto.task(response.getTask()); - case STATUS_UPDATE -> + case STATUS_UPDATE -> { event = ProtoUtils.FromProto.taskStatusUpdateEvent(response.getStatusUpdate()); + if (((TaskStatusUpdateEvent) event).isFinal() && future != null) { + future.cancel(true); // close SSE channel + } + } case ARTIFACT_UPDATE -> event = ProtoUtils.FromProto.taskArtifactUpdateEvent(response.getArtifactUpdate()); default -> { @@ -68,5 +74,4 @@ private void handleMessage(StreamResponse response) { } eventHandler.accept(event); } - } diff --git a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java index a296553c..ae938cb4 100644 --- a/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java +++ b/client/transport/rest/src/test/java/io/a2a/client/transport/rest/RestTransportTest.java @@ -1,6 +1,5 @@ package io.a2a.client.transport.rest; - import static io.a2a.client.transport.rest.JsonRestMessages.CANCEL_TASK_TEST_REQUEST; import static io.a2a.client.transport.rest.JsonRestMessages.CANCEL_TASK_TEST_RESPONSE; import static io.a2a.client.transport.rest.JsonRestMessages.GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE; @@ -22,9 +21,6 @@ import static org.mockserver.model.HttpResponse.response; import io.a2a.client.transport.spi.interceptors.ClientCallContext; -import io.a2a.spec.AgentCapabilities; -import io.a2a.spec.AgentCard; -import io.a2a.spec.AgentSkill; import io.a2a.spec.Artifact; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.EventKind; @@ -67,28 +63,7 @@ public class RestTransportTest { private static final Logger log = Logger.getLogger(RestTransportTest.class.getName()); private ClientAndServer server; - private static final AgentCard CARD = new AgentCard.Builder() - .name("Hello World Agent") - .description("Just a hello world agent") - .url("http://localhost:4001") - .version("1.0.0") - .documentationUrl("http://example.com/docs") - .capabilities(new AgentCapabilities.Builder() - .streaming(true) - .pushNotifications(true) - .stateTransitionHistory(true) - .build()) - .defaultInputModes(Collections.singletonList("text")) - .defaultOutputModes(Collections.singletonList("text")) - .skills(Collections.singletonList(new AgentSkill.Builder() - .id("hello_world") - .name("Returns hello world") - .description("just returns hello world") - .tags(Collections.singletonList("hello world")) - .examples(List.of("hi", "hello world")) - .build())) - .protocolVersion("0.3.0") - .build(); + private static final String AGENT_URL = "http://localhost:4001"; @BeforeEach public void setUp() throws IOException { @@ -129,7 +104,7 @@ public void testSendMessage() throws Exception { MessageSendParams messageSendParams = new MessageSendParams(message, null, null); ClientCallContext context = null; - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); EventKind result = instance.sendMessage(messageSendParams, context); assertEquals("task", result.getKind()); Task task = (Task) result; @@ -170,7 +145,7 @@ public void testCancelTask() throws Exception { .withBody(CANCEL_TASK_TEST_RESPONSE) ); ClientCallContext context = null; - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); Task task = instance.cancelTask(new TaskIdParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), context); assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); @@ -196,7 +171,7 @@ public void testGetTask() throws Exception { ); ClientCallContext context = null; TaskQueryParams request = new TaskQueryParams("de38c76d-d54c-436c-8b9f-4c2703648d64", 10); - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); Task task = instance.getTask(request, context); assertEquals("de38c76d-d54c-436c-8b9f-4c2703648d64", task.getId()); assertEquals(TaskState.COMPLETED, task.getStatus().state()); @@ -248,7 +223,7 @@ public void testSendMessageStreaming() throws Exception { .withBody(SEND_MESSAGE_STREAMING_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); Message message = new Message.Builder() .role(Message.Role.USER) .parts(Collections.singletonList(new TextPart("tell me some jokes"))) @@ -298,7 +273,7 @@ public void testSetTaskPushNotificationConfiguration() throws Exception { .withStatusCode(200) .withBody(SET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); TaskPushNotificationConfig pushedConfig = new TaskPushNotificationConfig( "de38c76d-d54c-436c-8b9f-4c2703648d64", new PushNotificationConfig.Builder() @@ -331,7 +306,7 @@ public void testGetTaskPushNotificationConfiguration() throws Exception { .withBody(GET_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); TaskPushNotificationConfig taskPushNotificationConfig = client.getTaskPushNotificationConfiguration( new GetTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10", new HashMap<>()), null); @@ -359,7 +334,7 @@ public void testListTaskPushNotificationConfigurations() throws Exception { .withBody(LIST_TASK_PUSH_NOTIFICATION_CONFIG_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); List taskPushNotificationConfigs = client.listTaskPushNotificationConfigurations( new ListTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", new HashMap<>()), null); assertEquals(2, taskPushNotificationConfigs.size()); @@ -395,7 +370,7 @@ public void testDeleteTaskPushNotificationConfigurations() throws Exception { .withStatusCode(200) ); ClientCallContext context = null; - RestTransport instance = new RestTransport(CARD); + RestTransport instance = new RestTransport(AGENT_URL); instance.deleteTaskPushNotificationConfigurations(new DeleteTaskPushNotificationConfigParams("de38c76d-d54c-436c-8b9f-4c2703648d64", "10"), context); } @@ -418,7 +393,7 @@ public void testResubscribe() throws Exception { .withBody(TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE) ); - RestTransport client = new RestTransport(CARD); + RestTransport client = new RestTransport(AGENT_URL); TaskIdParams taskIdParams = new TaskIdParams("task-1234"); AtomicReference receivedEvent = new AtomicReference<>(); diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java new file mode 100644 index 00000000..fff6f284 --- /dev/null +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/AbstractClientTransport.java @@ -0,0 +1,31 @@ +package io.a2a.client.transport.spi; + +import io.a2a.client.transport.spi.interceptors.ClientCallContext; +import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor; +import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders; +import io.a2a.spec.AgentCard; +import org.jspecify.annotations.Nullable; + +import java.util.List; + +public abstract class AbstractClientTransport implements ClientTransport { + + private final @Nullable List interceptors; + + public AbstractClientTransport(@Nullable List interceptors) { + this.interceptors = interceptors; + } + + protected PayloadAndHeaders applyInterceptors(String methodName, @Nullable Object payload, + @Nullable AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { + PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload, + clientCallContext != null ? clientCallContext.getHeaders() : null); + if (interceptors != null && ! interceptors.isEmpty()) { + for (ClientCallInterceptor interceptor : interceptors) { + payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(), + payloadAndHeaders.getHeaders(), agentCard, clientCallContext); + } + } + return payloadAndHeaders; + } +} diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java index 41141298..b8a8de79 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/ClientCallInterceptor.java @@ -23,5 +23,5 @@ public abstract class ClientCallInterceptor { * @return the potentially modified payload and headers */ public abstract PayloadAndHeaders intercept(String methodName, @Nullable Object payload, Map headers, - AgentCard agentCard, @Nullable ClientCallContext clientCallContext); + @Nullable AgentCard agentCard, @Nullable ClientCallContext clientCallContext); } diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java index 4783cb71..816ad3e5 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/PayloadAndHeaders.java @@ -10,7 +10,7 @@ public class PayloadAndHeaders { private final @Nullable Object payload; private final Map headers; - public PayloadAndHeaders(@Nullable Object payload, Map headers) { + public PayloadAndHeaders(@Nullable Object payload, @Nullable Map headers) { this.payload = payload; this.headers = headers == null ? Collections.emptyMap() : new HashMap<>(headers); } diff --git a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java index d2f2a576..8fda4ca4 100644 --- a/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java +++ b/client/transport/spi/src/main/java/io/a2a/client/transport/spi/interceptors/auth/AuthInterceptor.java @@ -33,7 +33,7 @@ public AuthInterceptor(final CredentialService credentialService) { @Override public PayloadAndHeaders intercept(String methodName, @Nullable Object payload, Map headers, - AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { + @Nullable AgentCard agentCard, @Nullable ClientCallContext clientCallContext) { Map updatedHeaders = new HashMap<>(headers == null ? new HashMap<>() : headers); if (agentCard == null || agentCard.security() == null || agentCard.securitySchemes() == null) { return new PayloadAndHeaders(payload, updatedHeaders); diff --git a/extras/README.md b/extras/README.md index 3f85e4f9..19807a8f 100644 --- a/extras/README.md +++ b/extras/README.md @@ -6,4 +6,5 @@ Please see the README's of each child directory for more details. [`task-store-database-jpa`](./task-store-database-jpa/README.md) - Replaces the default `InMemoryTaskStore` with a `TaskStore` backed by a RDBMS. It uses JPA to interact with the RDBMS. [`push-notification-config-store-database-jpa`](./push-notification-config-store-database-jpa/README.md) - Replaces the default `InMemoryPushNotificationConfigStore` with a `PushNotificationConfigStore` backed by a RDBMS. It uses JPA to interact with the RDBMS. -[`queue-manager-replicated`](./queue-manager-replicated/README.md) - Replaces the default `InMemoryQueueManager` with a `QueueManager` supporting replication to other A2A servers implementing the same agent. You can write your own `ReplicationStrategy`, or use the provided `MicroProfile Reactive Messaging implementation`. \ No newline at end of file +[`queue-manager-replicated`](./queue-manager-replicated/README.md) - Replaces the default `InMemoryQueueManager` with a `QueueManager` supporting replication to other A2A servers implementing the same agent. You can write your own `ReplicationStrategy`, or use the provided `MicroProfile Reactive Messaging implementation`. +[`vertx-http-client`](./vertx-http-client/README.md) - Replaces the default `HttpClient` JDK implementation with a http-client implementation backed by Vertx, better suited for Quarkus applications. \ No newline at end of file diff --git a/extras/http-client-vertx/README.md b/extras/http-client-vertx/README.md new file mode 100644 index 00000000..f18b90d3 --- /dev/null +++ b/extras/http-client-vertx/README.md @@ -0,0 +1,76 @@ +# A2A Java SDK - Vertx HTTP Client + +This module provides an HTTP client implementation of the `HttpClient` interface that relies on Vertx for the HTTP transport communication. + +By default, the A2A client is relying on the default JDK HttpClient implementation. While this one is convenient for most of use-cases, it may still +be relevant to switch to the Vertx based implementation, especially when your current code is already relying on Vertx or if your A2A server is based on Quarkus which, itself, heavily relies on Vertx. + +## Quick Start + +This section will get you up and running quickly with a `Client` using the `VertxHttpClient` implementation. + +### 1. Add Dependency + +Add this module to your project's `pom.xml`: + +```xml + + io.github.a2asdk + a2a-java-extras-http-client-vertx + ${a2a.version} + +``` + +### 2. Configure Client + +##### JSON-RPC Transport Configuration + +For the JSON-RPC transport, to use the default `JdkHttpClient`, provide a `JSONRPCTransportConfig` created with its default constructor. + +To use a custom HTTP client implementation, simply create a `JSONRPCTransportConfig` as follows: + +```java +import io.a2a.client.http.vertx.VertxHttpClientBuilder; + +// Create a Vertx HTTP client +HttpClientBuilder vertxHttpClientBuilder = new VertxHttpClientBuilder(); + +// Configure the client settings +ClientConfig clientConfig = new ClientConfig.Builder() + .setAcceptedOutputModes(List.of("text")) + .build(); + +Client client = Client + .builder(agentCard) + .clientConfig(clientConfig) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(vertxHttpClientBuilder)) + .build(); +``` + +## Configuration Options + +This implementation allows to pass the Vertx context you want to rely on, but also the HTTPClientOptions, in case +you want / need to provide some extended configuration's properties such as a better of management of SSL Context, or an HTTP proxy. + +```java +import io.a2a.client.http.vertx.VertxHttpClientBuilder; +import io.vertx.core.Vertx; +import io.vertx.core.http.HttpClientOptions; +import io.vertx.core.net.ProxyOptions; + +// Create a Vertx HTTP client +HttpClientBuilder vertxHttpClientBuilder = new VertxHttpClientBuilder() + .vertx(Vertx.vertx()) + .options(new HttpClientOptions().setProxyOptions(new ProxyOptions().setHost("host").setPort("1234"))); + + // Configure the client settings + ClientConfig clientConfig = new ClientConfig.Builder() + .setAcceptedOutputModes(List.of("text")) + .build(); + + Client client = Client + .builder(agentCard) + .clientConfig(clientConfig) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfig(vertxHttpClientBuilder)) + .build(); +``` diff --git a/extras/http-client-vertx/pom.xml b/extras/http-client-vertx/pom.xml new file mode 100644 index 00000000..656f1943 --- /dev/null +++ b/extras/http-client-vertx/pom.xml @@ -0,0 +1,57 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta3-SNAPSHOT + ../../pom.xml + + a2a-java-extras-http-client-vertx + + jar + + Java A2A Extras: Vertx HTTP Client + Java SDK for the Agent2Agent Protocol (A2A) - Extras - Vertx HTTP Client + + + + ${project.groupId} + a2a-java-sdk-http-client + + + + ${project.groupId} + a2a-java-sdk-client + test + + + + ${project.groupId} + a2a-java-sdk-tests-client-common + test-jar + test + + + + io.vertx + vertx-core + + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.wiremock + wiremock + 3.13.1 + test + + + \ No newline at end of file diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java new file mode 100644 index 00000000..62284dc6 --- /dev/null +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClient.java @@ -0,0 +1,216 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.Event; +import io.a2a.client.http.vertx.sse.SSEHandler; +import io.a2a.common.A2AErrorMessages; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.vertx.core.*; +import io.vertx.core.http.*; + +import java.io.IOException; +import java.net.*; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import java.util.function.Function; + +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; + +public class VertxHttpClient implements HttpClient { + + private final io.vertx.core.http.HttpClient client; + + private final Vertx vertx; + + VertxHttpClient(String baseUrl, Vertx vertx, HttpClientOptions options) { + this.vertx = vertx; + this.client = initClient(baseUrl, options); + } + + private io.vertx.core.http.HttpClient initClient(String baseUrl, HttpClientOptions options) { + URL targetUrl = buildUrl(baseUrl); + + return this.vertx.createHttpClient(options + .setDefaultHost(targetUrl.getHost()) + .setDefaultPort(targetUrl.getPort() != -1 ? targetUrl.getPort() : targetUrl.getDefaultPort()) + .setSsl(isSecureProtocol(targetUrl.getProtocol()))); + } + + @Override + public GetRequestBuilder get(String path) { + return new VertxGetRequestBuilder(path); + } + + @Override + public PostRequestBuilder post(String path) { + return new VertxPostRequestBuilder(path); + } + + @Override + public DeleteRequestBuilder delete(String path) { + return new VertxDeleteRequestBuilder(path); + } + + private static final URLStreamHandler URL_HANDLER = new URLStreamHandler() { + protected URLConnection openConnection(URL u) { + return null; + } + }; + + private static URL buildUrl(String uri) { + try { + return new URL(null, uri, URL_HANDLER); + } catch (MalformedURLException var2) { + throw new IllegalArgumentException("URI [" + uri + "] is not valid"); + } + } + + private static boolean isSecureProtocol(String protocol) { + return protocol.charAt(protocol.length() - 1) == 's' && protocol.length() > 2; + } + + private abstract class VertxRequestBuilder> implements RequestBuilder { + protected final Future request; + protected final Map headers = new HashMap<>(); + + public VertxRequestBuilder(String path, HttpMethod method) { + this.request = client.request(method, path); + } + + @Override + public T addHeader(String name, String value) { + headers.put(name, value); + return self(); + } + + @Override + public T addHeaders(Map headers) { + if (headers != null && ! headers.isEmpty()) { + for (Map.Entry entry : headers.entrySet()) { + addHeader(entry.getKey(), entry.getValue()); + } + } + return self(); + } + + @SuppressWarnings("unchecked") + T self() { + return (T) this; + } + + protected Future sendRequest() { + return sendRequest(Optional.empty()); + } + + protected Future sendRequest(Optional body) { + return request + .compose(new Function>() { + @Override + public Future apply(HttpClientRequest request) { + // Prepare the request + request.headers().addAll(headers); + + if (body.isPresent()) { + return request.send(body.get()); + } else { + return request.send(); + } + } + }); + } + + @Override + public CompletableFuture send() { + return sendRequest() + .compose(RESPONSE_MAPPER) + .toCompletionStage() + .toCompletableFuture(); + } + } + + private class VertxGetRequestBuilder extends VertxRequestBuilder implements GetRequestBuilder { + + public VertxGetRequestBuilder(String path) { + super(path, HttpMethod.GET); + } + } + + private class VertxDeleteRequestBuilder extends VertxRequestBuilder implements DeleteRequestBuilder { + + public VertxDeleteRequestBuilder(String path) { + super(path, HttpMethod.DELETE); + } + } + + private class VertxPostRequestBuilder extends VertxRequestBuilder implements PostRequestBuilder { + String body = ""; + + public VertxPostRequestBuilder(String path) { + super(path, HttpMethod.POST); + } + + @Override + public PostRequestBuilder body(String body) { + this.body = body; + return this; + } + + @Override + public CompletableFuture send() { + return sendRequest(Optional.of(this.body)) + .compose(RESPONSE_MAPPER) + .toCompletionStage() + .toCompletableFuture(); + } + } + + private final Function> RESPONSE_MAPPER = response -> { + if (response.statusCode() == HTTP_UNAUTHORIZED) { + return Future.failedFuture(new IOException(A2AErrorMessages.AUTHENTICATION_FAILED)); + } else if (response.statusCode() == HTTP_FORBIDDEN) { + return Future.failedFuture(new IOException(A2AErrorMessages.AUTHORIZATION_FAILED)); + } + + return Future.succeededFuture(new VertxHttpResponse(response)); + }; + + private record VertxHttpResponse(HttpClientResponse response)implements HttpResponse { + + @Override + public int statusCode() { + return response.statusCode(); + } + + @Override + public String body() { + try { + return response.body().toCompletionStage().toCompletableFuture().get().toString(); + + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + String contentType = response.headers().get(HttpHeaderNames.CONTENT_TYPE.toString()); + + if (contentType != null && HttpHeaderValues.TEXT_EVENT_STREAM.contentEqualsIgnoreCase(contentType)) { + final SSEHandler handler = new SSEHandler(eventConsumer); + + response.handler(handler).exceptionHandler(errorConsumer::accept); + } else { + throw new IllegalStateException("Response is not an event-stream response."); + } + } + } +} diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java new file mode 100644 index 00000000..c727612b --- /dev/null +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/VertxHttpClientBuilder.java @@ -0,0 +1,30 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; +import io.vertx.core.Vertx; +import io.vertx.core.http.HttpClientOptions; + +public class VertxHttpClientBuilder implements HttpClientBuilder { + + private Vertx vertx; + + private HttpClientOptions options; + + public VertxHttpClientBuilder vertx(Vertx vertx) { + this.vertx = vertx; + return this; + } + + public VertxHttpClientBuilder options(HttpClientOptions options) { + this.options = options; + return this; + } + + @Override + public HttpClient create(String url) { + return new VertxHttpClient(url, + vertx != null ? vertx : Vertx.vertx(), + options != null ? options : new HttpClientOptions()); + } +} diff --git a/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java new file mode 100644 index 00000000..d9ee3df8 --- /dev/null +++ b/extras/http-client-vertx/src/main/java/io/a2a/client/http/vertx/sse/SSEHandler.java @@ -0,0 +1,124 @@ +package io.a2a.client.http.vertx.sse; + +import io.a2a.client.http.sse.CommentEvent; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; +import io.vertx.core.Handler; +import io.vertx.core.buffer.Buffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.function.Consumer; + +public class SSEHandler implements Handler { + + + private static final Logger LOG = LoggerFactory.getLogger(SSEHandler.class); + + private static final String UTF8_BOM = "\uFEFF"; + + private static final String DEFAULT_EVENT_NAME = "message"; + + private String currentEventName = DEFAULT_EVENT_NAME; + private final StringBuilder dataBuffer = new StringBuilder(); + + private String lastEventId = ""; + + private final Consumer eventConsumer; + + public SSEHandler(Consumer eventConsumer) { + this.eventConsumer = eventConsumer; + } + + private void handleFieldValue(String fieldName, String value) { + switch (fieldName) { + case "event": + currentEventName = value; + break; + case "data": + dataBuffer.append(value).append("\n"); + break; + case "id": + if (!value.contains("\0")) { + lastEventId = value; + } + break; + case "retry": + // ignored + break; + } + } + + private String stripLeadingSpaceIfPresent(String field) { + if (field.charAt(0) == ' ') { + return field.substring(1); + } + return field; + } + + private String removeLeadingBom(String input) { + if (input.startsWith(UTF8_BOM)) { + return input.substring(UTF8_BOM.length()); + } + return input; + } + + private String removeTrailingNewline(String input) { + if (input.endsWith("\n")) { + return input.substring(0, input.length() - 1); + } + return input; + } + + private Buffer buffer = Buffer.buffer(); + + @Override + public void handle(Buffer chunk) { + buffer.appendBuffer(chunk); + int separatorIndex; + // The separator for events is a double newline + String separator = "\n\n"; + while ((separatorIndex = buffer.toString().indexOf(separator)) != -1) { + Buffer eventData = buffer.getBuffer(0, separatorIndex); + parse(eventData.toString()); + buffer = buffer.getBuffer(separatorIndex + separator.length(), buffer.length()); + } + } + + private void parse(String input) { + String[] parts = input.split("\n"); + + for (String part : parts) { + LOG.debug("got line `{}`", part); + String line = removeTrailingNewline(removeLeadingBom(part)); + + if (line.startsWith(":")) { + eventConsumer.accept(new CommentEvent(line.substring(1).trim())); + } else if (line.contains(":")) { + List lineParts = List.of(line.split(":", 2)); + if (lineParts.size() == 2) { + handleFieldValue(lineParts.get(0), stripLeadingSpaceIfPresent(lineParts.get(1))); + } + } else { + handleFieldValue(line, ""); + } + } + + LOG.debug( + "broadcasting new event named {} lastEventId is {}", + currentEventName, + lastEventId + ); + + if (!dataBuffer.isEmpty()) { + // Remove trailing newline + dataBuffer.setLength(dataBuffer.length() - 1); + eventConsumer.accept(new DataEvent(currentEventName, dataBuffer.toString(), lastEventId)); + } + + // reset + dataBuffer.setLength(0); + currentEventName = DEFAULT_EVENT_NAME; + } +} diff --git a/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java new file mode 100644 index 00000000..68a8399d --- /dev/null +++ b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/ClientBuilderTest.java @@ -0,0 +1,62 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.Client; +import io.a2a.client.config.ClientConfig; +import io.a2a.client.transport.jsonrpc.JSONRPCTransport; +import io.a2a.client.transport.jsonrpc.JSONRPCTransportConfigBuilder; +import io.a2a.spec.A2AClientException; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentInterface; +import io.a2a.spec.AgentSkill; +import io.a2a.spec.TransportProtocol; +import io.vertx.core.Vertx; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; + +/** + * The purpose of this one is to make sure that the Vertx http implementation can be integrated into + * the Client builder when creating a new instance of the Client. + */ +public class ClientBuilderTest { + + private final AgentCard card = new AgentCard.Builder() + .name("Hello World Agent") + .description("Just a hello world agent") + .url("http://localhost:9999") + .version("1.0.0") + .documentationUrl("http://example.com/docs") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) + .pushNotifications(true) + .stateTransitionHistory(true) + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(Collections.singletonList(new AgentSkill.Builder() + .id("hello_world") + .name("Returns hello world") + .description("just returns hello world") + .tags(Collections.singletonList("hello world")) + .examples(List.of("hi", "hello world")) + .build())) + .protocolVersion("0.3.0") + .additionalInterfaces(List.of( + new AgentInterface(TransportProtocol.JSONRPC.asString(), "http://localhost:9999"))) + .build(); + + @Test + public void shouldCreateJSONRPCClient() throws A2AClientException { + Client client = Client + .builder(card) + .clientConfig(new ClientConfig.Builder().build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder() + .httpClientBuilder(new VertxHttpClientBuilder().vertx(Vertx.vertx()))) + .build(); + + Assertions.assertNotNull(client); + } +} diff --git a/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java new file mode 100644 index 00000000..6a94f13e --- /dev/null +++ b/extras/http-client-vertx/src/test/java/io/a2a/client/http/vertx/VertxHttpClientTest.java @@ -0,0 +1,13 @@ +package io.a2a.client.http.vertx; + +import io.a2a.client.http.HttpClientBuilder; +import io.a2a.client.http.common.AbstractHttpClientTest; +import io.vertx.core.http.HttpClientOptions; + +public class VertxHttpClientTest extends AbstractHttpClientTest { + + protected HttpClientBuilder getHttpClientBuilder() { + return new VertxHttpClientBuilder() + .options(new HttpClientOptions().setMaxChunkSize(24)); + } +} diff --git a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java index 70f9d1e5..383c4405 100644 --- a/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java +++ b/extras/push-notification-config-store-database-jpa/src/test/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaPushNotificationConfigStoreTest.java @@ -5,13 +5,15 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.server.http.HttpClientManager; import org.mockito.ArgumentCaptor; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -22,8 +24,6 @@ import jakarta.inject.Inject; import jakarta.transaction.Transactional; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; import io.a2a.server.tasks.BasePushNotificationSender; import io.a2a.server.tasks.PushNotificationConfigStore; import io.a2a.spec.PushNotificationConfig; @@ -41,18 +41,18 @@ public class JpaPushNotificationConfigStoreTest { private BasePushNotificationSender notificationSender; @Mock - private A2AHttpClient mockHttpClient; + private HttpClientManager clientManager; @Mock - private A2AHttpClient.PostBuilder mockPostBuilder; + private HttpClient.PostRequestBuilder mockPostBuilder; @Mock - private A2AHttpResponse mockHttpResponse; + private HttpResponse mockHttpResponse; @BeforeEach public void setUp() { MockitoAnnotations.openMocks(this); - notificationSender = new BasePushNotificationSender(configStore, mockHttpClient); + notificationSender = new BasePushNotificationSender(configStore, clientManager); } @Test @@ -232,21 +232,22 @@ public void testSendNotificationSuccess() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); configStore.setInfo(taskId, config); + HttpClient mockHttpClient = mock(HttpClient.class); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); // Verify HTTP client was called ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -263,11 +264,13 @@ public void testSendNotificationWithToken() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token"); configStore.setInfo(taskId, config); + HttpClient mockHttpClient = mock(HttpClient.class); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); @@ -279,10 +282,9 @@ public void testSendNotificationWithToken() throws Exception { // For now, just verify basic HTTP client interaction ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -299,7 +301,7 @@ public void testSendNotificationNoConfig() throws Exception { notificationSender.sendNotification(task); // Verify HTTP client was never called - verify(mockHttpClient, never()).createPost(); + verify(clientManager, never()).getOrCreate(any()); } @Test diff --git a/http-client/pom.xml b/http-client/pom.xml index 4e138b09..e8c4541e 100644 --- a/http-client/pom.xml +++ b/http-client/pom.xml @@ -29,8 +29,8 @@ - org.mock-server - mockserver-netty + org.wiremock + wiremock test diff --git a/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java b/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java index 5d94686b..d938bb93 100644 --- a/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java +++ b/http-client/src/main/java/io/a2a/client/http/A2ACardResolver.java @@ -2,10 +2,10 @@ import static io.a2a.util.Utils.unmarshalFrom; -import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; +import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; @@ -15,63 +15,77 @@ import org.jspecify.annotations.Nullable; public class A2ACardResolver { - private final A2AHttpClient httpClient; - private final String url; + private final HttpClient httpClient; private final @Nullable Map authHeaders; - + private final String agentCardPath; private static final String DEFAULT_AGENT_CARD_PATH = "/.well-known/agent-card.json"; private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; /** * Get the agent card for an A2A agent. - * The {@code JdkA2AHttpClient} will be used to fetch the agent card. + * The {@code HttpClient} will be used to fetch the agent card. * * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @throws A2AClientError if the URL for the agent is invalid */ public A2ACardResolver(String baseUrl) throws A2AClientError { - this(new JdkA2AHttpClient(), baseUrl, null, null); + this.httpClient = HttpClient.createHttpClient(baseUrl); + this.authHeaders = null; + + try { + String agentCardPath = new URI(baseUrl).getPath(); + + if (agentCardPath.endsWith("/")) { + agentCardPath = agentCardPath.substring(0, agentCardPath.length() - 1); + } + + if (agentCardPath.isEmpty()) { + this.agentCardPath = DEFAULT_AGENT_CARD_PATH; + } else if (agentCardPath.endsWith(DEFAULT_AGENT_CARD_PATH)) { + this.agentCardPath = agentCardPath; + } else { + this.agentCardPath = agentCardPath + DEFAULT_AGENT_CARD_PATH; + } + } catch (URISyntaxException e) { + throw new A2AClientError("Invalid agent URL", e); + } } /** - /**Get the agent card for an A2A agent. - * * @param httpClient the http client to use - * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl) throws A2AClientError { - this(httpClient, baseUrl, null, null); + A2ACardResolver(HttpClient httpClient) throws A2AClientError { + this(httpClient, null, null); } /** * @param httpClient the http client to use - * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent-card.json" * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, String agentCardPath) throws A2AClientError { - this(httpClient, baseUrl, agentCardPath, null); + public A2ACardResolver(HttpClient httpClient, String agentCardPath) throws A2AClientError { + this(httpClient, agentCardPath, null); } /** * @param httpClient the http client to use - * @param baseUrl the base URL for the agent whose agent card we want to retrieve * @param agentCardPath optional path to the agent card endpoint relative to the base * agent URL, defaults to ".well-known/agent-card.json" * @param authHeaders the HTTP authentication headers to use. May be {@code null} * @throws A2AClientError if the URL for the agent is invalid */ - public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, @Nullable String agentCardPath, - @Nullable Map authHeaders) throws A2AClientError { + public A2ACardResolver(HttpClient httpClient, @Nullable String agentCardPath, + @Nullable Map authHeaders) throws A2AClientError { this.httpClient = httpClient; - String effectiveAgentCardPath = agentCardPath == null || agentCardPath.isEmpty() ? DEFAULT_AGENT_CARD_PATH : agentCardPath; - try { - this.url = new URI(baseUrl).resolve(effectiveAgentCardPath).toString(); - } catch (URISyntaxException e) { - throw new A2AClientError("Invalid agent URL", e); + if (agentCardPath == null || agentCardPath.isEmpty()) { + this.agentCardPath = DEFAULT_AGENT_CARD_PATH; + } else if (agentCardPath.endsWith(DEFAULT_AGENT_CARD_PATH)) { + this.agentCardPath = agentCardPath; + } else { + this.agentCardPath = agentCardPath + DEFAULT_AGENT_CARD_PATH; } this.authHeaders = authHeaders; } @@ -84,8 +98,7 @@ public A2ACardResolver(A2AHttpClient httpClient, String baseUrl, @Nullable Strin * @throws A2AClientJSONError f the response body cannot be decoded as JSON or validated against the AgentCard schema */ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { - A2AHttpClient.GetBuilder builder = httpClient.createGet() - .url(url) + HttpClient.GetRequestBuilder builder = httpClient.get(agentCardPath) .addHeader("Content-Type", "application/json"); if (authHeaders != null) { @@ -95,13 +108,14 @@ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { } String body; + try { - A2AHttpResponse response = builder.get(); + HttpResponse response = builder.send().get(); if (!response.success()) { - throw new A2AClientError("Failed to obtain agent card: " + response.status()); + throw new A2AClientError("Failed to obtain agent card: " + response.statusCode()); } body = response.body(); - } catch (IOException | InterruptedException e) { + } catch (InterruptedException | ExecutionException e) { throw new A2AClientError("Failed to obtain agent card", e); } @@ -110,8 +124,5 @@ public AgentCard getAgentCard() throws A2AClientError, A2AClientJSONError { } catch (JsonProcessingException e) { throw new A2AClientJSONError("Could not unmarshal agent card response", e); } - } - - } diff --git a/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java deleted file mode 100644 index 52c252a8..00000000 --- a/http-client/src/main/java/io/a2a/client/http/A2AHttpClient.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.a2a.client.http; - -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -public interface A2AHttpClient { - - GetBuilder createGet(); - - PostBuilder createPost(); - - DeleteBuilder createDelete(); - - interface Builder> { - T url(String s); - T addHeaders(Map headers); - T addHeader(String name, String value); - } - - interface GetBuilder extends Builder { - A2AHttpResponse get() throws IOException, InterruptedException; - CompletableFuture getAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException; - } - - interface PostBuilder extends Builder { - PostBuilder body(String body); - A2AHttpResponse post() throws IOException, InterruptedException; - CompletableFuture postAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException; - } - - interface DeleteBuilder extends Builder { - A2AHttpResponse delete() throws IOException, InterruptedException; - } -} diff --git a/http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java b/http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java deleted file mode 100644 index 171fceeb..00000000 --- a/http-client/src/main/java/io/a2a/client/http/A2AHttpResponse.java +++ /dev/null @@ -1,9 +0,0 @@ -package io.a2a.client.http; - -public interface A2AHttpResponse { - int status(); - - boolean success(); - - String body(); -} diff --git a/http-client/src/main/java/io/a2a/client/http/HttpClient.java b/http-client/src/main/java/io/a2a/client/http/HttpClient.java new file mode 100644 index 00000000..1cb14fde --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/HttpClient.java @@ -0,0 +1,45 @@ +package io.a2a.client.http; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public interface HttpClient { + + static HttpClient createHttpClient(String baseUrl) { + return HttpClientBuilder.DEFAULT_FACTORY.create(baseUrl); + } + + GetRequestBuilder get(String path); + + PostRequestBuilder post(String path); + + DeleteRequestBuilder delete(String path); + + interface RequestBuilder> { + CompletableFuture send(); + + T addHeader(String name, String value); + + T addHeaders(Map headers); + } + + interface GetRequestBuilder extends RequestBuilder { + + } + + interface PostRequestBuilder extends RequestBuilder { + PostRequestBuilder body(String body); + + default PostRequestBuilder asSSE() { + return addHeader("Accept", "text/event-stream"); + } + + default CompletableFuture send(String body) { + return this.body(body).send(); + } + } + + interface DeleteRequestBuilder extends RequestBuilder { + + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java b/http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java new file mode 100644 index 00000000..1e894a9d --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/HttpClientBuilder.java @@ -0,0 +1,10 @@ +package io.a2a.client.http; + +import io.a2a.client.http.jdk.JdkHttpClientBuilder; + +public interface HttpClientBuilder { + + HttpClientBuilder DEFAULT_FACTORY = new JdkHttpClientBuilder(); + + HttpClient create(String url); +} diff --git a/http-client/src/main/java/io/a2a/client/http/HttpResponse.java b/http-client/src/main/java/io/a2a/client/http/HttpResponse.java new file mode 100644 index 00000000..3e2f35f6 --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/HttpResponse.java @@ -0,0 +1,17 @@ +package io.a2a.client.http; + +import io.a2a.client.http.sse.Event; + +import java.util.function.Consumer; + +public interface HttpResponse { + int statusCode(); + + default boolean success() { + return statusCode() >= 200 && statusCode() < 300; + } + + String body(); + + void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer); +} diff --git a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java deleted file mode 100644 index 9b800374..00000000 --- a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java +++ /dev/null @@ -1,311 +0,0 @@ -package io.a2a.client.http; - -import static java.net.HttpURLConnection.HTTP_FORBIDDEN; -import static java.net.HttpURLConnection.HTTP_MULT_CHOICE; -import static java.net.HttpURLConnection.HTTP_OK; -import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.net.http.HttpResponse.BodyHandler; -import java.net.http.HttpResponse.BodyHandlers; -import java.net.http.HttpResponse.BodySubscribers; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.function.Consumer; -import org.jspecify.annotations.Nullable; - -import io.a2a.common.A2AErrorMessages; - -public class JdkA2AHttpClient implements A2AHttpClient { - - private final HttpClient httpClient; - - public JdkA2AHttpClient() { - httpClient = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_2) - .followRedirects(HttpClient.Redirect.NORMAL) - .build(); - } - - @Override - public GetBuilder createGet() { - return new JdkGetBuilder(); - } - - @Override - public PostBuilder createPost() { - return new JdkPostBuilder(); - } - - @Override - public DeleteBuilder createDelete() { - return new JdkDeleteBuilder(); - } - - private abstract class JdkBuilder> implements Builder { - private String url = ""; - private Map headers = new HashMap<>(); - - @Override - public T url(String url) { - this.url = url; - return self(); - } - - @Override - public T addHeader(String name, String value) { - headers.put(name, value); - return self(); - } - - @Override - public T addHeaders(Map headers) { - if(headers != null && ! headers.isEmpty()) { - for (Map.Entry entry : headers.entrySet()) { - addHeader(entry.getKey(), entry.getValue()); - } - } - return self(); - } - - @SuppressWarnings("unchecked") - T self() { - return (T) this; - } - - protected HttpRequest.Builder createRequestBuilder() throws IOException { - HttpRequest.Builder builder = HttpRequest.newBuilder() - .uri(URI.create(url)); - for (Map.Entry headerEntry : headers.entrySet()) { - builder.header(headerEntry.getKey(), headerEntry.getValue()); - } - return builder; - } - - protected CompletableFuture asyncRequest( - HttpRequest request, - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable - ) { - Flow.Subscriber subscriber = new Flow.Subscriber() { - private Flow.@Nullable Subscription subscription; - private volatile boolean errorRaised = false; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - this.subscription.request(1); - } - - @Override - public void onNext(String item) { - // SSE messages sometimes start with "data:". Strip that off - if (item != null && item.startsWith("data:")) { - item = item.substring(5).trim(); - if (!item.isEmpty()) { - messageConsumer.accept(item); - } - } - if (subscription != null) { - subscription.request(1); - } - } - - @Override - public void onError(Throwable throwable) { - if (!errorRaised) { - errorRaised = true; - errorConsumer.accept(throwable); - } - if (subscription != null) { - subscription.cancel(); - } - } - - @Override - public void onComplete() { - if (!errorRaised) { - completeRunnable.run(); - } - if (subscription != null) { - subscription.cancel(); - } - } - }; - - // Create a custom body handler that checks status before processing body - BodyHandler bodyHandler = responseInfo -> { - // Check for authentication/authorization errors only - if (responseInfo.statusCode() == HTTP_UNAUTHORIZED || responseInfo.statusCode() == HTTP_FORBIDDEN) { - final String errorMessage; - if (responseInfo.statusCode() == HTTP_UNAUTHORIZED) { - errorMessage = A2AErrorMessages.AUTHENTICATION_FAILED; - } else { - errorMessage = A2AErrorMessages.AUTHORIZATION_FAILED; - } - // Return a body subscriber that immediately signals error - return BodySubscribers.fromSubscriber(new Flow.Subscriber>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriber.onError(new IOException(errorMessage)); - } - - @Override - public void onNext(List item) { - // Should not be called - } - - @Override - public void onError(Throwable throwable) { - // Should not be called - } - - @Override - public void onComplete() { - // Should not be called - } - }); - } else { - // For all other status codes (including other errors), proceed with normal line subscriber - return BodyHandlers.fromLineSubscriber(subscriber).apply(responseInfo); - } - }; - - // Send the response async, and let the subscriber handle the lines. - return httpClient.sendAsync(request, bodyHandler) - .thenAccept(response -> { - // Handle non-authentication/non-authorization errors here - if (!isSuccessStatus(response.statusCode()) && - response.statusCode() != HTTP_UNAUTHORIZED && - response.statusCode() != HTTP_FORBIDDEN) { - subscriber.onError(new IOException("Request failed with status " + response.statusCode() + ":" + response.body())); - } - }); - } - } - - private class JdkGetBuilder extends JdkBuilder implements A2AHttpClient.GetBuilder { - - private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { - HttpRequest.Builder builder = super.createRequestBuilder().GET(); - if (SSE) { - builder.header("Accept", "text/event-stream"); - } - return builder; - } - - @Override - public A2AHttpResponse get() throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(false) - .build(); - HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); - return new JdkHttpResponse(response); - } - - @Override - public CompletableFuture getAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(true) - .build(); - return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); - } - - } - - private class JdkDeleteBuilder extends JdkBuilder implements A2AHttpClient.DeleteBuilder { - - @Override - public A2AHttpResponse delete() throws IOException, InterruptedException { - HttpRequest request = super.createRequestBuilder().DELETE().build(); - HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); - return new JdkHttpResponse(response); - } - - } - - private class JdkPostBuilder extends JdkBuilder implements A2AHttpClient.PostBuilder { - String body = ""; - - @Override - public PostBuilder body(String body) { - this.body = body; - return self(); - } - - private HttpRequest.Builder createRequestBuilder(boolean SSE) throws IOException { - HttpRequest.Builder builder = super.createRequestBuilder() - .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)); - if (SSE) { - builder.header("Accept", "text/event-stream"); - } - return builder; - } - - @Override - public A2AHttpResponse post() throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(false) - .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)) - .build(); - HttpResponse response = - httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); - - if (response.statusCode() == HTTP_UNAUTHORIZED) { - throw new IOException(A2AErrorMessages.AUTHENTICATION_FAILED); - } else if (response.statusCode() == HTTP_FORBIDDEN) { - throw new IOException(A2AErrorMessages.AUTHORIZATION_FAILED); - } - - return new JdkHttpResponse(response); - } - - @Override - public CompletableFuture postAsyncSSE( - Consumer messageConsumer, - Consumer errorConsumer, - Runnable completeRunnable) throws IOException, InterruptedException { - HttpRequest request = createRequestBuilder(true) - .build(); - return super.asyncRequest(request, messageConsumer, errorConsumer, completeRunnable); - } - } - - private record JdkHttpResponse(HttpResponse response) implements A2AHttpResponse { - - @Override - public int status() { - return response.statusCode(); - } - - @Override - public boolean success() {// Send the request and get the response - return success(response); - } - - static boolean success(HttpResponse response) { - return response.statusCode() >= HTTP_OK && response.statusCode() < HTTP_MULT_CHOICE; - } - - @Override - public String body() { - return response.body(); - } - } - - private static boolean isSuccessStatus(int statusCode) { - return statusCode >= HTTP_OK && statusCode < HTTP_MULT_CHOICE; - } -} diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java new file mode 100644 index 00000000..83e31208 --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClient.java @@ -0,0 +1,260 @@ +package io.a2a.client.http.jdk; + +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_MULT_CHOICE; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.jdk.sse.SSEHandler; +import io.a2a.client.http.sse.Event; + +import java.io.IOException; +import java.net.*; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse.BodyHandler; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Flow; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.a2a.common.A2AErrorMessages; + +class JdkHttpClient implements HttpClient { + + private final java.net.http.HttpClient httpClient; + private final String baseUrl; + + JdkHttpClient(String baseUrl) { + this.httpClient = java.net.http.HttpClient.newBuilder() + .version(java.net.http.HttpClient.Version.HTTP_2) + .followRedirects(java.net.http.HttpClient.Redirect.NORMAL) + .build(); + + URL targetUrl = buildUrl(baseUrl); + this.baseUrl = targetUrl.getProtocol() + "://" + targetUrl.getAuthority(); + } + + String getBaseUrl() { + return baseUrl; + } + + private static final URLStreamHandler URL_HANDLER = new URLStreamHandler() { + protected URLConnection openConnection(URL u) { + return null; + } + }; + + private static URL buildUrl(String uri) { + try { + return new URL(null, uri, URL_HANDLER); + } catch (MalformedURLException var2) { + throw new IllegalArgumentException("URI [" + uri + "] is not valid"); + } + } + + @Override + public GetRequestBuilder get(String path) { + return new JdkGetRequestBuilder(path); + } + + @Override + public PostRequestBuilder post(String path) { + return new JdkPostRequestBuilder(path); + } + + @Override + public DeleteRequestBuilder delete(String path) { + return new JdkDeleteBuilder(path); + } + + private abstract class JdkRequestBuilder> implements RequestBuilder { + private final String path; + protected final Map headers = new HashMap<>(); + + public JdkRequestBuilder(String path) { + this.path = path; + } + + @Override + public T addHeader(String name, String value) { + headers.put(name, value); + return self(); + } + + @Override + public T addHeaders(Map headers) { + if (headers != null && !headers.isEmpty()) { + for (Map.Entry entry : headers.entrySet()) { + addHeader(entry.getKey(), entry.getValue()); + } + } + return self(); + } + + @SuppressWarnings("unchecked") + T self() { + return (T) this; + } + + protected HttpRequest.Builder createRequestBuilder() { + HttpRequest.Builder builder = HttpRequest.newBuilder() + .uri(URI.create(baseUrl + path)); + for (Map.Entry headerEntry : headers.entrySet()) { + builder.header(headerEntry.getKey(), headerEntry.getValue()); + } + return builder; + } + } + + private class JdkGetRequestBuilder extends JdkRequestBuilder implements GetRequestBuilder { + + public JdkGetRequestBuilder(String path) { + super(path); + } + + @Override + public CompletableFuture send() { + HttpRequest request = super.createRequestBuilder().GET().build(); + return httpClient + .sendAsync(request, BodyHandlers.ofString(StandardCharsets.UTF_8)) + .thenCompose(RESPONSE_MAPPER); + } + } + + private class JdkDeleteBuilder extends JdkRequestBuilder implements DeleteRequestBuilder { + + public JdkDeleteBuilder(String path) { + super(path); + } + + @Override + public CompletableFuture send() { + HttpRequest request = super.createRequestBuilder().DELETE().build(); + return httpClient + .sendAsync(request, BodyHandlers.ofString(StandardCharsets.UTF_8)) + .thenCompose(RESPONSE_MAPPER); + } + } + + private class JdkPostRequestBuilder extends JdkRequestBuilder implements PostRequestBuilder { + String body = ""; + + public JdkPostRequestBuilder(String path) { + super(path); + } + + @Override + public PostRequestBuilder body(String body) { + this.body = body; + return this; + } + + @Override + public CompletableFuture send() { + final HttpRequest request = super.createRequestBuilder() + .POST(HttpRequest.BodyPublishers.ofString(body, StandardCharsets.UTF_8)) + .build(); + + final BodyHandler bodyHandler; + + final String contentTypeHeader = this.headers.get("Accept"); + if ("text/event-stream".equalsIgnoreCase(contentTypeHeader)) { + bodyHandler = BodyHandlers.ofPublisher(); + } else { + bodyHandler = BodyHandlers.ofString(StandardCharsets.UTF_8); + } + + return httpClient.sendAsync(request, bodyHandler).thenCompose(RESPONSE_MAPPER); + } + } + + private final static Function, CompletionStage> RESPONSE_MAPPER = response -> { + if (response.statusCode() == HTTP_UNAUTHORIZED) { + return CompletableFuture.failedStage(new IOException(A2AErrorMessages.AUTHENTICATION_FAILED)); + } else if (response.statusCode() == HTTP_FORBIDDEN) { + return CompletableFuture.failedStage(new IOException(A2AErrorMessages.AUTHORIZATION_FAILED)); + } + + return CompletableFuture.completedFuture(new JdkHttpResponse(response)); + }; + + private record JdkHttpResponse(java.net.http.HttpResponse response) implements HttpResponse { + + @Override + public int statusCode() { + return response.statusCode(); + } + + static boolean success(java.net.http.HttpResponse response) { + return response.statusCode() >= HTTP_OK && response.statusCode() < HTTP_MULT_CHOICE; + } + + @Override + public String body() { + if (response.body() instanceof String) { + return (String) response.body(); + } + + throw new IllegalStateException(); + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + if (success()) { + Optional contentTypeOpt = response.headers().firstValue("Content-Type"); + + if (contentTypeOpt.isPresent() && contentTypeOpt.get().equalsIgnoreCase("text/event-stream")) { + Flow.Publisher> publisher = (Flow.Publisher>) response.body(); + + SSEHandler sseHandler = new SSEHandler(); + sseHandler.subscribe(new Flow.Subscriber<>() { + private Flow.Subscription subscription; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(Event item) { + eventConsumer.accept(item); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + errorConsumer.accept(throwable); + subscription.cancel(); + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + publisher.subscribe(java.net.http.HttpResponse.BodySubscribers.fromLineSubscriber(sseHandler)); + } else { + errorConsumer.accept(new IOException("Response is not an event-stream response: Content-Type[" + contentTypeOpt.orElse("unknown") + "]")); + } + } else { + errorConsumer.accept(new IOException("Request failed: status[" + response.statusCode() + "]")); + } + } + } + + private static boolean isSuccessStatus(int statusCode) { + return statusCode >= HTTP_OK && statusCode < HTTP_MULT_CHOICE; + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java new file mode 100644 index 00000000..21f50ade --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/jdk/JdkHttpClientBuilder.java @@ -0,0 +1,12 @@ +package io.a2a.client.http.jdk; + +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpClientBuilder; + +public class JdkHttpClientBuilder implements HttpClientBuilder { + + @Override + public HttpClient create(String url) { + return new JdkHttpClient(url); + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java b/http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java new file mode 100644 index 00000000..b7975bae --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/jdk/sse/SSEHandler.java @@ -0,0 +1,120 @@ +package io.a2a.client.http.jdk.sse; + +import io.a2a.client.http.sse.CommentEvent; +import io.a2a.client.http.sse.DataEvent; +import io.a2a.client.http.sse.Event; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.concurrent.Flow; +import java.util.concurrent.SubmissionPublisher; + +public class SSEHandler extends SubmissionPublisher + implements Flow.Processor { + + public static final String EVENT_STREAM_MEDIA_TYPE = "text/event-stream"; + + private static final Logger LOG = LoggerFactory.getLogger(SSEHandler.class); + + private static final String UTF8_BOM = "\uFEFF"; + + private static final String DEFAULT_EVENT_NAME = "message"; + + private Flow.Subscription subscription; + + private String currentEventName = DEFAULT_EVENT_NAME; + private final StringBuilder dataBuffer = new StringBuilder(); + + private String lastEventId = ""; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.subscription = subscription; + subscription.request(1); + } + + @Override + public void onNext(String input) { + LOG.debug("got line `{}`", input); + String line = removeTrailingNewline(removeLeadingBom(input)); + + if (line.startsWith(":")) { + submit(new CommentEvent(line.substring(1).trim())); + } else if (line.isBlank()) { + LOG.debug( + "broadcasting new event named {} lastEventId is {}", + currentEventName, + lastEventId + ); + + String dataString = dataBuffer.toString(); + if (!dataString.isEmpty()) { + submit(new DataEvent(currentEventName, dataBuffer.toString(), lastEventId)); + } + //reset things + dataBuffer.setLength(0); + currentEventName = DEFAULT_EVENT_NAME; + } else if (line.contains(":")) { + List lineParts = List.of(line.split(":", 2)); + if (lineParts.size() == 2) { + handleFieldValue(lineParts.get(0), stripLeadingSpaceIfPresent(lineParts.get(1))); + } + } else { + handleFieldValue(line, ""); + } + subscription.request(1); + } + + private void handleFieldValue(String fieldName, String value) { + switch (fieldName) { + case "event": + currentEventName = value; + break; + case "data": + dataBuffer.append(value).append("\n"); + break; + case "id": + if (!value.contains("\0")) { + lastEventId = value; + } + break; + case "retry": + // ignored + break; + } + } + + @Override + public void onError(Throwable throwable) { + LOG.debug("Error in SSE handler {}", throwable.getMessage()); + closeExceptionally(throwable); + } + + @Override + public void onComplete() { + LOG.debug("SSE handler complete"); + close(); + } + + private String stripLeadingSpaceIfPresent(String field) { + if (field.charAt(0) == ' ') { + return field.substring(1); + } + return field; + } + + private String removeLeadingBom(String input) { + if (input.startsWith(UTF8_BOM)) { + return input.substring(UTF8_BOM.length()); + } + return input; + } + + private String removeTrailingNewline(String input) { + if (input.endsWith("\n")) { + return input.substring(0, input.length() - 1); + } + return input; + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java b/http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java new file mode 100644 index 00000000..0a0b3e68 --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/sse/CommentEvent.java @@ -0,0 +1,54 @@ +package io.a2a.client.http.sse; + +import java.util.Objects; +import java.util.StringJoiner; + +/** + * Represents an SSE Comment + * This is a line starting with a colon (:) + */ +public class CommentEvent extends Event { + + private final String comment; + + @Override + Type getType() { + return Type.COMMENT; + } + + public CommentEvent(String comment) { + this.comment = comment; + } + + /** + * + * @return the contents of the last line starting with `:` (omitting the colon) + */ + public String getComment() { + return comment; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CommentEvent that = (CommentEvent) o; + return Objects.equals(comment, that.comment); + } + + @Override + public int hashCode() { + return Objects.hash(comment); + } + + @Override + public String toString() { + return new StringJoiner(", ", CommentEvent.class.getSimpleName() + "[", "]") + .add("comment='" + comment + "'") + .toString(); + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java b/http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java new file mode 100644 index 00000000..daae9c5c --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/sse/DataEvent.java @@ -0,0 +1,81 @@ +package io.a2a.client.http.sse; + +import java.util.Objects; +import java.util.StringJoiner; + +/** + * Represents an SSE DataEvent + * It contains three fields: event name, data, and lastEventId + */ +public class DataEvent extends Event { + + private final String eventName; + private final String data; + private final String lastEventId; + + public DataEvent(String eventName, String data, String lastEventId) { + this.eventName = eventName; + this.data = data; + this.lastEventId = lastEventId; + } + + @Override + Type getType() { + return Type.DATA; + } + + /** + * + * @return the content of the last line starting with `event:` + */ + public String getEventName() { + return eventName; + } + + /** + * + * @return the accumulated contents of data buffers from lines starting with `data:` + */ + public String getData() { + return data; + } + + /** + * + * @return the last event id sent in a line starting with `id:` + */ + public String getLastEventId() { + return lastEventId; + } + + @Override + public String toString() { + return new StringJoiner(", ", DataEvent.class.getSimpleName() + "[", "]") + .add("eventName='" + eventName + "'") + .add("data='" + data + "'") + .add("lastEventId='" + lastEventId + "'") + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DataEvent event = (DataEvent) o; + return ( + Objects.equals(getType(), event.getType()) && + Objects.equals(eventName, event.eventName) && + Objects.equals(data, event.data) && + Objects.equals(lastEventId, event.lastEventId) + ); + } + + @Override + public int hashCode() { + return Objects.hash(getType(), eventName, data, lastEventId); + } +} diff --git a/http-client/src/main/java/io/a2a/client/http/sse/Event.java b/http-client/src/main/java/io/a2a/client/http/sse/Event.java new file mode 100644 index 00000000..66920e9c --- /dev/null +++ b/http-client/src/main/java/io/a2a/client/http/sse/Event.java @@ -0,0 +1,11 @@ +package io.a2a.client.http.sse; + +public abstract class Event { + + enum Type { + COMMENT, + DATA, + } + + abstract Type getType(); +} diff --git a/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java b/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java index 99d26ada..6f921c8d 100644 --- a/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java +++ b/http-client/src/test/java/io/a2a/client/http/A2ACardResolverTest.java @@ -1,20 +1,20 @@ package io.a2a.client.http; +import static com.github.tomakehurst.wiremock.client.WireMock.*; import static io.a2a.util.Utils.OBJECT_MAPPER; import static io.a2a.util.Utils.unmarshalFrom; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; -import java.io.IOException; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; import io.a2a.spec.A2AClientError; import io.a2a.spec.A2AClientJSONError; import io.a2a.spec.AgentCard; -import java.util.Map; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class A2ACardResolverTest { @@ -22,54 +22,90 @@ public class A2ACardResolverTest { private static final String AGENT_CARD_PATH = "/.well-known/agent-card.json"; private static final TypeReference AGENT_CARD_TYPE_REFERENCE = new TypeReference<>() {}; + private WireMockServer server; + + @BeforeEach + public void setUp() { + server = new WireMockServer(WireMockConfiguration.options().dynamicPort()); + server.start(); + + configureFor("localhost", server.port()); + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(); + } + } + @Test public void testConstructorStripsSlashes() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = JsonMessages.AGENT_CARD; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); + + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); + + givenThat(get(urlPathEqualTo("/subpath" + AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(client); AgentCard card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - resolver = new A2ACardResolver(client, "http://example.com"); + resolver = new A2ACardResolver(client, AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl with trailing slash, agentCardParth with leading slash - resolver = new A2ACardResolver(client, "http://example.com/", AGENT_CARD_PATH); + + resolver = new A2ACardResolver("http://localhost:" + server.port()); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl without trailing slash, agentCardPath with leading slash - resolver = new A2ACardResolver(client, "http://example.com", AGENT_CARD_PATH); + resolver = new A2ACardResolver("http://localhost:" + server.port() + AGENT_CARD_PATH); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl with trailing slash, agentCardPath without leading slash - resolver = new A2ACardResolver(client, "http://example.com/", AGENT_CARD_PATH.substring(1)); + // baseUrl with trailing slash + resolver = new A2ACardResolver("http://localhost:" + server.port() + "/"); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); - // baseUrl without trailing slash, agentCardPath without leading slash - resolver = new A2ACardResolver(client, "http://example.com", AGENT_CARD_PATH.substring(1)); + // Sub-path + // baseUrl with trailing slash + resolver = new A2ACardResolver("http://localhost:" + server.port() + "/subpath"); card = resolver.getAgentCard(); - assertEquals("http://example.com" + AGENT_CARD_PATH, client.url); + assertNotNull(card); + verify(getRequestedFor(urlEqualTo("/subpath" + AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } @Test public void testGetAgentCardSuccess() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = JsonMessages.AGENT_CARD; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); + + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(client); AgentCard card = resolver.getAgentCard(); AgentCard expectedCard = unmarshalFrom(JsonMessages.AGENT_CARD, AGENT_CARD_TYPE_REFERENCE); @@ -77,14 +113,19 @@ public void testGetAgentCardSuccess() throws Exception { String requestCardString = OBJECT_MAPPER.writeValueAsString(card); assertEquals(expected, requestCardString); + + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } @Test public void testGetAgentCardJsonDecodeError() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.body = "X" + JsonMessages.AGENT_CARD; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", "X" + JsonMessages.AGENT_CARD))); + + A2ACardResolver resolver = new A2ACardResolver(client); boolean success = false; try { @@ -93,15 +134,20 @@ public void testGetAgentCardJsonDecodeError() throws Exception { } catch (A2AClientJSONError expected) { } assertFalse(success); + + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } @Test public void testGetAgentCardRequestError() throws Exception { - TestHttpClient client = new TestHttpClient(); - client.status = 503; + HttpClient client = HttpClient.createHttpClient("http://localhost:" + server.port()); + + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(status(503))); - A2ACardResolver resolver = new A2ACardResolver(client, "http://example.com/"); + A2ACardResolver resolver = new A2ACardResolver(client); String msg = null; try { @@ -110,71 +156,9 @@ public void testGetAgentCardRequestError() throws Exception { msg = expected.getMessage(); } assertTrue(msg.contains("503")); - } - - private static class TestHttpClient implements A2AHttpClient { - int status = 200; - String body; - String url; - - @Override - public GetBuilder createGet() { - return new TestGetBuilder(); - } - - @Override - public PostBuilder createPost() { - return null; - } - @Override - public DeleteBuilder createDelete() { - return null; - } - - class TestGetBuilder implements A2AHttpClient.GetBuilder { - - @Override - public A2AHttpResponse get() throws IOException, InterruptedException { - return new A2AHttpResponse() { - @Override - public int status() { - return status; - } - - @Override - public boolean success() { - return status == 200; - } - - @Override - public String body() { - return body; - } - }; - } - - @Override - public CompletableFuture getAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; - } - - @Override - public GetBuilder url(String s) { - url = s; - return this; - } - - @Override - public GetBuilder addHeader(String name, String value) { - return this; - } - - @Override - public GetBuilder addHeaders(Map headers) { - return this; - } - } + verify(getRequestedFor(urlEqualTo(AGENT_CARD_PATH)) + .withHeader("Content-Type", equalTo("application/json"))); } } diff --git a/http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java b/http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java new file mode 100644 index 00000000..7ca4e3d4 --- /dev/null +++ b/http-client/src/test/java/io/a2a/client/http/jdk/JdkHttpClientTest.java @@ -0,0 +1,31 @@ +package io.a2a.client.http.jdk; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class JdkHttpClientTest { + + @Test + public void testBaseUrlNormalization() { + String baseUrl = "http://localhost:8080"; + + JdkHttpClient client = new JdkHttpClient(baseUrl); + Assertions.assertEquals(baseUrl, client.getBaseUrl()); + + baseUrl = "http://localhost"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("http://localhost", client.getBaseUrl()); + + baseUrl = "https://localhost"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("https://localhost", client.getBaseUrl()); + + baseUrl = "https://localhost:443"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("https://localhost:443", client.getBaseUrl()); + + baseUrl = "https://localhost:80/test"; + client = new JdkHttpClient(baseUrl); + Assertions.assertEquals("https://localhost:80", client.getBaseUrl()); + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index 913b0c8c..74dc84f3 100644 --- a/pom.xml +++ b/pom.xml @@ -55,6 +55,7 @@ 3.1.0 5.13.4 5.17.0 + 3.13.1 5.15.0 1.1.1 1.7.1 @@ -247,6 +248,12 @@ ${mockserver.version} test + + org.wiremock + wiremock + ${wiremock.version} + test + ch.qos.logback logback-classic @@ -265,6 +272,18 @@ test ${project.version} + + ${project.groupId} + a2a-java-sdk-tests-client-common + ${project.version} + + + ${project.groupId} + a2a-java-sdk-tests-client-common + test-jar + test + ${project.version} + ${project.groupId} a2a-java-sdk-server-common @@ -437,6 +456,7 @@ extras/task-store-database-jpa extras/push-notification-config-store-database-jpa extras/queue-manager-replicated + extras/http-client-vertx http-client reference/common reference/grpc @@ -447,6 +467,7 @@ spec-grpc tck tests/server-common + tests/client-common transport/jsonrpc transport/grpc transport/rest diff --git a/server-common/src/main/java/io/a2a/server/http/HttpClientManager.java b/server-common/src/main/java/io/a2a/server/http/HttpClientManager.java new file mode 100644 index 00000000..fd02caf1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/http/HttpClientManager.java @@ -0,0 +1,59 @@ +package io.a2a.server.http; + +import io.a2a.client.http.HttpClient; +import io.a2a.util.Assert; +import jakarta.enterprise.context.ApplicationScoped; + +import java.net.URI; +import java.net.URL; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +@ApplicationScoped +public class HttpClientManager { + + private final Map clients = new ConcurrentHashMap<>(); + + public HttpClient getOrCreate(String url) { + Assert.checkNotNullParam("url", url); + + try { + return clients.computeIfAbsent(Endpoint.from(URI.create(url).toURL()), new Function() { + @Override + public HttpClient apply(Endpoint edpt) { + return HttpClient.createHttpClient(url); + } + }); + } catch (Exception ex) { + throw new IllegalArgumentException("URL is malformed: [" + url + "]"); + } + } + + private static class Endpoint { + private final String host; + private final int port; + + public Endpoint(String host, int port) { + this.host = host; + this.port = port; + } + + public static Endpoint from(URL url) { + return new Endpoint(url.getHost(), url.getPort() != -1 ? url.getPort() : url.getDefaultPort()); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Endpoint endpoint = (Endpoint) o; + return port == endpoint.port && Objects.equals(host, endpoint.host); + } + + @Override + public int hashCode() { + return Objects.hash(host, port); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java index 4afaf3b4..bb304b44 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java +++ b/server-common/src/main/java/io/a2a/server/tasks/BasePushNotificationSender.java @@ -1,18 +1,19 @@ package io.a2a.server.tasks; import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN; + +import io.a2a.server.http.HttpClientManager; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; -import java.io.IOException; +import java.net.URI; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import com.fasterxml.jackson.core.JsonProcessingException; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.JdkA2AHttpClient; +import io.a2a.client.http.HttpClient; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; import io.a2a.util.Utils; @@ -25,18 +26,13 @@ public class BasePushNotificationSender implements PushNotificationSender { private static final Logger LOGGER = LoggerFactory.getLogger(BasePushNotificationSender.class); - private final A2AHttpClient httpClient; private final PushNotificationConfigStore configStore; + private final HttpClientManager clientManager; @Inject - public BasePushNotificationSender(PushNotificationConfigStore configStore) { - this.httpClient = new JdkA2AHttpClient(); - this.configStore = configStore; - } - - public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHttpClient httpClient) { + public BasePushNotificationSender(PushNotificationConfigStore configStore, HttpClientManager clientManager) { this.configStore = configStore; - this.httpClient = httpClient; + this.clientManager = clientManager; } @Override @@ -68,10 +64,13 @@ private CompletableFuture dispatch(Task task, PushNotificationConfig pu } private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) { - String url = pushInfo.url(); - String token = pushInfo.token(); + final String url = pushInfo.url(); + final String token = pushInfo.token(); - A2AHttpClient.PostBuilder postBuilder = httpClient.createPost(); + // Delegate to the HTTP client manager to better manage client's connection pool. + final HttpClient client = clientManager.getOrCreate(url); + final URI uri = URI.create(url); + HttpClient.PostRequestBuilder postBuilder = client.post(uri.getPath()); if (token != null && !token.isBlank()) { postBuilder.addHeader(X_A2A_NOTIFICATION_TOKEN, token); } @@ -89,10 +88,10 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) try { postBuilder - .url(url) .body(body) - .post(); - } catch (IOException | InterruptedException e) { + .send() + .get(); + } catch (ExecutionException | InterruptedException e) { LOGGER.debug("Error pushing data to " + url + ": {}", e.getMessage(), e); return false; } diff --git a/server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java b/server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java new file mode 100644 index 00000000..f244dfdf --- /dev/null +++ b/server-common/src/test/java/io/a2a/server/http/HttpClientManagerTest.java @@ -0,0 +1,52 @@ +package io.a2a.server.http; + +import io.a2a.client.http.HttpClient; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class HttpClientManagerTest { + + private final HttpClientManager clientManager = new HttpClientManager(); + + @Test + public void testThrowsIllegalArgument() { + Assertions.assertThrows( + IllegalArgumentException.class, + () -> clientManager.getOrCreate(null) + ); + } + + @Test + public void testValidateCacheInstance() { + HttpClient client1 = clientManager.getOrCreate("http://localhost:8000"); + HttpClient client2 = clientManager.getOrCreate("http://localhost:8000"); + HttpClient client3 = clientManager.getOrCreate("http://localhost:8001"); + HttpClient client4 = clientManager.getOrCreate("http://remote_agent:8001"); + + Assertions.assertSame(client1, client2); + Assertions.assertNotSame(client1, client3); + Assertions.assertNotSame(client1, client4); + Assertions.assertNotSame(client3, client4); + } + + @Test + public void testValidateCacheNoPort() { + HttpClient client1 = clientManager.getOrCreate("https://localhost"); + HttpClient client2 = clientManager.getOrCreate("https://localhost:443"); + HttpClient client3 = clientManager.getOrCreate("http://localhost"); + HttpClient client4 = clientManager.getOrCreate("http://localhost:80"); + + Assertions.assertSame(client1, client2); + Assertions.assertNotSame(client1, client3); + Assertions.assertSame(client3, client4); + Assertions.assertNotSame(client2, client4); + } + + @Test + public void testThrowsInvalidUrl() { + Assertions.assertThrows( + IllegalArgumentException.class, + () -> clientManager.getOrCreate("this_is_invalid") + ); + } +} diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index 9f12ee79..3dd5a97c 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -3,6 +3,10 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; + +import io.a2a.client.http.sse.Event; +import io.a2a.server.http.HttpClientManager; +import jakarta.enterprise.context.Dependent; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -13,10 +17,8 @@ import java.util.concurrent.Executors; import java.util.function.Consumer; -import jakarta.enterprise.context.Dependent; - -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; @@ -42,6 +44,11 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; public class AbstractA2ARequestHandlerTest { @@ -61,6 +68,9 @@ public class AbstractA2ARequestHandlerTest { private static final String PREFERRED_TRANSPORT = "preferred-transport"; private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties"; + @Mock + private HttpClientManager clientManager; + protected AgentExecutor executor; protected TaskStore taskStore; protected RequestHandler requestHandler; @@ -73,6 +83,8 @@ public class AbstractA2ARequestHandlerTest { @BeforeEach public void init() { + MockitoAnnotations.openMocks(this); + executor = new AgentExecutor() { @Override public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError { @@ -92,8 +104,10 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC taskStore = new InMemoryTaskStore(); queueManager = new InMemoryQueueManager(); httpClient = new TestHttpClient(); + + Mockito.when(clientManager.getOrCreate(any())).thenReturn(httpClient); PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); - PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); + PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, clientManager); requestHandler = new DefaultRequestHandler(executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); } @@ -148,75 +162,79 @@ protected interface AgentExecutorMethod { @Dependent @IfBuildProfile("test") - protected static class TestHttpClient implements A2AHttpClient { + protected static class TestHttpClient implements HttpClient { public final List tasks = Collections.synchronizedList(new ArrayList<>()); public volatile CountDownLatch latch; @Override - public GetBuilder createGet() { + public GetRequestBuilder get(String path) { return null; } @Override - public PostBuilder createPost() { - return new TestHttpClient.TestPostBuilder(); + public PostRequestBuilder post(String path) { + return new TestPostRequestBuilder(); } @Override - public DeleteBuilder createDelete() { + public DeleteRequestBuilder delete(String path) { return null; } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostRequestBuilder implements PostRequestBuilder { + private volatile String body; @Override - public PostBuilder body(String body) { + public PostRequestBuilder body(String body) { this.body = body; return this; } @Override - public A2AHttpResponse post() throws IOException, InterruptedException { - tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + public CompletableFuture send() { + CompletableFuture future = new CompletableFuture<>(); + try { - return new A2AHttpResponse() { - @Override - public int status() { - return 200; - } - - @Override - public boolean success() { - return true; - } - - @Override - public String body() { - return ""; - } - }; + tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + + future.complete( + new HttpResponse() { + @Override + public int statusCode() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + + } + }); + } catch (Exception ex) { + future.completeExceptionally(ex); } finally { latch.countDown(); } - } - - @Override - public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; - } - @Override - public PostBuilder url(String s) { - return this; + return future; } @Override - public PostBuilder addHeader(String name, String value) { + public PostRequestBuilder addHeader(String name, String value) { return this; } @Override - public PostBuilder addHeaders(Map headers) { + public PostRequestBuilder addHeaders(Map headers) { return this; } diff --git a/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java b/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java index 9156f78b..81d27be1 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/InMemoryPushNotificationConfigStoreTest.java @@ -9,17 +9,19 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.server.http.HttpClientManager; import org.mockito.ArgumentCaptor; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; import io.a2a.common.A2AHeaders; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; @@ -32,35 +34,39 @@ class InMemoryPushNotificationConfigStoreTest { private BasePushNotificationSender notificationSender; @Mock - private A2AHttpClient mockHttpClient; + private HttpClientManager clientManager; @Mock - private A2AHttpClient.PostBuilder mockPostBuilder; + private HttpClient mockHttpClient; @Mock - private A2AHttpResponse mockHttpResponse; + private HttpClient.PostRequestBuilder mockPostBuilder; + + @Mock + private HttpResponse mockHttpResponse; @BeforeEach public void setUp() { MockitoAnnotations.openMocks(this); configStore = new InMemoryPushNotificationConfigStore(); - notificationSender = new BasePushNotificationSender(configStore, mockHttpClient); + notificationSender = new BasePushNotificationSender(configStore, clientManager); } private void setupBasicMockHttpResponse() throws Exception { - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); +// when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); } private void verifyHttpCallWithoutToken(PushNotificationConfig config, Task task, String expectedToken) throws Exception { ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); +// verify(mockPostBuilder).url(config.url()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify that addHeader was never called for authentication token verify(mockPostBuilder, never()).addHeader(A2AHeaders.X_A2A_NOTIFICATION_TOKEN, expectedToken); @@ -229,21 +235,23 @@ public void testSendNotificationSuccess() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", null); configStore.setInfo(taskId, config); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); +// when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); // Verify HTTP client was called ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); +// verify(mockPostBuilder).url(config.url()); verify(mockPostBuilder).body(bodyCaptor.capture()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -258,24 +266,26 @@ public void testSendNotificationWithToken() throws Exception { PushNotificationConfig config = createSamplePushConfig("http://notify.me/here", "cfg1", "unique_token"); configStore.setInfo(taskId, config); + when(clientManager.getOrCreate(any())).thenReturn(mockHttpClient); + // Mock successful HTTP response - when(mockHttpClient.createPost()).thenReturn(mockPostBuilder); - when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); + when(mockHttpClient.post(any())).thenReturn(mockPostBuilder); +// when(mockPostBuilder.url(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.body(any(String.class))).thenReturn(mockPostBuilder); when(mockPostBuilder.addHeader(any(String.class), any(String.class))).thenReturn(mockPostBuilder); - when(mockPostBuilder.post()).thenReturn(mockHttpResponse); + when(mockPostBuilder.send()).thenReturn(CompletableFuture.completedFuture(mockHttpResponse)); when(mockHttpResponse.success()).thenReturn(true); notificationSender.sendNotification(task); // Verify HTTP client was called with proper authentication ArgumentCaptor bodyCaptor = ArgumentCaptor.forClass(String.class); - verify(mockHttpClient).createPost(); - verify(mockPostBuilder).url(config.url()); + verify(mockHttpClient).post(any()); +// verify(mockPostBuilder).url(config.url()); verify(mockPostBuilder).body(bodyCaptor.capture()); // Verify that the token is included in request headers as X-A2A-Notification-Token verify(mockPostBuilder).addHeader(A2AHeaders.X_A2A_NOTIFICATION_TOKEN, config.token()); - verify(mockPostBuilder).post(); + verify(mockPostBuilder).send(); // Verify the request body contains the task data String sentBody = bodyCaptor.getValue(); @@ -291,7 +301,7 @@ public void testSendNotificationNoConfig() throws Exception { notificationSender.sendNotification(task); // Verify HTTP client was never called - verify(mockHttpClient, never()).createPost(); + verify(mockHttpClient, never()).post(any()); } @Test diff --git a/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java b/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java index 2ab974ed..f01d2422 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/PushNotificationSenderTest.java @@ -2,6 +2,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; import java.io.IOException; import java.util.ArrayList; @@ -13,20 +15,27 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.Event; +import io.a2a.server.http.HttpClientManager; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; import io.a2a.common.A2AHeaders; import io.a2a.util.Utils; import io.a2a.spec.PushNotificationConfig; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; public class PushNotificationSenderTest { + @Mock + private HttpClientManager clientManager; + private TestHttpClient testHttpClient; private InMemoryPushNotificationConfigStore configStore; private BasePushNotificationSender sender; @@ -34,7 +43,7 @@ public class PushNotificationSenderTest { /** * Simple test implementation of A2AHttpClient that captures HTTP calls for verification */ - private static class TestHttpClient implements A2AHttpClient { + private static class TestHttpClient implements HttpClient { final List tasks = Collections.synchronizedList(new ArrayList<>()); final List urls = Collections.synchronizedList(new ArrayList<>()); final List> headers = Collections.synchronizedList(new ArrayList<>()); @@ -42,85 +51,85 @@ private static class TestHttpClient implements A2AHttpClient { volatile boolean shouldThrowException = false; @Override - public GetBuilder createGet() { + public GetRequestBuilder get(String path) { return null; } @Override - public PostBuilder createPost() { + public PostRequestBuilder post(String path) { return new TestPostBuilder(); } @Override - public DeleteBuilder createDelete() { + public DeleteRequestBuilder delete(String path) { return null; } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostBuilder implements HttpClient.PostRequestBuilder { private volatile String body; - private volatile String url; private final Map requestHeaders = new java.util.HashMap<>(); @Override - public PostBuilder body(String body) { + public PostRequestBuilder body(String body) { this.body = body; return this; } @Override - public A2AHttpResponse post() throws IOException, InterruptedException { + public CompletableFuture send() { + CompletableFuture future = new CompletableFuture<>(); + if (shouldThrowException) { - throw new IOException("Simulated network error"); + future.completeExceptionally(new IOException("Simulated network error")); + return future; } try { Task task = Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE); tasks.add(task); - urls.add(url); headers.add(new java.util.HashMap<>(requestHeaders)); - - return new A2AHttpResponse() { - @Override - public int status() { - return 200; - } - - @Override - public boolean success() { - return true; - } - - @Override - public String body() { - return ""; - } - }; + + future.complete( + new HttpResponse() { + @Override + public int statusCode() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + + } + }); + } catch (Exception e) { + future.completeExceptionally(e); } finally { if (latch != null) { latch.countDown(); } } - } - @Override - public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; + return future; } @Override - public PostBuilder url(String url) { - this.url = url; - return this; - } - - @Override - public PostBuilder addHeader(String name, String value) { + public PostRequestBuilder addHeader(String name, String value) { requestHeaders.put(name, value); return this; } @Override - public PostBuilder addHeaders(Map headers) { + public PostRequestBuilder addHeaders(Map headers) { requestHeaders.putAll(headers); return this; } @@ -129,9 +138,10 @@ public PostBuilder addHeaders(Map headers) { @BeforeEach public void setUp() { + MockitoAnnotations.openMocks(this); testHttpClient = new TestHttpClient(); configStore = new InMemoryPushNotificationConfigStore(); - sender = new BasePushNotificationSender(configStore, testHttpClient); + sender = new BasePushNotificationSender(configStore, clientManager); } private void testSendNotificationWithInvalidToken(String token, String testName) throws InterruptedException { @@ -141,7 +151,9 @@ private void testSendNotificationWithInvalidToken(String token, String testName) // Set up the configuration in the store configStore.setInfo(taskId, config); - + + when(clientManager.getOrCreate(any())).thenReturn(testHttpClient); + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -185,7 +197,9 @@ public void testSendNotificationSuccess() throws InterruptedException { // Set up the configuration in the store configStore.setInfo(taskId, config); - + + when(clientManager.getOrCreate(any())).thenReturn(testHttpClient); + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -210,7 +224,9 @@ public void testSendNotificationWithTokenSuccess() throws InterruptedException { // Set up the configuration in the store configStore.setInfo(taskId, config); - + + when(clientManager.getOrCreate(any())).thenReturn(testHttpClient); + // Set up latch to wait for async completion testHttpClient.latch = new CountDownLatch(1); @@ -263,22 +279,27 @@ public void testSendNotificationMultipleConfigs() throws InterruptedException { // Set up multiple configurations in the store configStore.setInfo(taskId, config1); configStore.setInfo(taskId, config2); - + + TestHttpClient httpClient = spy(testHttpClient); + when(clientManager.getOrCreate(any())).thenReturn(httpClient); + // Set up latch to wait for async completion (2 calls expected) - testHttpClient.latch = new CountDownLatch(2); + httpClient.latch = new CountDownLatch(2); sender.sendNotification(taskData); // Wait for the async operations to complete - assertTrue(testHttpClient.latch.await(5, TimeUnit.SECONDS), "HTTP calls should complete within 5 seconds"); + assertTrue(httpClient.latch.await(5, TimeUnit.SECONDS), "HTTP calls should complete within 5 seconds"); // Verify both tasks were sent via HTTP - assertEquals(2, testHttpClient.tasks.size()); - assertEquals(2, testHttpClient.urls.size()); - assertTrue(testHttpClient.urls.containsAll(java.util.List.of("http://notify.me/cfg1", "http://notify.me/cfg2"))); + assertEquals(2, httpClient.tasks.size()); + //assertEquals(2, testHttpClient.urls.size()); + verify(httpClient).post("/cfg1"); + verify(httpClient).post("/cfg2"); + // assertTrue(testHttpClient.urls.containsAll(java.util.List.of("http://notify.me/cfg1", "http://notify.me/cfg2"))); // Both tasks should be identical (same task sent to different endpoints) - for (Task sentTask : testHttpClient.tasks) { + for (Task sentTask : httpClient.tasks) { assertEquals(taskData.getId(), sentTask.getId()); assertEquals(taskData.getContextId(), sentTask.getContextId()); assertEquals(taskData.getStatus().state(), sentTask.getStatus().state()); diff --git a/tests/client-common/pom.xml b/tests/client-common/pom.xml new file mode 100644 index 00000000..f003c7b6 --- /dev/null +++ b/tests/client-common/pom.xml @@ -0,0 +1,60 @@ + + + 4.0.0 + + + io.github.a2asdk + a2a-java-sdk-parent + 0.3.0.Beta3-SNAPSHOT + ../../pom.xml + + a2a-java-sdk-tests-client-common + + jar + + Java A2A SDK Client Tests Common + Java SDK for the Agent2Agent Protocol (A2A) - SDK - Client Tests Common + + + + ${project.groupId} + a2a-java-sdk-http-client + test + + + org.junit.jupiter + junit-jupiter-api + test + + + org.wiremock + wiremock + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + + + + \ No newline at end of file diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java b/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java new file mode 100644 index 00000000..9b21d4f1 --- /dev/null +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/AbstractHttpClientTest.java @@ -0,0 +1,187 @@ +package io.a2a.client.http.common; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import io.a2a.client.http.HttpClientBuilder; +import io.a2a.client.http.HttpResponse; +import io.a2a.client.http.sse.Event; +import org.junit.jupiter.api.*; + +import java.net.HttpURLConnection; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static org.junit.jupiter.api.Assertions.*; + +public abstract class AbstractHttpClientTest { + + private static final String AGENT_CARD_PATH = "/.well-known/agent-card.json"; + + private WireMockServer server; + + @BeforeEach + public void setUp() { + server = new WireMockServer(WireMockConfiguration.options().dynamicPort()); + server.start(); + + configureFor("localhost", server.port()); + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(); + } + } + + protected abstract HttpClientBuilder getHttpClientBuilder(); + + private String getServerUrl() { + return "http://localhost:" + server.port(); + } + + /** + * This test is disabled until we can make the http-client layer fully async + */ + @Test + @Disabled + public void testGetWithBodyResponse() throws Exception { + givenThat(get(urlPathEqualTo(AGENT_CARD_PATH)) + .willReturn(okForContentType("application/json", JsonMessages.AGENT_CARD))); + + CountDownLatch latch = new CountDownLatch(1); + getHttpClientBuilder() + .create(getServerUrl()) + .get(AGENT_CARD_PATH) + .send() + .thenAccept(new Consumer() { + @Override + public void accept(HttpResponse httpResponse) { + String body = httpResponse.body(); + + Assertions.assertEquals(JsonMessages.AGENT_CARD, body); + latch.countDown(); + } + }); + + boolean dataReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue(dataReceived); + + } + + @Test + public void testA2AClientSendStreamingMessage() throws Exception { + String eventStream = + JsonStreamingMessages.SEND_MESSAGE_STREAMING_TEST_RESPONSE + + JsonStreamingMessages.TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE; + + givenThat(post(urlPathEqualTo("/")) + .willReturn(okForContentType("text/event-stream", eventStream))); + + CountDownLatch latch = new CountDownLatch(2); + AtomicReference errorRef = new AtomicReference<>(); + + getHttpClientBuilder() + .create(getServerUrl()) + .post("/") + .send() + .thenAccept(new Consumer() { + @Override + public void accept(HttpResponse httpResponse) { + httpResponse.bodyAsSse(new Consumer() { + @Override + public void accept(Event event) { + System.out.println(event); + latch.countDown(); + } + }, new Consumer() { + @Override + public void accept(Throwable throwable) { + errorRef.set(throwable); + latch.countDown(); + } + }); + } + }); + + boolean dataReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue(dataReceived); + assertNull(errorRef.get(), "Should not receive errors during SSE stream"); + } + + @Test + public void testUnauthorizedClient_post() throws Exception { + givenThat(post(urlPathEqualTo("/")) + .willReturn(aResponse().withStatus(HttpURLConnection.HTTP_UNAUTHORIZED))); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + + getHttpClientBuilder() + // Enforce that the client will be receiving the SSE stream into multiple chunks + // .options(new HttpClientOptions().setMaxChunkSize(24)) + .create(getServerUrl()) + .post("/") + .send() + .whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (throwable != null) { + errorRef.set(throwable); + } + + if (httpResponse != null) { + responseRef.set(httpResponse); + } + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); + assertNull(responseRef.get(), "Should not receive response when unauthorized"); + assertNotNull(errorRef.get(), "Should not receive errors during SSE stream"); + } + + @Test + public void testUnauthorizedClient_get() throws Exception { + givenThat(get(urlPathEqualTo("/")) + .willReturn(aResponse().withStatus(HttpURLConnection.HTTP_UNAUTHORIZED))); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference errorRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + + getHttpClientBuilder() + // Enforce that the client will be receiving the SSE stream into multiple chunks + // .options(new HttpClientOptions().setMaxChunkSize(24)) + .create(getServerUrl()) + .get("/") + .send() + .whenComplete(new BiConsumer() { + @Override + public void accept(HttpResponse httpResponse, Throwable throwable) { + if (throwable != null) { + errorRef.set(throwable); + } + + if (httpResponse != null) { + responseRef.set(httpResponse); + } + + latch.countDown(); + } + }); + + boolean callCompleted = latch.await(5, TimeUnit.SECONDS); + assertTrue(callCompleted); + assertNull(responseRef.get(), "Should not receive response when unauthorized"); + assertNotNull(errorRef.get(), "Should not receive errors during SSE stream"); + } +} diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java new file mode 100644 index 00000000..0ab9d811 --- /dev/null +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonMessages.java @@ -0,0 +1,85 @@ +package io.a2a.client.http.common; + +/** + * Request and response messages used by the tests. These have been created following examples from + * the A2A sample messages. + */ +public class JsonMessages { + + static final String AGENT_CARD = """ + { + "protocolVersion": "0.2.9", + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "preferredTransport": "JSONRPC", + "additionalInterfaces" : [ + {"url": "https://georoute-agent.example.com/a2a/v1", "transport": "JSONRPC"}, + {"url": "https://georoute-agent.example.com/a2a/grpc", "transport": "GRPC"}, + {"url": "https://georoute-agent.example.com/a2a/json", "transport": "HTTP+JSON"} + ], + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "iconUrl": "https://georoute-agent.example.com/icon.png", + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "type": "openIdConnect", + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration" + } + }, + "security": [{ "google": ["openid", "profile", "email"] }], + "defaultInputModes": ["application/json", "text/plain"], + "defaultOutputModes": ["application/json", "image/png"], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": ["maps", "routing", "navigation", "directions", "traffic"], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\\"origin\\": {\\"lat\\": 37.422, \\"lng\\": -122.084}, \\"destination\\": {\\"lat\\": 37.7749, \\"lng\\": -122.4194}, \\"preferences\\": [\\"avoid_ferries\\"]}" + ], + "inputModes": ["application/json", "text/plain"], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": ["maps", "customization", "visualization", "cartography"], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": ["application/json"], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": true, + "signatures": [ + { + "protected": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0", + "signature": "QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ" + } + ] + }"""; +} \ No newline at end of file diff --git a/tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java new file mode 100644 index 00000000..15ae5c38 --- /dev/null +++ b/tests/client-common/src/test/java/io/a2a/client/http/common/JsonStreamingMessages.java @@ -0,0 +1,15 @@ +package io.a2a.client.http.common; + +/** + * Contains JSON strings for testing SSE streaming. + */ +public class JsonStreamingMessages { + + static final String SEND_MESSAGE_STREAMING_TEST_RESPONSE = + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"id\":\"2\",\"contextId\":\"context-1234\",\"status\":{\"state\":\"completed\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"kind\":\"text\",\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{},\"kind\":\"task\"}}\n\n"; + + static final String TASK_RESUBSCRIPTION_REQUEST_TEST_RESPONSE = + "event: message\n" + + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"id\":\"2\",\"contextId\":\"context-5678\",\"status\":{\"state\":\"completed\"},\"artifacts\":[{\"artifactId\":\"artifact-1\",\"name\":\"joke\",\"parts\":[{\"kind\":\"text\",\"text\":\"Why did the chicken cross the road? To get to the other side!\"}]}],\"metadata\":{},\"kind\":\"task\"}}\n\n"; +} \ No newline at end of file diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java index f161307a..d79c5df5 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestHttpClient.java @@ -1,6 +1,5 @@ package io.a2a.server.apps.common; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -8,86 +7,91 @@ import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; +import io.a2a.client.http.sse.Event; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.Alternative; -import io.a2a.client.http.A2AHttpClient; -import io.a2a.client.http.A2AHttpResponse; +import io.a2a.client.http.HttpClient; +import io.a2a.client.http.HttpResponse; import io.a2a.spec.Task; import io.a2a.util.Utils; import java.util.Map; @Dependent @Alternative -public class TestHttpClient implements A2AHttpClient { +public class TestHttpClient implements HttpClient { final List tasks = Collections.synchronizedList(new ArrayList<>()); volatile CountDownLatch latch; @Override - public GetBuilder createGet() { + public GetRequestBuilder get(String path) { return null; } @Override - public PostBuilder createPost() { - return new TestPostBuilder(); + public PostRequestBuilder post(String path) { + return new TestPostRequestBuilder(); } @Override - public DeleteBuilder createDelete() { + public DeleteRequestBuilder delete(String path) { return null; } - class TestPostBuilder implements A2AHttpClient.PostBuilder { + class TestPostRequestBuilder implements PostRequestBuilder { + private volatile String body; @Override - public PostBuilder body(String body) { + public PostRequestBuilder body(String body) { this.body = body; return this; } @Override - public A2AHttpResponse post() throws IOException, InterruptedException { - tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + public CompletableFuture send() { + CompletableFuture future = new CompletableFuture<>(); + try { - return new A2AHttpResponse() { - @Override - public int status() { - return 200; - } - - @Override - public boolean success() { - return true; - } - - @Override - public String body() { - return ""; - } - }; + tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE)); + + future.complete( + new HttpResponse() { + @Override + public int statusCode() { + return 200; + } + + @Override + public boolean success() { + return true; + } + + @Override + public String body() { + return ""; + } + + @Override + public void bodyAsSse(Consumer eventConsumer, Consumer errorConsumer) { + + } + }); + } catch (Exception ex) { + future.completeExceptionally(ex); } finally { latch.countDown(); } - } - - @Override - public CompletableFuture postAsyncSSE(Consumer messageConsumer, Consumer errorConsumer, Runnable completeRunnable) throws IOException, InterruptedException { - return null; - } - @Override - public PostBuilder url(String s) { - return this; + return future; } @Override - public PostBuilder addHeader(String name, String value) { + public PostRequestBuilder addHeader(String name, String value) { return this; } @Override - public PostBuilder addHeaders(Map headers) { + public PostRequestBuilder addHeaders(Map headers) { return this; } } diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index 9d12824b..e330180c 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -22,8 +22,7 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; import io.a2a.server.requesthandlers.DefaultRequestHandler; -import io.a2a.server.tasks.ResultAggregator; -import io.a2a.server.tasks.TaskUpdater; +import io.a2a.server.tasks.*; import io.a2a.spec.AgentCard; import io.a2a.spec.Artifact; import io.a2a.spec.AuthenticatedExtendedCardNotConfiguredError;