diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffWorkflowBuilder.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffWorkflowBuilder.cs index 4e9f201053..4c93414c63 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffWorkflowBuilder.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/HandoffWorkflowBuilder.cs @@ -8,6 +8,10 @@ using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; +using ExecutorFactoryFunc = System.Func, + string, + System.Threading.Tasks.ValueTask>; + namespace Microsoft.Agents.AI.Workflows; internal static class DiagnosticConstants @@ -233,6 +237,57 @@ public TBuilder WithHandoff(AIAgent from, AIAgent to, string? handoffReason = nu return (TBuilder)this; } + private Dictionary CreateExecutorBindings(WorkflowBuilder builder) + { + HandoffAgentExecutorOptions options = new(this.HandoffInstructions, + this._emitAgentResponseEvents, + this._emitAgentResponseUpdateEvents, + this._toolCallFilteringBehavior); + + // There are two types of ids being used in this method, and it is critical that we are clear about + // which one we are using, and where. + // AgentId...: comes from AIAgent.Id, is often an unreadable machine identifier (e.g. a Guid), and is used to address + // the handoffs + // ExecutorId: uses AIAgent.GetDescriptiveId() to use a friendlier name in telemetry, and is used for ExecutorBinding, + // which are subsequently used in building the workflow + + // The outgoing dictionary maps from AgentId => ExecutorBinding + return this._allAgents.ToDictionary(keySelector: a => a.Id, elementSelector: CreateFactoryBinding); + + ExecutorBinding CreateFactoryBinding(AIAgent agent) + { + if (!this._targets.TryGetValue(agent, out HashSet? handoffs)) + { + handoffs = new(); + } + + // Use the ExecutorId as the placeholder id for a (possibly) future-bound factory + builder.AddSwitch(HandoffAgentExecutor.IdFor(agent), (SwitchBuilder sb) => + { + foreach (HandoffTarget handoff in handoffs) + { + sb.AddCase(state => state?.RequestedHandoffTargetAgentId == handoff.Target.Id, // Use AgentId for target matching + HandoffAgentExecutor.IdFor(handoff.Target)); // Use ExecutorId in for routing at the workflow level + } + + sb.WithDefault(HandoffEndExecutor.ExecutorId); + }); + + ExecutorFactoryFunc factory = + (config, sessionId) => new( + new HandoffAgentExecutor(agent, + handoffs, + options)); + + // Make sure to use ExecutorId when binding the executor, not AgentId + ExecutorBinding binding = factory.BindExecutor(HandoffAgentExecutor.IdFor(agent)); + + builder.BindExecutor(binding); + + return binding; + } + } + /// /// Builds a composed of agents that operate via handoffs, with the next /// agent to process messages selected by the current agent. @@ -240,17 +295,12 @@ public TBuilder WithHandoff(AIAgent from, AIAgent to, string? handoffReason = nu /// The workflow built based on the handoffs in the builder. public Workflow Build() { - HandoffsStartExecutor start = new(this._returnToPrevious); - HandoffsEndExecutor end = new(this._returnToPrevious); + HandoffStartExecutor start = new(this._returnToPrevious); + HandoffEndExecutor end = new(this._returnToPrevious); WorkflowBuilder builder = new(start); - HandoffAgentExecutorOptions options = new(this.HandoffInstructions, - this._emitAgentResponseEvents, - this._emitAgentResponseUpdateEvents, - this._toolCallFilteringBehavior); - - // Create an AgentExecutor for each agent. - Dictionary executors = this._allAgents.ToDictionary(a => a.Id, a => new HandoffAgentExecutor(a, options)); + // Create an factory-based ExecutorBinding for each agent. + Dictionary executors = this.CreateExecutorBindings(builder); // Connect the start executor to the initial agent (or use dynamic routing when ReturnToPrevious is enabled). if (this._returnToPrevious) @@ -263,7 +313,7 @@ public Workflow Build() if (agent.Id != initialAgentId) { string agentId = agent.Id; - sb.AddCase(state => state?.CurrentAgentId == agentId, executors[agentId]); + sb.AddCase(state => state?.PreviousAgentId == agentId, executors[agentId]); } } @@ -275,13 +325,6 @@ public Workflow Build() builder.AddEdge(start, executors[this._initialAgent.Id]); } - // Initialize each executor with its handoff targets to the other executors. - foreach (var agent in this._allAgents) - { - executors[agent.Id].Initialize(builder, end, executors, - this._targets.TryGetValue(agent, out HashSet? targets) ? targets : []); - } - // Build the workflow. return builder.WithOutputFrom(end).Build(); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs index 3f3d83fbee..b7d2911537 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs @@ -19,6 +19,9 @@ public static bool ShouldEmitStreamingEvents(this TurnToken token, bool? agentSe public static bool ShouldEmitStreamingEvents(bool? turnTokenSetting, bool? agentSetting) => turnTokenSetting ?? agentSetting ?? false; + + public static bool ShouldEmitStreamingEvents(this HandoffState handoffState, bool? agentSetting) + => handoffState.TurnToken.ShouldEmitStreamingEvents(agentSetting); } internal sealed class AIAgentHostExecutor : ChatProtocolExecutor @@ -81,7 +84,11 @@ private ValueTask HandleUserInputResponseAsync( // resumes can be processed in one invocation. return this.ProcessTurnMessagesAsync(async (pendingMessages, ctx, ct) => { - pendingMessages.Add(new ChatMessage(ChatRole.User, [response])); + pendingMessages.Add(new ChatMessage(ChatRole.User, [response]) + { + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + }); await this.ContinueTurnAsync(pendingMessages, ctx, this._currentTurnEmitEvents ?? false, ct).ConfigureAwait(false); @@ -104,7 +111,12 @@ private ValueTask HandleFunctionResultAsync( // resumes can be processed in one invocation. return this.ProcessTurnMessagesAsync(async (pendingMessages, ctx, ct) => { - pendingMessages.Add(new ChatMessage(ChatRole.Tool, [result])); + pendingMessages.Add(new ChatMessage(ChatRole.Tool, [result]) + { + AuthorName = this._agent.Name ?? this._agent.Id, + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + }); await this.ContinueTurnAsync(pendingMessages, ctx, this._currentTurnEmitEvents ?? false, ct).ConfigureAwait(false); @@ -186,16 +198,13 @@ protected override ValueTask TakeTurnAsync(List messages, IWorkflow TurnExtensions.ShouldEmitStreamingEvents(turnTokenSetting: emitEvents, this._options.EmitAgentUpdateEvents), cancellationToken); - private async ValueTask InvokeAgentAsync(IEnumerable messages, IWorkflowContext context, bool emitEvents, CancellationToken cancellationToken = default) + private async ValueTask InvokeAgentAsync(IEnumerable messages, IWorkflowContext context, bool emitUpdateEvents, CancellationToken cancellationToken = default) { -#pragma warning disable MEAI001 - Dictionary userInputRequests = new(); - Dictionary functionCalls = new(); AgentResponse response; + AIAgentUnservicedRequestsCollector collector = new(this._userInputHandler, this._functionCallHandler); - if (emitEvents) + if (emitUpdateEvents) { -#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. // Run the agent in streaming mode only when agent run update events are to be emitted. IAsyncEnumerable agentStream = this._agent.RunStreamingAsync( messages, @@ -206,7 +215,7 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false), await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) { await context.YieldOutputAsync(update, cancellationToken).ConfigureAwait(false); - ExtractUnservicedRequests(update.Contents); + collector.ProcessAgentResponseUpdate(update); updates.Add(update); } @@ -220,7 +229,7 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false), cancellationToken: cancellationToken) .ConfigureAwait(false); - ExtractUnservicedRequests(response.Messages.SelectMany(message => message.Contents)); + collector.ProcessAgentResponse(response); } if (this._options.EmitAgentResponseEvents) @@ -228,45 +237,8 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false), await context.YieldOutputAsync(response, cancellationToken).ConfigureAwait(false); } - if (userInputRequests.Count > 0 || functionCalls.Count > 0) - { - Task userInputTask = this._userInputHandler?.ProcessRequestContentsAsync(userInputRequests, context, cancellationToken) ?? Task.CompletedTask; - Task functionCallTask = this._functionCallHandler?.ProcessRequestContentsAsync(functionCalls, context, cancellationToken) ?? Task.CompletedTask; - - await Task.WhenAll(userInputTask, functionCallTask) - .ConfigureAwait(false); - } + await collector.SubmitAsync(context, cancellationToken).ConfigureAwait(false); return response; - - void ExtractUnservicedRequests(IEnumerable contents) - { - foreach (AIContent content in contents) - { - if (content is ToolApprovalRequestContent userInputRequest) - { - // It is an error to simultaneously have multiple outstanding user input requests with the same ID. - userInputRequests.Add(userInputRequest.RequestId, userInputRequest); - } - else if (content is ToolApprovalResponseContent userInputResponse) - { - // If the set of messages somehow already has a corresponding user input response, remove it. - _ = userInputRequests.Remove(userInputResponse.RequestId); - } - else if (content is FunctionCallContent functionCall) - { - // For function calls, we emit an event to notify the workflow. - // - // possibility 1: this will be handled inline by the agent abstraction - // possibility 2: this will not be handled inline by the agent abstraction - functionCalls.Add(functionCall.CallId, functionCall); - } - else if (content is FunctionResultContent functionResult) - { - _ = functionCalls.Remove(functionResult.CallId); - } - } - } -#pragma warning restore MEAI001 } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentUnservicedRequestsCollector.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentUnservicedRequestsCollector.cs new file mode 100644 index 0000000000..7e4f8c8c9d --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentUnservicedRequestsCollector.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +internal sealed class AIAgentUnservicedRequestsCollector(AIContentExternalHandler? userInputHandler, + AIContentExternalHandler? functionCallHandler) +{ + private readonly Dictionary _userInputRequests = []; + private readonly Dictionary _functionCalls = []; + + public Task SubmitAsync(IWorkflowContext context, CancellationToken cancellationToken) + { + Task userInputTask = userInputHandler != null && this._userInputRequests.Count > 0 + ? userInputHandler.ProcessRequestContentsAsync(this._userInputRequests, context, cancellationToken) + : Task.CompletedTask; + + Task functionCallTask = functionCallHandler != null && this._functionCalls.Count > 0 + ? functionCallHandler.ProcessRequestContentsAsync(this._functionCalls, context, cancellationToken) + : Task.CompletedTask; + + return Task.WhenAll(userInputTask, functionCallTask); + } + + public void ProcessAgentResponseUpdate(AgentResponseUpdate update, Func? functionCallFilter = null) + => this.ProcessAIContents(update.Contents, functionCallFilter); + + public void ProcessAgentResponse(AgentResponse response) + => this.ProcessAIContents(response.Messages.SelectMany(message => message.Contents)); + + public void ProcessAIContents(IEnumerable contents, Func? functionCallFilter = null) + { + foreach (AIContent content in contents) + { + if (content is ToolApprovalRequestContent userInputRequest) + { + if (this._userInputRequests.ContainsKey(userInputRequest.RequestId)) + { + throw new InvalidOperationException($"ToolApprovalRequestContent with duplicate RequestId: {userInputRequest.RequestId}"); + } + + // It is an error to simultaneously have multiple outstanding user input requests with the same ID. + this._userInputRequests.Add(userInputRequest.RequestId, userInputRequest); + } + else if (content is ToolApprovalResponseContent userInputResponse) + { + // If the set of messages somehow already has a corresponding user input response, remove it. + _ = this._userInputRequests.Remove(userInputResponse.RequestId); + } + else if (content is FunctionCallContent functionCall) + { + // For function calls, we emit an event to notify the workflow. + // + // possibility 1: this will be handled inline by the agent abstraction + // possibility 2: this will not be handled inline by the agent abstraction + if (functionCallFilter == null || functionCallFilter(functionCall)) + { + if (this._functionCalls.ContainsKey(functionCall.CallId)) + { + throw new InvalidOperationException($"FunctionCallContent with duplicate CallId: {functionCall.CallId}"); + } + + this._functionCalls.Add(functionCall.CallId, functionCall); + } + } + else if (content is FunctionResultContent functionResult) + { + _ = this._functionCalls.Remove(functionResult.CallId); + } + } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index e885b894fd..eac2eb5687 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.ComponentModel; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json; @@ -166,128 +165,331 @@ private class FilterCandidateState(string callId) } } +internal struct AgentInvocationResult(AgentResponse agentResponse, string? handoffTargetId) +{ + public AgentResponse Response => agentResponse; + + public string? HandoffTargetId => handoffTargetId; + + [MemberNotNullWhen(true, nameof(HandoffTargetId))] + public bool IsHandoffRequested => this.HandoffTargetId != null; +} + +internal record HandoffAgentHostState(HandoffState? CurrentTurnState, List FilteredIncomingMessages, List TurnMessages) +{ + public HandoffState PrepareHandoff(AgentInvocationResult invocationResult, string currentAgentId) + { + if (this.CurrentTurnState == null) + { + throw new InvalidOperationException("Cannot create a handoff request: Out of turn."); + } + + IEnumerable allMessages = [.. this.CurrentTurnState.Messages, .. this.TurnMessages, .. invocationResult.Response.Messages]; + + return new(this.CurrentTurnState.TurnToken, invocationResult.HandoffTargetId, allMessages.ToList(), currentAgentId); + } +} + /// Executor used to represent an agent in a handoffs workflow, responding to events. [Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] -internal sealed class HandoffAgentExecutor( - AIAgent agent, - HandoffAgentExecutorOptions options) : Executor(agent.GetDescriptiveId(), declareCrossRunShareable: true), IResettableExecutor +internal sealed class HandoffAgentExecutor : + StatefulExecutor { private static readonly JsonElement s_handoffSchema = AIFunctionFactory.Create( ([Description("The reason for the handoff")] string? reasonForHandoff) => { }).JsonSchema; - private readonly AIAgent _agent = agent; + public static string IdFor(AIAgent agent) => agent.GetDescriptiveId(); + + private readonly AIAgent _agent; + private readonly ChatClientAgentRunOptions? _agentOptions; + + private readonly HandoffAgentExecutorOptions _options; + private readonly HashSet _handoffFunctionNames = []; private readonly Dictionary _handoffFunctionToAgentId = []; - private ChatClientAgentRunOptions? _agentOptions; - - public void Initialize( - WorkflowBuilder builder, - Executor end, - Dictionary executors, - HashSet handoffs) => - builder.AddSwitch(this, sb => + + private static HandoffAgentHostState InitialStateFactory() => new(null, [], []); + + public HandoffAgentExecutor(AIAgent agent, HashSet handoffs, HandoffAgentExecutorOptions options) + : base(IdFor(agent), InitialStateFactory) + { + this._agent = agent; + this._options = options; + + this._agentOptions = CreateAgentHandoffContext(this._options.HandoffInstructions, handoffs, this._handoffFunctionNames, this._handoffFunctionToAgentId); + } + + private static ChatClientAgentRunOptions? CreateAgentHandoffContext(string? handoffInstructions, HashSet handoffs, HashSet functionNames, Dictionary functionToAgentId) + { + ChatClientAgentRunOptions? result = null; + + if (handoffs.Count != 0) { - if (handoffs.Count != 0) + result = new() { - Debug.Assert(this._agentOptions is null); - this._agentOptions = new() + ChatOptions = new() { - ChatOptions = new() - { - AllowMultipleToolCalls = false, - Instructions = options.HandoffInstructions, - Tools = [], - }, - }; + AllowMultipleToolCalls = false, + Instructions = handoffInstructions, + Tools = [], + }, + }; + + int index = 0; + foreach (HandoffTarget handoff in handoffs) + { + index++; + var handoffFunc = AIFunctionFactory.CreateDeclaration($"{HandoffWorkflowBuilder.FunctionPrefix}{index}", handoff.Reason, s_handoffSchema); - int index = 0; - foreach (HandoffTarget handoff in handoffs) + functionNames.Add(handoffFunc.Name); + functionToAgentId[handoffFunc.Name] = handoff.Target.Id; + + result.ChatOptions.Tools.Add(handoffFunc); + } + } + + return result; + } + + private AIContentExternalHandler? _userInputHandler; + private AIContentExternalHandler? _functionCallHandler; + + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) + { + return this.ConfigureUserInputHandling(base.ConfigureProtocol(protocolBuilder)) + .SendsMessage(); + } + + private ProtocolBuilder ConfigureUserInputHandling(ProtocolBuilder protocolBuilder) + { + this._userInputHandler = new AIContentExternalHandler( + ref protocolBuilder, + portId: $"{this.Id}_UserInput", + intercepted: false, + handler: this.HandleUserInputResponseAsync); + + this._functionCallHandler = new AIContentExternalHandler( + ref protocolBuilder, + portId: $"{this.Id}_FunctionCall", + intercepted: false, // TODO: Use this instead of manual function handling for handoff? + handler: this.HandleFunctionResultAsync); + + return protocolBuilder; + } + + private ValueTask HandleUserInputResponseAsync( + ToolApprovalResponseContent response, + IWorkflowContext context, + CancellationToken cancellationToken) + { + if (!this._userInputHandler!.MarkRequestAsHandled(response.RequestId)) + { + throw new InvalidOperationException($"No pending ToolApprovalRequest found with id '{response.RequestId}'."); + } + + // Merge the external response with any already-buffered regular messages so mixed-content + // resumes can be processed in one invocation. + return this.InvokeWithStateAsync((state, ctx, ct) => + { + state.TurnMessages.Add(new ChatMessage(ChatRole.User, [response]) + { + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + }); + + return this.ContinueTurnAsync(state, ctx, ct); + }, context, skipCache: false, cancellationToken); + } + + private ValueTask HandleFunctionResultAsync( + FunctionResultContent result, + IWorkflowContext context, + CancellationToken cancellationToken) + { + if (!this._functionCallHandler!.MarkRequestAsHandled(result.CallId)) + { + throw new InvalidOperationException($"No pending FunctionCall found with id '{result.CallId}'."); + } + + // Merge the external response with any already-buffered regular messages so mixed-content + // resumes can be processed in one invocation. + return this.InvokeWithStateAsync((state, ctx, ct) => + { + state.TurnMessages.Add( + new ChatMessage(ChatRole.Tool, [result]) { - index++; - var handoffFunc = AIFunctionFactory.CreateDeclaration($"{HandoffWorkflowBuilder.FunctionPrefix}{index}", handoff.Reason, s_handoffSchema); + AuthorName = this._agent.Name ?? this._agent.Id, + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + }); - this._handoffFunctionNames.Add(handoffFunc.Name); - this._handoffFunctionToAgentId[handoffFunc.Name] = handoff.Target.Id; + return this.ContinueTurnAsync(state, ctx, ct); + }, context, skipCache: false, cancellationToken); + } + + private async ValueTask ContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) + { + List? roleChanges = state.FilteredIncomingMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); - this._agentOptions.ChatOptions.Tools.Add(handoffFunc); + bool emitUpdateEvents = state.CurrentTurnState!.ShouldEmitStreamingEvents(this._options.EmitAgentResponseUpdateEvents); + AgentInvocationResult result = await this.InvokeAgentAsync([.. state.FilteredIncomingMessages, .. state.TurnMessages], context, emitUpdateEvents, cancellationToken) + .ConfigureAwait(false); - sb.AddCase(state => state?.InvokedHandoff == handoffFunc.Name, executors[handoff.Target.Id]); - } + if (this.HasOutstandingRequests && result.IsHandoffRequested) + { + throw new InvalidOperationException("Cannot request a handoff while holding pending requests."); + } + + roleChanges.ResetUserToAssistantForChangedRoles(); + + // We send on the HandoffState even if handoff is not requested because we might be terminating the processing, but this only + // happens if we have no outstanding requests. + if (!this.HasOutstandingRequests) + { + HandoffState outgoingState = state.PrepareHandoff(result, this._agent.Id); + + await context.SendMessageAsync(outgoingState, cancellationToken).ConfigureAwait(false); + + // reset the state for the next handoff (return-to-current is modeled as a new handoff turn, as opposed to "HITL", which + // can be a bit confusing.) + return null; + } + + state.TurnMessages.AddRange(result.Response.Messages); + return state; + } + + public override ValueTask HandleAsync(HandoffState message, IWorkflowContext context, CancellationToken cancellationToken = default) + { + return this.InvokeWithStateAsync(InvokeContinueTurnAsync, context, skipCache: false, cancellationToken); + + ValueTask InvokeContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) + { + // Check that we are not getting this message while in the middle of a turn + if (state.CurrentTurnState != null) + { + throw new InvalidOperationException("Cannot have multiple simultaneous conversations in Handoff Orchestration."); } - sb.WithDefault(end); - }); + // If a handoff was invoked by a previous agent, filter out the handoff function + // call and tool result messages before sending to the underlying agent. These + // are internal workflow mechanics that confuse the target model into ignoring the + // original user question. + HandoffMessagesFilter handoffMessagesFilter = new(this._options.ToolCallFilteringBehavior); + IEnumerable messagesForAgent = message.RequestedHandoffTargetAgentId is not null + ? handoffMessagesFilter.FilterMessages(message.Messages) + : message.Messages; + + // This works because the runtime guarantees that a given executor instance will process messages serially, + // though there is no global cross-executor ordering guarantee (and in turn, no canonical message delivery order) + state = new(message, messagesForAgent.ToList(), []); + + return this.ContinueTurnAsync(state, context, cancellationToken); + } + } + + private const string UserInputRequestStateKey = nameof(_userInputHandler); + private const string FunctionCallRequestStateKey = nameof(_functionCallHandler); + + protected internal override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default) + { + Task userInputRequestsTask = this._userInputHandler?.OnCheckpointingAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task functionCallRequestsTask = this._functionCallHandler?.OnCheckpointingAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + + Task baseTask = base.OnCheckpointingAsync(context, cancellationToken).AsTask(); + await Task.WhenAll(userInputRequestsTask, functionCallRequestsTask, baseTask).ConfigureAwait(false); + } - public override async ValueTask HandleAsync(HandoffState message, IWorkflowContext context, CancellationToken cancellationToken = default) + protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default) { + Task userInputRestoreTask = this._userInputHandler?.OnCheckpointRestoredAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task functionCallRestoreTask = this._functionCallHandler?.OnCheckpointRestoredAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + + await Task.WhenAll(userInputRestoreTask, functionCallRestoreTask).ConfigureAwait(false); + await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); + } + private bool HasOutstandingRequests => (this._userInputHandler?.HasPendingRequests == true) + || (this._functionCallHandler?.HasPendingRequests == true); + + private async ValueTask InvokeAgentAsync(IEnumerable messages, IWorkflowContext context, bool emitUpdateEvents, CancellationToken cancellationToken = default) + { + AgentResponse response; + + AIAgentUnservicedRequestsCollector collector = new(this._userInputHandler, this._functionCallHandler); + + IAsyncEnumerable agentStream = this._agent.RunStreamingAsync( + messages, + options: this._agentOptions, + cancellationToken: cancellationToken); + string? requestedHandoff = null; List updates = []; - List allMessages = message.Messages; - - List? roleChanges = allMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); - - // If a handoff was invoked by a previous agent, filter out the handoff function - // call and tool result messages before sending to the underlying agent. These - // are internal workflow mechanics that confuse the target model into ignoring the - // original user question. - HandoffMessagesFilter handoffMessagesFilter = new(options.ToolCallFilteringBehavior); - IEnumerable messagesForAgent = message.InvokedHandoff is not null - ? handoffMessagesFilter.FilterMessages(allMessages) - : allMessages; - - await foreach (var update in this._agent.RunStreamingAsync(messagesForAgent, - options: this._agentOptions, - cancellationToken: cancellationToken) - .ConfigureAwait(false)) + List candidateRequests = []; + await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) { await AddUpdateAsync(update, cancellationToken).ConfigureAwait(false); - foreach (var fcc in update.Contents.OfType() - .Where(fcc => this._handoffFunctionNames.Contains(fcc.Name))) + collector.ProcessAgentResponseUpdate(update, CollectHandoffRequestsFilter); + + bool CollectHandoffRequestsFilter(FunctionCallContent candidateHandoffRequest) { - requestedHandoff = fcc.Name; - await AddUpdateAsync( - new AgentResponseUpdate - { - AgentId = this._agent.Id, - AuthorName = this._agent.Name ?? this._agent.Id, - Contents = [new FunctionResultContent(fcc.CallId, "Transferred.")], - CreatedAt = DateTimeOffset.UtcNow, - MessageId = Guid.NewGuid().ToString("N"), - Role = ChatRole.Tool, - }, - cancellationToken - ) - .ConfigureAwait(false); + bool isHandoffRequest = this._handoffFunctionNames.Contains(candidateHandoffRequest.Name); + if (isHandoffRequest) + { + candidateRequests.Add(candidateHandoffRequest); + } + + return !isHandoffRequest; } } - AgentResponse agentResponse = updates.ToAgentResponse(); + if (candidateRequests.Count > 1) + { + string message = $"Duplicate handoff requests in single turn ([{string.Join(", ", candidateRequests.Select(request => request.Name))}]). Using last ({candidateRequests.Last().Name})"; + await context.AddEventAsync(new WorkflowWarningEvent(message), cancellationToken).ConfigureAwait(false); + } - if (options.EmitAgentResponseEvents) + if (candidateRequests.Count > 0) { - await context.YieldOutputAsync(agentResponse, cancellationToken).ConfigureAwait(false); + FunctionCallContent handoffRequest = candidateRequests[candidateRequests.Count - 1]; + requestedHandoff = handoffRequest.Name; + + await AddUpdateAsync( + new AgentResponseUpdate + { + AgentId = this._agent.Id, + AuthorName = this._agent.Name ?? this._agent.Id, + Contents = [new FunctionResultContent(handoffRequest.CallId, "Transferred.")], + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + Role = ChatRole.Tool, + }, + cancellationToken + ) + .ConfigureAwait(false); } - allMessages.AddRange(agentResponse.Messages); + response = updates.ToAgentResponse(); - roleChanges.ResetUserToAssistantForChangedRoles(); + if (this._options.EmitAgentResponseEvents) + { + await context.YieldOutputAsync(response, cancellationToken).ConfigureAwait(false); + } - string currentAgentId = requestedHandoff is not null && this._handoffFunctionToAgentId.TryGetValue(requestedHandoff, out string? targetAgentId) - ? targetAgentId - : this._agent.Id; + await collector.SubmitAsync(context, cancellationToken).ConfigureAwait(false); - return new(message.TurnToken, requestedHandoff, allMessages, currentAgentId); + return new(response, LookupHandoffTarget(requestedHandoff)); - async Task AddUpdateAsync(AgentResponseUpdate update, CancellationToken cancellationToken) + ValueTask AddUpdateAsync(AgentResponseUpdate update, CancellationToken cancellationToken) { updates.Add(update); - if (message.TurnToken.ShouldEmitStreamingEvents(options.EmitAgentResponseUpdateEvents)) - { - await context.YieldOutputAsync(update, cancellationToken).ConfigureAwait(false); - } + + return emitUpdateEvents ? context.YieldOutputAsync(update, cancellationToken) : default; } - } - public ValueTask ResetAsync() => default; + string? LookupHandoffTarget(string? requestedHandoff) + => requestedHandoff != null + ? this._handoffFunctionToAgentId.TryGetValue(requestedHandoff, out string? targetId) ? targetId : null + : null; + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs similarity index 86% rename from dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs rename to dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs index 4a43c00a72..0ba8fc3501 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsEndExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs @@ -8,7 +8,7 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; /// Executor used at the end of a handoff workflow to raise a final completed event. -internal sealed class HandoffsEndExecutor(bool returnToPrevious) : Executor(ExecutorId, declareCrossRunShareable: true), IResettableExecutor +internal sealed class HandoffEndExecutor(bool returnToPrevious) : Executor(ExecutorId, declareCrossRunShareable: true), IResettableExecutor { public const string ExecutorId = "HandoffEnd"; @@ -21,9 +21,9 @@ private async ValueTask HandleAsync(HandoffState handoff, IWorkflowContext conte { if (returnToPrevious) { - await context.QueueStateUpdateAsync(HandoffConstants.CurrentAgentTrackerKey, - handoff.CurrentAgentId, - HandoffConstants.CurrentAgentTrackerScope, + await context.QueueStateUpdateAsync(HandoffConstants.PreviousAgentTrackerKey, + handoff.PreviousAgentId, + HandoffConstants.PreviousAgentTrackerScope, cancellationToken) .ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs similarity index 71% rename from dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs rename to dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs index 87c3b4566b..063f73bb6f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffsStartExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs @@ -9,12 +9,12 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; internal static class HandoffConstants { - internal const string CurrentAgentTrackerKey = "LastAgentId"; - internal const string CurrentAgentTrackerScope = "HandoffOrchestration"; + internal const string PreviousAgentTrackerKey = "LastAgentId"; + internal const string PreviousAgentTrackerScope = "HandoffOrchestration"; } /// Executor used at the start of a handoffs workflow to accumulate messages and emit them as HandoffState upon receiving a turn token. -internal sealed class HandoffsStartExecutor(bool returnToPrevious) : ChatProtocolExecutor(ExecutorId, DefaultOptions, declareCrossRunShareable: true), IResettableExecutor +internal sealed class HandoffStartExecutor(bool returnToPrevious) : ChatProtocolExecutor(ExecutorId, DefaultOptions, declareCrossRunShareable: true), IResettableExecutor { internal const string ExecutorId = "HandoffStart"; @@ -32,15 +32,15 @@ protected override ValueTask TakeTurnAsync(List messages, IWorkflow if (returnToPrevious) { return context.InvokeWithStateAsync( - async (string? currentAgentId, IWorkflowContext context, CancellationToken cancellationToken) => + async (string? previousAgentId, IWorkflowContext context, CancellationToken cancellationToken) => { - HandoffState handoffState = new(new(emitEvents), null, messages, currentAgentId); + HandoffState handoffState = new(new(emitEvents), null, messages, previousAgentId); await context.SendMessageAsync(handoffState, cancellationToken).ConfigureAwait(false); - return currentAgentId; + return previousAgentId; }, - HandoffConstants.CurrentAgentTrackerKey, - HandoffConstants.CurrentAgentTrackerScope, + HandoffConstants.PreviousAgentTrackerKey, + HandoffConstants.PreviousAgentTrackerScope, cancellationToken); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs index 56e2fef9df..644bc7df0e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs @@ -7,6 +7,6 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; internal sealed record class HandoffState( TurnToken TurnToken, - string? InvokedHandoff, + string? RequestedHandoffTargetAgentId, List Messages, - string? CurrentAgentId = null); + string? PreviousAgentId = null); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs index 3ed23cc019..d1d239506f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/StatefulExecutor.cs @@ -113,6 +113,12 @@ protected async ValueTask InvokeWithStateAsync( { if (!skipCache && !context.ConcurrentRunsEnabled) { + if (this._stateCache is null) + { + this._stateCache = await context.ReadOrInitStateAsync(this.StateKey, this._initialStateFactory, this.Options.ScopeName, cancellationToken) + .ConfigureAwait(false); + } + TState newState = await invocation(this._stateCache ?? this._initialStateFactory(), context, cancellationToken).ConfigureAwait(false) @@ -168,9 +174,12 @@ public abstract class StatefulExecutor(string id, /// protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - protocolBuilder.RouteBuilder.AddHandler(this.HandleAsync); + Func handlerDelegate = this.HandleAsync; - return protocolBuilder.SendsMessageTypes(sentMessageTypes ?? []) + return protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler(handlerDelegate)) + .AddMethodAttributeTypes(handlerDelegate.Method) + .AddClassAttributeTypes(this.GetType()) + .SendsMessageTypes(sentMessageTypes ?? []) .YieldsOutputTypes(outputTypes ?? []); } @@ -203,19 +212,12 @@ public abstract class StatefulExecutor(string id, /// protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) { - protocolBuilder.RouteBuilder.AddHandler(this.HandleAsync); - - if (this.Options.AutoSendMessageHandlerResultObject) - { - protocolBuilder.SendsMessage(); - } - - if (this.Options.AutoYieldOutputHandlerResultObject) - { - protocolBuilder.YieldsOutput(); - } - - return protocolBuilder.SendsMessageTypes(sentMessageTypes ?? []).YieldsOutputTypes(outputTypes ?? []); + Func> handlerDelegate = this.HandleAsync; + return protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler(handlerDelegate)) + .AddMethodAttributeTypes(handlerDelegate.Method) + .AddClassAttributeTypes(this.GetType()) + .SendsMessageTypes(sentMessageTypes ?? []) + .YieldsOutputTypes(outputTypes ?? []); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index 7f06145a8e..c857811b08 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -9,6 +9,7 @@ using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; +using FluentAssertions; using Microsoft.Agents.AI.Workflows.InProc; using Microsoft.Extensions.AI; @@ -147,7 +148,7 @@ from i in Enumerable.Range(1, numAgents) for (int iter = 0; iter < 3; iter++) { const string UserInput = "abc"; - (string updateText, List? result, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, UserInput)]); + (string updateText, List? result, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, UserInput)]); Assert.NotNull(result); Assert.Equal(numAgents + 1, result.Count); @@ -225,7 +226,7 @@ public async Task BuildConcurrent_AgentsRunInParallelAsync() barrier.Value = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); remaining.Value = 2; - (string updateText, List? result, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); + (string updateText, List? result, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); Assert.NotEmpty(updateText); Assert.NotNull(result); @@ -258,7 +259,7 @@ public async Task Handoffs_NoTransfers_ResponseServedByOriginalAgentAsync() }), description: "nop")) .Build(); - (string updateText, List? result, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); + (string updateText, List? result, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); Assert.Equal("Hello from agent1", updateText); Assert.NotNull(result); @@ -296,7 +297,7 @@ public async Task Handoffs_OneTransfer_ResponseServedBySecondAgentAsync() .WithHandoff(initialAgent, nextAgent) .Build(); - (string updateText, List? result, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); + (string updateText, List? result, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); Assert.Equal("Hello from agent2", updateText); Assert.NotNull(result); @@ -406,7 +407,7 @@ public async Task Handoffs_TwoTransfers_HandoffTargetsDoNotReceiveHandoffFunctio .WithHandoff(secondAgent, thirdAgent) .Build(); - (string updateText, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); + (string updateText, _, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); Assert.Contains("Hello from agent3", updateText); @@ -604,7 +605,7 @@ public async Task Handoffs_TwoTransfers_ResponseServedByThirdAgentAsync() .WithHandoff(secondAgent, thirdAgent) .Build(); - (string updateText, List? result, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); + (string updateText, List? result, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, "abc")]); Assert.Equal("Hello from agent3", updateText); Assert.NotNull(result); @@ -634,6 +635,232 @@ public async Task Handoffs_TwoTransfers_ResponseServedByThirdAgentAsync() Assert.Contains("thirdAgent", result[5].AuthorName); } + [Fact] + public async Task Handoffs_TwoTransfers_SecondAgentUserApproval_ResponseServedByThirdAgentAsync() + { + var initialAgent = new ChatClientAgent(new MockChatClient((messages, options) => + { + ChatMessage message = Assert.Single(messages); + Assert.Equal("abc", Assert.IsType(Assert.Single(message.Contents)).Text); + + string? transferFuncName = options?.Tools?.FirstOrDefault(t => t.Name.StartsWith("handoff_to_", StringComparison.Ordinal))?.Name; + Assert.NotNull(transferFuncName); + + // Only a handoff function call. + return new(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", transferFuncName)])); + }), name: "initialAgent"); + + bool secondAgentInvoked = false; + + const string SomeOtherFunctionCallId = "call2first"; + + AIFunction someOtherFunction = new ApprovalRequiredAIFunction(AIFunctionFactory.Create(SomeOtherFunction)); + + var secondAgent = new ChatClientAgent(new MockChatClient((messages, options) => + { + if (!secondAgentInvoked) + { + secondAgentInvoked = true; + return new(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent(SomeOtherFunctionCallId, someOtherFunction.Name)])); + } + + // Second agent should receive the conversation so far (including previous assistant + tool messages eventually). + string? transferFuncName = options?.Tools?.FirstOrDefault(t => t.Name.StartsWith("handoff_to_", StringComparison.Ordinal))?.Name; + Assert.NotNull(transferFuncName); + + return new(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call2", transferFuncName)])); + }), name: "secondAgent", description: "The second agent", tools: [someOtherFunction]); + + var thirdAgent = new ChatClientAgent(new MockChatClient((messages, options) => + new(new ChatMessage(ChatRole.Assistant, "Hello from agent3"))), + name: "thirdAgent", + description: "The third / final agent"); + + var workflow = + AgentWorkflowBuilder.CreateHandoffBuilderWith(initialAgent) + .WithHandoff(initialAgent, secondAgent) + .WithHandoff(secondAgent, thirdAgent) + .Build(); + + CheckpointManager checkpointManager = CheckpointManager.CreateInMemory(); + const ExecutionEnvironment Environment = ExecutionEnvironment.InProcess_Lockstep; + + (string updateText, List? result, CheckpointInfo? lastCheckpoint, List requests) = + await RunWorkflowCheckpointedAsync(workflow, [new ChatMessage(ChatRole.User, "abc")], Environment, checkpointManager); + + Assert.Null(result); + Assert.NotNull(requests); + + requests.Should().HaveCount(1); + ExternalRequest request = requests[0].Request; + + ToolApprovalRequestContent approvalRequest = + request.Data.As().Should().NotBeNull() + .And.Subject.As(); + + approvalRequest.ToolCall.CallId.Should().Be(SomeOtherFunctionCallId); + + ExternalResponse response = request.CreateResponse(approvalRequest.CreateResponse(false, "Denied")); + + (updateText, result, _, requests) = + await RunWorkflowCheckpointedAsync(workflow, response, Environment, checkpointManager, lastCheckpoint); + + Assert.Equal("Hello from agent3", updateText); + Assert.NotNull(result); + + // User + (assistant empty + tool) for each of first two agents + final assistant with text. + Assert.Equal(10, result.Count); + + Assert.Equal(ChatRole.User, result[0].Role); + Assert.Equal("abc", result[0].Text); + + Assert.Equal(ChatRole.Assistant, result[1].Role); + Assert.Equal("", result[1].Text); + Assert.Contains("initialAgent", result[1].AuthorName); + + Assert.Equal(ChatRole.Tool, result[2].Role); + Assert.Contains("initialAgent", result[2].AuthorName); + + // Non-handoff tool invocation (and user denial) + Assert.Equal(ChatRole.Assistant, result[3].Role); + Assert.Equal("", result[3].Text); + Assert.Contains("secondAgent", result[3].AuthorName); + + Assert.Equal(ChatRole.User, result[4].Role); + Assert.Equal("", result[4].Text); + + // Rejected tool call + Assert.Equal(ChatRole.Assistant, result[5].Role); + Assert.Equal("", result[5].Text); + Assert.Contains("secondAgent", result[5].AuthorName); + + Assert.Equal(ChatRole.Tool, result[6].Role); + Assert.Contains("secondAgent", result[6].AuthorName); + + // Handoff invocation + Assert.Equal(ChatRole.Assistant, result[7].Role); + Assert.Equal("", result[7].Text); + Assert.Contains("secondAgent", result[7].AuthorName); + + Assert.Equal(ChatRole.Tool, result[8].Role); + Assert.Contains("secondAgent", result[8].AuthorName); + + Assert.Equal(ChatRole.Assistant, result[9].Role); + Assert.Equal("Hello from agent3", result[9].Text); + Assert.Contains("thirdAgent", result[9].AuthorName); + + static bool SomeOtherFunction() => true; + } + + [Fact] + public async Task Handoffs_TwoTransfers_SecondAgentToolCall_ResponseServedByThirdAgentAsync() + { + var initialAgent = new ChatClientAgent(new MockChatClient((messages, options) => + { + ChatMessage message = Assert.Single(messages); + Assert.Equal("abc", Assert.IsType(Assert.Single(message.Contents)).Text); + + string? transferFuncName = options?.Tools?.FirstOrDefault(t => t.Name.StartsWith("handoff_to_", StringComparison.Ordinal))?.Name; + Assert.NotNull(transferFuncName); + + // Only a handoff function call. + return new(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", transferFuncName)])); + }), name: "initialAgent"); + + bool secondAgentInvoked = false; + + const string SomeOtherFunctionName = "SomeOtherFunction"; + const string SomeOtherFunctionCallId = "call2first"; + + JsonElement otherFunctionSchema = AIFunctionFactory.Create(() => true).JsonSchema; + AIFunctionDeclaration someOtherFunction = AIFunctionFactory.CreateDeclaration(SomeOtherFunctionName, "Another function", otherFunctionSchema); + + var secondAgent = new ChatClientAgent(new MockChatClient((messages, options) => + { + if (!secondAgentInvoked) + { + secondAgentInvoked = true; + return new(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent(SomeOtherFunctionCallId, SomeOtherFunctionName)])); + } + + // Second agent should receive the conversation so far (including previous assistant + tool messages eventually). + string? transferFuncName = options?.Tools?.FirstOrDefault(t => t.Name.StartsWith("handoff_to_", StringComparison.Ordinal))?.Name; + Assert.NotNull(transferFuncName); + + return new(new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call2", transferFuncName)])); + }), name: "secondAgent", description: "The second agent", tools: [someOtherFunction]); + + var thirdAgent = new ChatClientAgent(new MockChatClient((messages, options) => + new(new ChatMessage(ChatRole.Assistant, "Hello from agent3"))), + name: "thirdAgent", + description: "The third / final agent"); + + var workflow = + AgentWorkflowBuilder.CreateHandoffBuilderWith(initialAgent) + .WithHandoff(initialAgent, secondAgent) + .WithHandoff(secondAgent, thirdAgent) + .Build(); + + CheckpointManager checkpointManager = CheckpointManager.CreateInMemory(); + const ExecutionEnvironment Environment = ExecutionEnvironment.InProcess_Lockstep; + + (string updateText, List? result, CheckpointInfo? lastCheckpoint, List requests) = + await RunWorkflowCheckpointedAsync(workflow, [new ChatMessage(ChatRole.User, "abc")], Environment, checkpointManager); + + Assert.Null(result); + Assert.NotNull(requests); + + requests.Should().HaveCount(1); + ExternalRequest request = requests[0].Request; + + FunctionCallContent functionCall = request.Data.As().Should().NotBeNull() + .And.Subject.As(); + + functionCall.CallId.Should().Be(SomeOtherFunctionCallId); + functionCall.Name.Should().Be(SomeOtherFunctionName); + + ExternalResponse response = request.CreateResponse(new FunctionResultContent(functionCall.CallId, true)); + + (updateText, result, _, requests) = + await RunWorkflowCheckpointedAsync(workflow, response, Environment, checkpointManager, lastCheckpoint); + + Assert.Equal("Hello from agent3", updateText); + Assert.NotNull(result); + + // User + (assistant empty + tool) for each of first two agents + final assistant with text. + Assert.Equal(8, result.Count); + + Assert.Equal(ChatRole.User, result[0].Role); + Assert.Equal("abc", result[0].Text); + + Assert.Equal(ChatRole.Assistant, result[1].Role); + Assert.Equal("", result[1].Text); + Assert.Contains("initialAgent", result[1].AuthorName); + + Assert.Equal(ChatRole.Tool, result[2].Role); + Assert.Contains("initialAgent", result[2].AuthorName); + + // Non-handoff tool invocation + Assert.Equal(ChatRole.Assistant, result[3].Role); + Assert.Equal("", result[3].Text); + Assert.Contains("secondAgent", result[3].AuthorName); + + Assert.Equal(ChatRole.Tool, result[4].Role); + Assert.Contains("secondAgent", result[4].AuthorName); + + // Handoff invocation + Assert.Equal(ChatRole.Assistant, result[5].Role); + Assert.Equal("", result[5].Text); + Assert.Contains("secondAgent", result[5].AuthorName); + + Assert.Equal(ChatRole.Tool, result[6].Role); + Assert.Contains("secondAgent", result[6].AuthorName); + + Assert.Equal(ChatRole.Assistant, result[7].Role); + Assert.Equal("Hello from agent3", result[7].Text); + Assert.Contains("thirdAgent", result[7].AuthorName); + } + [Theory] [InlineData(1)] [InlineData(2)] @@ -651,7 +878,7 @@ public async Task BuildGroupChat_AgentsRunInOrderAsync(int maxIterations) for (int iter = 0; iter < 3; iter++) { const string UserInput = "abc"; - (string updateText, List? result, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, UserInput)]); + (string updateText, List? result, _, _) = await RunWorkflowAsync(workflow, [new ChatMessage(ChatRole.User, UserInput)]); Assert.NotNull(result); Assert.Equal(maxIterations + 1, result.Count); @@ -832,7 +1059,7 @@ public async Task Handoffs_ReturnToPrevious_Enabled_AfterHandoffBackToCoordinato Assert.Equal(1, specialistCallCount); // specialist NOT called } - private sealed record WorkflowRunResult(string UpdateText, List? Result, CheckpointInfo? LastCheckpoint); + private sealed record WorkflowRunResult(string UpdateText, List? Result, CheckpointInfo? LastCheckpoint, List PendingRequests); private static Task RunWorkflowCheckpointedAsync( Workflow workflow, List input, ExecutionEnvironment executionEnvironment, CheckpointManager checkpointManager, CheckpointInfo? fromCheckpoint = null) @@ -843,6 +1070,15 @@ private static Task RunWorkflowCheckpointedAsync( return RunWorkflowCheckpointedAsync(workflow, input, environment, fromCheckpoint); } + private static Task RunWorkflowCheckpointedAsync( + Workflow workflow, ExternalResponse response, ExecutionEnvironment executionEnvironment, CheckpointManager checkpointManager, CheckpointInfo? fromCheckpoint = null) + { + InProcessExecutionEnvironment environment = executionEnvironment.ToWorkflowExecutionEnvironment() + .WithCheckpointing(checkpointManager); + + return RunWorkflowCheckpointedAsync(workflow, response, environment, fromCheckpoint); + } + private static async Task RunWorkflowCheckpointedAsync( Workflow workflow, List input, InProcessExecutionEnvironment environment, CheckpointInfo? fromCheckpoint = null) { @@ -853,15 +1089,39 @@ private static async Task RunWorkflowCheckpointedAsync( await run.TrySendMessageAsync(input); await run.TrySendMessageAsync(new TurnToken(emitEvents: true)); + return await ProcessWorkflowRunAsync(run); + } + + private static async Task RunWorkflowCheckpointedAsync( + Workflow workflow, ExternalResponse response, InProcessExecutionEnvironment environment, CheckpointInfo? fromCheckpoint = null) + { + await using StreamingRun run = + fromCheckpoint != null ? await environment.ResumeStreamingAsync(workflow, fromCheckpoint) + : await environment.OpenStreamingAsync(workflow); + + await run.SendResponseAsync(response); + + return await ProcessWorkflowRunAsync(run); + } + + private static async Task ProcessWorkflowRunAsync(StreamingRun run) + { StringBuilder sb = new(); WorkflowOutputEvent? output = null; CheckpointInfo? lastCheckpoint = null; - await foreach (WorkflowEvent evt in run.WatchStreamAsync().ConfigureAwait(false)) + + List pendingRequests = []; + + await foreach (WorkflowEvent evt in run.WatchStreamAsync(blockOnPendingRequest: false).ConfigureAwait(false)) { switch (evt) { - case AgentResponseUpdateEvent executorComplete: - sb.Append(executorComplete.Data); + case AgentResponseUpdateEvent responseUpdate: + sb.Append(responseUpdate.Data); + break; + + case RequestInfoEvent requestInfo: + pendingRequests.Add(requestInfo); break; case WorkflowOutputEvent e: @@ -878,7 +1138,7 @@ private static async Task RunWorkflowCheckpointedAsync( } } - return new(sb.ToString(), output?.As>(), lastCheckpoint); + return new(sb.ToString(), output?.As>(), lastCheckpoint, pendingRequests); } private static Task RunWorkflowAsync( diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs index 8bdbe23c5f..1a5b2ea4d1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs @@ -29,7 +29,7 @@ public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAs emitAgentResponseUpdateEvents: executorSetting, HandoffToolCallFilteringBehavior.None); - HandoffAgentExecutor executor = new(agent, options); + HandoffAgentExecutor executor = new(agent, [], options); testContext.ConfigureExecutor(executor); // Act @@ -57,7 +57,7 @@ public async Task Test_HandoffAgentExecutor_EmitsResponseIFFConfiguredAsync(bool emitAgentResponseUpdateEvents: false, HandoffToolCallFilteringBehavior.None); - HandoffAgentExecutor executor = new(agent, options); + HandoffAgentExecutor executor = new(agent, [], options); testContext.ConfigureExecutor(executor); // Act