diff --git a/src/openai/lib/_validation.py b/src/openai/lib/_validation.py new file mode 100644 index 0000000000..8532cd3354 --- /dev/null +++ b/src/openai/lib/_validation.py @@ -0,0 +1,110 @@ +"""Client-side validation helpers for API request parameters. + +These catch common configuration mistakes early, before sending the request +to the API, and surface clear error messages instead of opaque 500 errors. +""" + +from __future__ import annotations + +import re +from typing import Any, Iterable, Optional + +_PROTOCOL_RE = re.compile(r"^https?://", re.IGNORECASE) +_PATH_RE = re.compile(r"/.*$") + + +def validate_network_policy_allowlist( + allowed_domains: Iterable[str], + *, + source: str = "network_policy.allowed_domains", +) -> None: + """Validate ``allowed_domains`` entries before sending to the API. + + Raises :class:`ValueError` when common configuration mistakes are + detected, such as: + + * an empty domain list + * entries that include a protocol prefix (``http://`` / ``https://``) + * entries that include a URL path + + These mistakes would otherwise surface as an opaque ``500`` server error + (see https://github.com/openai/openai-python/issues/2920). + """ + domains = list(allowed_domains) + + if not domains: + raise ValueError( + f"{source} must contain at least one domain. " + "If you do not need network access, omit the network_policy " + "or use {\"type\": \"disabled\"} instead." + ) + + for domain in domains: + if not isinstance(domain, str) or not domain.strip(): + raise ValueError( + f"{source} contains an invalid entry: {domain!r}. " + "Each entry must be a non-empty domain string (e.g. \"example.com\")." + ) + + if _PROTOCOL_RE.match(domain): + bare = _PROTOCOL_RE.sub("", domain).rstrip("/") + raise ValueError( + f"{source} entry {domain!r} must be a bare domain without " + f"a protocol prefix. Use {bare!r} instead." + ) + + if _PATH_RE.search(domain): + raise ValueError( + f"{source} entry {domain!r} must be a domain name " + "without a path (e.g. \"example.com\", not \"example.com/path\")." + ) + + +def validate_shell_tool(tool: Any) -> None: + """Run validation checks on a shell tool dict before it is sent to the API. + + Currently validates the ``network_policy.allowed_domains`` field when + an allowlist policy is specified. + """ + if not isinstance(tool, dict): + return + + env: Optional[dict[str, Any]] = tool.get("environment") + if not isinstance(env, dict): + return + + policy: Optional[dict[str, Any]] = env.get("network_policy") + if not isinstance(policy, dict): + return + + if policy.get("type") != "allowlist": + return + + domains = policy.get("allowed_domains") + if domains is not None: + validate_network_policy_allowlist( + domains, + source="shell tool network_policy.allowed_domains", + ) + + +def validate_tools(tools: Iterable[Any]) -> None: + """Validate a list of tool dicts before they are sent to the API.""" + for tool in tools: + if not isinstance(tool, dict): + continue + + tool_type = tool.get("type") + if tool_type == "shell": + validate_shell_tool(tool) + elif tool_type == "code_interpreter": + container = tool.get("container") + if isinstance(container, dict): + policy = container.get("network_policy") + if isinstance(policy, dict) and policy.get("type") == "allowlist": + domains = policy.get("allowed_domains") + if domains is not None: + validate_network_policy_allowlist( + domains, + source="code_interpreter container network_policy.allowed_domains", + ) diff --git a/src/openai/resources/responses/responses.py b/src/openai/resources/responses/responses.py index 5d34909fd1..d8068b3a5e 100644 --- a/src/openai/resources/responses/responses.py +++ b/src/openai/resources/responses/responses.py @@ -30,6 +30,7 @@ ) from ..._streaming import Stream, AsyncStream from ...lib._tools import PydanticFunctionTool, ResponsesPydanticFunctionTool +from ...lib._validation import validate_tools from .input_tokens import ( InputTokens, AsyncInputTokens, @@ -3537,8 +3538,13 @@ def _make_tools(tools: Iterable[ParseableToolParam] | Omit) -> List[ToolParam] | if not is_given(tools): return omit + # Materialise once so that validation doesn't consume a one-shot iterator + tools_list = list(tools) if not isinstance(tools, list) else tools + + validate_tools(tools_list) + converted_tools: List[ToolParam] = [] - for tool in tools: + for tool in tools_list: if tool["type"] != "function": converted_tools.append(tool) continue diff --git a/tests/test_shell_tool_allowlist.py b/tests/test_shell_tool_allowlist.py new file mode 100644 index 0000000000..684e2de16b --- /dev/null +++ b/tests/test_shell_tool_allowlist.py @@ -0,0 +1,312 @@ +"""Tests for shell tool network_policy allowlist validation and serialization. + +Covers https://github.com/openai/openai-python/issues/2920 +""" + +from __future__ import annotations + +import json +from typing import Any + +import httpx +import pytest + +import openai +from openai._utils import maybe_transform +from openai.lib._validation import ( + validate_tools, + validate_shell_tool, + validate_network_policy_allowlist, +) +from openai.types.responses.response_create_params import ResponseCreateParamsNonStreaming + + +# --------------------------------------------------------------------------- +# Validation unit tests +# --------------------------------------------------------------------------- + + +class TestValidateNetworkPolicyAllowlist: + def test_valid_single_domain(self) -> None: + validate_network_policy_allowlist(["example.com"]) + + def test_valid_multiple_domains(self) -> None: + validate_network_policy_allowlist(["pypi.org", "files.pythonhosted.org", "github.com"]) + + def test_valid_subdomain(self) -> None: + validate_network_policy_allowlist(["api.example.com"]) + + def test_empty_list_raises(self) -> None: + with pytest.raises(ValueError, match="must contain at least one domain"): + validate_network_policy_allowlist([]) + + def test_http_prefix_raises(self) -> None: + with pytest.raises(ValueError, match="bare domain without a protocol prefix"): + validate_network_policy_allowlist(["http://example.com"]) + + def test_https_prefix_raises(self) -> None: + with pytest.raises(ValueError, match="bare domain without a protocol prefix"): + validate_network_policy_allowlist(["https://example.com"]) + + def test_protocol_suggestion_strips_prefix(self) -> None: + with pytest.raises(ValueError, match=r"Use 'example\.com' instead"): + validate_network_policy_allowlist(["https://example.com"]) + + def test_domain_with_path_raises(self) -> None: + with pytest.raises(ValueError, match="without a path"): + validate_network_policy_allowlist(["example.com/api/v1"]) + + def test_empty_string_raises(self) -> None: + with pytest.raises(ValueError, match="invalid entry"): + validate_network_policy_allowlist([""]) + + def test_whitespace_only_raises(self) -> None: + with pytest.raises(ValueError, match="invalid entry"): + validate_network_policy_allowlist([" "]) + + +class TestValidateShellTool: + def test_valid_shell_tool_with_allowlist(self) -> None: + validate_shell_tool({ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": ["google.com"], + }, + }, + }) + + def test_shell_tool_with_disabled_policy(self) -> None: + validate_shell_tool({ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": {"type": "disabled"}, + }, + }) + + def test_shell_tool_without_environment(self) -> None: + validate_shell_tool({"type": "shell"}) + + def test_shell_tool_without_network_policy(self) -> None: + validate_shell_tool({ + "type": "shell", + "environment": {"type": "container_auto"}, + }) + + def test_shell_tool_with_bad_domain_raises(self) -> None: + with pytest.raises(ValueError, match="protocol prefix"): + validate_shell_tool({ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": ["https://example.com"], + }, + }, + }) + + +class TestValidateTools: + def test_valid_tools_pass(self) -> None: + validate_tools([ + { + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": ["example.com"], + }, + }, + }, + {"type": "web_search"}, + ]) + + def test_non_dict_tools_are_skipped(self) -> None: + validate_tools(["not_a_dict"]) # type: ignore[list-item] + + def test_code_interpreter_allowlist_validated(self) -> None: + with pytest.raises(ValueError, match="protocol prefix"): + validate_tools([ + { + "type": "code_interpreter", + "container": { + "network_policy": { + "type": "allowlist", + "allowed_domains": ["https://pypi.org"], + }, + }, + } + ]) + + +# --------------------------------------------------------------------------- +# Serialization tests — prove the library sends the correct JSON +# --------------------------------------------------------------------------- + + +class _CaptureTransport(httpx.BaseTransport): + """Transport that records the last request and returns a minimal valid response.""" + + last_request: httpx.Request | None = None + + def handle_request(self, request: httpx.Request) -> httpx.Response: + self.last_request = request + return httpx.Response(200, json={ + "id": "resp_test", + "object": "response", + "created_at": 1234567890, + "status": "completed", + "model": "gpt-5.2", + "output": [], + "output_text": "", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + }) + + +def _captured_body(transport: _CaptureTransport) -> dict[str, Any]: + assert transport.last_request is not None + return json.loads(transport.last_request.content) + + +class TestShellToolSerialization: + """Ensure shell tool with allowlist is serialized exactly as the API expects.""" + + def _make_client(self) -> tuple[openai.OpenAI, _CaptureTransport]: + transport = _CaptureTransport() + client = openai.OpenAI( + api_key="test-key", + http_client=httpx.Client(transport=transport), + ) + return client, transport + + def test_allowlist_network_policy_serialization(self) -> None: + client, transport = self._make_client() + client.responses.create( + model="gpt-5.2", + input="test", + tools=[{ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": ["google.com"], + }, + }, + }], + ) + body = _captured_body(transport) + tool = body["tools"][0] + assert tool["type"] == "shell" + assert tool["environment"]["type"] == "container_auto" + policy = tool["environment"]["network_policy"] + assert policy["type"] == "allowlist" + assert policy["allowed_domains"] == ["google.com"] + + def test_allowlist_with_multiple_domains(self) -> None: + client, transport = self._make_client() + domains = ["pypi.org", "files.pythonhosted.org", "github.com"] + client.responses.create( + model="gpt-5.2", + input="test", + tools=[{ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": domains, + }, + }, + }], + ) + body = _captured_body(transport) + assert body["tools"][0]["environment"]["network_policy"]["allowed_domains"] == domains + + def test_allowlist_with_domain_secrets(self) -> None: + client, transport = self._make_client() + client.responses.create( + model="gpt-5.2", + input="test", + tools=[{ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": ["httpbin.org"], + "domain_secrets": [ + {"domain": "httpbin.org", "name": "API_KEY", "value": "secret-123"}, + ], + }, + }, + }], + ) + body = _captured_body(transport) + policy = body["tools"][0]["environment"]["network_policy"] + assert policy["type"] == "allowlist" + assert policy["allowed_domains"] == ["httpbin.org"] + assert policy["domain_secrets"] == [ + {"domain": "httpbin.org", "name": "API_KEY", "value": "secret-123"}, + ] + + def test_disabled_network_policy_serialization(self) -> None: + client, transport = self._make_client() + client.responses.create( + model="gpt-5.2", + input="test", + tools=[{ + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": {"type": "disabled"}, + }, + }], + ) + body = _captured_body(transport) + assert body["tools"][0]["environment"]["network_policy"] == {"type": "disabled"} + + def test_shell_without_environment_serialization(self) -> None: + client, transport = self._make_client() + client.responses.create( + model="gpt-5.2", + input="test", + tools=[{"type": "shell"}], + ) + body = _captured_body(transport) + assert body["tools"][0] == {"type": "shell"} + + +class TestTransformAllowlist: + """Verify that maybe_transform preserves allowlist fields exactly.""" + + def test_transform_preserves_all_fields(self) -> None: + params = { + "model": "gpt-5.2", + "input": "test", + "tools": [ + { + "type": "shell", + "environment": { + "type": "container_auto", + "network_policy": { + "type": "allowlist", + "allowed_domains": ["example.com", "api.example.com"], + }, + }, + } + ], + } + result = maybe_transform(params, ResponseCreateParamsNonStreaming) + tool = result["tools"][0] + assert tool["environment"]["network_policy"] == { + "type": "allowlist", + "allowed_domains": ["example.com", "api.example.com"], + }