From 0c9de3867bf7283e525ef8c7dff4e01fde4a55c1 Mon Sep 17 00:00:00 2001 From: xiangyan99 Date: Mon, 23 Feb 2026 17:05:26 -0800 Subject: [PATCH 1/2] Added ModelContextProtocol.AspNetCore.Distributed library --- Directory.Packages.props | 8 + ModelContextProtocol.slnx | 2 + .../IListeningEndpointResolver.cs | 30 + .../Abstractions/ISessionAffinityBuilder.cs | 17 + .../Abstractions/ISessionStore.cs | 31 + .../Abstractions/SessionAffinityOptions.cs | 103 ++ .../SessionAffinityOptionsValidator.cs | 16 + .../Abstractions/SessionOwnerInfo.cs | 19 + .../HybridCacheSessionStore.cs | 125 +++ .../ListeningEndpointResolver.cs | 145 +++ .../MapSessionAffinityExtensions.cs | 38 + ...textProtocol.AspNetCore.Distributed.csproj | 30 + .../README.md | 116 +++ .../SemanticLogging.cs | 262 +++++ .../SerializerContext.cs | 17 + .../ServiceCollectionExtensions.cs | 77 ++ .../SessionAffinityBuilder.cs | 12 + .../SessionAffinityEndpointFilter.cs | 228 +++++ .../SessionOwnerInfoSerializer.cs | 36 + .../KeyedServiceTests.cs | 124 +++ .../ListeningEndpointResolverTests.cs | 705 ++++++++++++++ ...otocol.AspNetCore.Distributed.Tests.csproj | 43 + .../RealServerIntegrationTests.cs | 483 ++++++++++ .../SessionAffinityEndpointFilterTests.cs | 912 ++++++++++++++++++ .../SessionAffinityOptionsValidationTests.cs | 181 ++++ .../SessionOwnerInfoSerializerTests.cs | 283 ++++++ 26 files changed, 4043 insertions(+) create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/IListeningEndpointResolver.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionAffinityBuilder.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionStore.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptions.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptionsValidator.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionOwnerInfo.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/HybridCacheSessionStore.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/ListeningEndpointResolver.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/MapSessionAffinityExtensions.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/ModelContextProtocol.AspNetCore.Distributed.csproj create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/README.md create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/SemanticLogging.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/SerializerContext.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/ServiceCollectionExtensions.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityBuilder.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityEndpointFilter.cs create mode 100644 src/ModelContextProtocol.AspNetCore.Distributed/SessionOwnerInfoSerializer.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/KeyedServiceTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ListeningEndpointResolverTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ModelContextProtocol.AspNetCore.Distributed.Tests.csproj create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/RealServerIntegrationTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityEndpointFilterTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityOptionsValidationTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionOwnerInfoSerializerTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index b9a66c78b..a39fd3317 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -49,6 +49,12 @@ + + + + + + @@ -91,6 +97,8 @@ + + diff --git a/ModelContextProtocol.slnx b/ModelContextProtocol.slnx index 1090c5377..13fe0b833 100644 --- a/ModelContextProtocol.slnx +++ b/ModelContextProtocol.slnx @@ -65,12 +65,14 @@ + + diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/IListeningEndpointResolver.cs b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/IListeningEndpointResolver.cs new file mode 100644 index 000000000..ece15da33 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/IListeningEndpointResolver.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Hosting.Server; + +namespace ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +/// +/// Resolves the listening endpoint address for the local server instance +/// that should be advertised to other instances for session affinity routing. +/// +public interface IListeningEndpointResolver +{ + /// + /// Resolves the local server address that should be advertised to other instances + /// for session affinity routing. + /// + /// The server instance to resolve addresses from. + /// Configuration options containing explicit address overrides. + /// A normalized address string in the format "scheme://host:port". + /// + /// The resolution strategy is: + /// + /// If is set, validate and return it + /// Otherwise, resolve from server bindings, preferring non-localhost HTTP addresses + /// Fall back to http://localhost:80 if no addresses are available + /// + /// + string ResolveListeningEndpoint(IServer server, SessionAffinityOptions options); +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionAffinityBuilder.cs b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionAffinityBuilder.cs new file mode 100644 index 000000000..1303b0ca9 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionAffinityBuilder.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +/// +/// A builder for configuring MCP session affinity. +/// +public interface ISessionAffinityBuilder +{ + /// + /// Gets the host application builder. + /// + IServiceCollection Services { get; } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionStore.cs b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionStore.cs new file mode 100644 index 000000000..301f55cb0 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/ISessionStore.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +/// +/// Provides persistence for MCP session ownership. +/// +public interface ISessionStore +{ + /// + /// Gets the current owner of a session, or claims ownership if unclaimed. + /// + /// The session identifier. + /// A factory function that creates the owner information if the session is unclaimed. + /// Cancellation token. + /// The current or newly claimed owner information for the session. + Task GetOrClaimOwnershipAsync( + string sessionId, + Func> ownerInfoFactory, + CancellationToken cancellationToken = default + ); + + /// + /// Removes a session from the store. + /// + /// The session identifier to remove. + /// Cancellation token. + /// A task representing the asynchronous operation. + Task RemoveAsync(string sessionId, CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptions.cs b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptions.cs new file mode 100644 index 000000000..05ecb7892 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptions.cs @@ -0,0 +1,103 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using Yarp.ReverseProxy.Configuration; +using Yarp.ReverseProxy.Forwarder; + +namespace ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +/// +/// Configuration options for MCP session affinity routing behavior. +/// +public sealed class SessionAffinityOptions +{ + /// + /// Configuration for the YARP forwarder when routing requests to other silos. + /// If not set, a default configuration will be used. + /// + public ForwarderRequestConfig? ForwarderRequestConfig { get; set; } + + /// + /// Configuration for the HTTP client used when forwarding requests to other silos. + /// If not set, an empty configuration will be used. + /// + public HttpClientConfig? HttpClientConfig { get; set; } + + /// + /// The service key to use when resolving the service. + /// When set, the session store will use a keyed HybridCache service that can be configured + /// to use a specific distributed cache backend (e.g., Redis, SQL Server). + /// This enables scenarios where multiple cache instances are needed in a single application. + /// + /// + /// This property is used in conjunction with keyed HybridCache registration. + /// Register a keyed HybridCache instance using the standard DI keyed services APIs. + /// + public string? HybridCacheServiceKey { get; set; } + + /// + /// Explicitly sets the local server address that will be advertised to other instances + /// for session affinity routing. This address is stored in the distributed session store + /// and used by other servers to forward requests back to this instance. + /// + /// + /// + /// When set, this value takes precedence over automatic address resolution from server bindings. + /// This is useful in scenarios where: + /// + /// Running in containerized environments where internal addresses differ from advertised addresses + /// Using service meshes where specific addresses/ports must be used for routing + /// Multiple network interfaces are available and a specific one should be used + /// Running behind load balancers or proxies with address translation + /// + /// + /// + /// The value must be a valid absolute URI including scheme (http or https), host, and port. + /// Examples: + /// + /// http://pod-1.mcp-service.default.svc.cluster.local:8080 + /// http://10.0.1.5:5000 + /// https://server1.internal:443 + /// + /// + /// + /// If not set, the address will be automatically resolved from the server's configured + /// bindings, preferring HTTP over HTTPS for service mesh scenarios. + /// + /// + [HttpOrHttpsUri] + public string? LocalServerAddress { get; set; } +} + +/// +/// Validates that a string is a valid HTTP or HTTPS URI. +/// +[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Parameter)] +internal sealed class HttpOrHttpsUriAttribute : ValidationAttribute +{ + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is null or string { Length: 0 }) + { + return ValidationResult.Success; + } + + if (value is not string stringValue) + { + return new ValidationResult("Value must be a string."); + } + + if (!Uri.TryCreate(stringValue, UriKind.Absolute, out Uri? uri)) + { + return new ValidationResult($"'{stringValue}' is not a valid absolute URI."); + } + + if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps) + { + return new ValidationResult($"URI must use HTTP or HTTPS scheme. Found: {uri.Scheme}"); + } + + return ValidationResult.Success; + } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptionsValidator.cs b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptionsValidator.cs new file mode 100644 index 000000000..fe46a205a --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionAffinityOptionsValidator.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Options; + +namespace ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +/// +/// Validator for that ensures configuration is valid. +/// Uses compile-time code generation for AOT compatibility. +/// The source generator will automatically validate data annotations on the options class. +/// +[OptionsValidator] +internal sealed partial class SessionAffinityOptionsValidator + : IValidateOptions +{ } diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionOwnerInfo.cs b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionOwnerInfo.cs new file mode 100644 index 000000000..e32abf2d1 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/Abstractions/SessionOwnerInfo.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +/// +/// Identifies which server currently owns a session. +/// +public sealed record SessionOwnerInfo +{ + /// Unique identifier for the owner (server id, instance id, etc.). + public required string OwnerId { get; init; } + + /// Address (host[:port]) requests should be forwarded to. + public required string Address { get; init; } + + /// Timestamp showing when the owner claimed this session. + public DateTimeOffset? ClaimedAt { get; init; } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/HybridCacheSessionStore.cs b/src/ModelContextProtocol.AspNetCore.Distributed/HybridCacheSessionStore.cs new file mode 100644 index 000000000..d838699d3 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/HybridCacheSessionStore.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Caching.Hybrid; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// HybridCache-backed implementation of . +/// This implementation provides distributed session ownership across multiple servers +/// using HybridCache, which combines in-memory and distributed caching for optimal performance. +/// Sessions are stored with a configurable expiration time (default: 15 minutes). +/// +/// +/// HybridCache provides several advantages over IDistributedCache: +/// - Automatic serialization/deserialization +/// - Built-in stampede protection +/// - L1 (in-memory) + L2 (distributed) caching for better performance +/// - Tag-based cache invalidation support +/// +internal sealed class HybridCacheSessionStore : ISessionStore +{ + private static readonly TimeSpan DefaultSessionTimeout = TimeSpan.FromMinutes(15); + private readonly HybridCache _cache; + private readonly ILogger _logger; + private readonly HybridCacheEntryOptions _cacheEntryOptions; + + public HybridCacheSessionStore( + HybridCache cache, + ILogger logger, + TimeSpan? sessionTimeout = null + ) + { + _cache = cache; + _logger = logger; + var resolvedSessionTimeout = sessionTimeout ?? DefaultSessionTimeout; + _cacheEntryOptions = new() + { + Expiration = resolvedSessionTimeout, + // Allow L1 cache to expire sooner for better memory management + LocalCacheExpiration = TimeSpan.FromMinutes( + Math.Min(resolvedSessionTimeout.TotalMinutes / 2, 5) + ), + }; + } + + public async Task GetOrClaimOwnershipAsync( + string sessionId, + Func> ownerInfoFactory, + CancellationToken cancellationToken = default + ) + { + ArgumentNullException.ThrowIfNull(sessionId); + ArgumentNullException.ThrowIfNull(ownerInfoFactory); + + var key = $"mcp:session:{sessionId}"; + + try + { + // Track whether we created a new entry or retrieved an existing one + var wasCreated = false; + + // HybridCache.GetOrCreateAsync will check L1 (memory) first, then L2 (distributed) + // If not found, it will call the factory to create and cache the value + var owner = await _cache.GetOrCreateAsync( + key, + async cancel => + { + wasCreated = true; + + // Call the provided factory to create the owner info + var ownerInfo = await ownerInfoFactory(cancel); + + _logger.SessionClaimed(sessionId, ownerInfo.OwnerId); + + return ownerInfo; + }, + options: _cacheEntryOptions, + cancellationToken: cancellationToken + ); + + // HybridCache uses absolute expiration. We need to implement sliding expiration manually + // by re-setting the value with a new expiration time on each access. + // Only refresh if we retrieved an existing entry (not if we just created it). + if (!wasCreated) + { + await _cache.SetAsync( + key, + owner, + _cacheEntryOptions, + cancellationToken: cancellationToken + ); + } + + _logger.SessionOwnerRetrieved(sessionId, owner.OwnerId); + + return owner; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.FailedToRetrieveSessionOwner(sessionId, ex); + throw; + } + } + + public async Task RemoveAsync(string sessionId, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(sessionId); + + var key = $"mcp:session:{sessionId}"; + + try + { + await _cache.RemoveAsync(key, cancellationToken); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.FailedToRemoveSession(sessionId, ex); + // Don't rethrow - session removal is a best-effort cleanup operation + // The session will expire naturally if removal fails + } + } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/ListeningEndpointResolver.cs b/src/ModelContextProtocol.AspNetCore.Distributed/ListeningEndpointResolver.cs new file mode 100644 index 000000000..5d7d5eac3 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/ListeningEndpointResolver.cs @@ -0,0 +1,145 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// Default implementation of that resolves +/// the local server listening endpoint from explicit configuration or server bindings. +/// +internal sealed class ListeningEndpointResolver : IListeningEndpointResolver +{ + /// + public string ResolveListeningEndpoint(IServer server, SessionAffinityOptions options) + { + ArgumentNullException.ThrowIfNull(server); + ArgumentNullException.ThrowIfNull(options); + + // Use explicit configuration if provided + if (!string.IsNullOrWhiteSpace(options.LocalServerAddress)) + { + return ValidateAndNormalizeAddress(options.LocalServerAddress); + } + + // Resolve from server bindings + return ResolveFromServerBindings(server); + } + + private static string ValidateAndNormalizeAddress(string address) + { + if (!Uri.TryCreate(address, UriKind.Absolute, out var uri)) + { + throw new ArgumentException( + $"LocalServerAddress '{address}' is not a valid absolute URI. " + + "It must include the scheme (http or https), host, and port (e.g., 'http://localhost:5000').", + nameof(address) + ); + } + + if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException( + $"LocalServerAddress '{address}' must use either 'http' or 'https' scheme. " + + $"Got '{uri.Scheme}' instead.", + nameof(address) + ); + } + + // Normalize the address to include scheme, host, and port + // Remove any path, query, or fragment components as they're not needed for forwarding + var normalizedAddress = $"{uri.Scheme}://{uri.Host}:{uri.Port}"; + + return normalizedAddress; + } + + private static string ResolveFromServerBindings(IServer server) + { + var addressesFeature = server.Features.Get(); + if (addressesFeature is null || addressesFeature.Addresses.Count == 0) + { + // Fallback to http://localhost:80 if no addresses are available + return "http://localhost:80"; + } + + Uri? httpUri = null; + Uri? httpsUri = null; + Uri? localhostHttpUri = null; + Uri? localhostHttpsUri = null; + + foreach (var address in addressesFeature.Addresses) + { + if (Uri.TryCreate(address, UriKind.Absolute, out var uri)) + { + bool isLocalhost = IsLocalhostAddress(uri.Host); + + if (uri.Scheme == "http") + { + if (isLocalhost) + { + localhostHttpUri ??= uri; + } + else + { + httpUri ??= uri; + } + } + else if (uri.Scheme == "https") + { + if (isLocalhost) + { + localhostHttpsUri ??= uri; + } + else + { + httpsUri ??= uri; + } + } + } + } + + // Prefer external interfaces over localhost for reachability from other servers + // Prefer HTTP for internal routing in service mesh scenarios + // In service meshes, internal traffic is typically HTTP while external is HTTPS + var selectedUri = httpUri ?? httpsUri ?? localhostHttpUri ?? localhostHttpsUri; + if (selectedUri is null) + { + // Fallback if no valid URI found + return "http://localhost:80"; + } + + // Build address string in format "scheme://host:port" + var host = selectedUri.Host; + var port = selectedUri.Port; + var scheme = selectedUri.Scheme; + + return $"{scheme}://{host}:{port}"; + } + + private static bool IsLocalhostAddress(string host) + { + // Check for common localhost representations + if ( + string.Equals(host, "localhost", StringComparison.OrdinalIgnoreCase) + || host.EndsWith(".localhost", StringComparison.OrdinalIgnoreCase) + || string.Equals(host, "127.0.0.1", StringComparison.Ordinal) + || string.Equals(host, "::1", StringComparison.Ordinal) + || string.Equals(host, "[::1]", StringComparison.Ordinal) + ) + { + return true; + } + + // Try to parse as IP address and check if it's loopback + if (IPAddress.TryParse(host, out var ipAddress)) + { + return IPAddress.IsLoopback(ipAddress); + } + + return false; + } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/MapSessionAffinityExtensions.cs b/src/ModelContextProtocol.AspNetCore.Distributed/MapSessionAffinityExtensions.cs new file mode 100644 index 000000000..ab1b77453 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/MapSessionAffinityExtensions.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// Extension methods for adding session affinity to MCP endpoints. +/// +public static class MapSessionAffinityExtensions +{ + /// + /// Adds session affinity to MCP endpoints. + /// This endpoint filter routes requests to the correct host based on session ownership. + /// Use this on the return value of MapMcp() to add session affinity routing. + /// Requires calling AddMcpHttpSessionAffinity() on the builder first. + /// + /// The endpoint convention builder from MapMcp(). + /// Returns the builder for chaining additional configurations. + public static IEndpointConventionBuilder WithSessionAffinity( + this IEndpointConventionBuilder builder + ) + { + ArgumentNullException.ThrowIfNull(builder); + builder.AddEndpointFilterFactory( + (routeHandlerContext, next) => + { + var filter = + routeHandlerContext.ApplicationServices.GetRequiredService(); + return (context) => filter.InvokeAsync(context, next); + } + ); + return builder; + } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/ModelContextProtocol.AspNetCore.Distributed.csproj b/src/ModelContextProtocol.AspNetCore.Distributed/ModelContextProtocol.AspNetCore.Distributed.csproj new file mode 100644 index 000000000..989fa32c6 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/ModelContextProtocol.AspNetCore.Distributed.csproj @@ -0,0 +1,30 @@ + + + + net10.0 + true + true + ModelContextProtocol.AspNetCore.Distributed + ASP.NET Core extensions for building enterprise-grade MCP servers that do not require external session affinity, yet work with session-aware features. + README.md + true + + + + + + + + + + + + + + + + + + + + diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/README.md b/src/ModelContextProtocol.AspNetCore.Distributed/README.md new file mode 100644 index 000000000..ebc30850c --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/README.md @@ -0,0 +1,116 @@ +# ModelContextProtocol.AspNetCore.Distributed + +Session-aware routing for Model Context Protocol (MCP) servers that need to run across multiple instances. This package builds on ASP.NET Core HybridCache and YARP so every MCP request reaches the server that owns the session state. + +## Why Use It + +- Keep in-memory session data (prompt history, tool context) with its owning instance +- Scale stateful MCP servers horizontally without changing tool handlers +- Forward requests automatically when the owning instance lives elsewhere +- Plug in any `IDistributedCache` (Redis, SQL Server, NCache, etc.) for distributed storage + +## Install + +```bash +dotnet add package ModelContextProtocol.AspNetCore.Distributed --prerelease +``` + +Add the distributed cache provider that matches your environment (for example `Microsoft.Extensions.Caching.StackExchangeRedis`). + +## Quick Start (Single Instance / Local Dev) + +```csharp +using ModelContextProtocol.AspNetCore.Distributed; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services + .AddMcpServer() + .WithToolsFromAssembly() + .WithHttpTransport(); + +builder.Services.AddMcpHttpSessionAffinity(); // Tracks ownership + routing + +var app = builder.Build(); + +app.MapMcp() + .WithSessionAffinity(); // Add this to enable session affinity routing + +app.Run(); +``` + +No distributed cache is required until you add additional instances. + +## Production Checklist + +1. Register an L2 cache (Redis + Azure AD auth is the most battle-tested option). +2. Set `LocalServerAddress` to the routable address other replicas use (scheme, host, port). +3. Tune `ForwarderRequestConfig` and `HttpClientConfig` for your downstream SLAs. +4. Use `DefaultAzureCredential` locally and deployment-specific credentials in production. +5. Monitor HybridCache hit rate and distributed cache availability for early warning. + +### Minimal Redis Configuration + +```csharp +using Azure.Identity; +using Microsoft.Extensions.Caching.StackExchangeRedis; +using StackExchange.Redis; + +var redisCredential = builder.Environment.IsDevelopment() + ? new DefaultAzureCredential() + : new ManagedIdentityCredential(); + +var endpoint = builder.Configuration["Redis:Endpoint"] + ?? throw new InvalidOperationException("Redis:Endpoint is required."); + +var redisConfig = await ConfigurationOptions + .Parse(endpoint) + .ConfigureForAzureWithTokenCredentialAsync(redisCredential); + +redisConfig.Ssl = true; // Always require TLS in production + +builder.Services.AddStackExchangeRedisCache(options => +{ + options.ConfigurationOptions = redisConfig; + options.InstanceName = "MCP:"; +}); + +builder.Services.AddMcpHttpSessionAffinity(options => +{ + options.LocalServerAddress = builder.Configuration["Server:InternalAddress"] + ?? throw new InvalidOperationException("Server:InternalAddress is required."); +}); +``` + +`appsettings.json` + +```json +{ + "Redis": { + "Endpoint": "your-mcp-session-affinity.region.redis.azure.net:6380" + }, + "Server": { + "InternalAddress": "http://pod-1.mcp.default.svc.cluster.local:8080" + } +} +``` + +## Core Concepts + +- Session ownership: the first request with `Mcp-Session-Id` (header) or `sessionId` (query) claims the session and stores ownership in HybridCache. +- HybridCache tiers: L1 memory cache plus optional L2 distributed cache; tune expiration to control how long ownership survives inactivity. +- Forwarding: if the current node is not the owner, YARP forwards the request to the owning instance over HTTP(S). +- Stale detection: when an owning instance restarts, the affinity entry is discarded so clients can establish a fresh session and rebuild state. + +## Configuration Reference + +- `SessionAffinityOptions.LocalServerAddress`: required in multi-instance environments; must be a routable absolute URI. +- `ForwarderRequestConfig`: controls forwarding timeout, buffering, and HTTP version. +- `HttpClientConfig`: tune connection pooling for heavy cross-node routing. +- `HybridCacheOptions`: set `DefaultEntryOptions.Expiration` (L2) and `LocalCacheExpiration` (L1) to balance freshness versus resilience. + +## Observability + +- Enable `ModelContextProtocol.AspNetCore.Distributed` logs at `Information` by default and `Debug` for routing traces. +- Watch for `ResolvingSessionOwner`, `SessionEstablished`, and `ForwardingRequest` events to understand ownership decisions. +- Export HybridCache hit/miss metrics to confirm cache sizing and detect unusual churn. diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/SemanticLogging.cs b/src/ModelContextProtocol.AspNetCore.Distributed/SemanticLogging.cs new file mode 100644 index 000000000..412222282 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/SemanticLogging.cs @@ -0,0 +1,262 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// Defines semantic logging event IDs for the session state management module. +/// These IDs enable filtering, categorization, and structured logging across all session-related operations. +/// Uses a high base number (50000) to avoid conflicts with other libraries. +/// +internal enum LogEventId +{ + /// + /// Resolving the owner of an existing session. Event ID: 50001 + /// + ResolvingSessionOwner = 50001, + + /// + /// A new session has been claimed by an owner. Event ID: 50002 + /// + SessionClaimed = 50002, + + /// + /// Session is already owned by another server instance. Event ID: 50003 + /// + SessionOwnedByOther = 50003, + + /// + /// Failed to deserialize session owner information from cache. Event ID: 50004 + /// + FailedToDeserializeSessionOwner = 50004, + + /// + /// Session has been established on the current host. Event ID: 50100 + /// + SessionEstablished = 50100, + + /// + /// Forwarding a request to another server that owns the session. Event ID: 50101 + /// + ForwardingRequest = 50101, + + /// + /// Session owner information retrieved from cache. Event ID: 50102 + /// + SessionOwnerRetrieved = 50102, + + /// + /// Failed to retrieve session owner information from cache. Event ID: 50103 + /// + FailedToRetrieveSessionOwner = 50103, + + /// + /// Removing stale session after receiving 404 from remote endpoint. Event ID: 50104 + /// + RemovingStaleSession = 50104, + + /// + /// Failed to remove session from cache. Event ID: 50105 + /// + FailedToRemoveSession = 50105, + + /// + /// Removing a stale session that points to the local address but has an outdated OwnerId. Event ID: 50106 + /// + RemovingStaleLocalSession = 50106, +} + +/// +/// Semantic logging methods for session state management operations. +/// Uses structured logging with compile-time code generation via LoggerMessage attributes. +/// +internal static partial class SemanticLogging +{ + /// + /// Logs when resolving the owner of an existing session. + /// + /// The logger instance. + /// The session identifier. + /// The proposed owner identifier. + [LoggerMessage( + EventId = (int)LogEventId.ResolvingSessionOwner, + Level = LogLevel.Debug, + Message = "Resolving session owner for session {SessionId}, proposed owner: {OwnerId}" + )] + public static partial void ResolvingSessionOwner( + this ILogger logger, + string sessionId, + string ownerId + ); + + /// + /// Logs when a new session is claimed by the current server instance. + /// + /// The logger instance. + /// The session identifier. + /// The owner identifier of this server instance. + [LoggerMessage( + EventId = (int)LogEventId.SessionClaimed, + Level = LogLevel.Debug, + Message = "Session {SessionId} claimed by owner {OwnerId}" + )] + public static partial void SessionClaimed( + this ILogger logger, + string sessionId, + string ownerId + ); + + /// + /// Logs when an existing session is found to be owned by another server instance. + /// + /// The logger instance. + /// The session identifier. + /// The owner identifier of the server instance that owns the session. + [LoggerMessage( + EventId = (int)LogEventId.SessionOwnedByOther, + Level = LogLevel.Debug, + Message = "Session {SessionId} already owned by {OwnerId}" + )] + public static partial void SessionOwnedByOther( + this ILogger logger, + string sessionId, + string ownerId + ); + + /// + /// Logs when deserialization of session owner information fails. + /// This can occur if the cached data is corrupted or in an unexpected format. + /// + /// The logger instance. + /// The session identifier. + /// The exception that occurred during deserialization. + [LoggerMessage( + EventId = (int)LogEventId.FailedToDeserializeSessionOwner, + Level = LogLevel.Warning, + Message = "Failed to deserialize session owner for session {SessionId}" + )] + public static partial void FailedToDeserializeSessionOwner( + this ILogger logger, + string sessionId, + Exception ex + ); + + /// + /// Logs when a session is established on the current server instance. + /// + /// The logger instance. + /// The session identifier. + [LoggerMessage( + EventId = (int)LogEventId.SessionEstablished, + Level = LogLevel.Information, + Message = "Session established for session {SessionId} on this host" + )] + public static partial void SessionEstablished(this ILogger logger, string sessionId); + + /// + /// Logs when a request is being forwarded to another server instance that owns the session. + /// + /// The logger instance. + /// The destination server prefix (scheme://host:port). + /// The session identifier. + [LoggerMessage( + EventId = (int)LogEventId.ForwardingRequest, + Level = LogLevel.Information, + Message = "Forwarding request to {DestinationPrefix} for session {SessionId}" + )] + public static partial void ForwardingRequest( + this ILogger logger, + string destinationPrefix, + string sessionId + ); + + /// + /// Logs when session owner information is retrieved from the session store. + /// + /// The logger instance. + /// The session identifier. + /// The owner identifier of the retrieved session. + [LoggerMessage( + EventId = (int)LogEventId.SessionOwnerRetrieved, + Level = LogLevel.Debug, + Message = "Retrieved session owner {OwnerId} for session {SessionId}" + )] + public static partial void SessionOwnerRetrieved( + this ILogger logger, + string sessionId, + string ownerId + ); + + /// + /// Logs when retrieval of session owner information fails. + /// This can occur if there are cache connectivity issues or other errors. + /// + /// The logger instance. + /// The session identifier. + /// The exception that occurred during retrieval. + [LoggerMessage( + EventId = (int)LogEventId.FailedToRetrieveSessionOwner, + Level = LogLevel.Warning, + Message = "Failed to retrieve session owner for session {SessionId}" + )] + public static partial void FailedToRetrieveSessionOwner( + this ILogger logger, + string sessionId, + Exception ex + ); + + /// + /// Logs when removing a stale session after receiving a 404 from the remote endpoint. + /// This indicates the remote server no longer has the session, likely due to a process restart. + /// + /// The logger instance. + /// The session identifier. + /// The owner identifier of the server that returned 404. + [LoggerMessage( + EventId = (int)LogEventId.RemovingStaleSession, + Level = LogLevel.Warning, + Message = "Removing stale session {SessionId} after 404 response from owner {OwnerId}" + )] + public static partial void RemovingStaleSession( + this ILogger logger, + string sessionId, + string ownerId + ); + + /// + /// Logs when removal of a session from the cache fails. + /// + /// The logger instance. + /// The session identifier. + /// The exception that occurred during removal. + [LoggerMessage( + EventId = (int)LogEventId.FailedToRemoveSession, + Level = LogLevel.Warning, + Message = "Failed to remove session {SessionId} from cache" + )] + public static partial void FailedToRemoveSession( + this ILogger logger, + string sessionId, + Exception ex + ); + + /// + /// Logs when removing a stale session record that references this host with an outdated OwnerId. + /// This occurs when the application restarts and generates a new OwnerId, making the previous session unusable. + /// + /// The logger instance. + /// The session identifier. + /// The stale owner identifier from the cache. + [LoggerMessage( + EventId = (int)LogEventId.RemovingStaleLocalSession, + Level = LogLevel.Warning, + Message = "Removing stale session {SessionId} owned by previous instance {OldOwnerId}" + )] + public static partial void RemovingStaleLocalSession( + this ILogger logger, + string sessionId, + string oldOwnerId + ); +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/SerializerContext.cs b/src/ModelContextProtocol.AspNetCore.Distributed/SerializerContext.cs new file mode 100644 index 000000000..f93947ed9 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/SerializerContext.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// JSON serialization context for distributed session store. +/// +[JsonSourceGenerationOptions( + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull +)] +[JsonSerializable(typeof(SessionOwnerInfo))] +internal sealed partial class SerializerContext : JsonSerializerContext { } diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/ServiceCollectionExtensions.cs b/src/ModelContextProtocol.AspNetCore.Distributed/ServiceCollectionExtensions.cs new file mode 100644 index 000000000..459977926 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/ServiceCollectionExtensions.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Caching.Hybrid; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// Extension methods for configuring MCP session affinity services. +/// +public static class ServiceCollectionExtensions +{ + /// + /// Adds the required services for MCP session affinity. + /// This includes YARP reverse proxy and the session affinity routing filter. + /// Uses HybridCache for session storage (L1 memory + L2 distributed caching). + /// + /// The host service collection. + /// Optional action to configure SessionAffinityOptions. + /// A builder for configuring MCP session affinity. + public static ISessionAffinityBuilder AddMcpHttpSessionAffinity( + this IServiceCollection services, + Action? configure = null + ) + { + ArgumentNullException.ThrowIfNull(services); + + // Configure options using the options pattern + if (configure is not null) + { + services.Configure(configure); + } + + // Add validation for SessionAffinityOptions using source-generated validator + services.TryAddSingleton< + IValidateOptions, + SessionAffinityOptionsValidator + >(); + + // Register HybridCache with default configuration + // This provides L1 (in-memory) + L2 (distributed) caching + // Consumers can add their own distributed cache (Redis, SQL Server, etc.) + // via AddStackExchangeRedisCache, AddDistributedSqlServerCache, etc. + // Use source-generated serialization for SessionOwnerInfo (AOT-compatible) + services + .AddHybridCache() + .AddSerializer(); + + // Register HybridCache session store + services.TryAdd( + ServiceDescriptor.Singleton(sp => + { + var options = sp.GetRequiredService>().Value; + var cache = options.HybridCacheServiceKey is null + ? sp.GetRequiredService() + : sp.GetRequiredKeyedService(options.HybridCacheServiceKey); + var logger = sp.GetRequiredService>(); + return new HybridCacheSessionStore(cache, logger); + }) + ); + + services.TryAddSingleton(); + + // Add YARP reverse proxy for request forwarding + services.AddReverseProxy(); + + // Register the endpoint filter for dependency injection + services.TryAddSingleton(); + + return new SessionAffinityBuilder(services); + } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityBuilder.cs b/src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityBuilder.cs new file mode 100644 index 000000000..38673d397 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityBuilder.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +internal sealed class SessionAffinityBuilder(IServiceCollection services) : ISessionAffinityBuilder +{ + public IServiceCollection Services { get; } = services; +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityEndpointFilter.cs b/src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityEndpointFilter.cs new file mode 100644 index 000000000..74dc6ad75 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/SessionAffinityEndpointFilter.cs @@ -0,0 +1,228 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; +using Yarp.ReverseProxy.Configuration; +using Yarp.ReverseProxy.Forwarder; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// Endpoint filter that implements session affinity for MCP requests. +/// Routes requests to the server that owns the session, or handles locally if this is the owner. +/// +internal sealed class SessionAffinityEndpointFilter : IEndpointFilter +{ + private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + + private readonly ISessionStore _sessionStore; + private readonly string _localOwnerId; + private readonly IHttpForwarder _forwarder; + private readonly HttpMessageInvoker _httpClient; + private readonly ForwarderRequestConfig _forwarderRequestConfig; + private readonly ILogger _logger; + private readonly string _localAddress; + + public SessionAffinityEndpointFilter( + ISessionStore sessionStore, + IHttpForwarder forwarder, + IForwarderHttpClientFactory httpClientFactory, + IListeningEndpointResolver listeningEndpointResolver, + IServer server, + IOptions options, + ILogger logger + ) + { + ArgumentNullException.ThrowIfNull(options); + var optionsValue = options.Value; + + _sessionStore = sessionStore; + // IMPORTANT: The OwnerId (_localOwnerId) is regenerated as a new GUID each time the application restarts. + // Session ownership data does not persist across restarts, so stale session entries are cleared when encountered. + _localOwnerId = Guid.NewGuid().ToString(); + _forwarder = forwarder; + _httpClient = httpClientFactory.CreateClient( + new ForwarderHttpClientContext + { + NewConfig = optionsValue.HttpClientConfig ?? HttpClientConfig.Empty, + } + ); + _forwarderRequestConfig = + optionsValue.ForwarderRequestConfig ?? ForwarderRequestConfig.Empty; + _logger = logger; + + // Use the listening endpoint resolver to get the advertised address + // IServerAddressesFeature is populated before endpoint filters are created + // Note: LocalServerAddress can be set via IPostConfigureOptions for dynamic resolution + _localAddress = listeningEndpointResolver.ResolveListeningEndpoint(server, optionsValue); + } + + public async ValueTask InvokeAsync( + EndpointFilterInvocationContext context, + EndpointFilterDelegate next + ) + { + var httpContext = context.HttpContext; + var sessionId = ExtractSessionId(httpContext); + + // Resolve the owner of this session (if session ID exists) + SessionOwnerInfo? ownerInfo = null; + if (!string.IsNullOrEmpty(sessionId)) + { + _logger.ResolvingSessionOwner(sessionId, _localOwnerId); + + ownerInfo = await _sessionStore.GetOrClaimOwnershipAsync( + sessionId, + CreateSessionOwnerInfo, + httpContext.RequestAborted + ); + + if (ownerInfo.OwnerId != _localOwnerId) + { + if ( + string.Equals( + ownerInfo.Address, + _localAddress, + StringComparison.OrdinalIgnoreCase + ) + ) + { + // Application restart detected - the session points to this host but has a different OwnerId + _logger.RemovingStaleLocalSession(sessionId, ownerInfo.OwnerId); + await _sessionStore.RemoveAsync(sessionId, httpContext.RequestAborted); + ownerInfo = null; + sessionId = null; + } + else + { + _logger.SessionOwnedByOther(sessionId, ownerInfo.OwnerId); + } + } + } + + // Handle locally if no session ID or this host owns the session + if (ownerInfo is null || ownerInfo.OwnerId == _localOwnerId) + { + context.HttpContext.Response.OnStarting(async () => + { + // Check if the server set a session ID different from the original session ID + var responseSessionId = ExtractResponseSessionId(httpContext); + if ( + !string.IsNullOrEmpty(responseSessionId) + && !string.Equals(sessionId, responseSessionId, StringComparison.Ordinal) + ) + { + // Update the new session to point to this host + await _sessionStore.GetOrClaimOwnershipAsync( + responseSessionId, + CreateSessionOwnerInfo, + httpContext.RequestAborted + ); + _logger.SessionEstablished(responseSessionId); + } + }); + + return await next(context); + } + + // Forward to the owner - this writes directly to the response + _logger.ForwardingRequest(ownerInfo.Address, sessionId ?? "(none)"); + var error = await _forwarder.SendAsync( + httpContext, + ownerInfo.Address, + _httpClient, + _forwarderRequestConfig + ); + + if (error == ForwarderError.None) + { + // Check if the remote server returned 404 - indicates session no longer exists + // Only remove session if this is an MCP endpoint request (not a health check, metrics, etc.) + if ( + httpContext.Response.StatusCode == StatusCodes.Status404NotFound + && !string.IsNullOrEmpty(sessionId) + && IsMcpEndpointRequest(httpContext) + ) + { + _logger.RemovingStaleSession(sessionId, ownerInfo.OwnerId); + await _sessionStore.RemoveAsync(sessionId, httpContext.RequestAborted); + } + + // The forwarder has already written the response, return null to indicate completion + return null; + } + + return Results.StatusCode(StatusCodes.Status502BadGateway); + } + + private static bool IsMcpEndpointRequest(HttpContext context) + { + // The session affinity filter is only applied to MCP endpoints + // Check if the endpoint has MCP-related metadata or path patterns + var endpoint = context.GetEndpoint(); + if (endpoint is null) + { + return false; + } + + // Check for MCP-specific endpoint metadata + // The endpoint display name typically contains route pattern information + var displayName = endpoint.DisplayName; + if ( + !string.IsNullOrEmpty(displayName) + && ( + displayName.Contains("mcp", StringComparison.OrdinalIgnoreCase) + || displayName.Contains("sse", StringComparison.OrdinalIgnoreCase) + ) + ) + { + return true; + } + + return false; + } + + private static string? ExtractSessionId(HttpContext context) + { + // Try header first (for Streamable HTTP POST/GET/DELETE) + if (context.Request.Headers.TryGetValue(McpSessionIdHeaderName, out var header)) + { + return header.ToString(); + } + + // Try query string (for legacy SSE /message endpoint) + if (context.Request.Query.TryGetValue("sessionId", out var sessionId)) + { + return sessionId.ToString(); + } + + return null; + } + + private static string? ExtractResponseSessionId(HttpContext context) + { + // Check response headers for the session ID + if (context.Response.Headers.TryGetValue(McpSessionIdHeaderName, out var header)) + { + return header.ToString(); + } + + return null; + } + + private Task CreateSessionOwnerInfo(CancellationToken cancellationToken) + { + return Task.FromResult( + new SessionOwnerInfo + { + OwnerId = _localOwnerId, + Address = _localAddress, + ClaimedAt = DateTimeOffset.UtcNow, + } + ); + } +} diff --git a/src/ModelContextProtocol.AspNetCore.Distributed/SessionOwnerInfoSerializer.cs b/src/ModelContextProtocol.AspNetCore.Distributed/SessionOwnerInfoSerializer.cs new file mode 100644 index 000000000..a97b1e60f --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore.Distributed/SessionOwnerInfoSerializer.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Text.Json; +using Microsoft.Extensions.Caching.Hybrid; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; + +namespace ModelContextProtocol.AspNetCore.Distributed; + +/// +/// Source-generated JSON serializer for . +/// Uses the generated for AOT-compatible, +/// high-performance serialization without reflection. +/// +internal sealed class SessionOwnerInfoSerializer : IHybridCacheSerializer +{ + /// + /// Deserializes a from a buffer using source-generated JSON. + /// + public SessionOwnerInfo Deserialize(ReadOnlySequence source) + { + var reader = new Utf8JsonReader(source); + return JsonSerializer.Deserialize(ref reader, SerializerContext.Default.SessionOwnerInfo) + ?? throw new InvalidOperationException("Failed to deserialize SessionOwnerInfo"); + } + + /// + /// Serializes a to a buffer using source-generated JSON. + /// + public void Serialize(SessionOwnerInfo value, IBufferWriter target) + { + using Utf8JsonWriter writer = new(target); + JsonSerializer.Serialize(writer, value, SerializerContext.Default.SessionOwnerInfo); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/KeyedServiceTests.cs b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/KeyedServiceTests.cs new file mode 100644 index 000000000..27aa560ae --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/KeyedServiceTests.cs @@ -0,0 +1,124 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; +using Xunit; + +namespace ModelContextProtocol.AspNetCore.Distributed.Tests; + +public class KeyedServiceTests +{ + [Fact] + public void AddMcpHttpSessionAffinity_WithHybridCacheServiceKey_ConfiguresOptionsCorrectly() + { + // Arrange + var services = new ServiceCollection(); + services.AddLogging(); + services.AddHybridCache(); // Default cache + + // Act - Use non-keyed registration but specify key in options + services.AddMcpHttpSessionAffinity(options => + { + options.HybridCacheServiceKey = "my-cache"; + }); + + var provider = services.BuildServiceProvider(); + + // Assert + var options = provider.GetRequiredService>().Value; + Assert.Equal("my-cache", options.HybridCacheServiceKey); + } + + [Fact] + public void AddMcpHttpSessionAffinity_WithoutHybridCacheServiceKey_UsesDefaultCache() + { + // Arrange + var services = new ServiceCollection(); + services.AddLogging(); + services.AddHybridCache(); // Default cache + + // Act + services.AddMcpHttpSessionAffinity(); // Should use default + + var provider = services.BuildServiceProvider(); + + // Assert + var sessionStore = provider.GetService(); + Assert.NotNull(sessionStore); + Assert.IsType(sessionStore); + } + + [Fact] + public async Task SessionStore_StoreAndRetrieve_WorksCorrectly() + { + // Arrange + var services = new ServiceCollection(); + services.AddLogging(); + services.AddHybridCache(); + services.AddMcpHttpSessionAffinity(); + + var provider = services.BuildServiceProvider(); + var sessionStore = provider.GetRequiredService(); + + var sessionId = "test-session-123"; + var ownerInfo = new SessionOwnerInfo + { + OwnerId = "server-1", + Address = "http://server-1:5000", + ClaimedAt = DateTimeOffset.UtcNow, + }; + + // Act - Claim ownership using the factory pattern + var setResult = await sessionStore.GetOrClaimOwnershipAsync( + sessionId, + async ct => + { + await Task.Yield(); // Simulate async work + return ownerInfo; + }, + TestContext.Current.CancellationToken + ); + + var getResult = await sessionStore.GetOrClaimOwnershipAsync( + sessionId, + async ct => + { + await Task.Yield(); + // This factory shouldn't be called since the session already exists + throw new InvalidOperationException( + "Factory should not be called for existing session" + ); + }, + TestContext.Current.CancellationToken + ); + + // Assert + Assert.NotNull(setResult); + Assert.NotNull(getResult); + Assert.Equal(ownerInfo.OwnerId, setResult.OwnerId); + Assert.Equal(ownerInfo.OwnerId, getResult.OwnerId); + Assert.Equal(ownerInfo.Address, getResult.Address); + } + + [Fact] + public void AddMcpHttpSessionAffinity_RegistersReverseProxyServices() + { + // Arrange + var services = new ServiceCollection(); + services.AddLogging(); + services.AddHybridCache(); + + // Act + services.AddMcpHttpSessionAffinity(); + + // Assert - Verify reverse proxy services are registered + var descriptor = services.FirstOrDefault(d => + d.ServiceType.FullName?.Contains("ReverseProxy", StringComparison.Ordinal) == true + ); + + Assert.NotNull(descriptor); + Assert.Equal(ServiceLifetime.Singleton, descriptor!.Lifetime); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ListeningEndpointResolverTests.cs b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ListeningEndpointResolverTests.cs new file mode 100644 index 000000000..fb030634f --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ListeningEndpointResolverTests.cs @@ -0,0 +1,705 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Http.Features; +using ModelContextProtocol.AspNetCore.Distributed; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; +using NSubstitute; +using Xunit; + +namespace ModelContextProtocol.AspNetCore.Distributed.Tests; + +public sealed class ListeningEndpointResolverTests +{ + private readonly ListeningEndpointResolver _resolver; + + public ListeningEndpointResolverTests() + { + _resolver = new ListeningEndpointResolver(); + } + + #region Explicit Configuration Tests + + [Fact] + public void ResolveListeningEndpoint_WithValidExplicitAddress_ReturnsNormalizedAddress() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://pod-1.service.cluster.local:8080", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://pod-1.service.cluster.local:8080", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithExplicitHttpsAddress_ReturnsNormalizedAddress() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "https://secure.example.com:443", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("https://secure.example.com:443", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithExplicitAddressWithPath_RemovesPath() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://example.com:5000/api/mcp", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://example.com:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithExplicitAddressWithQueryString_RemovesQuery() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://example.com:5000?param=value", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://example.com:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithExplicitAddressWithFragment_RemovesFragment() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://example.com:5000#section", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://example.com:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithInvalidUri_ThrowsArgumentException() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions { LocalServerAddress = "not a valid uri" }; + + // Act & Assert + Assert.Throws(() => + _resolver.ResolveListeningEndpoint(server, options) + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithRelativeUri_ThrowsArgumentException() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions { LocalServerAddress = "/api/mcp" }; + + // Act & Assert + Assert.Throws(() => + _resolver.ResolveListeningEndpoint(server, options) + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithInvalidScheme_ThrowsArgumentException() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions { LocalServerAddress = "ftp://example.com:21" }; + + // Act & Assert + Assert.Throws(() => + _resolver.ResolveListeningEndpoint(server, options) + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithIPv4Address_ReturnsNormalizedAddress() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://192.168.1.100:5000", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://192.168.1.100:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithIPv6Address_ReturnsNormalizedAddress() + { + // Arrange + var server = CreateServer(); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://[2001:db8::1]:8080", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://[2001:db8::1]:8080", result); + } + + #endregion + + #region Server Binding Resolution Tests + + [Fact] + public void ResolveListeningEndpoint_WithHttpBinding_ReturnsHttpAddress() + { + // Arrange + var server = CreateServer("http://0.0.0.0:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://0.0.0.0:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithHttpsBinding_ReturnsHttpsAddress() + { + // Arrange + var server = CreateServer("https://0.0.0.0:5001"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("https://0.0.0.0:5001", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithMultipleBindings_PrefersHttpOverHttps() + { + // Arrange - HTTP should be preferred for service mesh scenarios + var server = CreateServer("https://0.0.0.0:5001", "http://0.0.0.0:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://0.0.0.0:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithHttpAndHttpsBindings_PrefersHttp() + { + // Arrange + var server = CreateServer("http://10.0.1.5:5000", "https://10.0.1.5:5001"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://10.0.1.5:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_PrefersExternalOverLocalhost() + { + // Arrange - External interfaces should be preferred over localhost + var server = CreateServer("http://localhost:5000", "http://192.168.1.100:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://192.168.1.100:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithOnlyLocalhostBinding_ReturnsLocalhostAddress() + { + // Arrange + var server = CreateServer("http://localhost:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://localhost:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithOnlyHttpsAndLocalhostBindings_PrefersLocalhostHttps() + { + // Arrange - Only localhost HTTPS available + var server = CreateServer("https://localhost:5001"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("https://localhost:5001", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithComplexBindings_FollowsPriorityOrder() + { + // Arrange - Priority: external HTTP > external HTTPS > localhost HTTP > localhost HTTPS + var server = CreateServer( + "https://localhost:5443", + "http://localhost:5000", + "https://10.0.1.5:5001", + "http://10.0.1.5:5000" + ); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://10.0.1.5:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithNoAddresses_ReturnsFallbackAddress() + { + // Arrange + var server = CreateServer(); // No addresses + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://localhost:80", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithNullServerAddressesFeature_ReturnsFallbackAddress() + { + // Arrange + var server = CreateServerWithoutAddressesFeature(); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://localhost:80", result); + } + + #endregion + + #region Localhost Detection Tests + + [Fact] + public void ResolveListeningEndpoint_WithLocalhostVariants_DetectsAllAsLocalhost() + { + // Test various localhost representations + var testCases = new[] + { + "http://localhost:5000", + "http://LOCALHOST:5000", + "http://127.0.0.1:5000", + "http://[::1]:5000", + "http://subdomain.localhost:5000", + }; + + foreach (var localhostAddress in testCases) + { + // Arrange - Add an external address to verify localhost is NOT preferred + var server = CreateServer(localhostAddress, "http://10.0.1.5:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - Should prefer external address over localhost variants + Assert.Equal( + "http://10.0.1.5:5000", + result); + } + } + + [Fact] + public void ResolveListeningEndpoint_WithIPv4Loopback_TreatedAsLocalhost() + { + // Arrange + var server = CreateServer("http://127.0.0.1:5000", "http://192.168.1.100:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - External address should be preferred + Assert.Equal("http://192.168.1.100:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithIPv6Loopback_TreatedAsLocalhost() + { + // Arrange + var server = CreateServer("http://[::1]:5000", "http://[2001:db8::1]:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - External address should be preferred + Assert.Equal("http://[2001:db8::1]:5000", result); + } + + #endregion + + #region Null/Empty Validation Tests + + [Fact] + public void ResolveListeningEndpoint_WithNullServer_ThrowsArgumentNullException() + { + // Arrange + var options = new SessionAffinityOptions(); + + // Act & Assert + Assert.Throws(() => + _resolver.ResolveListeningEndpoint(null!, options) + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithNullOptions_ThrowsArgumentNullException() + { + // Arrange + var server = CreateServer(); + + // Act & Assert + Assert.Throws(() => + _resolver.ResolveListeningEndpoint(server, null!) + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithEmptyStringAddress_IgnoresAndResolvesFromServer() + { + // Arrange + var server = CreateServer("http://10.0.1.5:5000"); + var options = new SessionAffinityOptions { LocalServerAddress = "" }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://10.0.1.5:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithWhitespaceAddress_IgnoresAndResolvesFromServer() + { + // Arrange + var server = CreateServer("http://10.0.1.5:5000"); + var options = new SessionAffinityOptions { LocalServerAddress = " " }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://10.0.1.5:5000", result); + } + + #endregion + + #region Priority and Selection Tests + + [Fact] + public void ResolveListeningEndpoint_WithOnlyHttpsExternal_ReturnsHttpsAddress() + { + // Arrange - No HTTP available, should return HTTPS + var server = CreateServer("https://10.0.1.5:5001"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("https://10.0.1.5:5001", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithMultipleHttpBindings_ReturnsFirstHttpBinding() + { + // Arrange + var server = CreateServer("http://10.0.1.5:5000", "http://10.0.1.6:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - First HTTP binding should be selected + Assert.Equal("http://10.0.1.5:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithInvalidBindings_SkipsInvalidAndUsesValid() + { + // Arrange + var mockFeature = Substitute.For(); + mockFeature.Addresses.Returns( + new List + { + "not-a-valid-uri", + "http://valid.example.com:5000", + "also-invalid", + } + ); + + var server = Substitute.For(); + var features = new FeatureCollection(); + features.Set(mockFeature); + server.Features.Returns(features); + + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://valid.example.com:5000", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithAllInvalidBindings_ReturnsFallback() + { + // Arrange + var mockFeature = Substitute.For(); + mockFeature.Addresses.Returns(new List { "not-a-valid-uri", "also-invalid", "still-not-valid" }); + + var server = Substitute.For(); + var features = new FeatureCollection(); + features.Set(mockFeature); + server.Features.Returns(features); + + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://localhost:80", result); + } + + #endregion + + #region Edge Case Tests + + [Fact] + public void ResolveListeningEndpoint_WithNonStandardPorts_PreservesPort() + { + // Arrange + var server = CreateServer("http://example.com:8888"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://example.com:8888", result); + } + + [Fact] + public void ResolveListeningEndpoint_WithDefaultHttpPort_IncludesPort80Explicitly() + { + // Arrange - Port 80 should be explicitly included in the result + var server = CreateServer("http://example.com:80"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - Verify port 80 is explicitly included + Assert.Equal("http://example.com:80", result); + Assert.True( + result.EndsWith(":80", StringComparison.Ordinal), + "Port 80 should be explicitly included" + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithDefaultHttpsPort_IncludesPort443Explicitly() + { + // Arrange - Port 443 should be explicitly included in the result + var server = CreateServer("https://example.com:443"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - Verify port 443 is explicitly included + Assert.Equal("https://example.com:443", result); + Assert.True( + result.EndsWith(":443", StringComparison.Ordinal), + "Port 443 should be explicitly included" + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithWildcardAddress_ReturnsWildcardAddress() + { + // Arrange - Wildcard addresses (0.0.0.0, [::]) are valid and should be preserved + var server = CreateServer("http://0.0.0.0:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - IPv4 wildcard should be preserved + Assert.Equal("http://0.0.0.0:5000", result); + Assert.True( + result.Contains("0.0.0.0", StringComparison.Ordinal), + "Should preserve IPv4 wildcard address" + ); + } + + [Fact] + public void ResolveListeningEndpoint_WithIPv6WildcardAddress_ReturnsWildcardAddress() + { + // Arrange + var server = CreateServer("http://[::]:5000"); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - IPv6 wildcard should be preserved + Assert.Equal("http://[::]:5000", result); + Assert.True( + result.Contains("[::]", StringComparison.Ordinal), + "Should preserve IPv6 wildcard address" + ); + } + + #endregion + + #region Integration Tests + + [Fact] + public void ResolveListeningEndpoint_ExplicitConfigTakesPrecedenceOverServerBindings() + { + // Arrange - Even with server bindings, explicit config should win + var server = CreateServer("http://10.0.1.5:5000", "https://10.0.1.5:5001"); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://custom.example.com:9000", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://custom.example.com:9000", result); + } + + [Fact] + public void ResolveListeningEndpoint_ServiceMeshScenario_PrefersHttpForInternalRouting() + { + // Arrange - Typical service mesh: HTTPS external, HTTP internal + var server = CreateServer( + "https://external.service.mesh:443", + "http://internal.service.mesh:8080" + ); + var options = new SessionAffinityOptions(); + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert - HTTP should be preferred for internal service mesh routing + Assert.Equal("http://internal.service.mesh:8080", result); + } + + [Fact] + public void ResolveListeningEndpoint_KubernetesScenario_UsesExplicitPodAddress() + { + // Arrange - Kubernetes pod with explicit service address + var server = CreateServer("http://0.0.0.0:8080"); + var options = new SessionAffinityOptions + { + LocalServerAddress = "http://pod-1.mcp-service.default.svc.cluster.local:8080", + }; + + // Act + var result = _resolver.ResolveListeningEndpoint(server, options); + + // Assert + Assert.Equal("http://pod-1.mcp-service.default.svc.cluster.local:8080", result); + } + + #endregion + + #region Helper Methods + + private static IServer CreateServer(params string[] addresses) + { + var mockFeature = Substitute.For(); + mockFeature.Addresses.Returns(new List(addresses)); + + var server = Substitute.For(); + var features = new FeatureCollection(); + features.Set(mockFeature); + server.Features.Returns(features); + + return server; + } + + private static IServer CreateServerWithoutAddressesFeature() + { + var server = Substitute.For(); + var features = new FeatureCollection(); + // Deliberately not setting IServerAddressesFeature + server.Features.Returns(features); + + return server; + } + + #endregion +} diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ModelContextProtocol.AspNetCore.Distributed.Tests.csproj b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ModelContextProtocol.AspNetCore.Distributed.Tests.csproj new file mode 100644 index 000000000..d97a94291 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/ModelContextProtocol.AspNetCore.Distributed.Tests.csproj @@ -0,0 +1,43 @@ + + + + net10.0 + enable + enable + false + true + ModelContextProtocol.AspNetCore.Distributed.Tests + true + $(RepoRoot)Open.snk + + + + true + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/RealServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/RealServerIntegrationTests.cs new file mode 100644 index 000000000..11bc4098b --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/RealServerIntegrationTests.cs @@ -0,0 +1,483 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel; +using System.Globalization; +using System.Net; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Distributed; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Xunit; +using DescriptionAttribute = System.ComponentModel.DescriptionAttribute; + +namespace ModelContextProtocol.AspNetCore.Distributed.Tests; + +/// +/// Integration tests for session affinity using real Kestrel servers. +/// These tests verify that multiple WebApplication instances can share session state +/// via a shared IDistributedCache and that each client maintains its own session. +/// +public sealed class RealServerIntegrationTests +{ + [Fact] + public async Task MultipleServersWithSharedCacheMaintainSeparateClientSessions() + { + // Arrange - Create shared distributed cache for session affinity + var sharedCache = new MemoryDistributedCache( + Options.Create(new MemoryDistributedCacheOptions()) + ); + + // Create two real Kestrel servers with shared cache + await using var host1 = await CreateKestrelServerAsync(sharedCache, "server-1"); + await using var host2 = await CreateKestrelServerAsync(sharedCache, "server-2"); + + // Create two separate MCP clients connecting to different servers + using var httpClient1 = new HttpClient(); + using var httpClient2 = new HttpClient(); + + await using var transport1 = new HttpClientTransport( + new HttpClientTransportOptions { Endpoint = host1.McpEndpoint }, + httpClient1 + ); + + await using var transport2 = new HttpClientTransport( + new HttpClientTransportOptions { Endpoint = host2.McpEndpoint }, + httpClient2 + ); + + await using var mcpClient1 = await McpClient.CreateAsync( + transport1, + cancellationToken: CancellationToken.None + ); + await using var mcpClient2 = await McpClient.CreateAsync( + transport2, + cancellationToken: CancellationToken.None + ); + + // Act - Each client calls tools on its connected server + // First, verify which server each client is connected to + var client1Server = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var client2Server = await mcpClient2.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + + // Assert - Each server identifies correctly + Assert.Equal("server-1", ((TextContentBlock)client1Server.Content[0]).Text); + Assert.Equal("server-2", ((TextContentBlock)client2Server.Content[0]).Text); + } + + [Fact] + public async Task SingleClientCanCallToolsSuccessfully() + { + // Arrange + var sharedCache = new MemoryDistributedCache( + Options.Create(new MemoryDistributedCacheOptions()) + ); + + await using var host = await CreateKestrelServerAsync(sharedCache, "server-test"); + + using var httpClient = new HttpClient(); + + await using var transport = new HttpClientTransport( + new HttpClientTransportOptions { Endpoint = host.McpEndpoint }, + httpClient + ); + + await using var mcpClient = await McpClient.CreateAsync( + transport, + cancellationToken: CancellationToken.None + ); + + // Act - Make multiple requests with the same client + var serverId = await mcpClient.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var counter1 = await mcpClient.CallToolAsync( + "increment_counter", + cancellationToken: CancellationToken.None + ); + var counter2 = await mcpClient.CallToolAsync( + "increment_counter", + cancellationToken: CancellationToken.None + ); + + // Assert - Tools execute successfully + Assert.Equal("server-test", ((TextContentBlock)serverId.Content[0]).Text); + Assert.Equal("1", ((TextContentBlock)counter1.Content[0]).Text); + // Note: Counter resets because tools are scoped per request, not per session + // This demonstrates that multiple tool calls work correctly + Assert.NotNull(((TextContentBlock)counter2.Content[0]).Text); + } + + [Fact] + public async Task MultipleClientsWithSameSessionIdStickToSameServer() + { + // This test demonstrates that session affinity works by manually + // simulating what would happen with load balancing: different clients + // connect to different servers initially, but if they share a session ID, + // subsequent requests would be redirected to the session owner. + + // Arrange - Create shared distributed cache for session affinity + var sharedCache = new MemoryDistributedCache( + Options.Create(new MemoryDistributedCacheOptions()) + ); + + // Create two real Kestrel servers with shared cache + await using var host1 = await CreateKestrelServerAsync(sharedCache, "server-1"); + await using var host2 = await CreateKestrelServerAsync(sharedCache, "server-2"); + + // Create first client connecting to server-1 + using var httpClient1 = new HttpClient(); + await using var transport1 = new HttpClientTransport( + new HttpClientTransportOptions { Endpoint = host1.McpEndpoint }, + httpClient1 + ); + await using var mcpClient1 = await McpClient.CreateAsync( + transport1, + cancellationToken: CancellationToken.None + ); + + // Act - First client establishes session on server-1 + var firstResponse = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var firstServerId = ((TextContentBlock)firstResponse.Content[0]).Text; + + // Make more requests with same client - should stay on same server + var secondResponse = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var secondServerId = ((TextContentBlock)secondResponse.Content[0]).Text; + + var thirdResponse = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var thirdServerId = ((TextContentBlock)thirdResponse.Content[0]).Text; + + // Assert - All requests from same client stay on same server + Assert.Equal("server-1", firstServerId); + Assert.Equal(firstServerId, secondServerId); + Assert.Equal(firstServerId, thirdServerId); + } + + [Fact] + public async Task LoadBalancingDistributesNewClientsAcrossDifferentServers() + { + // Arrange + var sharedCache = new MemoryDistributedCache( + Options.Create(new MemoryDistributedCacheOptions()) + ); + + await using var host1 = await CreateKestrelServerAsync(sharedCache, "server-1"); + await using var host2 = await CreateKestrelServerAsync(sharedCache, "server-2"); + + // Create multiple clients with load balancing + var serverEndpoints = new[] { host1.BaseAddress, host2.BaseAddress }; + var clients = + new List<(HttpClient HttpClient, HttpClientTransport Transport, McpClient McpClient)>(); + var loadBalancers = new List(); + + try + { + // Create 4 clients with load balancing + for (int i = 0; i < 4; i++) + { + var requestCount = i; // Capture for closure + var loadBalancer = new RoundRobinLoadBalancingHandler( + serverEndpoints, + () => requestCount + ); + loadBalancers.Add(loadBalancer); + + var httpClient = new HttpClient(loadBalancer); + var transport = new HttpClientTransport( + new HttpClientTransportOptions { Endpoint = host1.McpEndpoint }, + httpClient + ); + var mcpClient = await McpClient.CreateAsync( + transport, + cancellationToken: CancellationToken.None + ); + + clients.Add((httpClient, transport, mcpClient)); + } + + // Act - Each client makes a request + var serverIds = new List(); + foreach (var (_, _, mcpClient) in clients) + { + var response = await mcpClient.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + serverIds.Add(((TextContentBlock)response.Content[0]).Text); + } + + // Assert - Should have both servers represented + Assert.Contains("server-1", serverIds); + Assert.Contains("server-2", serverIds); + Assert.Equal(4, serverIds.Count); + } + finally + { + // Cleanup + foreach (var (httpClient, transport, mcpClient) in clients) + { + await mcpClient.DisposeAsync(); + await transport.DisposeAsync(); + httpClient.Dispose(); + } + + foreach (var loadBalancer in loadBalancers) + { + loadBalancer.Dispose(); + } + } + } + + [Fact] + public async Task SessionAffinityPreservesConnectionToOriginalServer() + { + // This test demonstrates that once a client establishes a session with a server, + // it maintains that connection across multiple requests (simulating what would + // happen with session affinity if requests were load balanced). + + // Arrange - Create shared distributed cache for session affinity + var sharedCache = new MemoryDistributedCache( + Options.Create(new MemoryDistributedCacheOptions()) + ); + + await using var host1 = await CreateKestrelServerAsync(sharedCache, "server-1"); + await using var host2 = await CreateKestrelServerAsync(sharedCache, "server-2"); + + // Create client connecting to server-1 + using var httpClient1 = new HttpClient(); + await using var transport1 = new HttpClientTransport( + new HttpClientTransportOptions { Endpoint = host1.McpEndpoint }, + httpClient1 + ); + await using var mcpClient1 = await McpClient.CreateAsync( + transport1, + cancellationToken: CancellationToken.None + ); + + // Act - Client makes multiple requests, all should stay on server-1 + var firstResponse = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var firstServerId = ((TextContentBlock)firstResponse.Content[0]).Text; + + var secondResponse = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var secondServerId = ((TextContentBlock)secondResponse.Content[0]).Text; + + var thirdResponse = await mcpClient1.CallToolAsync( + "get_server_id", + cancellationToken: CancellationToken.None + ); + var thirdServerId = ((TextContentBlock)thirdResponse.Content[0]).Text; + + // Assert - All requests stay on the same server (session affinity in action) + Assert.Equal("server-1", firstServerId); + Assert.Equal("server-1", secondServerId); + Assert.Equal("server-1", thirdServerId); + } + + private static async Task CreateKestrelServerAsync( + IDistributedCache sharedCache, + string serverId + ) + { + var hostBuilder = new HostBuilder().ConfigureWebHost(webHost => + { + webHost.UseKestrel(options => + { + options.Listen(new IPEndPoint(IPAddress.Loopback, 0)); // Let the OS select an available port + }); + + webHost.ConfigureServices(services => + { + // Use shared distributed cache for session affinity across servers + services.AddSingleton(sharedCache); + + // Add MCP server with tools + services.AddMcpServer().WithTools().WithHttpTransport(); + + // Add session affinity (listening endpoint resolver will determine address) + services.AddMcpHttpSessionAffinity(); + + // Register server-specific state (identifies which server instance this is) + services.AddSingleton(new ServerState { ServerId = serverId }); + }); + + webHost.Configure(app => + { + // Enable routing middleware + app.UseRouting(); + + // Map MCP endpoints with session affinity + app.UseEndpoints(endpoints => + { + endpoints.MapMcp("mcp").WithSessionAffinity(); + }); + }); + }); + + var host = await hostBuilder.StartAsync(); + var baseAddress = ResolveBaseAddress(host); + return new KestrelServerHandle(host, baseAddress); + } + + private static Uri ResolveBaseAddress(IHost host) + { + var server = host.Services.GetRequiredService(); + var addressesFeature = server.Features.Get(); + if (addressesFeature is null || addressesFeature.Addresses.Count == 0) + { + throw new InvalidOperationException("Kestrel server did not expose any addresses."); + } + + foreach (var address in addressesFeature.Addresses) + { + if (Uri.TryCreate(address, UriKind.Absolute, out var uri)) + { + return NormalizeBaseAddress(uri); + } + } + + throw new InvalidOperationException("Failed to resolve a valid server address."); + } + + private static Uri NormalizeBaseAddress(Uri uri) + { + var builder = new UriBuilder(uri) + { + Path = "/", + Query = null, + Fragment = null, + }; + + return builder.Uri; + } + + [McpServerToolType] + private sealed class TestTools + { + private readonly ServerState _serverState; + private int _counter; + +#pragma warning disable S1144 // Constructor used via dependency injection + public TestTools(ServerState serverState) +#pragma warning restore S1144 + { + _serverState = serverState; + } + + [McpServerTool] + [Description("Returns the ID of the server handling the request")] + public string GetServerId() => _serverState.ServerId; + + [McpServerTool] + [Description("Increments a counter and returns the new value")] + public string IncrementCounter() => + Interlocked.Increment(ref _counter).ToString(CultureInfo.InvariantCulture); + } + + private sealed class ServerState + { + public required string ServerId { get; init; } + } + + private sealed class KestrelServerHandle : IAsyncDisposable + { + private readonly IHost _host; + + public KestrelServerHandle(IHost host, Uri baseAddress) + { + _host = host; + BaseAddress = baseAddress; + McpEndpoint = new Uri(baseAddress, "mcp"); + } + + public Uri BaseAddress { get; } + + public Uri McpEndpoint { get; } + + public async ValueTask DisposeAsync() + { + try + { + await _host.StopAsync(); + } + finally + { + _host.Dispose(); + } + } + } + + /// + /// HTTP handler that implements client-side round-robin load balancing across multiple servers. + /// Modifies request URIs to alternate between different port numbers. + /// + private sealed class RoundRobinLoadBalancingHandler : DelegatingHandler + { + private readonly Uri[] _endpoints; + private readonly Func _getRequestCount; + +#pragma warning disable CA2000 // DelegatingHandler takes ownership of the inner handler + public RoundRobinLoadBalancingHandler(Uri[] endpoints, Func getRequestCount) + : base(new HttpClientHandler()) +#pragma warning restore CA2000 + { + _endpoints = endpoints; + _getRequestCount = getRequestCount; + } + + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken + ) + { + if (request.RequestUri != null && _endpoints.Length > 0) + { + // Round-robin: alternate between endpoints based on request count + var requestCount = _getRequestCount(); + var selectedEndpoint = _endpoints[requestCount % _endpoints.Length]; + + // Modify the request URI to use the selected endpoint + var builder = new UriBuilder(request.RequestUri) + { + Scheme = selectedEndpoint.Scheme, + Host = selectedEndpoint.Host, + Port = selectedEndpoint.Port, + }; + + request.RequestUri = builder.Uri; + } + + return base.SendAsync(request, cancellationToken); + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityEndpointFilterTests.cs b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityEndpointFilterTests.cs new file mode 100644 index 000000000..179abb668 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityEndpointFilterTests.cs @@ -0,0 +1,912 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Distributed; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; +using NSubstitute; +using Xunit; +using Yarp.ReverseProxy.Forwarder; + +namespace ModelContextProtocol.AspNetCore.Distributed.Tests; + +public sealed class SessionAffinityEndpointFilterTests +{ + [Fact] + public async Task InvokeAsync_WithoutSessionId_CallsNextAndSkipsStore() + { + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.True(nextCalled); + await sessionStore.DidNotReceive().GetOrClaimOwnershipAsync( + Arg.Any(), + Arg.Any>>(), + Arg.Any()); + Assert.Empty(forwarder.Calls); + } + + [Fact] + public async Task InvokeAsync_WhenSessionClaimedLocally_CallsNext() + { + const string sessionId = "session-1"; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + // The session store will call the factory to create owner info + // We need to capture and return whatever owner info the filter creates + SessionOwnerInfo? capturedOwnerInfo = null; + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(callInfo => + { + var factory = callInfo.ArgAt>>(1); + var ct = callInfo.ArgAt(2); + // Call the factory to get the owner info that the filter would create + capturedOwnerInfo = factory(ct).GetAwaiter().GetResult(); + return Task.FromResult(capturedOwnerInfo); + }); + + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://localhost:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + // When the session is claimed locally, the filter should call next + Assert.Null(result); + Assert.True(nextCalled); + Assert.NotNull(capturedOwnerInfo); + Assert.Equal("http://localhost:5000", capturedOwnerInfo.Address); + + // Verify the session ownership was checked/claimed + await sessionStore.Received(1).GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()); + + Assert.Empty(forwarder.Calls); + } + + [Fact] + public async Task InvokeAsync_WhenSessionOwnedElsewhere_ForwardsRequest() + { + const string sessionId = "session-remote"; + var remoteOwner = new SessionOwnerInfo + { + OwnerId = "remote-owner", + Address = "http://remotehost:8080", + ClaimedAt = DateTimeOffset.UtcNow, + }; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(Task.FromResult(remoteOwner)); + + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.False(nextCalled); + await sessionStore.Received(1).GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()); + Assert.Single(forwarder.Calls); + Assert.Equal("http://remotehost:8080", forwarder.Calls[0].Destination); + } + + [Fact] + public async Task InvokeAsync_WhenForwarderFails_ReturnsBadGatewayResult() + { + const string sessionId = "session-error"; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(Task.FromResult( + new SessionOwnerInfo + { + OwnerId = "remote", + Address = "http://remotehost:8080", + ClaimedAt = DateTimeOffset.UtcNow, + })); + + using var forwarder = new TestHttpForwarder { NextResult = ForwarderError.RequestCanceled }; + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var result = await filter.InvokeAsync( + invocationContext, + _ => ValueTask.FromResult(null) + ); + + Assert.NotNull(result); + Assert.IsAssignableFrom(result); + + await ((IResult)result!).ExecuteAsync(httpContext); + Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + + await sessionStore.Received(1).GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()); + Assert.Single(forwarder.Calls); + Assert.Equal("http://remotehost:8080", forwarder.Calls[0].Destination); + } + + [Fact] + public async Task InvokeAsync_When404FromMcpEndpoint_RemovesStaleSession() + { + const string sessionId = "session-stale"; + var remoteOwner = new SessionOwnerInfo + { + OwnerId = "remote-owner", + Address = "http://remotehost:8080", + ClaimedAt = DateTimeOffset.UtcNow, + }; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Path = "/mcp"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + + // Set up endpoint with MCP-related display name + var endpoint = new Endpoint( + requestDelegate: null, + metadata: new EndpointMetadataCollection(), + displayName: "POST /mcp" + ); + httpContext.SetEndpoint(endpoint); + + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(Task.FromResult(remoteOwner)); + + sessionStore + .RemoveAsync(sessionId, Arg.Any()) + .Returns(Task.CompletedTask); + + using var forwarder = new TestHttpForwarder + { + NextStatusCode = StatusCodes.Status404NotFound, + }; + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var result = await filter.InvokeAsync( + invocationContext, + _ => ValueTask.FromResult(null) + ); + + Assert.Null(result); + Assert.Equal(StatusCodes.Status404NotFound, httpContext.Response.StatusCode); + + // Verify session was removed + await sessionStore.Received(1).RemoveAsync(sessionId, Arg.Any()); + + await sessionStore.Received(1).GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()); + } + + [Fact] + public async Task InvokeAsync_When404FromSseEndpoint_RemovesStaleSession() + { + const string sessionId = "session-stale-sse"; + var remoteOwner = new SessionOwnerInfo + { + OwnerId = "remote-owner", + Address = "http://remotehost:8080", + ClaimedAt = DateTimeOffset.UtcNow, + }; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Path = "/sse"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + + // Set up endpoint with SSE-related display name + var endpoint = new Endpoint( + requestDelegate: null, + metadata: new EndpointMetadataCollection(), + displayName: "GET /sse" + ); + httpContext.SetEndpoint(endpoint); + + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(Task.FromResult(remoteOwner)); + + sessionStore + .RemoveAsync(sessionId, Arg.Any()) + .Returns(Task.CompletedTask); + + using var forwarder = new TestHttpForwarder + { + NextStatusCode = StatusCodes.Status404NotFound, + }; + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var result = await filter.InvokeAsync( + invocationContext, + _ => ValueTask.FromResult(null) + ); + + Assert.Null(result); + Assert.Equal(StatusCodes.Status404NotFound, httpContext.Response.StatusCode); + + // Verify session was removed + await sessionStore.Received(1).RemoveAsync(sessionId, Arg.Any()); + } + + [Fact] + public async Task InvokeAsync_When404FromNonMcpEndpoint_DoesNotRemoveSession() + { + const string sessionId = "session-health"; + var remoteOwner = new SessionOwnerInfo + { + OwnerId = "remote-owner", + Address = "remotehost:8080", + ClaimedAt = DateTimeOffset.UtcNow, + }; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Path = "/health"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + + // Set up endpoint with non-MCP display name + var endpoint = new Endpoint( + requestDelegate: null, + metadata: new EndpointMetadataCollection(), + displayName: "GET /health" + ); + httpContext.SetEndpoint(endpoint); + + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(Task.FromResult(remoteOwner)); + + using var forwarder = new TestHttpForwarder + { + NextStatusCode = StatusCodes.Status404NotFound, + }; + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var result = await filter.InvokeAsync( + invocationContext, + _ => ValueTask.FromResult(null) + ); + + Assert.Null(result); + Assert.Equal(StatusCodes.Status404NotFound, httpContext.Response.StatusCode); + + // Verify session was NOT removed (only GetOrClaimOwnershipAsync was called) + await sessionStore.Received(1).GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()); + await sessionStore.DidNotReceive().RemoveAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task InvokeAsync_When200FromMcpEndpoint_DoesNotRemoveSession() + { + const string sessionId = "session-success"; + var remoteOwner = new SessionOwnerInfo + { + OwnerId = "remote-owner", + Address = "http://remotehost:8080", + ClaimedAt = DateTimeOffset.UtcNow, + }; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Path = "/mcp"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + + // Set up endpoint with MCP-related display name + var endpoint = new Endpoint( + requestDelegate: null, + metadata: new EndpointMetadataCollection(), + displayName: "POST /mcp" + ); + httpContext.SetEndpoint(endpoint); + + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(Task.FromResult(remoteOwner)); + + using var forwarder = new TestHttpForwarder { NextStatusCode = StatusCodes.Status200OK }; + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var result = await filter.InvokeAsync( + invocationContext, + _ => ValueTask.FromResult(null) + ); + + Assert.Null(result); + Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode); + + // Verify session was NOT removed (only GetOrClaimOwnershipAsync was called) + await sessionStore.Received(1).GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()); + await sessionStore.DidNotReceive().RemoveAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task InvokeAsync_When404WithoutSessionId_DoesNotRemoveSession() + { + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Path = "/mcp"; + + // Set up endpoint with MCP-related display name + var endpoint = new Endpoint( + requestDelegate: null, + metadata: new EndpointMetadataCollection(), + displayName: "POST /mcp" + ); + httpContext.SetEndpoint(endpoint); + + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + + using var forwarder = new TestHttpForwarder + { + NextStatusCode = StatusCodes.Status404NotFound, + }; + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://127.0.0.1:5000"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + ctx.HttpContext.Response.StatusCode = StatusCodes.Status404NotFound; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.True(nextCalled); + Assert.Equal(StatusCodes.Status404NotFound, httpContext.Response.StatusCode); + + // Verify no session store operations were performed + await sessionStore.DidNotReceive().GetOrClaimOwnershipAsync( + Arg.Any(), + Arg.Any>>(), + Arg.Any()); + await sessionStore.DidNotReceive().RemoveAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task InvokeAsync_WhenServerUsesHttps_PrefersHttpForServiceMesh() + { + const string sessionId = "session-https-test"; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + // The session store will call the factory to create owner info + SessionOwnerInfo? capturedOwnerInfo = null; + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(callInfo => + { + var factory = callInfo.ArgAt>>(1); + var ct = callInfo.ArgAt(2); + capturedOwnerInfo = factory(ct).GetAwaiter().GetResult(); + return Task.FromResult(capturedOwnerInfo); + }); + + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + // Server listening on both HTTP and HTTPS + using var server = new TestServer("http://localhost:5000", "https://localhost:5001"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.True(nextCalled); + Assert.NotNull(capturedOwnerInfo); + // Should prefer HTTP for internal service mesh routing + Assert.Equal("http://localhost:5000", capturedOwnerInfo.Address); + } + + [Fact] + public async Task InvokeAsync_WhenServerUsesOnlyHttps_UsesHttpsScheme() + { + const string sessionId = "session-https-only"; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + SessionOwnerInfo? capturedOwnerInfo = null; + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(callInfo => + { + var factory = callInfo.ArgAt>>(1); + var ct = callInfo.ArgAt(2); + capturedOwnerInfo = factory(ct).GetAwaiter().GetResult(); + return Task.FromResult(capturedOwnerInfo); + }); + + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + // Server listening only on HTTPS + using var server = new TestServer("https://localhost:5001"); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.True(nextCalled); + Assert.NotNull(capturedOwnerInfo); + // Should use HTTPS when that's the only available scheme + Assert.Equal("https://localhost:5001", capturedOwnerInfo.Address); + } + + [Fact] + public async Task InvokeAsync_WhenLocalServerAddressConfigured_UsesConfiguredAddress() + { + const string sessionId = "session-explicit-address"; + const string configuredAddress = "http://pod-1.mcp-service.default.svc.cluster.local:8080"; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + SessionOwnerInfo? capturedOwnerInfo = null; + + var sessionStore = Substitute.For(); + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(callInfo => + { + var factory = callInfo.ArgAt>>(1); + var ct = callInfo.ArgAt(2); + capturedOwnerInfo = factory(ct).GetAwaiter().GetResult(); + return Task.FromResult(capturedOwnerInfo); + }); + + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer("http://localhost:5000"); + + var options = new SessionAffinityOptions { LocalServerAddress = configuredAddress }; + + var filter = CreateFilter( + sessionStore, + forwarder, + httpClientFactory, + server, + options + ); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.True(nextCalled); + Assert.NotNull(capturedOwnerInfo); + // Should use the explicitly configured address, not the server binding + Assert.Equal(configuredAddress, capturedOwnerInfo.Address); + } + + [Fact] + public async Task InvokeAsync_WithStaleSessionOwnership_ReclaimsAndHandlesLocally() + { + // Simulates application restart scenario: + // Session exists in cache with same Address but different OwnerId (stale) + const string sessionId = "session123"; + const string localAddress = "http://localhost:5000"; + const string staleOwnerId = "old-guid"; + + var httpContext = CreateHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Headers["Mcp-Session-Id"] = sessionId; + + // Set up endpoint with MCP-related display name + var endpoint = new Endpoint( + requestDelegate: null, + metadata: new EndpointMetadataCollection(), + displayName: "POST /mcp/v1/sse" + ); + httpContext.SetEndpoint(endpoint); + + var invocationContext = new TestEndpointFilterInvocationContext(httpContext); + + var sessionStore = Substitute.For(); + var getOrClaimCallCount = 0; + + // First call returns stale ownership info + sessionStore + .GetOrClaimOwnershipAsync( + sessionId, + Arg.Any>>(), + Arg.Any()) + .Returns(callInfo => + { + var factory = callInfo.ArgAt>>(1); + var ct = callInfo.ArgAt(2); + getOrClaimCallCount++; + if (getOrClaimCallCount == 1) + { + // Return stale ownership with old OwnerId but same address + return Task.FromResult( + new SessionOwnerInfo + { + OwnerId = staleOwnerId, + Address = localAddress, + ClaimedAt = DateTimeOffset.UtcNow.AddMinutes(-10), + } + ); + } + else + { + // Second call after RemoveAsync - return new ownership + return factory(ct); + } + }); + + // Expect RemoveAsync to be called to clear stale entry + sessionStore + .RemoveAsync(sessionId, Arg.Any()) + .Returns(Task.CompletedTask); + + using var forwarder = new TestHttpForwarder(); + using var httpClientFactory = new TestForwarderHttpClientFactory(); + using var server = new TestServer(localAddress); + + var filter = CreateFilter(sessionStore, forwarder, httpClientFactory, server); + + var nextCalled = false; + var result = await filter.InvokeAsync( + invocationContext, + ctx => + { + nextCalled = true; + return ValueTask.FromResult(null); + } + ); + + Assert.Null(result); + Assert.True(nextCalled); + + // Verify interactions + Assert.Equal(1, getOrClaimCallCount); + await sessionStore.Received(1).RemoveAsync(sessionId, Arg.Any()); + + // Should NOT forward since we reclaimed locally + Assert.Empty(forwarder.Calls); + } + + private static SessionAffinityEndpointFilter CreateFilter( + ISessionStore sessionStore, + IHttpForwarder forwarder, + IForwarderHttpClientFactory httpClientFactory, + IServer server, + SessionAffinityOptions? options = null + ) + { + return new SessionAffinityEndpointFilter( + sessionStore, + forwarder, + httpClientFactory, + new ListeningEndpointResolver(), + server, + Options.Create(options ?? new SessionAffinityOptions()), + NullLogger.Instance + ); + } + + private static DefaultHttpContext CreateHttpContext() + { + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddLogging(); + context.RequestServices = services.BuildServiceProvider(); + return context; + } + + private sealed class TestEndpointFilterInvocationContext : EndpointFilterInvocationContext + { + private readonly List _arguments; + + public TestEndpointFilterInvocationContext( + HttpContext httpContext, + IEnumerable? arguments = null + ) + { + HttpContext = httpContext; + _arguments = arguments?.ToList() ?? []; + } + + public override HttpContext HttpContext { get; } + + public override T GetArgument(int index) + { + return (T)_arguments[index]!; + } + + public override IList Arguments => _arguments; + } + + private sealed class TestServer : IServer + { + public TestServer(params string[] addresses) + { + var addressesFeature = new TestServerAddressesFeature(addresses); + Features = new FeatureCollection(); + Features.Set(addressesFeature); + } + + public IFeatureCollection Features { get; } + + public void Dispose() + { + if (Features.Get() is TestServerAddressesFeature feature) + { + feature.Addresses.Clear(); + } + } + + public Task StartAsync( + IHttpApplication application, + CancellationToken cancellationToken + ) + where TContext : notnull + { + throw new NotSupportedException(); + } + + public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; + } + + private sealed class TestServerAddressesFeature : IServerAddressesFeature + { + public TestServerAddressesFeature(IEnumerable addresses) + { + foreach (var address in addresses) + { + Addresses.Add(address); + } + } + + public ICollection Addresses { get; } = []; + + public bool PreferHostingUrls { get; set; } + } + + private sealed class TestForwarderHttpClientFactory : IForwarderHttpClientFactory, IDisposable + { + private readonly HttpMessageInvoker _invoker = new(new TestHttpMessageHandler()); + + public HttpMessageInvoker CreateClient(ForwarderHttpClientContext context) => _invoker; + + public void Dispose() + { + _invoker.Dispose(); + } + + private sealed class TestHttpMessageHandler : HttpMessageHandler + { + protected override Task SendAsync( + HttpRequestMessage request, + CancellationToken cancellationToken + ) + { + throw new NotSupportedException(); + } + } + } + + private sealed class TestHttpForwarder : IHttpForwarder, IDisposable + { + private readonly List _calls = []; + + public List Calls => _calls; + + public ForwarderError NextResult { get; set; } = ForwarderError.None; + + public int NextStatusCode { get; set; } = StatusCodes.Status200OK; + + public void Dispose() + { + _calls.Clear(); + } + + // Explicit interface implementation for IHttpForwarder.SendAsync + // The 4-parameter extension method calls this internally, so we need to track the call + ValueTask IHttpForwarder.SendAsync( + HttpContext context, + string destinationPrefix, + HttpMessageInvoker httpClient, + ForwarderRequestConfig requestConfig, + HttpTransformer transformer + ) + { + _calls.Add(new ForwarderCall(context, destinationPrefix)); + + // Set the response status code if forwarder succeeds + if (NextResult == ForwarderError.None) + { + context.Response.StatusCode = NextStatusCode; + } + + return new ValueTask(NextResult); + } + + public readonly record struct ForwarderCall(HttpContext Context, string Destination); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityOptionsValidationTests.cs b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityOptionsValidationTests.cs new file mode 100644 index 000000000..271d30508 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionAffinityOptionsValidationTests.cs @@ -0,0 +1,181 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; +using Xunit; + +namespace ModelContextProtocol.AspNetCore.Distributed.Tests; + +public class SessionAffinityOptionsValidationTests +{ + [Fact] + public void Validate_WithValidHttpUri_Succeeds() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = "http://localhost:5000" }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.True(result.Succeeded); + } + + [Fact] + public void Validate_WithValidHttpsUri_Succeeds() + { + // Arrange + SessionAffinityOptions options = new() + { + LocalServerAddress = "https://server1.internal:443", + }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.True(result.Succeeded); + } + + [Fact] + public void Validate_WithNullLocalServerAddress_Succeeds() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = null }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.True(result.Succeeded); + } + + [Fact] + public void Validate_WithEmptyLocalServerAddress_Succeeds() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = string.Empty }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.True(result.Succeeded); + } + + [Fact] + public void Validate_WithInvalidUri_Fails() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = "not a valid uri" }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.False(result.Succeeded); + Assert.True( + result.FailureMessage?.Contains("not a valid absolute URI", StringComparison.Ordinal) + ); + } + + [Fact] + public void Validate_WithRelativeUri_Fails() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = "/relative/path" }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.False(result.Succeeded); + } + + [Fact] + public void Validate_WithFtpScheme_Fails() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = "ftp://server:21" }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.False(result.Succeeded); + Assert.True(result.FailureMessage?.Contains("HTTP or HTTPS", StringComparison.Ordinal)); + } + + [Fact] + public void Validate_WithWsScheme_Fails() + { + // Arrange + SessionAffinityOptions options = new() { LocalServerAddress = "ws://server:8080" }; + SessionAffinityOptionsValidator validator = new(); + + // Act + ValidateOptionsResult result = validator.Validate(null, options); + + // Assert + Assert.False(result.Succeeded); + Assert.True(result.FailureMessage?.Contains("HTTP or HTTPS", StringComparison.Ordinal)); + } + + [Fact] + public void AddMcpHttpSessionAffinity_RegistersValidator() + { + // Arrange + ServiceCollection services = []; + services.AddLogging(); + services.AddHybridCache(); + + // Act + services.AddMcpHttpSessionAffinity(); + ServiceProvider provider = services.BuildServiceProvider(); + + // Assert + var validators = provider.GetServices>(); + Assert.True( + validators.Any(v => v is SessionAffinityOptionsValidator), + "SessionAffinityOptionsValidator should be registered" + ); + } + + [Fact] + public void ValidationAttribute_WithValidHttpUri_Succeeds() + { + // Arrange + var attribute = new HttpOrHttpsUriAttribute(); + var context = new System.ComponentModel.DataAnnotations.ValidationContext(new object()); + + // Act + var result = attribute.GetValidationResult("http://localhost:5000", context); + + // Assert + Assert.Equal(System.ComponentModel.DataAnnotations.ValidationResult.Success, result); + } + + [Fact] + public void ValidationAttribute_WithInvalidScheme_Fails() + { + // Arrange + var attribute = new HttpOrHttpsUriAttribute(); + var context = new System.ComponentModel.DataAnnotations.ValidationContext(new object()); + + // Act + var result = attribute.GetValidationResult("ftp://server:21", context); + + // Assert + Assert.NotNull(result); + Assert.True(result.ErrorMessage?.Contains("HTTP or HTTPS", StringComparison.Ordinal)); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionOwnerInfoSerializerTests.cs b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionOwnerInfoSerializerTests.cs new file mode 100644 index 000000000..086f7d117 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Distributed.Tests/SessionOwnerInfoSerializerTests.cs @@ -0,0 +1,283 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Text; +using ModelContextProtocol.AspNetCore.Distributed.Abstractions; +using Xunit; + +namespace ModelContextProtocol.AspNetCore.Distributed.Tests; + +public class SessionOwnerInfoSerializerTests +{ + private readonly SessionOwnerInfoSerializer _serializer; + + public SessionOwnerInfoSerializerTests() + { + _serializer = new SessionOwnerInfoSerializer(); + } + + [Fact] + public void Serialize_ValidSessionOwnerInfo_WritesToBuffer() + { + // Arrange + var sessionOwner = new SessionOwnerInfo + { + OwnerId = "test-owner-123", + Address = "http://localhost:5000", + ClaimedAt = new DateTimeOffset(2025, 10, 24, 12, 0, 0, TimeSpan.Zero), + }; + var buffer = new ArrayBufferWriter(); + + // Act + _serializer.Serialize(sessionOwner, buffer); + + // Assert + var json = Encoding.UTF8.GetString(buffer.WrittenSpan); + Assert.True(json.Contains("test-owner-123", StringComparison.Ordinal)); + Assert.True(json.Contains("http://localhost:5000", StringComparison.Ordinal)); + Assert.True(json.Contains("2025-10-24", StringComparison.Ordinal)); + } + + [Fact] + public void Serialize_SessionOwnerInfoWithNullClaimedAt_WritesToBuffer() + { + // Arrange + var sessionOwner = new SessionOwnerInfo + { + OwnerId = "owner-456", + Address = "https://example.com:8080", + ClaimedAt = null, + }; + var buffer = new ArrayBufferWriter(); + + // Act + _serializer.Serialize(sessionOwner, buffer); + + // Assert + var json = Encoding.UTF8.GetString(buffer.WrittenSpan); + Assert.True(json.Contains("owner-456", StringComparison.Ordinal)); + Assert.True(json.Contains("https://example.com:8080", StringComparison.Ordinal)); + // ClaimedAt should not be present due to JsonIgnoreCondition.WhenWritingNull + Assert.False(json.Contains("claimedAt", StringComparison.Ordinal)); + } + + [Fact] + public void Deserialize_ValidJson_ReturnsSessionOwnerInfo() + { + // Arrange + var json = """ + { + "ownerId": "deserialized-owner", + "address": "http://server:3000", + "claimedAt": "2025-10-24T15:30:00Z" + } + """; + var bytes = Encoding.UTF8.GetBytes(json); + var sequence = new ReadOnlySequence(bytes); + + // Act + var result = _serializer.Deserialize(sequence); + + // Assert + Assert.NotNull(result); + Assert.Equal("deserialized-owner", result.OwnerId); + Assert.Equal("http://server:3000", result.Address); + Assert.NotNull(result.ClaimedAt); + Assert.Equal( + new DateTimeOffset(2025, 10, 24, 15, 30, 0, TimeSpan.Zero), + result.ClaimedAt + ); + } + + [Fact] + public void Deserialize_JsonWithoutClaimedAt_ReturnsSessionOwnerInfoWithNullClaimedAt() + { + // Arrange + var json = """ + { + "ownerId": "owner-without-timestamp", + "address": "http://localhost:9000" + } + """; + var bytes = Encoding.UTF8.GetBytes(json); + var sequence = new ReadOnlySequence(bytes); + + // Act + var result = _serializer.Deserialize(sequence); + + // Assert + Assert.NotNull(result); + Assert.Equal("owner-without-timestamp", result.OwnerId); + Assert.Equal("http://localhost:9000", result.Address); + Assert.Null(result.ClaimedAt); + } + + [Fact] + public void RoundTrip_SerializeAndDeserialize_PreservesData() + { + // Arrange + var original = new SessionOwnerInfo + { + OwnerId = "roundtrip-test", + Address = "https://roundtrip.example.com:443", + ClaimedAt = DateTimeOffset.UtcNow, + }; + var buffer = new ArrayBufferWriter(); + + // Act - Serialize + _serializer.Serialize(original, buffer); + + // Act - Deserialize + var sequence = new ReadOnlySequence(buffer.WrittenMemory); + var deserialized = _serializer.Deserialize(sequence); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal(original.OwnerId, deserialized.OwnerId); + Assert.Equal(original.Address, deserialized.Address); + Assert.Equal(original.ClaimedAt, deserialized.ClaimedAt); + } + + [Fact] + public void RoundTrip_WithNullClaimedAt_PreservesData() + { + // Arrange + var original = new SessionOwnerInfo + { + OwnerId = "null-timestamp-test", + Address = "http://test.local", + ClaimedAt = null, + }; + var buffer = new ArrayBufferWriter(); + + // Act - Serialize + _serializer.Serialize(original, buffer); + + // Act - Deserialize + var sequence = new ReadOnlySequence(buffer.WrittenMemory); + var deserialized = _serializer.Deserialize(sequence); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal(original.OwnerId, deserialized.OwnerId); + Assert.Equal(original.Address, deserialized.Address); + Assert.Null(deserialized.ClaimedAt); + } + + [Fact] + public void Serialize_UsesCamelCaseNaming() + { + // Arrange + var sessionOwner = new SessionOwnerInfo + { + OwnerId = "case-test", + Address = "http://localhost", + ClaimedAt = DateTimeOffset.UtcNow, + }; + var buffer = new ArrayBufferWriter(); + + // Act + _serializer.Serialize(sessionOwner, buffer); + + // Assert + var json = Encoding.UTF8.GetString(buffer.WrittenSpan); + // Verify camelCase naming policy is applied + Assert.True(json.Contains("\"ownerId\"", StringComparison.Ordinal)); + Assert.True(json.Contains("\"address\"", StringComparison.Ordinal)); + Assert.True(json.Contains("\"claimedAt\"", StringComparison.Ordinal)); + // Should not contain PascalCase + Assert.False(json.Contains("\"OwnerId\"", StringComparison.Ordinal)); + Assert.False(json.Contains("\"Address\"", StringComparison.Ordinal)); + Assert.False(json.Contains("\"ClaimedAt\"", StringComparison.Ordinal)); + } + + [Fact] + public void Deserialize_NullJson_ThrowsInvalidOperationException() + { + // Arrange + var json = "null"; + var bytes = Encoding.UTF8.GetBytes(json); + var sequence = new ReadOnlySequence(bytes); + + // Act & Assert + Assert.Throws(() => _serializer.Deserialize(sequence)); + } + + [Fact] + public void Deserialize_MultiSegmentBuffer_ReturnsSessionOwnerInfo() + { + // Arrange + var json = """ + { + "ownerId": "multi-segment-test", + "address": "http://multi.example.com", + "claimedAt": "2025-10-24T10:00:00Z" + } + """; + var bytes = Encoding.UTF8.GetBytes(json); + + // Create a multi-segment buffer + var segment1 = new ReadOnlyMemory(bytes, 0, bytes.Length / 2); + var segment2 = new ReadOnlyMemory( + bytes, + bytes.Length / 2, + bytes.Length - bytes.Length / 2 + ); + var sequence = CreateMultiSegmentSequence(segment1, segment2); + + // Act + var result = _serializer.Deserialize(sequence); + + // Assert + Assert.NotNull(result); + Assert.Equal("multi-segment-test", result.OwnerId); + Assert.Equal("http://multi.example.com", result.Address); + } + + [Fact] + public void Serialize_SpecialCharactersInAddress_EncodesCorrectly() + { + // Arrange + var sessionOwner = new SessionOwnerInfo + { + OwnerId = "special-chars", + Address = "http://server/path?query=value&foo=bar", + ClaimedAt = null, + }; + var buffer = new ArrayBufferWriter(); + + // Act + _serializer.Serialize(sessionOwner, buffer); + var sequence = new ReadOnlySequence(buffer.WrittenMemory); + var deserialized = _serializer.Deserialize(sequence); + + // Assert + Assert.Equal(sessionOwner.Address, deserialized.Address); + } + + private static ReadOnlySequence CreateMultiSegmentSequence( + ReadOnlyMemory segment1, + ReadOnlyMemory segment2 + ) + { + var first = new BufferSegment(segment1); + var second = first.Append(segment2); + return new ReadOnlySequence(first, 0, second, second.Memory.Length); + } + + private sealed class BufferSegment : ReadOnlySequenceSegment + { + public BufferSegment(ReadOnlyMemory memory) + { + Memory = memory; + } + + public BufferSegment Append(ReadOnlyMemory memory) + { + var segment = new BufferSegment(memory) { RunningIndex = RunningIndex + Memory.Length }; + Next = segment; + return segment; + } + } +} From 8a7a423b4bc6578962d60c134d43a7f9c1fe5e49 Mon Sep 17 00:00:00 2001 From: xiangyan99 Date: Mon, 23 Feb 2026 17:07:33 -0800 Subject: [PATCH 2/2] update readme --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d282e9af3..9433d63f9 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,14 @@ The official C# SDK for the [Model Context Protocol](https://modelcontextprotoco ## Packages -This SDK consists of three main packages: +This SDK consists of four main packages: - **[ModelContextProtocol](https://www.nuget.org/packages/ModelContextProtocol/absoluteLatest)** [![NuGet preview version](https://img.shields.io/nuget/vpre/ModelContextProtocol.svg)](https://www.nuget.org/packages/ModelContextProtocol/absoluteLatest) - The main package with hosting and dependency injection extensions. This is the right fit for most projects that don't need HTTP server capabilities. This README serves as documentation for this package. - **[ModelContextProtocol.AspNetCore](https://www.nuget.org/packages/ModelContextProtocol.AspNetCore/absoluteLatest)** [![NuGet preview version](https://img.shields.io/nuget/vpre/ModelContextProtocol.AspNetCore.svg)](https://www.nuget.org/packages/ModelContextProtocol.AspNetCore/absoluteLatest) - The library for HTTP-based MCP servers. [Documentation](src/ModelContextProtocol.AspNetCore/README.md) +- **[ModelContextProtocol.AspNetCore.Distributed](https://www.nuget.org/packages/ModelContextProtocol.AspNetCore.Distributed/absoluteLatest)** [![NuGet preview version](https://img.shields.io/nuget/vpre/ModelContextProtocol.AspNetCore.Distributed.svg)](https://www.nuget.org/packages/ModelContextProtocol.AspNetCore.Distributed/absoluteLatest) - Session-aware routing for MCP servers running across multiple instances, built on ASP.NET Core HybridCache and YARP. [Documentation](src/ModelContextProtocol.AspNetCore.Distributed/README.md) + - **[ModelContextProtocol.Core](https://www.nuget.org/packages/ModelContextProtocol.Core/absoluteLatest)** [![NuGet preview version](https://img.shields.io/nuget/vpre/ModelContextProtocol.Core.svg)](https://www.nuget.org/packages/ModelContextProtocol.Core/absoluteLatest) - For people who only need to use the client or low-level server APIs and want the minimum number of dependencies. [Documentation](src/ModelContextProtocol.Core/README.md) > [!NOTE]