diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs index 7166e1f25..66304ed5f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs @@ -1,7 +1,6 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Microsoft.Extensions.DependencyInjection; -using System.Collections.Concurrent; using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -538,106 +537,130 @@ public async Task CallToolHandler_CanBeSetToNull_ThenOtherCanBeSet() /// private sealed class InMemoryTaskStore { - private readonly ConcurrentDictionary _tasks = new(); + private readonly Dictionary _tasks = new(); public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working) { var taskId = Guid.NewGuid().ToString("N"); - _tasks[taskId] = new TaskEntry + lock (_tasks) { - Status = initialStatus, - CreatedAt = DateTimeOffset.UtcNow, - LastUpdatedAt = DateTimeOffset.UtcNow, - }; + _tasks[taskId] = new TaskEntry + { + Status = initialStatus, + CreatedAt = DateTimeOffset.UtcNow, + LastUpdatedAt = DateTimeOffset.UtcNow, + }; + } return taskId; } - public IEnumerable GetAllTaskIds() => _tasks.Keys; - - public GetTaskResult GetTask(string taskId) + public IEnumerable GetAllTaskIds() { - if (!_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - throw new McpException($"Unknown task: '{taskId}'"); + return _tasks.Keys.ToArray(); } + } - return entry.Status switch + public GetTaskResult GetTask(string taskId) + { + lock (_tasks) { - McpTaskStatus.Working => new WorkingTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - PollIntervalMs = 50, - }, - McpTaskStatus.Completed => new CompletedTaskResult + if (!_tasks.TryGetValue(taskId, out var entry)) { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), - }, - McpTaskStatus.Failed => new FailedTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - Error = entry.Error!.Value, - }, - McpTaskStatus.Cancelled => new CancelledTaskResult - { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - }, - McpTaskStatus.InputRequired => new InputRequiredTaskResult + throw new McpException($"Unknown task: '{taskId}'"); + } + + return entry.Status switch { - TaskId = taskId, - CreatedAt = entry.CreatedAt, - LastUpdatedAt = entry.LastUpdatedAt, - InputRequests = entry.InputRequests ?? new Dictionary(), - }, - _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") - }; + McpTaskStatus.Working => new WorkingTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + PollIntervalMs = 50, + }, + McpTaskStatus.Completed => new CompletedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions), + }, + McpTaskStatus.Failed => new FailedTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + Error = entry.Error!.Value, + }, + McpTaskStatus.Cancelled => new CancelledTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + }, + McpTaskStatus.InputRequired => new InputRequiredTaskResult + { + TaskId = taskId, + CreatedAt = entry.CreatedAt, + LastUpdatedAt = entry.LastUpdatedAt, + InputRequests = entry.InputRequests ?? new Dictionary(), + }, + _ => throw new InvalidOperationException($"Unexpected status: {entry.Status}") + }; + } } public void CompleteTask(string taskId, CallToolResult result) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Completed; - entry.Result = result; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Result = result; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Completed; + } } } public void FailTask(string taskId, JsonElement error) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Failed; - entry.Error = error; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.Error = error; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Failed; + } } } public void CancelTask(string taskId) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.Status = McpTaskStatus.Cancelled; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + entry.Status = McpTaskStatus.Cancelled; + } } } public void ProvideInput(string taskId, IDictionary inputResponses) { - if (_tasks.TryGetValue(taskId, out var entry)) + lock (_tasks) { - entry.InputResponses = inputResponses; - // Transition back to working after receiving input - entry.Status = McpTaskStatus.Working; - entry.LastUpdatedAt = DateTimeOffset.UtcNow; + if (_tasks.TryGetValue(taskId, out var entry)) + { + entry.InputResponses = inputResponses; + entry.LastUpdatedAt = DateTimeOffset.UtcNow; + // Transition back to working after receiving input + entry.Status = McpTaskStatus.Working; + } } }