Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import io.modelcontextprotocol.util.Assert;

Expand Down Expand Up @@ -47,27 +47,18 @@ private DefaultServerTransportSecurityValidator(List<String> allowedOrigins, Lis
}

@Override
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
boolean missingHost = true;
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw new ServerTransportSecurityException(403, "Invalid Origin header");
}
validateOrigin(values.get(0));
}
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
missingHost = false;
List<String> values = entry.getValue();
if (values == null || values.isEmpty()) {
throw new ServerTransportSecurityException(421, "Invalid Host header");
}
validateHost(values.get(0));
}
public void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException {
List<String> originValues = headerAccessor.apply(ORIGIN_HEADER);
if (originValues != null && !originValues.isEmpty()) {
validateOrigin(originValues.get(0));
}
if (!allowedHosts.isEmpty() && missingHost) {
throw new ServerTransportSecurityException(421, "Invalid Host header");

if (!allowedHosts.isEmpty()) {
List<String> hostValues = headerAccessor.apply(HOST_HEADER);
if (hostValues == null || hostValues.isEmpty()) {
throw new ServerTransportSecurityException(421, "Invalid Host header");
}
validateHost(hostValues.get(0));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,68 @@

import java.util.List;
import java.util.Map;
import java.util.function.Function;

/**
* Interface for validating HTTP requests in server transports. Implementations can
* validate Origin headers, Host headers, or any other security-related headers according
* to the MCP specification.
*
* <p>
* New implementations should override {@link #validateHeaders(Function)
* validateHeaders(Function)} for more efficient, case-insensitive header access. The
* older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be
* removed in a future major version.
*
* @author Daniel Garnier-Moiroux
* @see DefaultServerTransportSecurityValidator
* @see ServerTransportSecurityException
*/
@FunctionalInterface
public interface ServerTransportSecurityValidator {

/**
* A no-op validator that accepts all requests without validation.
*/
ServerTransportSecurityValidator NOOP = headers -> {
ServerTransportSecurityValidator NOOP = new ServerTransportSecurityValidator() {
};

/**
* Validates the HTTP headers from an incoming request.
*
* <p>
* The default implementation converts the map into a case-insensitive header accessor
* and delegates to {@link #validateHeaders(Function)}.
* @param headers A map of header names to their values (multi-valued headers
* supported)
* @throws ServerTransportSecurityException if validation fails
* @deprecated Use {@link #validateHeaders(Function)} instead for more efficient,
* case-insensitive header access. This method will be removed in a future major
* version.
*/
@Deprecated
default void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
validateHeaders(name -> headers.entrySet()
.stream()
.filter(e -> e.getKey().equalsIgnoreCase(name))
.map(Map.Entry::getValue)
.findFirst()
.orElse(List.of()));
}

/**
* Validates the HTTP headers from an incoming request using a header accessor
* function.
*
* <p>
* New implementations should override this method. Header name lookup through the
* accessor should be case-insensitive (e.g., when backed by
* {@code HttpServletRequest.getHeaders}).
* @param headerAccessor A function that returns the list of values for a given header
* name, or an empty list if the header is not present.
* @throws ServerTransportSecurityException if validation fails
*/
void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException;
default void validateHeaders(Function<String, List<String>> headerAccessor)
throws ServerTransportSecurityException {
}

}
Loading