Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.a2a.spec.GetTaskPushNotificationConfigParams;
import io.a2a.spec.Message;
import io.a2a.spec.PushNotificationConfig;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskPushNotificationConfig;
import io.a2a.spec.TextPart;
Expand Down Expand Up @@ -88,9 +89,9 @@ public void testDirectNotificationTrigger() {
mockPushNotificationSender.sendNotification(testTask);

// Verify it was captured
Queue<Task> captured = mockPushNotificationSender.getCapturedTasks();
Queue<StreamingEventKind> captured = mockPushNotificationSender.getCapturedEvents();
assertEquals(1, captured.size());
assertEquals("direct-test-task", captured.peek().getId());
assertEquals("direct-test-task", ((Task)captured.peek()).getId());
}

@Test
Expand Down Expand Up @@ -151,7 +152,7 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
boolean notificationReceived = false;

while (System.currentTimeMillis() < end) {
if (!mockPushNotificationSender.getCapturedTasks().isEmpty()) {
if (!mockPushNotificationSender.getCapturedEvents().isEmpty()) {
notificationReceived = true;
break;
}
Expand All @@ -161,10 +162,12 @@ public void testJpaDatabasePushNotificationConfigStoreIntegration() throws Excep
assertTrue(notificationReceived, "Timeout waiting for push notification.");

// Step 6: Verify the captured notification
Queue<Task> capturedTasks = mockPushNotificationSender.getCapturedTasks();
Queue<StreamingEventKind> capturedTasks = mockPushNotificationSender.getCapturedEvents();

// Verify the notification contains the correct task with artifacts
Task notifiedTaskWithArtifact = capturedTasks.stream()
.filter(e -> Task.TASK.equals(e.getKind()))
.map(e -> (Task)e)
.filter(t -> taskId.equals(t.getId()) && t.getArtifacts() != null && t.getArtifacts().size() > 0)
.findFirst()
.orElse(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jakarta.enterprise.inject.Alternative;

import io.a2a.server.tasks.PushNotificationSender;
import io.a2a.spec.Task;
import io.a2a.spec.StreamingEventKind;

/**
* Mock implementation of PushNotificationSender for integration testing.
Expand All @@ -19,18 +19,18 @@
@Priority(100)
public class MockPushNotificationSender implements PushNotificationSender {

private final Queue<Task> capturedTasks = new ConcurrentLinkedQueue<>();
private final Queue<StreamingEventKind> capturedEvents = new ConcurrentLinkedQueue<>();

@Override
public void sendNotification(Task task) {
capturedTasks.add(task);
public void sendNotification(StreamingEventKind kind) {
capturedEvents.add(kind);
}

public Queue<Task> getCapturedTasks() {
return capturedTasks;
public Queue<StreamingEventKind> getCapturedEvents() {
return capturedEvents;
}

public void clear() {
capturedTasks.clear();
capturedEvents.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@
import static io.a2a.client.http.A2AHttpClient.APPLICATION_JSON;
import static io.a2a.client.http.A2AHttpClient.CONTENT_TYPE;
import static io.a2a.common.A2AHeaders.X_A2A_NOTIFICATION_TOKEN;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import static io.a2a.spec.Message.MESSAGE;
import static io.a2a.spec.Task.TASK;
import static io.a2a.spec.TaskArtifactUpdateEvent.ARTIFACT_UPDATE;
import static io.a2a.spec.TaskStatusUpdateEvent.STATUS_UPDATE;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

import com.fasterxml.jackson.core.JsonProcessingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.a2a.client.http.A2AHttpClient;
import io.a2a.client.http.JdkA2AHttpClient;
import io.a2a.spec.Message;
import io.a2a.spec.PushNotificationConfig;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskArtifactUpdateEvent;
import io.a2a.spec.TaskStatusUpdateEvent;
import io.a2a.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ApplicationScoped
public class BasePushNotificationSender implements PushNotificationSender {

Expand All @@ -42,34 +50,44 @@ public BasePushNotificationSender(PushNotificationConfigStore configStore, A2AHt
}

@Override
public void sendNotification(Task task) {
List<PushNotificationConfig> pushConfigs = configStore.getInfo(task.getId());
public void sendNotification(StreamingEventKind kind) {
String taskId = switch (kind.getKind()) {
case TASK -> ((Task)kind).getId();
case MESSAGE -> ((Message)kind).getTaskId();
case STATUS_UPDATE -> ((TaskStatusUpdateEvent)kind).getTaskId();
case ARTIFACT_UPDATE -> ((TaskArtifactUpdateEvent)kind).getTaskId();
default -> null;
};
Comment on lines +54 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The switch statement on kind.getKind() uses a default -> null case, which will cause push notifications for any new or unhandled StreamingEventKind types to be silently dropped. This could lead to hard-to-debug issues.

Since StreamingEventKind is a sealed interface, you can use a pattern-matching switch statement on the kind object itself. This is safer as the compiler will enforce that all permitted subtypes are handled, eliminating the need for a default case and preventing silent failures.

        String taskId = switch (kind) {
            case Task t -> t.getId();
            case Message m -> m.getTaskId();
            case TaskStatusUpdateEvent e -> e.getTaskId();
            case TaskArtifactUpdateEvent e -> e.getTaskId();
        };

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion but this requires Java 21+ while the project targets Java 17.

if (taskId == null) {
return;
Comment on lines +54 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The switch statement with casting can be simplified using pattern matching for switch, which is available in modern Java versions. This would make the code more concise and type-safe. Since StreamingEventKind is a sealed interface, the switch expression is exhaustive, so the default case and the subsequent null check for taskId are no longer necessary.

Suggested change
String taskId = switch (kind.getKind()) {
case TASK -> ((Task)kind).getId();
case MESSAGE -> ((Message)kind).getTaskId();
case STATUS_UPDATE -> ((TaskStatusUpdateEvent)kind).getTaskId();
case ARTIFACT_UPDATE -> ((TaskArtifactUpdateEvent)kind).getTaskId();
default -> null;
};
if (taskId == null) {
return;
String taskId = switch (kind) {
case Task task -> task.getId();
case Message message -> message.getTaskId();
case TaskStatusUpdateEvent event -> event.getTaskId();
case TaskArtifactUpdateEvent event -> event.getTaskId();
};

}
List<PushNotificationConfig> pushConfigs = configStore.getInfo(taskId);
if (pushConfigs == null || pushConfigs.isEmpty()) {
return;
}

List<CompletableFuture<Boolean>> dispatchResults = pushConfigs
.stream()
.map(pushConfig -> dispatch(task, pushConfig))
.map(pushConfig -> dispatch(kind, pushConfig))
.toList();
CompletableFuture<Void> allFutures = CompletableFuture.allOf(dispatchResults.toArray(new CompletableFuture[0]));
CompletableFuture<Boolean> dispatchResult = allFutures.thenApply(v -> dispatchResults.stream()
.allMatch(CompletableFuture::join));
try {
boolean allSent = dispatchResult.get();
if (! allSent) {
LOGGER.warn("Some push notifications failed to send for taskId: " + task.getId());
LOGGER.warn("Some push notifications failed to send for taskId: " + taskId);
}
} catch (InterruptedException | ExecutionException e) {
LOGGER.warn("Some push notifications failed to send for taskId " + task.getId() + ": {}", e.getMessage(), e);
LOGGER.warn("Some push notifications failed to send for taskId " + taskId + ": {}", e.getMessage(), e);
}
}

private CompletableFuture<Boolean> dispatch(Task task, PushNotificationConfig pushInfo) {
return CompletableFuture.supplyAsync(() -> dispatchNotification(task, pushInfo));
private CompletableFuture<Boolean> dispatch(StreamingEventKind kind, PushNotificationConfig pushInfo) {
return CompletableFuture.supplyAsync(() -> dispatchNotification(kind, pushInfo));
}

private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo) {
private boolean dispatchNotification(StreamingEventKind kind, PushNotificationConfig pushInfo) {
String url = pushInfo.url();
String token = pushInfo.token();

Expand All @@ -80,7 +98,7 @@ private boolean dispatchNotification(Task task, PushNotificationConfig pushInfo)

String body;
try {
body = Utils.OBJECT_MAPPER.writeValueAsString(task);
body = Utils.marshalFrom(kind);
} catch (JsonProcessingException e) {
LOGGER.debug("Error writing value as string: {}", e.getMessage(), e);
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package io.a2a.server.tasks;

import io.a2a.spec.Task;
import io.a2a.spec.StreamingEventKind;

/**
* Interface for sending push notifications for tasks.
*/
public interface PushNotificationSender {

/**
* Sends a push notification containing the latest task state.
* @param task the task
* Sends a push notification containing payload about a task.
* @param kind the payload to push
*/
void sendNotification(Task task);
void sendNotification(StreamingEventKind kind);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
Expand All @@ -15,6 +16,11 @@

import jakarta.enterprise.context.Dependent;

import io.quarkus.arc.profile.IfBuildProfile;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;

import io.a2a.client.http.A2AHttpClient;
import io.a2a.client.http.A2AHttpResponse;
import io.a2a.server.agentexecution.AgentExecutor;
Expand All @@ -30,20 +36,15 @@
import io.a2a.server.tasks.TaskStore;
import io.a2a.spec.AgentCapabilities;
import io.a2a.spec.AgentCard;
import io.a2a.spec.Event;
import io.a2a.spec.JSONRPCError;
import io.a2a.spec.Message;
import io.a2a.spec.StreamingEventKind;
import io.a2a.spec.Task;
import io.a2a.spec.TaskState;
import io.a2a.spec.TaskStatus;
import io.a2a.spec.Event;
import io.a2a.spec.TextPart;
import io.a2a.util.Utils;
import io.quarkus.arc.profile.IfBuildProfile;
import java.util.Map;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;

public class AbstractA2ARequestHandlerTest {

Expand Down Expand Up @@ -199,7 +200,10 @@ public PostBuilder body(String body) {

@Override
public A2AHttpResponse post() throws IOException, InterruptedException {
tasks.add(Utils.OBJECT_MAPPER.readValue(body, Task.TYPE_REFERENCE));
StreamingEventKind kind = Utils.unmarshalStreamingEventKindFrom(body);
if (kind instanceof Task task) {
tasks.add(task);
}
try {
return new A2AHttpResponse() {
@Override
Expand Down
Loading