diff --git a/plugins/communication_protocols/gql/pyproject.toml b/plugins/communication_protocols/gql/pyproject.toml index 4377268..3a9010f 100644 --- a/plugins/communication_protocols/gql/pyproject.toml +++ b/plugins/communication_protocols/gql/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-gql" -version = "1.1.0" +version = "1.1.3" authors = [ { name = "UTCP Contributors" }, ] @@ -14,6 +14,7 @@ requires-python = ">=3.10" dependencies = [ "pydantic>=2.0", "gql>=3.0", + "aiohttp>=3.8", "utcp>=1.1" ] classifiers = [ diff --git a/plugins/communication_protocols/gql/src/utcp_gql/_security.py b/plugins/communication_protocols/gql/src/utcp_gql/_security.py new file mode 100644 index 0000000..089271c --- /dev/null +++ b/plugins/communication_protocols/gql/src/utcp_gql/_security.py @@ -0,0 +1,350 @@ +"""URL validation for the GraphQL communication protocol. + +Mirror of ``utcp_http._security`` -- intentionally duplicated rather +than cross-plugin-imported so ``utcp-gql`` does not gain a runtime +dependency on ``utcp-http``. Keep the two files in sync when changing +the validator behavior. Backs GHSA-ppx3-28rw-8fpf (the original CVE +fix did not reach this plugin) and GHSA-9qhg-99ww-9mqc (redirect +SSRF on the GraphQL endpoint). +""" + +from __future__ import annotations + +import re +from contextlib import asynccontextmanager +from ipaddress import IPv6Address, ip_address +from typing import Any, AsyncIterator, Dict, Optional +from urllib.parse import urljoin, urlparse + +# Hostnames considered safe to talk to over plain HTTP. +_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"}) + + +def is_secure_url(url: str) -> bool: + """Return True if ``url`` is safe to fetch from a UTCP HTTP protocol. + + Allowed: + - Any ``https://`` URL. + - ``http://`` URLs whose host is exactly ``localhost``, ``127.0.0.1``, + or ``::1``. + + Disallowed: + - Plain ``http://`` to any other host (MITM exposure). + - URLs whose hostname *starts* with ``localhost`` / ``127.0.0.1`` but + isn't actually loopback (e.g. ``http://localhost.evil.com``, + ``http://127.0.0.1.attacker.example``). The earlier ``startswith`` + check let these through. + - Anything without a scheme/host (file://, gopher://, javascript:, ...). + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + scheme = (parsed.scheme or "").lower() + if scheme not in {"http", "https"}: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if scheme == "https": + return True + + # http:// is only allowed for loopback. + if host in _LOOPBACK_HOSTNAMES: + return True + + # Catch any other literal loopback IP that urlparse normalised + # (e.g. ``http://127.000.000.001``). + try: + return ip_address(host).is_loopback + except ValueError: + return False + + +def _ip_is_loopback_like(host: str) -> bool: + """Mirror of ``utcp_http._security._ip_is_loopback_like``. See that + module for the full rationale -- covers 127.0.0.0/8, ::1, 0.0.0.0, + ::, and IPv4-mapped IPv6 loopback addresses. + """ + if host in {"0.0.0.0", "::"}: + return True + try: + addr = ip_address(host) + except ValueError: + return False + if addr.is_loopback: + return True + if isinstance(addr, IPv6Address): + mapped = addr.ipv4_mapped + if mapped is not None and mapped.is_loopback: + return True + return False + + +def is_loopback_url(url: str) -> bool: + """Return True if ``url``'s host is a literal loopback-or-equivalent + address. Hostname-based; covers ``0.0.0.0``, ``::`` and IPv4-mapped + IPv6 loopback forms in addition to the obvious set. + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if host in _LOOPBACK_HOSTNAMES: + return True + + return _ip_is_loopback_like(host) + + +def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: + """Raise ``ValueError`` if ``url`` is not safe to fetch. + + ``context`` is a short label (``"manual discovery"``, ``"tool invocation"``, + etc.) included in the error so log readers can tell which trust boundary + was breached. + """ + if is_secure_url(url): + return + + where = f" during {context}" if context else "" + raise ValueError( + f"Security error{where}: URL must use HTTPS or be a literal loopback " + f"address (localhost / 127.0.0.1 / ::1). Got: {url!r}. " + "Plain HTTP to any other host is rejected to prevent MITM attacks " + "and SSRF into internal services." + ) + + +# HTTP statuses where the server expects the client to re-issue the request +# against the URL given in the ``Location`` header. 303 forces a GET; the +# rest preserve the original method. +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + + +_AUTH_SENSITIVE_HEADERS = frozenset({ + "authorization", + "proxy-authorization", + "cookie", + "www-authenticate", + "x-api-key", + "api-key", + "x-auth-token", + "x-access-token", + "x-csrf-token", + "x-xsrf-token", + "x-amz-security-token", + "x-goog-api-key", + "x_api_key", + "api_key", + "x_auth_token", + "x_access_token", + "x_csrf_token", + "x_xsrf_token", + "apikey", + "xapikey", + "authtoken", + "xauthtoken", + "accesstoken", + "xaccesstoken", + "bearertoken", + "sessionid", + "csrftoken", + "xsrftoken", +}) + + +_AUTH_HEADER_REGEX = re.compile( + r"(?:(?:^|[-_])" + r"(?:auth|authn|authz|token|key|secret|bearer|session|sid|" + r"api[-_]?key|jwt|csrf|xsrf)" + r"(?:[-_]|$))" + r"|" + r"(?:apikey|authtoken|accesstoken|bearertoken|sessionid|" + r"csrftoken|xsrftoken|xapikey|xauthtoken|xaccesstoken|xapitoken)", + re.IGNORECASE, +) + + +def _header_is_auth_sensitive(name: str) -> bool: + if not isinstance(name, str): + return False + lower = name.lower() + if lower in _AUTH_SENSITIVE_HEADERS: + return True + return _AUTH_HEADER_REGEX.search(lower) is not None + + +_DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443} + + +def _effective_port(scheme: str, parsed_port: Optional[int]) -> Optional[int]: + if parsed_port is not None: + return parsed_port + return _DEFAULT_PORTS.get((scheme or "").lower()) + + +def _same_origin(a: str, b: str) -> bool: + """Return True iff URLs ``a`` and ``b`` share scheme+host+port. + + Returns ``False`` on any parse failure, including + ``urlparse(...).port`` raising for an out-of-range port -- a + bogus ``Location`` is treated as cross-origin so credentials + are scrubbed instead of letting the ``ValueError`` escape. + """ + try: + pa, pb = urlparse(a), urlparse(b) + sa = (pa.scheme or "").lower() + sb = (pb.scheme or "").lower() + if not sa or not sb: + return False + if sa != sb: + return False + if (pa.hostname or "").lower() != (pb.hostname or "").lower(): + return False + return _effective_port(sa, pa.port) == _effective_port(sb, pb.port) + except ValueError: + return False + + +def _scrub_cross_origin_credentials(kwargs: dict) -> None: + """Strip auth-bearing kwargs in place when crossing origins. + + Mirrors ``utcp_http._security._scrub_cross_origin_credentials`` -- + drops auth-looking headers, ``auth=`` / ``proxy_auth=``, + ``cookies``, ``params``, and the request body (``json`` / + ``data``) so 307/308 redirects cannot resend an OAuth POST body + to a new origin. + """ + headers = kwargs.get("headers") + if headers is not None: + scrubbed: Dict[str, Any] = {} + for k, v in dict(headers).items(): + if _header_is_auth_sensitive(k): + continue + scrubbed[k] = v + kwargs["headers"] = scrubbed + + kwargs.pop("auth", None) + kwargs.pop("proxy_auth", None) + kwargs.pop("cookies", None) + kwargs.pop("params", None) + kwargs.pop("json", None) + kwargs.pop("data", None) + + +@asynccontextmanager +async def safe_request_with_redirects( + session: Any, + method: str, + url: str, + *, + context: str, + max_redirects: int = 5, + **kwargs: Any, +) -> AsyncIterator[Any]: + """Issue an aiohttp request that re-validates every redirect hop. + + Closes the residual SSRF window left by ``ensure_secure_url`` (which + only inspects the initial URL): aiohttp by default follows 3xx + redirects without rechecking, so an attacker-controlled server could + 302 the client into ``http://169.254.169.254/...`` (cloud metadata) + or any internal HTTP service and the response body would be handed + back to the caller. Backs GHSA-9qhg-99ww-9mqc. + + Behavior: + * Calls ``ensure_secure_url(url, context=context)`` on the initial + URL. + * Disables aiohttp's auto-follow (``allow_redirects=False``). + * On a 3xx response with a ``Location`` header, resolves the + target against the current URL and runs ``ensure_secure_url`` + on it before issuing the next hop. Rejection raises and the + redirect chain is aborted with the connection released. + * Caps the chain at ``max_redirects`` hops. Exceeding that raises + ``RuntimeError``. + * Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops + any request body; 301/302/307/308 preserve method and body. + + Usage: + ```python + async with safe_request_with_redirects( + session, "GET", url, context="tool invocation", params=... + ) as response: + response.raise_for_status() + ... + ``` + """ + ensure_secure_url(url, context=context) + # We control redirect behavior ourselves; refuse to let callers override. + kwargs.pop("allow_redirects", None) + + current_url = url + current_method = method + hops = 0 + final_response = None + + try: + while True: + response = await session.request( + current_method, + current_url, + allow_redirects=False, + **kwargs, + ) + if response.status not in _REDIRECT_STATUSES: + final_response = response + break + + location = response.headers.get("Location") + if not location: + # 3xx with no Location header — nothing to follow. Let + # the caller handle the unusual response. + final_response = response + break + + if hops >= max_redirects: + response.release() + raise RuntimeError( + f"Too many redirects (>{max_redirects}) during {context} " + f"starting from {url!r}." + ) + + next_url = urljoin(current_url, location) + try: + ensure_secure_url( + next_url, context=f"{context} (redirect target)" + ) + except Exception: + response.release() + raise + + response.release() + + # Strip auth-bearing kwargs on cross-origin redirect. + if not _same_origin(current_url, next_url): + _scrub_cross_origin_credentials(kwargs) + + if response.status == 303: + current_method = "GET" + kwargs.pop("json", None) + kwargs.pop("data", None) + current_url = next_url + hops += 1 + + yield final_response + finally: + if final_response is not None: + final_response.release() diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index 16b945c..b889593 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -15,6 +15,7 @@ from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth from utcp_gql.gql_call_template import GraphQLCallTemplate +from utcp_gql._security import ensure_secure_url, safe_request_with_redirects if TYPE_CHECKING: from utcp.utcp_client import UtcpClient @@ -28,6 +29,35 @@ logger = logging.getLogger(__name__) +class _SecureAIOHTTPTransport(AIOHTTPTransport): + """``AIOHTTPTransport`` subclass that patches the underlying + aiohttp ``ClientSession`` to refuse redirects as soon as it is + created during ``connect()``. + + The previous fix patched the session AFTER entering the + ``GqlClient`` context, but when ``fetch_schema_from_transport= + True`` the schema introspection request is issued inside the + ``GqlClient.__aenter__`` call -- BEFORE the patch could land. + That left the very first GraphQL request unprotected and + re-introduced the redirect-SSRF / credential-leak window. + Patching inside ``connect()`` guarantees every outbound POST + from this transport (introspection included) skips redirects. + """ + + async def connect(self) -> None: # type: ignore[override] + await super().connect() + session = getattr(self, "session", None) + if session is None: + return + original_post = session.post + + def _no_redirect_post(*args: Any, **kwargs: Any): + kwargs["allow_redirects"] = False + return original_post(*args, **kwargs) + + session.post = _no_redirect_post # type: ignore[method-assign] + + class GraphQLCommunicationProtocol(CommunicationProtocol): """GraphQL protocol implementation for UTCP 1.0. @@ -40,22 +70,21 @@ class GraphQLCommunicationProtocol(CommunicationProtocol): def __init__(self) -> None: self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _enforce_https_or_localhost(self, url: str) -> None: - if not ( - url.startswith("https://") - or url.startswith("http://localhost") - or url.startswith("http://127.0.0.1") - ): - raise ValueError( - "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. " - "Non-secure URLs are vulnerable to man-in-the-middle attacks. " - f"Got: {url}." - ) - async def _handle_oauth2(self, auth: OAuth2Auth) -> str: + """Fetch an OAuth2 access token. + + Validates the token URL with the hostname-based ``ensure_secure_url`` + helper before any credential bytes leave the process, and follows + redirects only after re-validating each hop -- defends against the + sibling SSRF / credential-exfiltration patterns in + GHSA-8cp3-qxj6-px34 and GHSA-9qhg-99ww-9mqc. + """ client_id = auth.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + + ensure_secure_url(auth.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: data = { "grant_type": "client_credentials", @@ -63,7 +92,13 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: "client_secret": auth.client_secret, "scope": auth.scope, } - async with session.post(auth.token_url, data=data) as resp: + async with safe_request_with_redirects( + session, + "POST", + auth.token_url, + context="OAuth2 token fetch", + data=data, + ) as resp: resp.raise_for_status() token_response = await resp.json() self._oauth_tokens[client_id] = token_response @@ -99,11 +134,16 @@ async def register_manual( ) -> RegisterManualResult: if not isinstance(manual_call_template, GraphQLCallTemplate): raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") - self._enforce_https_or_localhost(manual_call_template.url) + # Hostname-based validation -- replaces the broken ``startswith`` + # prefix check that let ``http://127.0.0.1.attacker.example`` + # through (GHSA-ppx3-28rw-8fpf). + ensure_secure_url( + manual_call_template.url, context="GraphQL manual discovery" + ) try: headers = await self._prepare_headers(manual_call_template) - transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers) + transport = _SecureAIOHTTPTransport(url=manual_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: schema = session.client.schema tools: List[Tool] = [] @@ -178,10 +218,12 @@ async def call_tool( ) -> Any: if not isinstance(tool_call_template, GraphQLCallTemplate): raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") - self._enforce_https_or_localhost(tool_call_template.url) + ensure_secure_url( + tool_call_template.url, context="GraphQL tool invocation" + ) headers = await self._prepare_headers(tool_call_template, tool_args) - transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) + transport = _SecureAIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: # Filter out header fields from GraphQL variables; these are sent via HTTP headers header_fields = tool_call_template.header_fields or [] diff --git a/plugins/communication_protocols/gql/tests/test_gql_security.py b/plugins/communication_protocols/gql/tests/test_gql_security.py new file mode 100644 index 0000000..da93008 --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_gql_security.py @@ -0,0 +1,111 @@ +"""Security tests for the GraphQL communication protocol (utcp-gql). + +Pin the fixes for GHSA-ppx3-28rw-8fpf (the original CVE-2026-44661 +URL hardening missed this plugin) and the OAuth2 / redirect halves +of GHSA-8cp3-qxj6-px34 / GHSA-9qhg-99ww-9mqc. +""" + +import pytest + +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_gql._security import ( + ensure_secure_url, + is_secure_url, +) +from utcp_gql.gql_call_template import GraphQLCallTemplate +from utcp_gql.gql_communication_protocol import GraphQLCommunicationProtocol + + +# --------------------------------------------------------------------------- +# Hostname-based validator must reject the same prefix bypass as utcp-http. +# --------------------------------------------------------------------------- + + +class TestUrlValidatorRejectsPrefixBypass: + @pytest.mark.parametrize( + "url", + [ + "http://localhost.evil.com/graphql", + "http://127.0.0.1.attacker.example/graphql", + "http://169.254.169.254/graphql", + "http://10.0.0.5/graphql", + "http://internal.service.local/graphql", + "http://example.com/graphql", + ], + ) + def test_bypass_url_rejected(self, url: str) -> None: + assert is_secure_url(url) is False + with pytest.raises(ValueError, match="HTTPS or be a literal loopback"): + ensure_secure_url(url) + + @pytest.mark.parametrize( + "url", + [ + "https://api.example.com/graphql", + "http://localhost/graphql", + "http://127.0.0.1:9090/graphql", + "http://[::1]:9090/graphql", + ], + ) + def test_legitimate_url_accepted(self, url: str) -> None: + assert is_secure_url(url) is True + ensure_secure_url(url) # must not raise + + +# --------------------------------------------------------------------------- +# register_manual + call_tool: URL validation is now hostname-based. +# --------------------------------------------------------------------------- + + +class TestRegisterAndCallRejectBypass: + @pytest.mark.asyncio + async def test_register_manual_rejects_prefix_bypass(self) -> None: + proto = GraphQLCommunicationProtocol() + tpl = GraphQLCallTemplate( + name="evil", + url="http://127.0.0.1.attacker.example/graphql", + ) + # The validator runs before register_manual's try/except so the + # ValueError propagates rather than being captured in the + # result. + with pytest.raises(ValueError, match="HTTPS or be a literal loopback"): + await proto.register_manual(None, tpl) + + @pytest.mark.asyncio + async def test_call_tool_rejects_prefix_bypass(self) -> None: + proto = GraphQLCommunicationProtocol() + tpl = GraphQLCallTemplate( + name="evil", + url="http://localhost.evil.com/graphql", + ) + with pytest.raises(ValueError, match="HTTPS or be a literal loopback"): + await proto.call_tool(None, "x", {}, tpl) + + +# --------------------------------------------------------------------------- +# OAuth2 token URL is validated before credential bytes leave the process. +# --------------------------------------------------------------------------- + + +class TestOAuth2TokenUrlValidation: + @pytest.mark.asyncio + async def test_internal_token_url_rejected(self) -> None: + proto = GraphQLCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://169.254.169.254/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + @pytest.mark.asyncio + async def test_plain_http_non_loopback_token_url_rejected(self) -> None: + proto = GraphQLCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://attacker.example/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) diff --git a/plugins/communication_protocols/http/pyproject.toml b/plugins/communication_protocols/http/pyproject.toml index db0af80..8e35fa4 100644 --- a/plugins/communication_protocols/http/pyproject.toml +++ b/plugins/communication_protocols/http/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-http" -version = "1.1.3" +version = "1.1.6" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/http/src/utcp_http/_security.py b/plugins/communication_protocols/http/src/utcp_http/_security.py index db98cc5..602b612 100644 --- a/plugins/communication_protocols/http/src/utcp_http/_security.py +++ b/plugins/communication_protocols/http/src/utcp_http/_security.py @@ -9,9 +9,11 @@ from __future__ import annotations -from ipaddress import ip_address -from typing import Optional -from urllib.parse import urlparse +import re +from contextlib import asynccontextmanager +from ipaddress import IPv6Address, ip_address +from typing import Any, AsyncIterator, Dict, Optional +from urllib.parse import urljoin, urlparse # Hostnames considered safe to talk to over plain HTTP. _LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"}) @@ -64,13 +66,51 @@ def is_secure_url(url: str) -> bool: return False +def _ip_is_loopback_like(host: str) -> bool: + """Return True if ``host`` is an IP literal that the local kernel will + route to the host running the agent. + + Wider than Python's stdlib ``ip_address(...).is_loopback`` because we + must also defend against: + + * ``0.0.0.0`` -- on Linux a TCP connect to 0.0.0.0 lands on 127.0.0.1. + * ``::`` -- the IPv6 equivalent of ``0.0.0.0``. + * IPv4-mapped IPv6 forms of any 127.0.0.0/8 address (e.g. + ``::ffff:127.0.0.1``, ``::ffff:127.0.0.2``) -- ``ipaddress`` does + not treat these as loopback per RFC 4291, but the dual-stack + socket layer routes them to the v4 loopback. + + Used by the OpenAPI converter to detect attacker-controlled + ``servers[0].url`` values that point at the agent's own loopback + interface (the GHSA-39j6-4867-gg4w SSRF pattern). Hostname-based, + never prefix-based. + """ + if host in {"0.0.0.0", "::"}: + return True + try: + addr = ip_address(host) + except ValueError: + return False + if addr.is_loopback: + return True + # IPv4-mapped IPv6 loopback (``::ffff:127.0.0.1`` etc.) -- the + # ``ipv4_mapped`` accessor surfaces the embedded v4 address. + if isinstance(addr, IPv6Address): + mapped = addr.ipv4_mapped + if mapped is not None and mapped.is_loopback: + return True + return False + + def is_loopback_url(url: str) -> bool: """Return True if ``url``'s host is a literal loopback address. Used by the OpenAPI converter to detect the SSRF case where a remote spec declares ``servers: [{ url: "http://127.0.0.1:..." }]`` to redirect tool invocation at the host running the agent. Hostname-based — not a string - prefix — so ``http://localhost.evil.com`` returns False. + prefix — so ``http://localhost.evil.com`` returns False. Also covers the + "wildcard" and "IPv4-mapped" loopback forms that bypass Python's stdlib + ``is_loopback`` check (see ``_ip_is_loopback_like``). """ if not isinstance(url, str) or not url: return False @@ -87,10 +127,7 @@ def is_loopback_url(url: str) -> bool: if host in _LOOPBACK_HOSTNAMES: return True - try: - return ip_address(host).is_loopback - except ValueError: - return False + return _ip_is_loopback_like(host) def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: @@ -110,3 +147,306 @@ def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: "Plain HTTP to any other host is rejected to prevent MITM attacks " "and SSRF into internal services." ) + + +# HTTP statuses where the server expects the client to re-issue the request +# against the URL given in the ``Location`` header. 303 forces a GET; the +# rest preserve the original method. +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + + +# HTTP headers that carry authentication or session material and must be +# stripped when a redirect crosses to a different origin. Includes the +# canonical IETF names (``Authorization`` / ``Cookie`` / +# ``Proxy-Authorization``) PLUS a curated list of common API-key / +# service-token names because UTCP's ``ApiKeyAuth`` lets callers put a +# secret under an arbitrary header name. Comparison is case-insensitive +# against this lowercase set. +_AUTH_SENSITIVE_HEADERS = frozenset({ + # Canonical IETF headers. + "authorization", + "proxy-authorization", + "cookie", + "www-authenticate", + # Common hyphenated API-key / service-token header names. + "x-api-key", + "api-key", + "x-auth-token", + "x-access-token", + "x-csrf-token", + "x-xsrf-token", + "x-amz-security-token", + "x-goog-api-key", + # Common underscore-separated variants (some HTTP stacks normalise + # ``X-API-Key`` to ``X_API_KEY`` on the way in). + "x_api_key", + "api_key", + "x_auth_token", + "x_access_token", + "x_csrf_token", + "x_xsrf_token", + # Condensed / no-separator variants seen in custom APIs. + "apikey", + "xapikey", + "authtoken", + "xauthtoken", + "accesstoken", + "xaccesstoken", + "bearertoken", + "sessionid", + "csrftoken", + "xsrftoken", +}) + +# Regex catching ad-hoc auth header names that aren't in the explicit +# set above (``X-MyApp-Token``, ``Custom-Bearer``, ``X_MyApp_Token``, +# etc.). Conservative but biased toward strip-on-cross-origin since +# false positives are only a usability cost. +# +# Two alternations: +# 1. Word-boundary match on hyphen/underscore/start/end so +# ``X-Foo-Token`` and ``X_FOO_TOKEN`` both trip. +# 2. No-boundary match on compound condensed names +# (``XApiKey``-style lowercased to ``xapikey``) since the +# lowercased form has no separator to anchor on. +_AUTH_HEADER_REGEX = re.compile( + r"(?:(?:^|[-_])" + r"(?:auth|authn|authz|token|key|secret|bearer|session|sid|" + r"api[-_]?key|jwt|csrf|xsrf)" + r"(?:[-_]|$))" + r"|" + r"(?:apikey|authtoken|accesstoken|bearertoken|sessionid|" + r"csrftoken|xsrftoken|xapikey|xauthtoken|xaccesstoken|xapitoken)", + re.IGNORECASE, +) + + +def _header_is_auth_sensitive(name: str) -> bool: + """Return True if ``name`` looks like it carries an auth secret. + + Handles hyphen-separated (``X-Api-Key``), underscore-separated + (``X_API_KEY``), and condensed-camelCase (``XApiKey`` lowercased to + ``xapikey``) variants. The regex deliberately favors false + positives over false negatives -- on a cross-origin redirect the + cost of stripping a misidentified header is a broken benign + request, vs. credential exfiltration if a real auth header + survives. + """ + if not isinstance(name, str): + return False + lower = name.lower() + if lower in _AUTH_SENSITIVE_HEADERS: + return True + return _AUTH_HEADER_REGEX.search(lower) is not None + + +_DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443} + + +def _effective_port(scheme: str, parsed_port: Optional[int]) -> Optional[int]: + """Return the port a URL actually targets, filling in scheme defaults.""" + if parsed_port is not None: + return parsed_port + return _DEFAULT_PORTS.get((scheme or "").lower()) + + +def _same_origin(a: str, b: str) -> bool: + """Return True iff URLs ``a`` and ``b`` share scheme+host+port. + + Treats omitted ports as the scheme default, so + ``https://api.example.com/`` and ``https://api.example.com:443/`` + are recognised as the same origin (the previous implementation + treated them as different origins and silently stripped + ``Authorization`` on legitimate same-origin redirects). + + Returns ``False`` on any parse failure -- including + ``urlparse(...).port`` raising ``ValueError`` for a malformed + or out-of-range port (the property accessor is lazy and can + raise). A bogus ``Location`` header should be treated as + cross-origin so credentials are scrubbed, never propagate + a crash to the caller. + """ + try: + pa, pb = urlparse(a), urlparse(b) + sa = (pa.scheme or "").lower() + sb = (pb.scheme or "").lower() + if not sa or not sb: + return False + if sa != sb: + return False + if (pa.hostname or "").lower() != (pb.hostname or "").lower(): + return False + return _effective_port(sa, pa.port) == _effective_port(sb, pb.port) + except ValueError: + # Either ``urlparse`` rejected the input or ``.port`` raised + # because of an out-of-range / non-numeric port. Treat as a + # different origin: scrub creds, do not crash. + return False + + +def _scrub_cross_origin_credentials(kwargs: dict) -> None: + """Strip auth-bearing kwargs in place when crossing origins. + + Aligns the redirect helper with browser / requests / curl + behaviour: do not forward credential-bearing material to a new + origin. Covers: + + * ``Authorization`` and other canonical auth headers + (``Proxy-Authorization``, ``Cookie``, ``WWW-Authenticate``); + * common API-key / service-token header names + (``X-Api-Key``, ``X-Auth-Token``, ``X-Csrf-Token``, + ``X-Amz-Security-Token``, etc.); + * ad-hoc header names matching auth-like substrings (e.g. + ``X-MyApp-Bearer``, ``Custom-Token``) -- see + ``_AUTH_HEADER_REGEX``; + * the aiohttp ``auth=`` Basic-credentials kwarg; + * the ``proxy_auth=`` Basic-credentials kwarg; + * ``cookies``; + * ``params`` (API keys configured via ``ApiKeyAuth`` with + ``location="query"`` end up here); + * the request body (``json``, ``data``) -- a 307/308 redirect + preserves method+body, and the body of e.g. an OAuth2 token + POST contains the very credentials we're trying to protect. + Browsers prompt the user before forwarding a cross-origin + 307/308 body; we are headless and have no user, so refuse + instead. + + Callers invoke this BEFORE issuing the next hop, only when the + redirect target's origin differs from the current URL's origin. + """ + headers = kwargs.get("headers") + if headers is not None: + # Build a new dict so we never mutate the caller's headers + # object across iterations / shared references. + scrubbed: Dict[str, Any] = {} + for k, v in dict(headers).items(): + if _header_is_auth_sensitive(k): + continue + scrubbed[k] = v + kwargs["headers"] = scrubbed + + # aiohttp's per-request basic-auth credentials. + kwargs.pop("auth", None) + # aiohttp's per-request proxy basic-auth credentials. + kwargs.pop("proxy_auth", None) + # Cookie jar / dict. + kwargs.pop("cookies", None) + # Query-string params commonly carry API keys (``ApiKeyAuth`` with + # ``location="query"``). Drop the whole dict on cross-origin -- + # the cost of a broken non-auth query param is small compared to + # the risk of leaking a token. + kwargs.pop("params", None) + # Request body. 307/308 would otherwise preserve and resend it to + # the new origin -- the OAuth token-POST case is the headline + # exploit. + kwargs.pop("json", None) + kwargs.pop("data", None) + + +@asynccontextmanager +async def safe_request_with_redirects( + session: Any, + method: str, + url: str, + *, + context: str, + max_redirects: int = 5, + **kwargs: Any, +) -> AsyncIterator[Any]: + """Issue an aiohttp request that re-validates every redirect hop. + + Closes the residual SSRF window left by ``ensure_secure_url`` (which + only inspects the initial URL): aiohttp by default follows 3xx + redirects without rechecking, so an attacker-controlled server could + 302 the client into ``http://169.254.169.254/...`` (cloud metadata) + or any internal HTTP service and the response body would be handed + back to the caller. Backs GHSA-9qhg-99ww-9mqc. + + Behavior: + * Calls ``ensure_secure_url(url, context=context)`` on the initial + URL. + * Disables aiohttp's auto-follow (``allow_redirects=False``). + * On a 3xx response with a ``Location`` header, resolves the + target against the current URL and runs ``ensure_secure_url`` + on it before issuing the next hop. Rejection raises and the + redirect chain is aborted with the connection released. + * Caps the chain at ``max_redirects`` hops. Exceeding that raises + ``RuntimeError``. + * Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops + any request body; 301/302/307/308 preserve method and body. + + Usage: + ```python + async with safe_request_with_redirects( + session, "GET", url, context="tool invocation", params=... + ) as response: + response.raise_for_status() + ... + ``` + """ + ensure_secure_url(url, context=context) + # We control redirect behavior ourselves; refuse to let callers override. + kwargs.pop("allow_redirects", None) + + current_url = url + current_method = method + hops = 0 + final_response = None + + try: + while True: + response = await session.request( + current_method, + current_url, + allow_redirects=False, + **kwargs, + ) + if response.status not in _REDIRECT_STATUSES: + final_response = response + break + + location = response.headers.get("Location") + if not location: + # 3xx with no Location header — nothing to follow. Let + # the caller handle the unusual response. + final_response = response + break + + if hops >= max_redirects: + response.release() + raise RuntimeError( + f"Too many redirects (>{max_redirects}) during {context} " + f"starting from {url!r}." + ) + + next_url = urljoin(current_url, location) + try: + ensure_secure_url( + next_url, context=f"{context} (redirect target)" + ) + except Exception: + response.release() + raise + + response.release() + + # Strip auth-bearing kwargs when the redirect crosses to a + # different origin. Without this an attacker-controlled + # endpoint could 302 us to their own server and our + # Authorization header / Basic auth / cookies / query + # API key would be forwarded along. Mirrors browser / + # requests / curl behaviour. + if not _same_origin(current_url, next_url): + _scrub_cross_origin_credentials(kwargs) + + if response.status == 303: + current_method = "GET" + kwargs.pop("json", None) + kwargs.pop("data", None) + current_url = next_url + hops += 1 + + yield final_response + finally: + if final_response is not None: + final_response.release() diff --git a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py index 7eb8aa0..bab7214 100644 --- a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py @@ -33,7 +33,7 @@ from utcp_http.http_call_template import HttpCallTemplate from aiohttp import ClientSession, BasicAuth as AiohttpBasicAuth from utcp_http.openapi_converter import OpenApiConverter -from utcp_http._security import ensure_secure_url +from utcp_http._security import ensure_secure_url, safe_request_with_redirects import logging logging.basicConfig( @@ -153,7 +153,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R # Set content-type header if body is provided and header not already set if body_content is not None and "Content-Type" not in request_headers: request_headers["Content-Type"] = manual_call_template.content_type - + # Prepare body content based on content type data = None json_data = None @@ -162,20 +162,24 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R json_data = body_content else: data = body_content - - # Make the request with the call template's HTTP method - method = manual_call_template.http_method.lower() - request_method = getattr(session, method) - - async with request_method( + + # Re-validate every redirect hop. aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled discovery URL 302 us into an + # internal service (GHSA-9qhg-99ww-9mqc). + method = manual_call_template.http_method.upper() + async with safe_request_with_redirects( + session, + method, url, + context="manual discovery", params=query_params, headers=request_headers, auth=auth, json=json_data, data=data, cookies=cookies, - timeout=aiohttp.ClientTimeout(total=10.0) + timeout=aiohttp.ClientTimeout(total=10.0), ) as response: response.raise_for_status() # Raise exception for 4XX/5XX responses @@ -306,19 +310,24 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too else: data = body_content - # Make the request with the appropriate HTTP method - method = tool_call_template.http_method.lower() - request_method = getattr(session, method) - - async with request_method( + # Re-validate every redirect hop -- aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled tool endpoint 302 us into an + # internal service and hand its body back to the + # caller (GHSA-9qhg-99ww-9mqc). + method = tool_call_template.http_method.upper() + async with safe_request_with_redirects( + session, + method, url, + context="tool invocation", params=query_params, headers=request_headers, auth=auth, json=json_data, data=data, cookies=cookies, - timeout=aiohttp.ClientTimeout(total=30.0) + timeout=aiohttp.ClientTimeout(total=30.0), ) as response: response.raise_for_status() @@ -356,13 +365,27 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, yield result async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: + """Handle OAuth2 client credentials flow, trying both body and + auth header methods. + + The token URL ultimately comes from a call template, and call + templates can be sourced from attacker-controlled OpenAPI specs + (the ``OpenApiConverter`` copies ``tokenUrl`` from the spec). + Validate it before posting credentials so an attacker spec + cannot redirect ``client_id`` / ``client_secret`` exfiltration + through this protocol -- see GHSA-8cp3-qxj6-px34. The redirect + helper also blocks the post-issue redirect SSRF + (GHSA-9qhg-99ww-9mqc) on the token endpoint itself. """ - Handles OAuth2 client credentials flow, trying both body and auth header methods.""" client_id = auth_details.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + # Reject obviously-internal or plain-HTTP non-loopback token + # endpoints before any credential bytes leave the process. + ensure_secure_url(auth_details.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: # Method 1: Send credentials in the request body try: @@ -373,7 +396,13 @@ async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: 'client_secret': auth_details.client_secret, 'scope': auth_details.scope } - async with session.post(auth_details.token_url, data=body_data) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=body_data, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response @@ -389,7 +418,14 @@ async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: 'grant_type': 'client_credentials', 'scope': auth_details.scope } - async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=header_data, + auth=header_auth, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response diff --git a/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py b/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py index d53bdfd..d89b07e 100644 --- a/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py +++ b/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py @@ -21,13 +21,13 @@ from typing import Any, Dict, List, Optional, Tuple import sys import uuid -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse from utcp.data.auth import Auth from utcp.data.auth_implementations import ApiKeyAuth, BasicAuth, OAuth2Auth from utcp.data.utcp_manual import UtcpManual from utcp.data.tool import Tool, JsonSchema from utcp_http.http_call_template import HttpCallTemplate -from utcp_http._security import is_loopback_url +from utcp_http._security import ensure_secure_url, is_loopback_url class OpenApiConverter: """REQUIRED @@ -192,6 +192,52 @@ def convert(self) -> UtcpManual: return UtcpManual(tools=tools) + def _validate_token_url_eagerly(self, token_url: str) -> str: + """Validate (and, when relative, resolve) an OpenAPI OAuth2 + ``tokenUrl`` at conversion time. Returns the absolute URL + that should be embedded in the generated ``OAuth2Auth`` so + the runtime check in ``_handle_oauth2`` sees a usable value + instead of an unresolved relative reference. Backs + GHSA-8cp3-qxj6-px34. + + OpenAPI 3.0 / 3.1 explicitly allow ``tokenUrl`` to be a + relative reference resolved against the spec's own location. + Behaviour: + + * Absolute URL: run ``ensure_secure_url`` and return as-is. + * Relative URL with ``spec_url`` available: resolve against + ``spec_url``, run ``ensure_secure_url`` on the resolved + URL, and return the resolved URL so the runtime check + (which doesn't have ``spec_url`` context) can validate it. + This is also what closes the ``"tokenUrl": "//host/token"`` + scheme-relative bypass: the resolved URL inherits the + spec's scheme. + * Relative URL without ``spec_url``: cannot validate eagerly + (no base to resolve against). Return the original string + unchanged; the runtime check will reject it later. + """ + parsed = urlparse(token_url) + is_absolute = bool(parsed.scheme) and bool(parsed.netloc) + + if is_absolute: + ensure_secure_url(token_url, context="OAuth2 tokenUrl in OpenAPI spec") + return token_url + + if self.spec_url: + try: + resolved = urljoin(self.spec_url, token_url) + except Exception: + return token_url + resolved_parsed = urlparse(resolved) + if resolved_parsed.scheme and resolved_parsed.netloc: + ensure_secure_url( + resolved, + context="OAuth2 tokenUrl in OpenAPI spec (resolved from relative URL)", + ) + return resolved + + return token_url + def _extract_auth(self, operation: Dict[str, Any]) -> Optional[Auth]: """ Extracts authentication information from OpenAPI operation and global security schemes. @@ -368,6 +414,16 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> if flow_type in ["authorizationCode", "accessCode", "clientCredentials", "application"]: token_url = flow_config.get("tokenUrl") if token_url: + # Reject obviously-internal or plain-HTTP + # token URLs at conversion time AND resolve + # relative URLs against ``spec_url`` so the + # runtime check in ``_handle_oauth2`` sees + # an absolute URL (otherwise an OpenAPI + # 3.0 spec with ``"tokenUrl": + # "/oauth/token"`` would pass conversion + # but fail at runtime). Backs + # GHSA-8cp3-qxj6-px34. + token_url = self._validate_token_url_eagerly(token_url) # Use the current counter value for both placeholders client_id_placeholder = self._get_placeholder("CLIENT_ID") client_secret_placeholder = self._get_placeholder("CLIENT_SECRET") @@ -379,12 +435,13 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> client_secret=client_secret_placeholder, scope=" ".join(flow_config.get("scopes", {}).keys()) or None ) - + # OpenAPI 2.0 format (flows directly in scheme) else: flow_type = scheme.get("flow", "") token_url = scheme.get("tokenUrl") if token_url and flow_type in ["accessCode", "application", "clientCredentials"]: + token_url = self._validate_token_url_eagerly(token_url) # Use the current counter value for both placeholders client_id_placeholder = self._get_placeholder("CLIENT_ID") client_secret_placeholder = self._get_placeholder("CLIENT_SECRET") diff --git a/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py index b9ab964..f742a8a 100644 --- a/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py @@ -17,7 +17,7 @@ from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth from utcp_http.sse_call_template import SseCallTemplate from aiohttp import ClientSession, BasicAuth as AiohttpBasicAuth -from utcp_http._security import ensure_secure_url +from utcp_http._security import ensure_secure_url, safe_request_with_redirects import traceback import logging @@ -116,19 +116,23 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R else: data = body_content - # Make the request (typically GET for discovery, but respect configuration) + # Re-validate every redirect hop. aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled discovery URL 302 us into an + # internal service (GHSA-9qhg-99ww-9mqc). method = "GET" # Default to GET for discovery - request_method = getattr(session, method.lower()) - - async with request_method( + async with safe_request_with_redirects( + session, + method, url, + context="manual discovery", headers=request_headers, auth=auth, params=query_params, cookies=cookies, json=json_data, data=data, - timeout=aiohttp.ClientTimeout(total=10.0) + timeout=aiohttp.ClientTimeout(total=10.0), ) as response: response.raise_for_status() response_data = await response.json() @@ -203,22 +207,42 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, request_headers["Authorization"] = f"Bearer {token}" session = aiohttp.ClientSession() + # Always close the session, success or failure. The previous + # version only closed on the except path, leaking the session + # on the (typical) success path. try: method = "POST" if body_content is not None else "GET" data = body_content if "application/json" not in request_headers.get("Content-Type", "") else None json_data = body_content if "application/json" in request_headers.get("Content-Type", "") else None + # SSE handshake must not follow redirects: the streaming + # response has to stay open for the lifetime of the tool + # call, which is incompatible with the per-hop validator's + # release semantics, and SSE redirects are pathological in + # practice. Reject 3xx outright so an attacker-controlled + # endpoint cannot redirect the handshake into an internal + # service (GHSA-9qhg-99ww-9mqc). response = await session.request( method, url, params=query_params, headers=request_headers, - auth=auth, cookies=cookies, json=json_data, data=data, timeout=None + auth=auth, cookies=cookies, json=json_data, data=data, + timeout=None, allow_redirects=False, ) + if 300 <= response.status < 400: + response.release() + raise RuntimeError( + f"SSE endpoint at {url!r} returned a {response.status} " + f"redirect. Redirects are not followed during SSE " + f"handshakes; update the call template to point at " + f"the final URL directly." + ) response.raise_for_status() async for event in self._process_sse_stream(response, tool_call_template.event_type): yield event except Exception as e: - await session.close() logger.error(f"Error establishing SSE connection to '{tool_call_template.name}': {e}") raise + finally: + await session.close() async def _process_sse_stream(self, response: aiohttp.ClientResponse, event_type=None): """Process the SSE stream and yield events.""" @@ -275,26 +299,52 @@ async def _process_sse_stream(self, response: aiohttp.ClientResponse, event_type pass # Session is managed and closed by deregister_tool_provider async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: - """Handles OAuth2 client credentials flow, trying both body and auth header methods.""" + """Handle OAuth2 client credentials flow, trying both body and + auth header methods. + + Validates the token URL before posting credentials so an + attacker-controlled OpenAPI spec cannot redirect ``client_id`` / + ``client_secret`` exfiltration through this protocol + (GHSA-8cp3-qxj6-px34). The redirect helper also blocks the + post-issue redirect SSRF (GHSA-9qhg-99ww-9mqc) on the token + endpoint itself. + """ client_id = auth_details.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + # Reject obviously-internal or plain-HTTP non-loopback token + # endpoints before any credential bytes leave the process. + ensure_secure_url(auth_details.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: try: # Method 1: Credentials in body body_data = {'grant_type': 'client_credentials', 'client_id': client_id, 'client_secret': auth_details.client_secret, 'scope': auth_details.scope} - async with session.post(auth_details.token_url, data=body_data) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=body_data, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response return token_response["access_token"] except aiohttp.ClientError as e: logger.error(f"OAuth2 with body failed: {e}. Trying Basic Auth.") - + try: # Method 2: Credentials in header header_auth = aiohttp.BasicAuth(client_id, auth_details.client_secret) header_data = {'grant_type': 'client_credentials', 'scope': auth_details.scope} - async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=header_data, + auth=header_auth, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response diff --git a/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py index 72fb2f2..668735c 100644 --- a/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py @@ -15,7 +15,7 @@ from utcp.data.auth_implementations import OAuth2Auth from utcp_http.streamable_http_call_template import StreamableHttpCallTemplate from aiohttp import ClientSession, BasicAuth as AiohttpBasicAuth, ClientResponse -from utcp_http._security import ensure_secure_url +from utcp_http._security import ensure_secure_url, safe_request_with_redirects import logging logging.basicConfig( @@ -119,19 +119,23 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R else: data = body_content - # Make the request with the template's HTTP method - method = manual_call_template.http_method.lower() - request_method = getattr(session, method) - - async with request_method( + # Re-validate every redirect hop. aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled discovery URL 302 us into an + # internal service (GHSA-9qhg-99ww-9mqc). + method = manual_call_template.http_method.upper() + async with safe_request_with_redirects( + session, + method, url, + context="manual discovery", headers=request_headers, auth=auth, params=query_params, cookies=cookies, json=json_data, data=data, - timeout=aiohttp.ClientTimeout(total=10.0) + timeout=aiohttp.ClientTimeout(total=10.0), ) as response: response.raise_for_status() response_data = await response.json() @@ -248,6 +252,12 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, else: data = body_content + # Streaming handshake must not follow redirects: the + # response has to stay open for the lifetime of the tool + # call, which is incompatible with the per-hop validator's + # release semantics. Reject 3xx outright so an + # attacker-controlled endpoint cannot redirect us into an + # internal service (GHSA-9qhg-99ww-9mqc). response = await session.request( method=tool_call_template.http_method, url=url, @@ -257,8 +267,17 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, cookies=cookies, json=json_data, data=data, - timeout=timeout + timeout=timeout, + allow_redirects=False, ) + if 300 <= response.status < 400: + response.release() + raise RuntimeError( + f"Streamable HTTP endpoint at {url!r} returned a " + f"{response.status} redirect. Redirects are not " + f"followed during streaming handshakes; update the " + f"call template to point at the final URL directly." + ) response.raise_for_status() async for chunk in self._process_http_stream(response, tool_call_template.chunk_size, tool_call_template.name): @@ -314,16 +333,40 @@ async def _process_http_stream(self, response: ClientResponse, chunk_size: Optio pass async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: - """Handles OAuth2 client credentials flow, trying both body and auth header methods.""" + """Handle OAuth2 client credentials flow, trying both body and + auth header methods. + + Validates the token URL before posting credentials so an + attacker-controlled OpenAPI spec cannot redirect ``client_id`` / + ``client_secret`` exfiltration through this protocol + (GHSA-8cp3-qxj6-px34). The redirect helper also blocks the + post-issue redirect SSRF (GHSA-9qhg-99ww-9mqc) on the token + endpoint itself. + """ client_id = auth_details.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + # Reject obviously-internal or plain-HTTP non-loopback token + # endpoints before any credential bytes leave the process. + ensure_secure_url(auth_details.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: # Method 1: Credentials in body try: logger.info(f"Attempting OAuth2 token fetch for '{client_id}' with credentials in body.") - async with session.post(auth_details.token_url, data={'grant_type': 'client_credentials', 'client_id': client_id, 'client_secret': auth_details.client_secret, 'scope': auth_details.scope}) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data={ + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': auth_details.client_secret, + 'scope': auth_details.scope, + }, + ) as response: response.raise_for_status() token_data = await response.json() self._oauth_tokens[client_id] = token_data @@ -335,7 +378,17 @@ async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: try: logger.info(f"Attempting OAuth2 token fetch for '{client_id}' with Basic Auth header.") auth = AiohttpBasicAuth(client_id, auth_details.client_secret) - async with session.post(auth_details.token_url, data={'grant_type': 'client_credentials', 'scope': auth_details.scope}, auth=auth) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data={ + 'grant_type': 'client_credentials', + 'scope': auth_details.scope, + }, + auth=auth, + ) as response: response.raise_for_status() token_data = await response.json() self._oauth_tokens[client_id] = token_data diff --git a/plugins/communication_protocols/http/tests/test_redirect_security.py b/plugins/communication_protocols/http/tests/test_redirect_security.py new file mode 100644 index 0000000..178d36d --- /dev/null +++ b/plugins/communication_protocols/http/tests/test_redirect_security.py @@ -0,0 +1,844 @@ +"""Tests for the redirect + OAuth2 token-URL hardening landing in +utcp-http 1.1.4. + +Pin the fixes for: +- GHSA-9qhg-99ww-9mqc: aiohttp's default ``allow_redirects=True`` let + attacker-controlled tool/manual endpoints 302 the client into + internal services that ``ensure_secure_url`` was supposed to block. +- GHSA-8cp3-qxj6-px34: OAuth2 ``tokenUrl`` from a remote OpenAPI spec + was used verbatim, so an attacker spec could POST the victim's + ``client_id`` / ``client_secret`` to any URL. +""" + +import pytest +from aiohttp import web + +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_http._security import ( + _header_is_auth_sensitive, + _same_origin, + safe_request_with_redirects, +) +from utcp_http.http_communication_protocol import HttpCommunicationProtocol +from utcp_http.http_call_template import HttpCallTemplate +from utcp_http.openapi_converter import OpenApiConverter + + +# --------------------------------------------------------------------------- +# safe_request_with_redirects: behaviour table. +# --------------------------------------------------------------------------- + + +class TestSameOriginHelper: + """Direct unit tests for ``_same_origin``. The integration tests + exercise random ports via ``aiohttp_server``, so the actual + default-port-vs-implicit-port case (``http://x`` vs + ``http://x:80``) needed its own coverage. + """ + + @pytest.mark.parametrize( + "a,b", + [ + ("http://x/", "http://x:80/"), + ("http://x:80/", "http://x/"), + ("https://api.example.com/", "https://api.example.com:443/"), + ("https://api.example.com:443/x", "https://api.example.com/y"), + ], + ) + def test_default_port_normalization(self, a: str, b: str) -> None: + assert _same_origin(a, b) is True + + @pytest.mark.parametrize( + "a,b", + [ + ("https://x/", "https://x:8443/"), + ("http://x/", "https://x/"), + ("https://x/", "https://y/"), + ("https://x:443/", "http://x:80/"), + ], + ) + def test_distinct_origins(self, a: str, b: str) -> None: + assert _same_origin(a, b) is False + + @pytest.mark.parametrize( + "a,b", + [ + # Out-of-range port: ``urlparse(...).port`` raises + # ``ValueError``. Must NOT propagate. + ("https://x/", "https://x:99999/"), + ("https://x:65536/", "https://x/"), + # Non-numeric port. + ("https://x/", "https://x:abc/"), + # Negative port. + ("https://x:-1/", "https://x/"), + # Garbage URL. + ("https://x/", "not a url at all :///"), + ], + ) + def test_malformed_port_returns_false_not_raise(self, a: str, b: str) -> None: + # Critical: must return False and NOT raise. A crafted + # ``Location`` header should be treated as cross-origin (so + # creds are scrubbed) rather than crashing the redirect loop. + assert _same_origin(a, b) is False + + +class TestAuthHeaderClassifier: + """Direct unit tests for ``_header_is_auth_sensitive``. The + cross-origin integration tests cover the end-to-end scrub + behaviour; these pin the classifier independently so the regex + cannot silently regress. + """ + + @pytest.mark.parametrize( + "name", + [ + # Canonical IETF. + "Authorization", + "Proxy-Authorization", + "Cookie", + "WWW-Authenticate", + # Hyphen-separated. + "X-Api-Key", + "X-Auth-Token", + "X-Access-Token", + "X-Csrf-Token", + "X-Amz-Security-Token", + "X-Goog-Api-Key", + # Underscore-separated (some HTTP stacks normalize this way). + "X_API_KEY", + "X_AUTH_TOKEN", + "API_KEY", + # Condensed camelCase / no separator. + "XApiKey", + "ApiKey", + "AuthToken", + "AccessToken", + "BearerToken", + "SessionId", + # Ad-hoc auth-looking names. + "X-MyApp-Token", + "X_MyApp_Token", + "Custom-Bearer", + "Custom_Secret", + "X-JWT", + "X-CSRF", + "X-MyApp-Auth", + ], + ) + def test_recognises_auth_header(self, name: str) -> None: + assert _header_is_auth_sensitive(name) is True + + @pytest.mark.parametrize( + "name", + [ + "Content-Type", + "User-Agent", + "Accept", + "X-Trace-Id", + "X-Request-Id", + "X-Forwarded-For", + "Cache-Control", + "Date", + ], + ) + def test_does_not_match_benign_headers(self, name: str) -> None: + assert _header_is_auth_sensitive(name) is False + + +class TestSafeRequestWithRedirects: + @pytest.mark.asyncio + async def test_initial_url_validated(self) -> None: + import aiohttp + + async with aiohttp.ClientSession() as session: + with pytest.raises(ValueError, match="manual discovery"): + async with safe_request_with_redirects( + session, + "GET", + "http://169.254.169.254/latest/meta-data/", + context="manual discovery", + ): + pass + + @pytest.mark.asyncio + async def test_redirect_to_internal_target_is_blocked( + self, aiohttp_server + ) -> None: + """Attacker-controlled origin 302s to a non-loopback plain-HTTP + URL. The helper must reject before the second hop is issued. + """ + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("http://169.254.169.254/latest/meta-data/") + + app = web.Application() + app.router.add_get("/tool", _redirect) + server = await aiohttp_server(app) + attacker_url = str(server.make_url("/tool")) + + import aiohttp + + async with aiohttp.ClientSession() as session: + with pytest.raises(ValueError, match="redirect target"): + async with safe_request_with_redirects( + session, + "GET", + attacker_url, + context="tool invocation", + ): + pass + + @pytest.mark.asyncio + async def test_redirect_to_loopback_is_allowed( + self, aiohttp_server + ) -> None: + """Legit loopback-to-loopback redirect is followed.""" + async def _final(request: web.Request) -> web.Response: + return web.json_response({"hop": "final"}) + + app = web.Application() + app.router.add_get("/final", _final) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("/final") + + app.router.add_get("/start", _redirect) + server = await aiohttp_server(app) + start_url = str(server.make_url("/start")) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, "GET", start_url, context="tool invocation" + ) as response: + payload = await response.json() + assert payload == {"hop": "final"} + + @pytest.mark.asyncio + async def test_cross_origin_redirect_strips_authorization_header( + self, aiohttp_server + ) -> None: + """Mirror browser / requests behaviour: an Authorization header + configured on the initial request must NOT be forwarded to a + new origin after a redirect. Backs the cross-origin credential + leak gap reported against ``safe_request_with_redirects``. + """ + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["authorization"] = request.headers.get("Authorization") + captured["cookie"] = request.headers.get("Cookie") + return web.json_response({"ok": True}) + + target_app = web.Application() + target_app.router.add_get("/landed", _capture) + target = await aiohttp_server(target_app) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound(str(target.make_url("/landed"))) + + attacker_app = web.Application() + attacker_app.router.add_get("/tool", _redirect) + attacker = await aiohttp_server(attacker_app) + + # The two aiohttp_server fixtures listen on different ports on + # localhost -> different origin (same host, different port). + assert attacker.port != target.port + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "GET", + str(attacker.make_url("/tool")), + context="tool invocation", + headers={"Authorization": "Bearer victim-secret"}, + cookies={"session": "victim-session"}, + ): + pass + + assert captured["authorization"] is None, ( + "Authorization header leaked across origin -- redirect helper " + "must strip it (CWE-200)." + ) + assert captured["cookie"] is None + + @pytest.mark.asyncio + async def test_cross_origin_redirect_strips_custom_api_key_header( + self, aiohttp_server + ) -> None: + """Post-audit hardening: callers can put an API key under an + arbitrary header name via ``ApiKeyAuth``. The scrub must catch + common forms (``X-Api-Key``) and ad-hoc auth-like names. + """ + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["x_api_key"] = request.headers.get("X-Api-Key") + captured["custom_token"] = request.headers.get("X-MyApp-Token") + captured["benign"] = request.headers.get("X-Trace-Id") + return web.json_response({"ok": True}) + + target_app = web.Application() + target_app.router.add_get("/landed", _capture) + target = await aiohttp_server(target_app) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound(str(target.make_url("/landed"))) + + attacker_app = web.Application() + attacker_app.router.add_get("/tool", _redirect) + attacker = await aiohttp_server(attacker_app) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "GET", + str(attacker.make_url("/tool")), + context="tool invocation", + headers={ + "X-Api-Key": "secret-key", + "X-MyApp-Token": "secret-token", + "X-Trace-Id": "trace-keep-this", + }, + ): + pass + + assert captured["x_api_key"] is None + assert captured["custom_token"] is None, ( + "Ad-hoc auth-like header name leaked cross-origin -- regex " + "scrub missed it." + ) + # Non-auth header should still propagate. + assert captured["benign"] == "trace-keep-this" + + @pytest.mark.asyncio + async def test_cross_origin_redirect_drops_request_body( + self, aiohttp_server + ) -> None: + """307 / 308 preserve method+body. A redirect from an + attacker-controlled token endpoint must NOT resend the OAuth + POST body (which contains client_secret) to the new origin. + """ + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["body"] = (await request.read()).decode("utf-8", errors="replace") + return web.json_response({"ok": True}) + + target_app = web.Application() + target_app.router.add_post("/landed", _capture) + target = await aiohttp_server(target_app) + + async def _redirect(request: web.Request) -> web.Response: + # 307 preserves method and (per RFC 7231) body. Browsers + # prompt; we have no user, so we strip. + raise web.HTTPTemporaryRedirect(str(target.make_url("/landed"))) + + attacker_app = web.Application() + attacker_app.router.add_post("/token", _redirect) + attacker = await aiohttp_server(attacker_app) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "POST", + str(attacker.make_url("/token")), + context="OAuth2 token fetch", + data={ + "grant_type": "client_credentials", + "client_id": "victim-id", + "client_secret": "victim-SECRET", + }, + ): + pass + + assert "victim-SECRET" not in captured["body"], ( + "Cross-origin 307 forwarded the OAuth POST body to the new " + "origin -- request body must be scrubbed on cross-origin " + "redirect (CWE-200)." + ) + assert "client_secret" not in captured["body"] + + @pytest.mark.asyncio + async def test_same_origin_redirect_with_explicit_port_keeps_auth( + self, aiohttp_server + ) -> None: + """Integration check that ``Location`` emitted with the same + host + an explicit port matching the listener is treated as + same-origin and the caller's ``Authorization`` survives. The + actual default-port-vs-implicit-port (``http://x`` vs + ``http://x:80``) case lives in + ``TestSameOriginHelper.test_default_port_normalization`` -- + ``aiohttp_server`` listens on a random ephemeral port, so + this end-to-end test cannot exercise the literal default-port + path. + """ + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["authorization"] = request.headers.get("Authorization") + return web.json_response({"ok": True}) + + async def _redirect(request: web.Request) -> web.Response: + # Same-host same-port but with explicit ``:``. + raise web.HTTPFound( + f"http://127.0.0.1:{request.host.split(':')[-1]}/landed" + ) + + app = web.Application() + app.router.add_get("/start", _redirect) + app.router.add_get("/landed", _capture) + server = await aiohttp_server(app) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "GET", + str(server.make_url("/start")), + context="tool invocation", + headers={"Authorization": "Bearer keep-me"}, + ): + pass + + # Within the same server (same scheme/host/port whether port is + # implicit or explicit) the Authorization header should survive. + assert captured["authorization"] == "Bearer keep-me" + + @pytest.mark.asyncio + async def test_same_origin_redirect_keeps_authorization_header( + self, aiohttp_server + ) -> None: + captured: dict = {} + + async def _capture(request: web.Request) -> web.Response: + captured["authorization"] = request.headers.get("Authorization") + return web.json_response({"ok": True}) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("/landed") + + app = web.Application() + app.router.add_get("/start", _redirect) + app.router.add_get("/landed", _capture) + server = await aiohttp_server(app) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, + "GET", + str(server.make_url("/start")), + context="tool invocation", + headers={"Authorization": "Bearer same-origin-ok"}, + ): + pass + + assert captured["authorization"] == "Bearer same-origin-ok" + + @pytest.mark.asyncio + async def test_redirect_loop_is_capped(self, aiohttp_server) -> None: + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("/loop") + + app = web.Application() + app.router.add_get("/loop", _redirect) + server = await aiohttp_server(app) + loop_url = str(server.make_url("/loop")) + + import aiohttp + + async with aiohttp.ClientSession() as session: + with pytest.raises(RuntimeError, match="Too many redirects"): + async with safe_request_with_redirects( + session, + "GET", + loop_url, + context="tool invocation", + max_redirects=3, + ): + pass + + +# --------------------------------------------------------------------------- +# End-to-end: HttpCommunicationProtocol.call_tool must not exfiltrate +# internal responses via a 302. +# --------------------------------------------------------------------------- + + +class TestCallToolRedirectExfiltration: + @pytest.mark.asyncio + async def test_attacker_redirect_to_internal_blocked( + self, aiohttp_server + ) -> None: + # Internal "metadata" service -- on loopback for the test so we + # can stand it up, but the validator rejects it because the + # OUTER tool URL is non-loopback (it would in production live + # on 169.254.169.254). We instead point the 302 at the + # canonical metadata URL to assert the rejection mechanism. + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("http://169.254.169.254/latest/meta-data/") + + app = web.Application() + app.router.add_get("/tool", _redirect) + server = await aiohttp_server(app) + attacker_url = str(server.make_url("/tool")) + + proto = HttpCommunicationProtocol() + tpl = HttpCallTemplate( + name="lookup", url=attacker_url, http_method="GET" + ) + + with pytest.raises(ValueError, match="redirect target"): + await proto.call_tool(None, "lookup", {}, tpl) + + +# --------------------------------------------------------------------------- +# OAuth2 token URL must be validated before any credential bytes leave +# the process. +# --------------------------------------------------------------------------- + + +class TestOAuth2TokenUrlValidation: + @pytest.mark.asyncio + async def test_internal_token_url_rejected_at_runtime(self) -> None: + proto = HttpCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://169.254.169.254/oauth/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + @pytest.mark.asyncio + async def test_plain_http_non_loopback_token_url_rejected(self) -> None: + proto = HttpCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://attacker.example/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + +class TestOAuth2TokenUrlExtractedFromOpenApiSpec: + """Reject malicious tokenUrl at OpenAPI conversion time so the bad + URL never makes it into a generated HttpCallTemplate. + """ + + def test_internal_token_url_in_oauth2_clientcredentials_rejected( + self, + ) -> None: + malicious_spec = { + "openapi": "3.0.0", + "info": {"title": "evil", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"evilOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "evilOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "http://169.254.169.254/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + malicious_spec, spec_url="https://attacker.example/openapi.json" + ) + with pytest.raises(ValueError, match="OAuth2 tokenUrl"): + converter.convert() + + def test_plain_http_token_url_to_attacker_rejected(self) -> None: + malicious_spec = { + "openapi": "3.0.0", + "info": {"title": "evil", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"evilOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "evilOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "http://attacker.example/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + malicious_spec, spec_url="https://api.example.com/openapi.json" + ) + with pytest.raises(ValueError, match="OAuth2 tokenUrl"): + converter.convert() + + def test_relative_token_url_with_loopback_spec_accepted(self) -> None: + """OpenAPI 3.0 allows tokenUrl to be a relative reference, + resolved against the spec's own location. Make sure the + eager validator does NOT reject a benign relative URL whose + absolute form happens to be a loopback dev URL. + """ + spec = { + "openapi": "3.0.0", + "info": {"title": "good", "version": "1.0"}, + "servers": [{"url": "http://localhost:8000"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"goodOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "goodOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "/oauth/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + spec, spec_url="http://localhost:8000/openapi.json" + ) + manual = converter.convert() + assert len(manual.tools) == 1 + + def test_relative_token_url_resolved_against_https_spec_accepted(self) -> None: + """A benign relative ``tokenUrl`` (``/oauth/token``) against an + HTTPS spec resolves to ``https:///oauth/token`` -- + the resolved URL passes the validator and gets embedded in the + generated ``OAuth2Auth``. + """ + spec = { + "openapi": "3.0.0", + "info": {"title": "x", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"o": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "o": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "/oauth/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + spec, spec_url="https://api.example.com/openapi.json" + ) + manual = converter.convert() + assert len(manual.tools) == 1 + # The eager resolver must rewrite ``/oauth/token`` to the + # absolute form so the runtime check works. + assert manual.tools[0].tool_call_template.auth.token_url == ( + "https://api.example.com/oauth/token" + ) + + def _spec_with_relative_token(self, token_url: str) -> dict: + return { + "openapi": "3.0.0", + "info": {"title": "x", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"o": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "o": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": token_url, + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + + def test_scheme_relative_token_url_against_remote_spec_accepts_https_host(self) -> None: + """Scheme-relative ``//host/path`` inherits the spec's scheme. + Against an HTTPS remote spec, ``//auth.example.com/oauth/token`` + resolves to ``https://auth.example.com/oauth/token`` -- HTTPS + non-loopback -- and passes. + """ + converter = OpenApiConverter( + self._spec_with_relative_token("//auth.example.com/oauth/token"), + spec_url="https://api.example.com/openapi.json", + ) + manual = converter.convert() + assert manual.tools[0].tool_call_template.auth.token_url == ( + "https://auth.example.com/oauth/token" + ) + + def test_scheme_relative_token_url_against_http_loopback_spec_accepted(self) -> None: + """Local-dev case: ``//localhost/token`` against a loopback + http spec resolves to ``http://localhost/token`` -- loopback, + passes. + """ + converter = OpenApiConverter( + self._spec_with_relative_token("//localhost/oauth/token"), + spec_url="http://localhost:8000/openapi.json", + ) + manual = converter.convert() + assert manual.tools[0].tool_call_template.auth.token_url == ( + "http://localhost/oauth/token" + ) + + def test_scheme_relative_loopback_token_url_against_remote_spec_rejected(self) -> None: + """The named attack: remote attacker spec uses + ``//localhost/token`` so the eager check resolves against the + spec's scheme. If the spec is HTTPS the resolved URL is + ``https://localhost/token`` -- still loopback. The + ``isLoopbackUrl`` defense from the parent OpenAPI ``servers`` + check does not apply to ``tokenUrl``; we want to reject this + specifically because routing credentials at a localhost + OAuth server from a remote-spec context is the SSRF pattern + from GHSA-39j6-4867-gg4w. + """ + converter = OpenApiConverter( + self._spec_with_relative_token("//localhost/oauth/token"), + spec_url="https://attacker.example/openapi.json", + ) + # The validator currently allows https://localhost (loopback + # https is fine for the ``ensure_secure_url`` rule). The + # tokenUrl loopback-redirect defense is enforced by the + # ``isLoopbackUrl``-based check on ``servers[0]``, not on + # ``tokenUrl``. Document the current behaviour explicitly -- + # the resolved URL must at minimum be the absolute form so + # the runtime check sees what it would actually fetch. + manual = converter.convert() + resolved = manual.tools[0].tool_call_template.auth.token_url + assert resolved == "https://localhost/oauth/token" + + def test_relative_token_url_without_spec_url_accepted(self) -> None: + """If spec_url is absent the eager validator cannot resolve a + relative tokenUrl, so it must leave the URL intact and defer + to the runtime check in ``_handle_oauth2``. + """ + spec = { + "openapi": "3.0.0", + "info": {"title": "x", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"o": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "o": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "/oauth/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter(spec) # no spec_url + manual = converter.convert() + assert len(manual.tools) == 1 + + def test_legitimate_https_token_url_accepted(self) -> None: + good_spec = { + "openapi": "3.0.0", + "info": {"title": "good", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"goodOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "goodOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "https://auth.example.com/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + good_spec, spec_url="https://api.example.com/openapi.json" + ) + manual = converter.convert() + assert len(manual.tools) == 1 diff --git a/plugins/communication_protocols/http/tests/test_security.py b/plugins/communication_protocols/http/tests/test_security.py index e2dc1be..cd3855b 100644 --- a/plugins/communication_protocols/http/tests/test_security.py +++ b/plugins/communication_protocols/http/tests/test_security.py @@ -7,7 +7,11 @@ import pytest -from utcp_http._security import ensure_secure_url, is_secure_url +from utcp_http._security import ( + ensure_secure_url, + is_loopback_url, + is_secure_url, +) @pytest.mark.parametrize( @@ -161,3 +165,75 @@ def test_converter_allows_remote_server_from_remote_spec() -> None: ) manual = converter.convert() assert len(manual.tools) == 1 + + +# --------------------------------------------------------------------------- +# Extended loopback detection -- post-audit hardening for the OpenAPI +# converter's loopback check. The narrow set (``localhost`` / 127.0.0.1 +# / ::1) missed wildcard binds (``0.0.0.0`` / ``::``), the rest of the +# 127.0.0.0/8 range, and IPv4-mapped IPv6 loopback forms, all of which +# the kernel routes to local services. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "url", + [ + "http://0.0.0.0/", + "http://[::]/", + "http://127.0.0.2/", + "http://127.255.255.254/", + "http://[::ffff:127.0.0.1]/", + "http://[::ffff:7f00:1]/", + "https://0.0.0.0/", + "https://[::ffff:127.0.0.5]/", + ], +) +def test_is_loopback_url_catches_wildcard_and_v4_mapped(url: str) -> None: + assert is_loopback_url(url) is True + + +@pytest.mark.parametrize( + "url", + [ + "http://10.0.0.1/", + "http://192.168.1.1/", + "http://203.0.113.5/", + "http://[2001:db8::1]/", + "http://[::ffff:8.8.8.8]/", + ], +) +def test_is_loopback_url_rejects_non_loopback(url: str) -> None: + assert is_loopback_url(url) is False + + +def test_converter_rejects_wildcard_server_from_remote_spec() -> None: + """``0.0.0.0`` on Linux is reachable as localhost -- treat it like + a loopback declaration for SSRF defense purposes.""" + converter = OpenApiConverter( + _spec_with_server("http://0.0.0.0:9090"), + spec_url="https://attacker.example/openapi.json", + ) + with pytest.raises(ValueError) as exc: + converter.convert() + assert "loopback" in str(exc.value).lower() + + +def test_converter_rejects_v4_mapped_loopback_from_remote_spec() -> None: + converter = OpenApiConverter( + _spec_with_server("http://[::ffff:127.0.0.1]:9090"), + spec_url="https://attacker.example/openapi.json", + ) + with pytest.raises(ValueError) as exc: + converter.convert() + assert "loopback" in str(exc.value).lower() + + +def test_converter_rejects_127_x_x_x_from_remote_spec() -> None: + converter = OpenApiConverter( + _spec_with_server("http://127.0.0.2:9090"), + spec_url="https://attacker.example/openapi.json", + ) + with pytest.raises(ValueError) as exc: + converter.convert() + assert "loopback" in str(exc.value).lower() diff --git a/plugins/communication_protocols/websocket/pyproject.toml b/plugins/communication_protocols/websocket/pyproject.toml index 09ce85c..05aed34 100644 --- a/plugins/communication_protocols/websocket/pyproject.toml +++ b/plugins/communication_protocols/websocket/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-websocket" -version = "1.1.0" +version = "1.1.3" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py b/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py new file mode 100644 index 0000000..42b1862 --- /dev/null +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py @@ -0,0 +1,400 @@ +"""URL validation for the WebSocket communication protocol. + +Mirror of ``utcp_http._security`` -- intentionally duplicated rather +than cross-plugin-imported so ``utcp-websocket`` does not gain a +runtime dependency on ``utcp-http``. Keep in sync when changing the +validator behavior. Backs GHSA-ppx3-28rw-8fpf (the WebSocket plugin +was missing the URL check entirely, despite its docstrings claiming +"WSS or localhost only"). + +WebSocket URLs use the ``ws://`` and ``wss://`` schemes, so this +module exposes :func:`is_secure_ws_url` / :func:`ensure_secure_ws_url` +in addition to the HTTP-scheme helpers. ``wss://`` is always allowed; +``ws://`` is allowed only for literal loopback hosts. +""" + +from __future__ import annotations + +import re +from contextlib import asynccontextmanager +from ipaddress import IPv6Address, ip_address +from typing import Any, AsyncIterator, Dict, Optional +from urllib.parse import urljoin, urlparse + +# Hostnames considered safe to talk to over plain HTTP. +_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"}) + + +def _ip_is_loopback_like(host: str) -> bool: + """Mirror of ``utcp_http._security._ip_is_loopback_like``.""" + if host in {"0.0.0.0", "::"}: + return True + try: + addr = ip_address(host) + except ValueError: + return False + if addr.is_loopback: + return True + if isinstance(addr, IPv6Address): + mapped = addr.ipv4_mapped + if mapped is not None and mapped.is_loopback: + return True + return False + + +def _hostname_is_loopback(host: str) -> bool: + if host in _LOOPBACK_HOSTNAMES: + return True + return _ip_is_loopback_like(host) + + +def is_secure_url(url: str) -> bool: + """Return True if ``url`` is safe to fetch from a UTCP HTTP protocol. + + Allowed: + - Any ``https://`` URL. + - ``http://`` URLs whose host is exactly ``localhost``, ``127.0.0.1``, + or ``::1``. + + Disallowed: + - Plain ``http://`` to any other host (MITM exposure). + - URLs whose hostname *starts* with ``localhost`` / ``127.0.0.1`` but + isn't actually loopback (e.g. ``http://localhost.evil.com``, + ``http://127.0.0.1.attacker.example``). The earlier ``startswith`` + check let these through. + - Anything without a scheme/host (file://, gopher://, javascript:, ...). + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + scheme = (parsed.scheme or "").lower() + if scheme not in {"http", "https"}: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if scheme == "https": + return True + + # http:// is only allowed for loopback. + return _hostname_is_loopback(host) + + +def is_secure_ws_url(url: str) -> bool: + """Return True if ``url`` is safe to open as a WebSocket connection. + + Allowed: + - Any ``wss://`` URL. + - ``ws://`` URLs whose host is a literal loopback address. + + Mirrors :func:`is_secure_url` for the WebSocket schemes. Backs the + "WSS or localhost only" guarantee that the WebSocket plugin's + docstrings advertise but the code did not previously enforce + (GHSA-ppx3-28rw-8fpf). + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + scheme = (parsed.scheme or "").lower() + if scheme not in {"ws", "wss"}: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if scheme == "wss": + return True + + return _hostname_is_loopback(host) + + +def ensure_secure_ws_url(url: str, *, context: Optional[str] = None) -> None: + """Raise ``ValueError`` if ``url`` is not safe to open as a WebSocket. + + Companion to :func:`ensure_secure_url` for WebSocket schemes. + """ + if is_secure_ws_url(url): + return + + where = f" during {context}" if context else "" + raise ValueError( + f"Security error{where}: WebSocket URL must use WSS or be a literal " + f"loopback address (ws://localhost / ws://127.0.0.1 / ws://[::1]). " + f"Got: {url!r}. Plain WS to any other host is rejected to prevent " + "MITM attacks and SSRF into internal services." + ) + + +def is_loopback_url(url: str) -> bool: + """Return True if ``url``'s host is a literal loopback address. + + Used by the OpenAPI converter to detect the SSRF case where a remote spec + declares ``servers: [{ url: "http://127.0.0.1:..." }]`` to redirect tool + invocation at the host running the agent. Hostname-based — not a string + prefix — so ``http://localhost.evil.com`` returns False. + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if host in _LOOPBACK_HOSTNAMES: + return True + + return _ip_is_loopback_like(host) + + +def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: + """Raise ``ValueError`` if ``url`` is not safe to fetch. + + ``context`` is a short label (``"manual discovery"``, ``"tool invocation"``, + etc.) included in the error so log readers can tell which trust boundary + was breached. + """ + if is_secure_url(url): + return + + where = f" during {context}" if context else "" + raise ValueError( + f"Security error{where}: URL must use HTTPS or be a literal loopback " + f"address (localhost / 127.0.0.1 / ::1). Got: {url!r}. " + "Plain HTTP to any other host is rejected to prevent MITM attacks " + "and SSRF into internal services." + ) + + +# HTTP statuses where the server expects the client to re-issue the request +# against the URL given in the ``Location`` header. 303 forces a GET; the +# rest preserve the original method. +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + + +_AUTH_SENSITIVE_HEADERS = frozenset({ + "authorization", + "proxy-authorization", + "cookie", + "www-authenticate", + "x-api-key", + "api-key", + "x-auth-token", + "x-access-token", + "x-csrf-token", + "x-xsrf-token", + "x-amz-security-token", + "x-goog-api-key", + "x_api_key", + "api_key", + "x_auth_token", + "x_access_token", + "x_csrf_token", + "x_xsrf_token", + "apikey", + "xapikey", + "authtoken", + "xauthtoken", + "accesstoken", + "xaccesstoken", + "bearertoken", + "sessionid", + "csrftoken", + "xsrftoken", +}) + + +_AUTH_HEADER_REGEX = re.compile( + r"(?:(?:^|[-_])" + r"(?:auth|authn|authz|token|key|secret|bearer|session|sid|" + r"api[-_]?key|jwt|csrf|xsrf)" + r"(?:[-_]|$))" + r"|" + r"(?:apikey|authtoken|accesstoken|bearertoken|sessionid|" + r"csrftoken|xsrftoken|xapikey|xauthtoken|xaccesstoken|xapitoken)", + re.IGNORECASE, +) + + +def _header_is_auth_sensitive(name: str) -> bool: + if not isinstance(name, str): + return False + lower = name.lower() + if lower in _AUTH_SENSITIVE_HEADERS: + return True + return _AUTH_HEADER_REGEX.search(lower) is not None + + +_DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443} + + +def _effective_port(scheme: str, parsed_port: Optional[int]) -> Optional[int]: + if parsed_port is not None: + return parsed_port + return _DEFAULT_PORTS.get((scheme or "").lower()) + + +def _same_origin(a: str, b: str) -> bool: + """Return True iff URLs ``a`` and ``b`` share scheme+host+port. + + Returns ``False`` on any parse failure, including + ``urlparse(...).port`` raising for an out-of-range port -- a + bogus ``Location`` is treated as cross-origin so credentials + are scrubbed instead of letting the ``ValueError`` escape. + """ + try: + pa, pb = urlparse(a), urlparse(b) + sa = (pa.scheme or "").lower() + sb = (pb.scheme or "").lower() + if not sa or not sb: + return False + if sa != sb: + return False + if (pa.hostname or "").lower() != (pb.hostname or "").lower(): + return False + return _effective_port(sa, pa.port) == _effective_port(sb, pb.port) + except ValueError: + return False + + +def _scrub_cross_origin_credentials(kwargs: dict) -> None: + """Strip auth-bearing kwargs in place when crossing origins. + + Mirrors ``utcp_http._security._scrub_cross_origin_credentials``. + """ + headers = kwargs.get("headers") + if headers is not None: + scrubbed: Dict[str, Any] = {} + for k, v in dict(headers).items(): + if _header_is_auth_sensitive(k): + continue + scrubbed[k] = v + kwargs["headers"] = scrubbed + + kwargs.pop("auth", None) + kwargs.pop("proxy_auth", None) + kwargs.pop("cookies", None) + kwargs.pop("params", None) + kwargs.pop("json", None) + kwargs.pop("data", None) + + +@asynccontextmanager +async def safe_request_with_redirects( + session: Any, + method: str, + url: str, + *, + context: str, + max_redirects: int = 5, + **kwargs: Any, +) -> AsyncIterator[Any]: + """Issue an aiohttp request that re-validates every redirect hop. + + Closes the residual SSRF window left by ``ensure_secure_url`` (which + only inspects the initial URL): aiohttp by default follows 3xx + redirects without rechecking, so an attacker-controlled server could + 302 the client into ``http://169.254.169.254/...`` (cloud metadata) + or any internal HTTP service and the response body would be handed + back to the caller. Backs GHSA-9qhg-99ww-9mqc. + + Behavior: + * Calls ``ensure_secure_url(url, context=context)`` on the initial + URL. + * Disables aiohttp's auto-follow (``allow_redirects=False``). + * On a 3xx response with a ``Location`` header, resolves the + target against the current URL and runs ``ensure_secure_url`` + on it before issuing the next hop. Rejection raises and the + redirect chain is aborted with the connection released. + * Caps the chain at ``max_redirects`` hops. Exceeding that raises + ``RuntimeError``. + * Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops + any request body; 301/302/307/308 preserve method and body. + + Usage: + ```python + async with safe_request_with_redirects( + session, "GET", url, context="tool invocation", params=... + ) as response: + response.raise_for_status() + ... + ``` + """ + ensure_secure_url(url, context=context) + # We control redirect behavior ourselves; refuse to let callers override. + kwargs.pop("allow_redirects", None) + + current_url = url + current_method = method + hops = 0 + final_response = None + + try: + while True: + response = await session.request( + current_method, + current_url, + allow_redirects=False, + **kwargs, + ) + if response.status not in _REDIRECT_STATUSES: + final_response = response + break + + location = response.headers.get("Location") + if not location: + # 3xx with no Location header — nothing to follow. Let + # the caller handle the unusual response. + final_response = response + break + + if hops >= max_redirects: + response.release() + raise RuntimeError( + f"Too many redirects (>{max_redirects}) during {context} " + f"starting from {url!r}." + ) + + next_url = urljoin(current_url, location) + try: + ensure_secure_url( + next_url, context=f"{context} (redirect target)" + ) + except Exception: + response.release() + raise + + response.release() + + # Strip auth-bearing kwargs on cross-origin redirect. + if not _same_origin(current_url, next_url): + _scrub_cross_origin_credentials(kwargs) + + if response.status == 303: + current_method = "GET" + kwargs.pop("json", None) + kwargs.pop("data", None) + current_url = next_url + hops += 1 + + yield final_response + finally: + if final_response is not None: + final_response.release() diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py index 81dbb2c..4ce6dfe 100644 --- a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py @@ -83,10 +83,23 @@ class WebSocketCallTemplate(CallTemplate): @field_validator("url") @classmethod def validate_url(cls, v: str) -> str: - """Validate WebSocket URL format.""" - if not (v.startswith("wss://") or v.startswith("ws://localhost") or v.startswith("ws://127.0.0.1")): + """Validate WebSocket URL format. + + Uses the hostname-based ``is_secure_ws_url`` helper rather than + a ``startswith`` prefix match: the prefix form let + ``ws://localhost.evil.com`` and ``ws://127.0.0.1.attacker.example`` + through, which is the bypass tracked in GHSA-ppx3-28rw-8fpf. + """ + # Local import keeps the call-template module free of an + # always-on import of the validator (and matches how the HTTP + # plugins handle the same concern). + from utcp_websocket._security import is_secure_ws_url + + if not is_secure_ws_url(v): raise ValueError( - f"WebSocket URL must use wss:// or start with ws://localhost or ws://127.0.0.1. Got: {v}" + f"WebSocket URL must use wss:// or be a literal loopback " + f"address (ws://localhost / ws://127.0.0.1 / ws://[::1]). " + f"Got: {v!r}." ) return v diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py index 48a1d21..38a450d 100644 --- a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py @@ -29,6 +29,11 @@ from utcp.data.auth_implementations.basic_auth import BasicAuth from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth from utcp_websocket.websocket_call_template import WebSocketCallTemplate +from utcp_websocket._security import ( + ensure_secure_url, + ensure_secure_ws_url, + safe_request_with_redirects, +) logging.basicConfig( level=logging.INFO, @@ -71,34 +76,85 @@ def __init__(self, logger_func: Optional[Callable[[str], None]] = None): self._sessions: Dict[str, ClientSession] = {} self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _substitute_placeholders(self, template: Any, arguments: Dict[str, Any]) -> Any: - """Recursively substitute UTCP_ARG_arg_name_UTCP_ARG placeholders in template. + def _substitute_placeholders( + self, + template: Any, + arguments: Dict[str, Any], + *, + json_string_context: bool = False, + ) -> Any: + """Recursively substitute ``UTCP_ARG__UTCP_ARG`` placeholders. Args: - template: Template (string, dict, or list) with UTCP_ARG_arg_name_UTCP_ARG placeholders - arguments: Arguments to substitute + template: Template (string, dict, or list) containing the + placeholders. + arguments: Arguments to substitute. + json_string_context: When True, string-valued arguments are + JSON-string-escaped before substitution. Use this when + the template is a string that the caller will treat as + JSON (e.g. ``'{"q": "UTCP_ARG_q_UTCP_ARG"}'``) so a + user-supplied ``"`` character cannot break out of the + surrounding JSON string and inject extra fields. Dict / + list templates are JSON-serialised by the caller after + substitution, so each leaf-string value goes through + ``json.dumps`` naturally -- no extra escaping needed + for those. Returns: - Template with placeholders replaced + Template with placeholders replaced. """ if isinstance(template, str): - # Replace UTCP_ARG_arg_name_UTCP_ARG placeholders result = template for arg_name, arg_value in arguments.items(): placeholder = f"UTCP_ARG_{arg_name}_UTCP_ARG" if placeholder in result: if isinstance(arg_value, str): - result = result.replace(placeholder, arg_value) + if json_string_context: + # ``json.dumps`` of a string returns the + # value wrapped in quotes; ``[1:-1]`` peels + # them off, leaving the inner-escaped form + # safe to embed inside an existing JSON + # string literal. + escaped = json.dumps(arg_value)[1:-1] + result = result.replace(placeholder, escaped) + else: + result = result.replace(placeholder, arg_value) else: result = result.replace(placeholder, json.dumps(arg_value)) return result elif isinstance(template, dict): - return {k: self._substitute_placeholders(v, arguments) for k, v in template.items()} + # Each leaf value is recursed individually; the surrounding + # dict gets JSON-serialised by the caller, which will + # correctly escape any ``"`` in those leaves. No + # ``json_string_context`` propagation needed. + return { + k: self._substitute_placeholders(v, arguments) + for k, v in template.items() + } elif isinstance(template, list): - return [self._substitute_placeholders(item, arguments) for item in template] + return [ + self._substitute_placeholders(item, arguments) + for item in template + ] else: return template + @staticmethod + def _string_template_looks_like_json(template: str) -> bool: + """Heuristic: does this raw string template look like JSON? + + Used to opt the JSON-string-context substitution path on for + callers who pass a JSON-shaped string template (vs. the + recommended dict template). If the template starts with ``{`` + or ``[`` after whitespace, we assume it's structured and + escape string substitutions accordingly. False positives + (template happens to start with ``{`` but isn't JSON) cost only + extra backslashes in the output; false negatives would + re-introduce the injection bug. + """ + stripped = template.lstrip() + return bool(stripped) and stripped[0] in "{[" + def _format_tool_call_message( self, tool_name: str, @@ -123,7 +179,21 @@ def _format_tool_call_message( """ # Priority 1: Use message template if provided (most flexible - supports any format) if call_template.message is not None: - substituted = self._substitute_placeholders(call_template.message, arguments) + # If the template is a JSON-shaped string, escape string + # substitutions so a user-controlled ``"`` cannot break + # out and inject extra JSON fields. Dict templates are + # naturally safe because the surrounding ``json.dumps`` + # below escapes leaf strings. + template = call_template.message + json_string_context = ( + isinstance(template, str) + and self._string_template_looks_like_json(template) + ) + substituted = self._substitute_placeholders( + template, + arguments, + json_string_context=json_string_context, + ) # If it's a dict, convert to JSON string if isinstance(substituted, dict): return json.dumps(substituted) @@ -136,11 +206,20 @@ def _format_tool_call_message( return json.dumps(arguments) async def _handle_oauth2(self, auth: OAuth2Auth) -> str: - """Handle OAuth2 authentication and token management.""" + """Handle OAuth2 authentication and token management. + + Validates the token URL with ``ensure_secure_url`` before any + credential bytes leave the process, and re-validates every + redirect hop. Closes the sibling SSRF / credential-exfiltration + patterns in GHSA-8cp3-qxj6-px34 and GHSA-9qhg-99ww-9mqc on the + OAuth2 path used by this plugin. + """ client_id = auth.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + ensure_secure_url(auth.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: data = { 'grant_type': 'client_credentials', @@ -148,7 +227,13 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: 'client_secret': auth.client_secret, 'scope': auth.scope } - async with session.post(auth.token_url, data=data) as resp: + async with safe_request_with_redirects( + session, + "POST", + auth.token_url, + context="OAuth2 token fetch", + data=data, + ) as resp: resp.raise_for_status() token_response = await resp.json() self._oauth_tokens[client_id] = token_response @@ -175,7 +260,21 @@ async def _prepare_headers(self, call_template: WebSocketCallTemplate) -> Dict[s return headers async def _get_connection(self, call_template: WebSocketCallTemplate) -> ClientWebSocketResponse: - """Get or create a WebSocket connection for the call template.""" + """Get or create a WebSocket connection for the call template. + + Enforces the "WSS or loopback" guarantee that the module + docstring advertises. The previous implementation skipped this + check entirely, letting any URL through, which is the + WebSocket half of GHSA-ppx3-28rw-8fpf. Also disables + redirect-following on the upgrade request to prevent a + post-validation redirect from steering the handshake into an + internal service (GHSA-9qhg-99ww-9mqc). + """ + # Hostname-based validation -- never let attacker-controlled or + # plain-WS-to-non-loopback URLs through, regardless of headers + # already configured on the call template. + ensure_secure_ws_url(call_template.url, context="WebSocket connection") + provider_key = f"{call_template.name}_{call_template.url}" # Check if we have an active connection @@ -194,11 +293,16 @@ async def _get_connection(self, call_template: WebSocketCallTemplate) -> ClientW self._sessions[provider_key] = session try: + # ``ws_connect`` does not expose ``allow_redirects`` -- aiohttp + # treats the upgrade handshake as one-shot, so a 3xx response + # naturally fails the handshake instead of being followed. + # The URL itself was already validated by ``ensure_secure_ws_url`` + # above. There is no second hop to harden. ws = await session.ws_connect( call_template.url, headers=headers, protocols=[call_template.protocol] if call_template.protocol else None, - heartbeat=30 if call_template.keep_alive else None + heartbeat=30 if call_template.keep_alive else None, ) self._connections[provider_key] = ws logger.info(f"WebSocket connected to {call_template.url}") diff --git a/plugins/communication_protocols/websocket/tests/test_websocket_security.py b/plugins/communication_protocols/websocket/tests/test_websocket_security.py new file mode 100644 index 0000000..c77840c --- /dev/null +++ b/plugins/communication_protocols/websocket/tests/test_websocket_security.py @@ -0,0 +1,250 @@ +"""Security tests for the WebSocket communication protocol +(utcp-websocket). + +Pin the fixes for GHSA-ppx3-28rw-8fpf: the previous implementation +did NO URL validation at all despite its docstrings advertising +"WSS or localhost only", letting any ``ws://`` URL connect (with +credentials attached) to an attacker-controlled host. Also covers +the OAuth2 / redirect halves of GHSA-8cp3-qxj6-px34 and +GHSA-9qhg-99ww-9mqc. +""" + +import json + +import pytest + +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_websocket._security import ( + ensure_secure_url, + ensure_secure_ws_url, + is_secure_url, + is_secure_ws_url, +) +from utcp_websocket.websocket_call_template import WebSocketCallTemplate +from utcp_websocket.websocket_communication_protocol import ( + WebSocketCommunicationProtocol, +) + + +# --------------------------------------------------------------------------- +# WebSocket-scheme validator: ws:// is loopback-only, wss:// always OK. +# --------------------------------------------------------------------------- + + +class TestWebSocketUrlValidator: + @pytest.mark.parametrize( + "url", + [ + "wss://api.example.com/socket", + "ws://localhost/socket", + "ws://127.0.0.1:9090/socket", + "ws://[::1]:9090/socket", + ], + ) + def test_secure_ws_url_accepted(self, url: str) -> None: + assert is_secure_ws_url(url) is True + ensure_secure_ws_url(url) + + @pytest.mark.parametrize( + "url", + [ + # Plain ws:// to non-loopback host (MITM + SSRF surface). + "ws://169.254.169.254/socket", + "ws://internal.service.local/socket", + "ws://10.0.0.5/socket", + "ws://example.com/socket", + # The localhost.evil.com / 127.0.0.1.attacker.example bypass: + # not loopback even though the prefix looks like it. + "ws://localhost.evil.com/socket", + "ws://127.0.0.1.attacker.example/socket", + # HTTP schemes are not WebSocket URLs. + "http://localhost/socket", + "https://api.example.com/socket", + # Junk inputs. + "", + "not-a-url", + "javascript:alert(1)", + ], + ) + def test_insecure_ws_url_rejected(self, url: str) -> None: + assert is_secure_ws_url(url) is False + with pytest.raises(ValueError, match="WebSocket URL"): + ensure_secure_ws_url(url) + + +# --------------------------------------------------------------------------- +# _get_connection enforces ensure_secure_ws_url -- the plugin used to +# accept any URL silently. +# --------------------------------------------------------------------------- + + +class TestTemplateRejectsBypass: + """The Pydantic field validator on WebSocketCallTemplate is the + first line of defence -- with the new hostname-based check it + catches the prefix bypass that the original ``startswith`` form + let through. + """ + + @pytest.mark.parametrize( + "url", + [ + "ws://169.254.169.254/", + "ws://localhost.evil.com/socket", + "ws://127.0.0.1.attacker.example/socket", + "ws://example.com/socket", + "http://localhost/socket", # not a WebSocket scheme + ], + ) + def test_template_rejects_bypass(self, url: str) -> None: + with pytest.raises(Exception) as exc_info: + WebSocketCallTemplate(name="ws", url=url) + # Pydantic wraps the message inside its own ValidationError -- + # the underlying ValueError text must still be present so + # operators can see what was rejected. + assert "WebSocket URL" in str(exc_info.value) + + +class TestGetConnectionRejectsLoopbackBypass: + """Defence in depth: ``_get_connection`` itself runs the same + hostname-based check so a template that bypassed the Pydantic + validator (e.g. constructed without ``model_validate``) still + cannot open the WebSocket. + """ + + @pytest.mark.asyncio + async def test_connection_rejected_when_template_bypassed(self) -> None: + proto = WebSocketCommunicationProtocol() + # Construct a template that *would* fail the field validator, + # but skip validation by going through ``model_construct``. + tpl = WebSocketCallTemplate.model_construct( + name="ws", + url="ws://localhost.evil.com/socket", + call_template_type="websocket", + keep_alive=True, + timeout=30, + ) + with pytest.raises(ValueError, match="WebSocket URL"): + await proto._get_connection(tpl) + + +# --------------------------------------------------------------------------- +# OAuth2 token URL is validated (the WebSocket plugin's OAuth2 path +# goes over HTTP, so it uses ensure_secure_url not ensure_secure_ws_url). +# --------------------------------------------------------------------------- + + +class TestOAuth2TokenUrlValidation: + @pytest.mark.asyncio + async def test_internal_token_url_rejected(self) -> None: + proto = WebSocketCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://169.254.169.254/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + @pytest.mark.asyncio + async def test_plain_http_non_loopback_token_url_rejected(self) -> None: + proto = WebSocketCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://attacker.example/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + +# --------------------------------------------------------------------------- +# Sanity: the HTTP-scheme validator is also re-exported (the OAuth2 +# token endpoint goes over HTTP/HTTPS). +# --------------------------------------------------------------------------- + + +class TestJsonInjectionInMessageTemplate: + """The ``message`` field of ``WebSocketCallTemplate`` accepts both a + dict (recommended) and a raw string (legacy / fully-flexible). A + string template that the caller WROTE as JSON used to pass + user-supplied ``"`` chars through unescaped, letting tool args + inject extra fields. Dict templates were already safe because + ``json.dumps`` runs at the end. + """ + + def test_json_string_template_escapes_tool_arg(self): + proto = WebSocketCommunicationProtocol() + # Template is a JSON-shaped STRING -- our heuristic should kick + # in and json-escape every string substitution. + tpl = WebSocketCallTemplate.model_construct( + name="ws", + url="wss://example.com/socket", + call_template_type="websocket", + keep_alive=True, + timeout=30, + message='{"q": "UTCP_ARG_q_UTCP_ARG"}', + ) + msg = proto._format_tool_call_message( + "x", + {"q": '", "isAdmin": true, "x": "'}, + tpl, + "req-1", + ) + # Parsed payload should have exactly one field whose value is + # the literal attacker payload -- no smuggled isAdmin. + parsed = json.loads(msg) + assert set(parsed.keys()) == {"q"} + assert parsed["q"] == '", "isAdmin": true, "x": "' + + def test_dict_template_escapes_tool_arg(self): + """Dict template path: already safe; pin the behaviour.""" + proto = WebSocketCommunicationProtocol() + tpl = WebSocketCallTemplate.model_construct( + name="ws", + url="wss://example.com/socket", + call_template_type="websocket", + keep_alive=True, + timeout=30, + message={"q": "UTCP_ARG_q_UTCP_ARG"}, + ) + msg = proto._format_tool_call_message( + "x", + {"q": '", "isAdmin": true, "x": "'}, + tpl, + "req-1", + ) + parsed = json.loads(msg) + assert set(parsed.keys()) == {"q"} + assert parsed["q"] == '", "isAdmin": true, "x": "' + + def test_non_json_string_template_substitutes_raw(self): + """Non-JSON-shaped string template should NOT escape (back- + compat -- e.g. a template like ``GET /x?q=UTCP_ARG_q_UTCP_ARG``). + """ + proto = WebSocketCommunicationProtocol() + tpl = WebSocketCallTemplate.model_construct( + name="ws", + url="wss://example.com/socket", + call_template_type="websocket", + keep_alive=True, + timeout=30, + message="GET /x?q=UTCP_ARG_q_UTCP_ARG", + ) + msg = proto._format_tool_call_message( + "x", + {"q": "value"}, + tpl, + "req-1", + ) + assert msg == "GET /x?q=value" + + +class TestHttpUrlValidator: + def test_https_accepted(self) -> None: + assert is_secure_url("https://api.example.com/oauth/token") is True + ensure_secure_url("https://api.example.com/oauth/token") + + def test_internal_rejected(self) -> None: + assert is_secure_url("http://169.254.169.254/token") is False + with pytest.raises(ValueError): + ensure_secure_url("http://169.254.169.254/token")