diff --git a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java index 6fdfe8ea4..0577e4f28 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java +++ b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -23,6 +23,7 @@ public class EventConsumer { private volatile boolean cancelled = false; private volatile boolean agentCompleted = false; private volatile int pollTimeoutsAfterAgentCompleted = 0; + private volatile @Nullable TaskState lastSeenTaskState = null; private static final String ERROR_MSG = "Agent did not return any response"; private static final int NO_WAIT = -1; @@ -89,7 +90,12 @@ public Flow.Publisher consumeAll() { // // IMPORTANT: In replicated scenarios, remote events may arrive AFTER local agent completes! // Use grace period to allow for Kafka replication delays (can be 400-500ms) - if (agentCompleted && queueSize == 0) { + // + // CRITICAL: Do NOT close if task is in interrupted state (INPUT_REQUIRED, AUTH_REQUIRED) + // Per A2A spec, interrupted states are NOT terminal - the stream must stay open + // for future state updates even after agent completes (agent will be re-invoked later). + boolean isInterruptedState = lastSeenTaskState != null && lastSeenTaskState.isInterrupted(); + if (agentCompleted && queueSize == 0 && !isInterruptedState) { pollTimeoutsAfterAgentCompleted++; if (pollTimeoutsAfterAgentCompleted >= MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED) { LOGGER.debug("Agent completed with {} consecutive poll timeouts and empty queue, closing for graceful completion (queue={})", @@ -102,6 +108,10 @@ public Flow.Publisher consumeAll() { LOGGER.debug("Agent completed but grace period active ({}/{} timeouts), continuing to poll (queue={})", pollTimeoutsAfterAgentCompleted, MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED, System.identityHashCode(queue)); } + } else if (agentCompleted && isInterruptedState) { + LOGGER.debug("Agent completed but task is in interrupted state ({}), stream must remain open (queue={})", + lastSeenTaskState, System.identityHashCode(queue)); + pollTimeoutsAfterAgentCompleted = 0; // Reset counter } else if (agentCompleted && queueSize > 0) { LOGGER.debug("Agent completed but queue has {} pending events, resetting timeout counter and continuing to poll (queue={})", queueSize, System.identityHashCode(queue)); @@ -115,6 +125,13 @@ public Flow.Publisher consumeAll() { LOGGER.debug("EventConsumer received event: {} (queue={})", event.getClass().getSimpleName(), System.identityHashCode(queue)); + // Track the latest task state for grace period logic + if (event instanceof Task task) { + lastSeenTaskState = task.status().state(); + } else if (event instanceof TaskStatusUpdateEvent tue) { + lastSeenTaskState = tue.status().state(); + } + // Defensive logging for error handling if (event instanceof Throwable thr) { LOGGER.debug("EventConsumer detected Throwable event: {} - triggering tube.fail()", @@ -195,17 +212,21 @@ public Flow.Publisher consumeAll() { /** * Determines if a task is in a state for terminating the stream. - *

A task is terminating if:

- *
    - *
  • Its state is final (e.g., completed, canceled, rejected, failed), OR
  • - *
  • Its state is interrupted (e.g., input-required)
  • - *
+ *

+ * Per A2A Protocol Specification 3.1.6 (SubscribeToTask): + * "The stream MUST terminate when the task reaches a terminal state + * (completed, failed, canceled, or rejected)." + *

+ * Interrupted states (INPUT_REQUIRED, AUTH_REQUIRED) are NOT terminal. + * The stream should remain open to deliver future state updates when + * the task resumes after receiving the required input or authorization. + * * @param task the task to check - * @return true if the task has a final state or an interrupted state, false otherwise + * @return true if the task has a terminal/final state, false otherwise */ private boolean isStreamTerminatingTask(Task task) { TaskState state = task.status().state(); - return state.isFinal() || state == TaskState.TASK_STATE_INPUT_REQUIRED; + return state.isFinal(); } public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index e31fa9b4b..17242ba89 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -238,26 +238,32 @@ public void testConsumeMessageEvents() throws Exception { @Test public void testConsumeTaskInputRequired() { + // Per A2A Protocol Specification 3.1.6 (SubscribeToTask): + // "The stream MUST terminate when the task reaches a terminal state + // (completed, failed, canceled, or rejected)." + // + // INPUT_REQUIRED is an interrupted state, NOT a terminal state. + // The stream should remain open to deliver future state updates. Task task = Task.builder() .id(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.TASK_STATE_INPUT_REQUIRED)) .build(); - List events = List.of( - task, - TaskArtifactUpdateEvent.builder() + TaskArtifactUpdateEvent artifactEvent = TaskArtifactUpdateEvent.builder() .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") .parts(new TextPart("text")) .build()) - .build(), - TaskStatusUpdateEvent.builder() + .build(); + TaskStatusUpdateEvent completedEvent = TaskStatusUpdateEvent.builder() .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.TASK_STATE_COMPLETED)) - .build()); + .build(); + List events = List.of(task, artifactEvent, completedEvent); + for (Event event : events) { eventQueue.enqueueEvent(event); } @@ -269,9 +275,12 @@ public void testConsumeTaskInputRequired() { publisher.subscribe(getSubscriber(receivedEvents, error)); assertNull(error.get()); - // The stream is closed after the input_required task - assertEquals(1, receivedEvents.size()); + // Stream should remain open for INPUT_REQUIRED and deliver all events + // until the terminal COMPLETED state is reached + assertEquals(3, receivedEvents.size()); assertSame(task, receivedEvents.get(0)); + assertSame(artifactEvent, receivedEvents.get(1)); + assertSame(completedEvent, receivedEvents.get(2)); } private Flow.Subscriber getSubscriber(List receivedEvents, AtomicReference error) { diff --git a/spec/src/main/java/io/a2a/spec/TaskState.java b/spec/src/main/java/io/a2a/spec/TaskState.java index ca5354167..831ac0fb5 100644 --- a/spec/src/main/java/io/a2a/spec/TaskState.java +++ b/spec/src/main/java/io/a2a/spec/TaskState.java @@ -6,11 +6,17 @@ * TaskState represents the discrete states a task can be in during its execution lifecycle. * States are categorized as either transitional (non-final) or terminal (final), where * terminal states indicate that the task has reached its end state and will not transition further. + * A subset of transitional states are also marked as interrupted, indicating the task execution + * has paused and requires external action before proceeding. *

- * Transitional States: + * Active Transitional States: *

    *
  • TASK_STATE_SUBMITTED: Task has been received by the agent and is queued for processing
  • *
  • TASK_STATE_WORKING: Agent is actively processing the task and may produce incremental results
  • + *
+ *

+ * Interrupted States: + *

    *
  • TASK_STATE_INPUT_REQUIRED: Agent needs additional input from the user to continue
  • *
  • TASK_STATE_AUTH_REQUIRED: Agent requires authentication or authorization before proceeding
  • *
@@ -25,7 +31,8 @@ * *

* The {@link #isFinal()} method can be used to determine if a state is terminal, which is - * important for event queue management and client polling logic. + * important for event queue management and client polling logic. The {@link #isInterrupted()} + * method identifies states where the task is paused awaiting external action. * * @see TaskStatus * @see Task @@ -33,36 +40,38 @@ */ public enum TaskState { /** Task has been received and is queued for processing (transitional state). */ - TASK_STATE_SUBMITTED(false), + TASK_STATE_SUBMITTED(false, false), /** Agent is actively processing the task (transitional state). */ - TASK_STATE_WORKING(false), + TASK_STATE_WORKING(false, false), - /** Agent requires additional input from the user to continue (transitional state). */ - TASK_STATE_INPUT_REQUIRED(false), + /** Agent requires additional input from the user to continue (interrupted state). */ + TASK_STATE_INPUT_REQUIRED(false, true), - /** Agent requires authentication or authorization to proceed (transitional state). */ - TASK_STATE_AUTH_REQUIRED(false), + /** Agent requires authentication or authorization to proceed (interrupted state). */ + TASK_STATE_AUTH_REQUIRED(false, true), /** Task completed successfully (terminal state). */ - TASK_STATE_COMPLETED(true), + TASK_STATE_COMPLETED(true, false), /** Task was canceled by user or system (terminal state). */ - TASK_STATE_CANCELED(true), + TASK_STATE_CANCELED(true, false), /** Task failed due to an error (terminal state). */ - TASK_STATE_FAILED(true), + TASK_STATE_FAILED(true, false), /** Task was rejected by the agent (terminal state). */ - TASK_STATE_REJECTED(true), + TASK_STATE_REJECTED(true, false), /** Task state is unknown or cannot be determined (terminal state). */ - UNRECOGNIZED(true); + UNRECOGNIZED(true, false); private final boolean isFinal; + private final boolean isInterrupted; - TaskState(boolean isFinal) { + TaskState(boolean isFinal, boolean isInterrupted) { this.isFinal = isFinal; + this.isInterrupted = isInterrupted; } /** @@ -71,10 +80,32 @@ public enum TaskState { * Terminal states indicate that the task has completed its lifecycle and will * not transition to any other state. This is used by the event queue system * to determine when to close queues and by clients to know when to stop polling. + *

+ * Terminal states: COMPLETED, FAILED, CANCELED, REJECTED, UNRECOGNIZED. * * @return {@code true} if this is a terminal state, {@code false} else. */ public boolean isFinal(){ return isFinal; } + + /** + * Determines whether this state is an interrupted state. + *

+ * Interrupted states indicate that the task execution has paused and requires + * external action before proceeding. The task may resume after the required + * action is provided. Interrupted states are NOT terminal - streams should + * remain open to deliver state updates. + *

+ * Interrupted states: INPUT_REQUIRED, AUTH_REQUIRED. + *

+ * Per A2A Protocol Specification 4.1.3 (TaskState): + * "TASK_STATE_INPUT_REQUIRED: This is an interrupted state." + * "TASK_STATE_AUTH_REQUIRED: This is an interrupted state." + * + * @return {@code true} if this is an interrupted state, {@code false} else. + */ + public boolean isInterrupted() { + return isInterrupted; + } } \ No newline at end of file diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index f7cdacc61..35f60ec29 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -875,6 +875,163 @@ public void testSubscribeExistingTaskSuccessWithClientConsumers() throws Excepti } } + /** + * Tests that SubscribeToTask stream stays open for interrupted states (INPUT_REQUIRED, AUTH_REQUIRED) + * and only terminates on terminal states. + *

+ * Per A2A Protocol Specification 3.1.6 (SubscribeToTask): + * "The stream MUST terminate when the task reaches a terminal state (completed, failed, canceled, or rejected)." + *

+ * Interrupted states are NOT terminal - the stream should remain open to deliver future state updates. + *

+ * This test addresses issue #754: Stream was incorrectly closing immediately for INPUT_REQUIRED state. + * The bug had two parts: + * 1. isStreamTerminatingTask() incorrectly treated INPUT_REQUIRED as terminating + * 2. Grace period logic closed queue after agent completion, even for interrupted states + */ + @Test + @Timeout(value = 3, unit = TimeUnit.MINUTES) + public void testSubscribeToTaskWithInterruptedStateKeepsStreamOpen() throws Exception { + // Use a taskId with the pattern the test agent recognizes + // When we send a message with a taskId to a non-existent task, it creates + // a new task with that ID, and context.getTask() is still null on first invocation + String taskId = "input-required-test-" + UUID.randomUUID(); + + try { + // Create initial message with the special taskId pattern + // Use non-streaming client so agent can emit INPUT_REQUIRED and return immediately + // This ensures context.getTask() == null on first agent invocation + Message message = Message.builder(MESSAGE) + .taskId(taskId) + .contextId("test-context") + .parts(new TextPart("Trigger INPUT_REQUIRED")) + .build(); + + // Send message with non-streaming client - agent will emit INPUT_REQUIRED and complete + AtomicReference finalStateRef = new AtomicReference<>(); + AtomicReference sendErrorRef = new AtomicReference<>(); + CountDownLatch sendLatch = new CountDownLatch(1); + + getNonStreamingClient().sendMessage(message, List.of((event, agentCard) -> { + if (event instanceof TaskEvent te) { + finalStateRef.set(te.getTask().status().state()); + sendLatch.countDown(); + } else if (event instanceof TaskUpdateEvent tue) { + if (tue.getUpdateEvent() instanceof TaskStatusUpdateEvent statusUpdate) { + finalStateRef.set(statusUpdate.status().state()); + } + } + }), error -> { + if (!isStreamClosedError(error)) { + sendErrorRef.set(error); + } + sendLatch.countDown(); + }); + + assertTrue(sendLatch.await(15, TimeUnit.SECONDS), "SendMessage should complete"); + assertNull(sendErrorRef.get(), "SendMessage should not error"); + TaskState finalState = finalStateRef.get(); + assertNotNull(finalState, "Final state should be captured"); + assertEquals(TaskState.TASK_STATE_INPUT_REQUIRED, finalState, + "Task should be in INPUT_REQUIRED state after agent completes"); + + // CRITICAL: At this point the agent has completed with INPUT_REQUIRED state + // The grace period logic should NOT close the queue because INPUT_REQUIRED + // is an interrupted state, not a terminal state + + // Wait 2 seconds - longer than the grace period (1.5 seconds) + // Before fix: queue would close after grace period + // After fix: queue stays open because task is in interrupted state + Thread.sleep(2000); + + // Track events received through subscription stream + CopyOnWriteArrayList receivedEvents = new CopyOnWriteArrayList<>(); + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); + AtomicBoolean streamClosedPrematurely = new AtomicBoolean(false); + AtomicReference subscribeErrorRef = new AtomicReference<>(); + CountDownLatch completionLatch = new CountDownLatch(1); + + // Consumer to track all events from subscription + BiConsumer consumer = (event, agentCard) -> { + if (event instanceof TaskEvent taskEvent) { + if (!receivedInitialTask.get()) { + receivedInitialTask.set(true); + // First event should be the initial task snapshot in INPUT_REQUIRED state + assertEquals(TaskState.TASK_STATE_INPUT_REQUIRED, + taskEvent.getTask().status().state(), + "Initial task should be in INPUT_REQUIRED state"); + return; + } + } else if (event instanceof TaskUpdateEvent taskUpdateEvent) { + io.a2a.spec.UpdateEvent updateEvent = taskUpdateEvent.getUpdateEvent(); + receivedEvents.add(updateEvent); + + // Check if this is the final terminal state + if (updateEvent instanceof TaskStatusUpdateEvent tue && tue.isFinal()) { + completionLatch.countDown(); + } + } + }; + + // Error handler to detect premature stream closure + Consumer errorHandler = error -> { + if (!isStreamClosedError(error)) { + subscribeErrorRef.set(error); + } + // If completion latch hasn't been counted down yet, stream closed prematurely + if (completionLatch.getCount() > 0) { + streamClosedPrematurely.set(true); + } + completionLatch.countDown(); + }; + + // Subscribe to the task - this is AFTER agent completed with INPUT_REQUIRED + CountDownLatch subscriptionLatch = new CountDownLatch(1); + awaitStreamingSubscription() + .whenComplete((unused, throwable) -> subscriptionLatch.countDown()); + + getClient().subscribeToTask(new TaskIdParams(taskId), List.of(consumer), errorHandler); + + // Wait for subscription to be established + assertTrue(subscriptionLatch.await(15, TimeUnit.SECONDS), "Subscription should be established"); + + // Verify stream received initial task and is still open + assertTrue(receivedInitialTask.get(), "Should receive initial task snapshot"); + assertFalse(streamClosedPrematurely.get(), + "Stream should NOT close for INPUT_REQUIRED state (interrupted, not terminal)"); + + // Send a follow-up message to provide the required input + // This will trigger the agent again, which will emit COMPLETED + Message followUpMessage = Message.builder() + .messageId("input-response-" + UUID.randomUUID()) + .role(Message.Role.ROLE_USER) + .parts(new TextPart("User input")) + .taskId(taskId) + .build(); + + getClient().sendMessage(followUpMessage, List.of(), error -> {}); + + // Stream should now close after receiving COMPLETED event + assertTrue(completionLatch.await(30, TimeUnit.SECONDS), + "Stream should close after terminal state"); + + // Verify we received the COMPLETED update + assertTrue(receivedEvents.size() >= 1, + "Should receive at least COMPLETED status update"); + + // Find the COMPLETED event + boolean foundCompleted = receivedEvents.stream() + .filter(e -> e instanceof TaskStatusUpdateEvent) + .map(e -> (TaskStatusUpdateEvent) e) + .anyMatch(tue -> tue.status().state() == TaskState.TASK_STATE_COMPLETED); + assertTrue(foundCompleted, "Should receive COMPLETED status update"); + + assertNull(subscribeErrorRef.get(), "Should not have any errors"); + } finally { + deleteTaskInTaskStore(taskId); + } + } + @Test public void testSubscribeNoExistingTaskError() throws Exception { CountDownLatch errorLatch = new CountDownLatch(1); diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java index 74fe78eab..d69d49b9b 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java @@ -74,27 +74,23 @@ public void execute(RequestContext context, AgentEmitter agentEmitter) throws A2 // Special handling for input-required test if (taskId != null && taskId.startsWith("input-required-test")) { - // First call: context.getTask() == null (new task) - if (context.getTask() == null) { - // Go directly to INPUT_REQUIRED without intermediate WORKING state - // This avoids race condition where blocking call interrupts on WORKING - // before INPUT_REQUIRED is persisted to TaskStore - agentEmitter.requiresInput(agentEmitter.newAgentMessage( - List.of(new TextPart("Please provide additional information")), - context.getMessage().metadata())); - // Return immediately - queue stays open because task is in INPUT_REQUIRED state - return; - } else { - String input = extractTextFromMessage(context.getMessage()); - if(! "User input".equals(input)) { - throw new InvalidParamsError("We didn't get the expected input"); - } - // Second call: context.getTask() != null (input provided) + String input = extractTextFromMessage(context.getMessage()); + // Second call: user provided the required input - complete the task + if ("User input".equals(input)) { // Go directly to COMPLETED without intermediate WORKING state - // This avoids the same race condition as the first call + // This avoids race condition where blocking call interrupts on WORKING agentEmitter.complete(); return; } + // First call: any other message - emit INPUT_REQUIRED + // Go directly to INPUT_REQUIRED without intermediate WORKING state + // This avoids race condition where blocking call interrupts on WORKING + // before INPUT_REQUIRED is persisted to TaskStore + agentEmitter.requiresInput(agentEmitter.newAgentMessage( + List.of(new TextPart("Please provide additional information")), + context.getMessage().metadata())); + // Return immediately - queue stays open because task is in INPUT_REQUIRED state + return; } // Special handling for auth-required test