Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 86 additions & 63 deletions tests/ModelContextProtocol.Tests/Server/McpServerTaskTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Microsoft.Extensions.DependencyInjection;
using System.Collections.Concurrent;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.Json.Nodes;
Expand Down Expand Up @@ -538,106 +537,130 @@ public async Task CallToolHandler_CanBeSetToNull_ThenOtherCanBeSet()
/// </summary>
private sealed class InMemoryTaskStore
{
private readonly ConcurrentDictionary<string, TaskEntry> _tasks = new();
private readonly Dictionary<string, TaskEntry> _tasks = new();

public string CreateTask(McpTaskStatus initialStatus = McpTaskStatus.Working)
{
var taskId = Guid.NewGuid().ToString("N");
_tasks[taskId] = new TaskEntry
lock (_tasks)
{
Status = initialStatus,
CreatedAt = DateTimeOffset.UtcNow,
LastUpdatedAt = DateTimeOffset.UtcNow,
};
_tasks[taskId] = new TaskEntry
{
Status = initialStatus,
CreatedAt = DateTimeOffset.UtcNow,
LastUpdatedAt = DateTimeOffset.UtcNow,
};
}
return taskId;
}

public IEnumerable<string> GetAllTaskIds() => _tasks.Keys;

public GetTaskResult GetTask(string taskId)
public IEnumerable<string> GetAllTaskIds()
{
if (!_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
throw new McpException($"Unknown task: '{taskId}'");
return _tasks.Keys.ToArray();
}
}

return entry.Status switch
public GetTaskResult GetTask(string taskId)
{
lock (_tasks)
{
McpTaskStatus.Working => new WorkingTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
PollIntervalMs = 50,
},
McpTaskStatus.Completed => new CompletedTaskResult
if (!_tasks.TryGetValue(taskId, out var entry))
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions),
},
McpTaskStatus.Failed => new FailedTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Error = entry.Error!.Value,
},
McpTaskStatus.Cancelled => new CancelledTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
},
McpTaskStatus.InputRequired => new InputRequiredTaskResult
throw new McpException($"Unknown task: '{taskId}'");
}

return entry.Status switch
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
InputRequests = entry.InputRequests ?? new Dictionary<string, InputRequest>(),
},
_ => throw new InvalidOperationException($"Unexpected status: {entry.Status}")
};
McpTaskStatus.Working => new WorkingTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
PollIntervalMs = 50,
},
McpTaskStatus.Completed => new CompletedTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Result = JsonSerializer.SerializeToElement(entry.Result, McpJsonUtilities.DefaultOptions),
},
McpTaskStatus.Failed => new FailedTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
Error = entry.Error!.Value,
},
McpTaskStatus.Cancelled => new CancelledTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
},
McpTaskStatus.InputRequired => new InputRequiredTaskResult
{
TaskId = taskId,
CreatedAt = entry.CreatedAt,
LastUpdatedAt = entry.LastUpdatedAt,
InputRequests = entry.InputRequests ?? new Dictionary<string, InputRequest>(),
},
_ => throw new InvalidOperationException($"Unexpected status: {entry.Status}")
};
}
}

public void CompleteTask(string taskId, CallToolResult result)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.Status = McpTaskStatus.Completed;
entry.Result = result;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.Result = result;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
entry.Status = McpTaskStatus.Completed;
}
}
}

public void FailTask(string taskId, JsonElement error)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.Status = McpTaskStatus.Failed;
entry.Error = error;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.Error = error;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
entry.Status = McpTaskStatus.Failed;
}
}
}

public void CancelTask(string taskId)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.Status = McpTaskStatus.Cancelled;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
entry.Status = McpTaskStatus.Cancelled;
}
}
}

public void ProvideInput(string taskId, IDictionary<string, InputResponse> inputResponses)
{
if (_tasks.TryGetValue(taskId, out var entry))
lock (_tasks)
{
entry.InputResponses = inputResponses;
// Transition back to working after receiving input
entry.Status = McpTaskStatus.Working;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
if (_tasks.TryGetValue(taskId, out var entry))
{
entry.InputResponses = inputResponses;
entry.LastUpdatedAt = DateTimeOffset.UtcNow;
// Transition back to working after receiving input
entry.Status = McpTaskStatus.Working;
}
}
}

Expand Down
Loading