diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 714352ad5..cc93d303b 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -38,175 +38,176 @@ def selective_mock(instance: Any = None, schema: Any = None, *args: Any, **kwarg yield -class TestClientOutputSchemaValidation: - """Test client-side validation of structured output from tools""" - - @pytest.mark.anyio - async def test_tool_structured_output_client_side_validation_basemodel(self): - """Test that client validates structured content against schema for BaseModel outputs""" - # Create a malicious low-level server that returns invalid structured content - server = Server("test-server") - - # Define the expected schema for our tool - output_schema = { - "type": "object", - "properties": {"name": {"type": "string", "title": "Name"}, "age": {"type": "integer", "title": "Age"}}, - "required": ["name", "age"], - "title": "UserOutput", - } - - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="get_user", - description="Get user data", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - age is string instead of integer - # The low-level server will wrap this in CallToolResult - return {"name": "John", "age": "invalid"} # Invalid: age should be int - - # Test that client validates the structured content - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_user", {}) - # Verify it's a validation error - assert "Invalid structured content returned by tool get_user" in str(exc_info.value) - - @pytest.mark.anyio - async def test_tool_structured_output_client_side_validation_primitive(self): - """Test that client validates structured content for primitive outputs""" - server = Server("test-server") - - # Primitive types are wrapped in {"result": value} - output_schema = { - "type": "object", - "properties": {"result": {"type": "integer", "title": "Result"}}, - "required": ["result"], - "title": "calculate_Output", - } - - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="calculate", - description="Calculate something", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - result is string instead of integer - return {"result": "not_a_number"} # Invalid: should be int - - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("calculate", {}) - assert "Invalid structured content returned by tool calculate" in str(exc_info.value) - - @pytest.mark.anyio - async def test_tool_structured_output_client_side_validation_dict_typed(self): - """Test that client validates dict[str, T] structured content""" - server = Server("test-server") - - # dict[str, int] schema - output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} - - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="get_scores", - description="Get scores", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return invalid structured content - values should be integers - return {"alice": "100", "bob": "85"} # Invalid: values should be int - - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_scores", {}) - assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) - - @pytest.mark.anyio - async def test_tool_structured_output_client_side_validation_missing_required(self): - """Test that client validates missing required fields""" - server = Server("test-server") - - output_schema = { - "type": "object", - "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "email": {"type": "string"}}, - "required": ["name", "age", "email"], # All fields required - "title": "PersonOutput", - } - - @server.list_tools() - async def list_tools(): - return [ - Tool( - name="get_person", - description="Get person data", - input_schema={"type": "object"}, - output_schema=output_schema, - ) - ] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]): - # Return structured content missing required field 'email' - return {"name": "John", "age": 30} # Missing required 'email' - - with bypass_server_output_validation(): - async with Client(server) as client: - # The client validates structured content and should raise an error - with pytest.raises(RuntimeError) as exc_info: - await client.call_tool("get_person", {}) - assert "Invalid structured content returned by tool get_person" in str(exc_info.value) - - @pytest.mark.anyio - async def test_tool_not_listed_warning(self, caplog: pytest.LogCaptureFixture): - """Test that client logs warning when tool is not in list_tools but has output_schema""" - server = Server("test-server") - - @server.list_tools() - async def list_tools() -> list[Tool]: - # Return empty list - tool is not listed - return [] - - @server.call_tool() - async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: - # Server still responds to the tool call with structured content - return {"result": 42} - - # Set logging level to capture warnings - caplog.set_level(logging.WARNING) - - with bypass_server_output_validation(): - async with Client(server) as client: - # Call a tool that wasn't listed - result = await client.call_tool("mystery_tool", {}) - assert result.structured_content == {"result": 42} - assert result.is_error is False - - # Check that warning was logged - assert "Tool mystery_tool not listed" in caplog.text +@pytest.mark.anyio +async def test_tool_structured_output_client_side_validation_basemodel(): + """Test that client validates structured content against schema for BaseModel outputs""" + # Create a malicious low-level server that returns invalid structured content + server = Server("test-server") + + # Define the expected schema for our tool + output_schema = { + "type": "object", + "properties": {"name": {"type": "string", "title": "Name"}, "age": {"type": "integer", "title": "Age"}}, + "required": ["name", "age"], + "title": "UserOutput", + } + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="get_user", + description="Get user data", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]): + # Return invalid structured content - age is string instead of integer + # The low-level server will wrap this in CallToolResult + return {"name": "John", "age": "invalid"} # Invalid: age should be int + + # Test that client validates the structured content + with bypass_server_output_validation(): + async with Client(server) as client: + # The client validates structured content and should raise an error + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_user", {}) + # Verify it's a validation error + assert "Invalid structured content returned by tool get_user" in str(exc_info.value) + + +@pytest.mark.anyio +async def test_tool_structured_output_client_side_validation_primitive(): + """Test that client validates structured content for primitive outputs""" + server = Server("test-server") + + # Primitive types are wrapped in {"result": value} + output_schema = { + "type": "object", + "properties": {"result": {"type": "integer", "title": "Result"}}, + "required": ["result"], + "title": "calculate_Output", + } + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="calculate", + description="Calculate something", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]): + # Return invalid structured content - result is string instead of integer + return {"result": "not_a_number"} # Invalid: should be int + + with bypass_server_output_validation(): + async with Client(server) as client: + # The client validates structured content and should raise an error + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("calculate", {}) + assert "Invalid structured content returned by tool calculate" in str(exc_info.value) + + +@pytest.mark.anyio +async def test_tool_structured_output_client_side_validation_dict_typed(): + """Test that client validates dict[str, T] structured content""" + server = Server("test-server") + + # dict[str, int] schema + output_schema = {"type": "object", "additionalProperties": {"type": "integer"}, "title": "get_scores_Output"} + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="get_scores", + description="Get scores", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]): + # Return invalid structured content - values should be integers + return {"alice": "100", "bob": "85"} # Invalid: values should be int + + with bypass_server_output_validation(): + async with Client(server) as client: + # The client validates structured content and should raise an error + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_scores", {}) + assert "Invalid structured content returned by tool get_scores" in str(exc_info.value) + + +@pytest.mark.anyio +async def test_tool_structured_output_client_side_validation_missing_required(): + """Test that client validates missing required fields""" + server = Server("test-server") + + output_schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "email": {"type": "string"}}, + "required": ["name", "age", "email"], # All fields required + "title": "PersonOutput", + } + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="get_person", + description="Get person data", + input_schema={"type": "object"}, + output_schema=output_schema, + ) + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]): + # Return structured content missing required field 'email' + return {"name": "John", "age": 30} # Missing required 'email' + + with bypass_server_output_validation(): + async with Client(server) as client: + # The client validates structured content and should raise an error + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("get_person", {}) + assert "Invalid structured content returned by tool get_person" in str(exc_info.value) + + +@pytest.mark.anyio +async def test_tool_not_listed_warning(caplog: pytest.LogCaptureFixture): + """Test that client logs warning when tool is not in list_tools but has output_schema""" + server = Server("test-server") + + @server.list_tools() + async def list_tools() -> list[Tool]: + # Return empty list - tool is not listed + return [] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: + # Server still responds to the tool call with structured content + return {"result": 42} + + # Set logging level to capture warnings + caplog.set_level(logging.WARNING) + + with bypass_server_output_validation(): + async with Client(server) as client: + # Call a tool that wasn't listed + result = await client.call_tool("mystery_tool", {}) + assert result.structured_content == {"result": 42} + assert result.is_error is False + + # Check that warning was logged + assert "Tool mystery_tool not listed" in caplog.text diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 5efc1d7d2..1046d43e3 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -25,360 +25,373 @@ def mock_exit_stack(): return mock.MagicMock(spec=contextlib.AsyncExitStack) +def test_client_session_group_init(): + mcp_session_group = ClientSessionGroup() + assert not mcp_session_group._tools + assert not mcp_session_group._resources + assert not mcp_session_group._prompts + assert not mcp_session_group._tool_to_session + + +def test_client_session_group_component_properties(): + # --- Mock Dependencies --- + mock_prompt = mock.Mock() + mock_resource = mock.Mock() + mock_tool = mock.Mock() + + # --- Prepare Session Group --- + mcp_session_group = ClientSessionGroup() + mcp_session_group._prompts = {"my_prompt": mock_prompt} + mcp_session_group._resources = {"my_resource": mock_resource} + mcp_session_group._tools = {"my_tool": mock_tool} + + # --- Assertions --- + assert mcp_session_group.prompts == {"my_prompt": mock_prompt} + assert mcp_session_group.resources == {"my_resource": mock_resource} + assert mcp_session_group.tools == {"my_tool": mock_tool} + + @pytest.mark.anyio -class TestClientSessionGroup: - def test_init(self): - mcp_session_group = ClientSessionGroup() - assert not mcp_session_group._tools - assert not mcp_session_group._resources - assert not mcp_session_group._prompts - assert not mcp_session_group._tool_to_session - - def test_component_properties(self): - # --- Mock Dependencies --- - mock_prompt = mock.Mock() - mock_resource = mock.Mock() - mock_tool = mock.Mock() - - # --- Prepare Session Group --- - mcp_session_group = ClientSessionGroup() - mcp_session_group._prompts = {"my_prompt": mock_prompt} - mcp_session_group._resources = {"my_resource": mock_resource} - mcp_session_group._tools = {"my_tool": mock_tool} - - # --- Assertions --- - assert mcp_session_group.prompts == {"my_prompt": mock_prompt} - assert mcp_session_group.resources == {"my_resource": mock_resource} - assert mcp_session_group.tools == {"my_tool": mock_tool} - - async def test_call_tool(self): - # --- Mock Dependencies --- - mock_session = mock.AsyncMock() - - # --- Prepare Session Group --- - def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cover - return f"{(server_info.name)}-{name}" - - mcp_session_group = ClientSessionGroup(component_name_hook=hook) - mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", input_schema={})} - mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} - text_content = types.TextContent(type="text", text="OK") - mock_session.call_tool.return_value = types.CallToolResult(content=[text_content]) - - # --- Test Execution --- - result = await mcp_session_group.call_tool( - name="server1-my_tool", - arguments={ - "name": "value1", - "args": {}, - }, - ) +async def test_client_session_group_call_tool(): + # --- Mock Dependencies --- + mock_session = mock.AsyncMock() + + # --- Prepare Session Group --- + def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cover + return f"{(server_info.name)}-{name}" + + mcp_session_group = ClientSessionGroup(component_name_hook=hook) + mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", input_schema={})} + mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} + text_content = types.TextContent(type="text", text="OK") + mock_session.call_tool.return_value = types.CallToolResult(content=[text_content]) + + # --- Test Execution --- + result = await mcp_session_group.call_tool( + name="server1-my_tool", + arguments={ + "name": "value1", + "args": {}, + }, + ) + + # --- Assertions --- + assert result.content == [text_content] + mock_session.call_tool.assert_called_once_with( + "my_tool", + arguments={"name": "value1", "args": {}}, + read_timeout_seconds=None, + progress_callback=None, + meta=None, + ) - # --- Assertions --- - assert result.content == [text_content] - mock_session.call_tool.assert_called_once_with( - "my_tool", - arguments={"name": "value1", "args": {}}, - read_timeout_seconds=None, - progress_callback=None, - meta=None, + +@pytest.mark.anyio +async def test_client_session_group_connect_to_server(mock_exit_stack: contextlib.AsyncExitStack): + """Test connecting to a server and aggregating components.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "TestServer1" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool_a" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "resource_b" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prompt_c" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) + + # --- Test Execution --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + assert "tool_a" in group.tools + assert group.tools["tool_a"] == mock_tool1 + assert group._tool_to_session["tool_a"] == mock_session + assert len(group.resources) == 1 + assert "resource_b" in group.resources + assert group.resources["resource_b"] == mock_resource1 + assert len(group.prompts) == 1 + assert "prompt_c" in group.prompts + assert group.prompts["prompt_c"] == mock_prompt1 + mock_session.list_tools.assert_awaited_once() + mock_session.list_resources.assert_awaited_once() + mock_session.list_prompts.assert_awaited_once() + + +@pytest.mark.anyio +async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): + """Test connecting with a component name hook.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "HookServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "base_tool" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Setup --- + def name_hook(name: str, server_info: types.Implementation) -> str: + return f"{server_info.name}.{name}" + + # --- Test Execution --- + group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook) + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): + await group.connect_to_server(StdioServerParameters(command="test")) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + expected_tool_name = "HookServer.base_tool" + assert expected_tool_name in group.tools + assert group.tools[expected_tool_name] == mock_tool + assert group._tool_to_session[expected_tool_name] == mock_session + + +@pytest.mark.anyio +async def test_client_session_group_disconnect_from_server(): + """Test disconnecting from a server.""" + # --- Test Setup --- + group = ClientSessionGroup() + server_name = "ServerToDisconnect" + + # Manually populate state using standard mocks + mock_session1 = mock.MagicMock(spec=mcp.ClientSession) + mock_session2 = mock.MagicMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool1" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "res1" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prm1" + mock_tool2 = mock.Mock(spec=types.Tool) + mock_tool2.name = "tool2" + mock_component_named_like_server = mock.Mock() + mock_session = mock.Mock(spec=mcp.ClientSession) + + group._tools = { + "tool1": mock_tool1, + "tool2": mock_tool2, + server_name: mock_component_named_like_server, + } + group._tool_to_session = { + "tool1": mock_session1, + "tool2": mock_session2, + server_name: mock_session1, + } + group._resources = { + "res1": mock_resource1, + server_name: mock_component_named_like_server, + } + group._prompts = { + "prm1": mock_prompt1, + server_name: mock_component_named_like_server, + } + group._sessions = { + mock_session: ClientSessionGroup._ComponentNames( + prompts=set({"prm1"}), + resources=set({"res1"}), + tools=set({"tool1", "tool2"}), ) + } - async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack): - """Test connecting to a server and aggregating components.""" - # --- Mock Dependencies --- - mock_server_info = mock.Mock(spec=types.Implementation) - mock_server_info.name = "TestServer1" - mock_session = mock.AsyncMock(spec=mcp.ClientSession) - mock_tool1 = mock.Mock(spec=types.Tool) - mock_tool1.name = "tool_a" - mock_resource1 = mock.Mock(spec=types.Resource) - mock_resource1.name = "resource_b" - mock_prompt1 = mock.Mock(spec=types.Prompt) - mock_prompt1.name = "prompt_c" - mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) - mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) - mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) - - # --- Test Execution --- - group = ClientSessionGroup(exit_stack=mock_exit_stack) - with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): - await group.connect_to_server(StdioServerParameters(command="test")) + # --- Assertions --- + assert mock_session in group._sessions + assert "tool1" in group._tools + assert "tool2" in group._tools + assert "res1" in group._resources + assert "prm1" in group._prompts + + # --- Test Execution --- + await group.disconnect_from_server(mock_session) - # --- Assertions --- - assert mock_session in group._sessions - assert len(group.tools) == 1 - assert "tool_a" in group.tools - assert group.tools["tool_a"] == mock_tool1 - assert group._tool_to_session["tool_a"] == mock_session - assert len(group.resources) == 1 - assert "resource_b" in group.resources - assert group.resources["resource_b"] == mock_resource1 - assert len(group.prompts) == 1 - assert "prompt_c" in group.prompts - assert group.prompts["prompt_c"] == mock_prompt1 - mock_session.list_tools.assert_awaited_once() - mock_session.list_resources.assert_awaited_once() - mock_session.list_prompts.assert_awaited_once() - - async def test_connect_to_server_with_name_hook(self, mock_exit_stack: contextlib.AsyncExitStack): - """Test connecting with a component name hook.""" - # --- Mock Dependencies --- - mock_server_info = mock.Mock(spec=types.Implementation) - mock_server_info.name = "HookServer" - mock_session = mock.AsyncMock(spec=mcp.ClientSession) - mock_tool = mock.Mock(spec=types.Tool) - mock_tool.name = "base_tool" - mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) - mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) - mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) - - # --- Test Setup --- - def name_hook(name: str, server_info: types.Implementation) -> str: - return f"{server_info.name}.{name}" - - # --- Test Execution --- - group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook) - with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): + # --- Assertions --- + assert mock_session not in group._sessions + assert "tool1" not in group._tools + assert "tool2" not in group._tools + assert "res1" not in group._resources + assert "prm1" not in group._prompts + + +@pytest.mark.anyio +async def test_client_session_group_connect_to_server_duplicate_tool_raises_error( + mock_exit_stack: contextlib.AsyncExitStack, +): + """Test McpError raised when connecting a server with a dup name.""" + # --- Setup Pre-existing State --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + existing_tool_name = "shared_tool" + # Manually add a tool to simulate a previous connection + group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) + group._tools[existing_tool_name].name = existing_tool_name + # Need a dummy session associated with the existing tool + mock_session = mock.MagicMock(spec=mcp.ClientSession) + group._tool_to_session[existing_tool_name] = mock_session + group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack) + + # --- Mock New Connection Attempt --- + mock_server_info_new = mock.Mock(spec=types.Implementation) + mock_server_info_new.name = "ServerWithDuplicate" + mock_session_new = mock.AsyncMock(spec=mcp.ClientSession) + + # Configure the new session to return a tool with the *same name* + duplicate_tool = mock.Mock(spec=types.Tool) + duplicate_tool.name = existing_tool_name + mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) + # Keep other lists empty for simplicity + mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Execution and Assertion --- + with pytest.raises(McpError) as excinfo: + with mock.patch.object( + group, + "_establish_session", + return_value=(mock_server_info_new, mock_session_new), + ): await group.connect_to_server(StdioServerParameters(command="test")) - # --- Assertions --- - assert mock_session in group._sessions - assert len(group.tools) == 1 - expected_tool_name = "HookServer.base_tool" - assert expected_tool_name in group.tools - assert group.tools[expected_tool_name] == mock_tool - assert group._tool_to_session[expected_tool_name] == mock_session - - async def test_disconnect_from_server(self): # No mock arguments needed - """Test disconnecting from a server.""" - # --- Test Setup --- - group = ClientSessionGroup() - server_name = "ServerToDisconnect" - - # Manually populate state using standard mocks - mock_session1 = mock.MagicMock(spec=mcp.ClientSession) - mock_session2 = mock.MagicMock(spec=mcp.ClientSession) - mock_tool1 = mock.Mock(spec=types.Tool) - mock_tool1.name = "tool1" - mock_resource1 = mock.Mock(spec=types.Resource) - mock_resource1.name = "res1" - mock_prompt1 = mock.Mock(spec=types.Prompt) - mock_prompt1.name = "prm1" - mock_tool2 = mock.Mock(spec=types.Tool) - mock_tool2.name = "tool2" - mock_component_named_like_server = mock.Mock() - mock_session = mock.Mock(spec=mcp.ClientSession) - - group._tools = { - "tool1": mock_tool1, - "tool2": mock_tool2, - server_name: mock_component_named_like_server, - } - group._tool_to_session = { - "tool1": mock_session1, - "tool2": mock_session2, - server_name: mock_session1, - } - group._resources = { - "res1": mock_resource1, - server_name: mock_component_named_like_server, - } - group._prompts = { - "prm1": mock_prompt1, - server_name: mock_component_named_like_server, - } - group._sessions = { - mock_session: ClientSessionGroup._ComponentNames( - prompts=set({"prm1"}), - resources=set({"res1"}), - tools=set({"tool1", "tool2"}), - ) - } - - # --- Assertions --- - assert mock_session in group._sessions - assert "tool1" in group._tools - assert "tool2" in group._tools - assert "res1" in group._resources - assert "prm1" in group._prompts - - # --- Test Execution --- - await group.disconnect_from_server(mock_session) - - # --- Assertions --- - assert mock_session not in group._sessions - assert "tool1" not in group._tools - assert "tool2" not in group._tools - assert "res1" not in group._resources - assert "prm1" not in group._prompts - - async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack: contextlib.AsyncExitStack): - """Test McpError raised when connecting a server with a dup name.""" - # --- Setup Pre-existing State --- - group = ClientSessionGroup(exit_stack=mock_exit_stack) - existing_tool_name = "shared_tool" - # Manually add a tool to simulate a previous connection - group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) - group._tools[existing_tool_name].name = existing_tool_name - # Need a dummy session associated with the existing tool - mock_session = mock.MagicMock(spec=mcp.ClientSession) - group._tool_to_session[existing_tool_name] = mock_session - group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack) - - # --- Mock New Connection Attempt --- - mock_server_info_new = mock.Mock(spec=types.Implementation) - mock_server_info_new.name = "ServerWithDuplicate" - mock_session_new = mock.AsyncMock(spec=mcp.ClientSession) - - # Configure the new session to return a tool with the *same name* - duplicate_tool = mock.Mock(spec=types.Tool) - duplicate_tool.name = existing_tool_name - mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) - # Keep other lists empty for simplicity - mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) - mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) - - # --- Test Execution and Assertion --- - with pytest.raises(McpError) as excinfo: - with mock.patch.object( - group, - "_establish_session", - return_value=(mock_server_info_new, mock_session_new), - ): - await group.connect_to_server(StdioServerParameters(command="test")) - - # Assert details about the raised error - assert excinfo.value.error.code == types.INVALID_PARAMS - assert existing_tool_name in excinfo.value.error.message - assert "already exist " in excinfo.value.error.message - - # Verify the duplicate tool was *not* added again (state should be unchanged) - assert len(group._tools) == 1 # Should still only have the original - assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock - - # No patching needed here - async def test_disconnect_non_existent_server(self): - """Test disconnecting a server that isn't connected.""" - session = mock.Mock(spec=mcp.ClientSession) - group = ClientSessionGroup() - with pytest.raises(McpError): - await group.disconnect_from_server(session) - - @pytest.mark.parametrize( - "server_params_instance, client_type_name, patch_target_for_client_func", - [ - ( - StdioServerParameters(command="test_stdio_cmd"), - "stdio", - "mcp.client.session_group.mcp.stdio_client", - ), - ( - SseServerParameters(url="http://test.com/sse", timeout=10.0), - "sse", - "mcp.client.session_group.sse_client", - ), # url, headers, timeout, sse_read_timeout - ( - StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), - "streamablehttp", - "mcp.client.session_group.streamable_http_client", - ), # url, headers, timeout, sse_read_timeout, terminate_on_close - ], - ) - async def test_establish_session_parameterized( - self, - server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters, - client_type_name: str, # Just for clarity or conditional logic if needed - patch_target_for_client_func: str, - ): - with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: - with mock.patch(patch_target_for_client_func) as mock_specific_client_func: - mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM") - mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") - mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") - - # streamable_http_client's __aenter__ returns three values - if client_type_name == "streamablehttp": - mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") - mock_client_cm_instance.__aenter__.return_value = ( - mock_read_stream, - mock_write_stream, - mock_extra_stream_val, - ) - else: - mock_client_cm_instance.__aenter__.return_value = ( - mock_read_stream, - mock_write_stream, - ) - - mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) - mock_specific_client_func.return_value = mock_client_cm_instance - - # --- Mock mcp.ClientSession (class) --- - # mock_ClientSession_class is already provided by the outer patch - mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") - mock_ClientSession_class.return_value = mock_raw_session_cm - - mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") - mock_raw_session_cm.__aenter__.return_value = mock_entered_session - mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) - - # Mock session.initialize() - mock_initialize_result = mock.AsyncMock(name="InitializeResult") - mock_initialize_result.server_info = types.Implementation(name="foo", version="1") - mock_entered_session.initialize.return_value = mock_initialize_result - - # --- Test Execution --- - group = ClientSessionGroup() - returned_server_info = None - returned_session = None - - async with contextlib.AsyncExitStack() as stack: - group._exit_stack = stack - ( - returned_server_info, - returned_session, - ) = await group._establish_session(server_params_instance, ClientSessionParameters()) - - # --- Assertions --- - # 1. Assert the correct specific client function was called - if client_type_name == "stdio": - assert isinstance(server_params_instance, StdioServerParameters) - mock_specific_client_func.assert_called_once_with(server_params_instance) - elif client_type_name == "sse": - assert isinstance(server_params_instance, SseServerParameters) - mock_specific_client_func.assert_called_once_with( - url=server_params_instance.url, - headers=server_params_instance.headers, - timeout=server_params_instance.timeout, - sse_read_timeout=server_params_instance.sse_read_timeout, - ) - elif client_type_name == "streamablehttp": # pragma: no branch - assert isinstance(server_params_instance, StreamableHttpParameters) - # Verify streamable_http_client was called with url, httpx_client, and terminate_on_close - # The http_client is created by the real create_mcp_http_client - call_args = mock_specific_client_func.call_args - assert call_args.kwargs["url"] == server_params_instance.url - assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close - assert isinstance(call_args.kwargs["http_client"], httpx.AsyncClient) - - mock_client_cm_instance.__aenter__.assert_awaited_once() - - # 2. Assert ClientSession was called correctly - mock_ClientSession_class.assert_called_once_with( + # Assert details about the raised error + assert excinfo.value.error.code == types.INVALID_PARAMS + assert existing_tool_name in excinfo.value.error.message + assert "already exist " in excinfo.value.error.message + + # Verify the duplicate tool was *not* added again (state should be unchanged) + assert len(group._tools) == 1 # Should still only have the original + assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock + + +@pytest.mark.anyio +async def test_client_session_group_disconnect_non_existent_server(): + """Test disconnecting a server that isn't connected.""" + session = mock.Mock(spec=mcp.ClientSession) + group = ClientSessionGroup() + with pytest.raises(McpError): + await group.disconnect_from_server(session) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "server_params_instance, client_type_name, patch_target_for_client_func", + [ + ( + StdioServerParameters(command="test_stdio_cmd"), + "stdio", + "mcp.client.session_group.mcp.stdio_client", + ), + ( + SseServerParameters(url="http://test.com/sse", timeout=10.0), + "sse", + "mcp.client.session_group.sse_client", + ), # url, headers, timeout, sse_read_timeout + ( + StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), + "streamablehttp", + "mcp.client.session_group.streamable_http_client", + ), # url, headers, timeout, sse_read_timeout, terminate_on_close + ], +) +async def test_client_session_group_establish_session_parameterized( + server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters, + client_type_name: str, # Just for clarity or conditional logic if needed + patch_target_for_client_func: str, +): + with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: + with mock.patch(patch_target_for_client_func) as mock_specific_client_func: + mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM") + mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") + mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") + + # streamable_http_client's __aenter__ returns three values + if client_type_name == "streamablehttp": + mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") + mock_client_cm_instance.__aenter__.return_value = ( mock_read_stream, mock_write_stream, - read_timeout_seconds=None, - sampling_callback=None, - elicitation_callback=None, - list_roots_callback=None, - logging_callback=None, - message_handler=None, - client_info=None, + mock_extra_stream_val, ) - mock_raw_session_cm.__aenter__.assert_awaited_once() - mock_entered_session.initialize.assert_awaited_once() + else: + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) + mock_specific_client_func.return_value = mock_client_cm_instance + + # --- Mock mcp.ClientSession (class) --- + # mock_ClientSession_class is already provided by the outer patch + mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") + mock_ClientSession_class.return_value = mock_raw_session_cm + + mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") + mock_raw_session_cm.__aenter__.return_value = mock_entered_session + mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + # Mock session.initialize() + mock_initialize_result = mock.AsyncMock(name="InitializeResult") + mock_initialize_result.server_info = types.Implementation(name="foo", version="1") + mock_entered_session.initialize.return_value = mock_initialize_result + + # --- Test Execution --- + group = ClientSessionGroup() + returned_server_info = None + returned_session = None + + async with contextlib.AsyncExitStack() as stack: + group._exit_stack = stack + ( + returned_server_info, + returned_session, + ) = await group._establish_session(server_params_instance, ClientSessionParameters()) + + # --- Assertions --- + # 1. Assert the correct specific client function was called + if client_type_name == "stdio": + assert isinstance(server_params_instance, StdioServerParameters) + mock_specific_client_func.assert_called_once_with(server_params_instance) + elif client_type_name == "sse": + assert isinstance(server_params_instance, SseServerParameters) + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + ) + elif client_type_name == "streamablehttp": # pragma: no branch + assert isinstance(server_params_instance, StreamableHttpParameters) + # Verify streamable_http_client was called with url, httpx_client, and terminate_on_close + # The http_client is created by the real create_mcp_http_client + call_args = mock_specific_client_func.call_args + assert call_args.kwargs["url"] == server_params_instance.url + assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close + assert isinstance(call_args.kwargs["http_client"], httpx.AsyncClient) + + mock_client_cm_instance.__aenter__.assert_awaited_once() + + # 2. Assert ClientSession was called correctly + mock_ClientSession_class.assert_called_once_with( + mock_read_stream, + mock_write_stream, + read_timeout_seconds=None, + sampling_callback=None, + elicitation_callback=None, + list_roots_callback=None, + logging_callback=None, + message_handler=None, + client_info=None, + ) + mock_raw_session_cm.__aenter__.assert_awaited_once() + mock_entered_session.initialize.assert_awaited_once() - # 3. Assert returned values - assert returned_server_info is mock_initialize_result.server_info - assert returned_session is mock_entered_session + # 3. Assert returned values + assert returned_server_info is mock_initialize_result.server_info + assert returned_session is mock_entered_session diff --git a/tests/server/auth/middleware/test_auth_context.py b/tests/server/auth/middleware/test_auth_context.py index 236490922..66481bcf7 100644 --- a/tests/server/auth/middleware/test_auth_context.py +++ b/tests/server/auth/middleware/test_auth_context.py @@ -45,76 +45,75 @@ def valid_access_token() -> AccessToken: @pytest.mark.anyio -class TestAuthContextMiddleware: - """Tests for the AuthContextMiddleware class.""" +async def test_auth_context_middleware_with_authenticated_user(valid_access_token: AccessToken): + """Test middleware with an authenticated user in scope.""" + app = MockApp() + middleware = AuthContextMiddleware(app) - async def test_with_authenticated_user(self, valid_access_token: AccessToken): - """Test middleware with an authenticated user in scope.""" - app = MockApp() - middleware = AuthContextMiddleware(app) + # Create an authenticated user + user = AuthenticatedUser(valid_access_token) - # Create an authenticated user - user = AuthenticatedUser(valid_access_token) + scope: Scope = {"type": "http", "user": user} - scope: Scope = {"type": "http", "user": user} + # Create dummy async functions for receive and send + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} - # Create dummy async functions for receive and send - async def receive() -> Message: # pragma: no cover - return {"type": "http.request"} + async def send(message: Message) -> None: # pragma: no cover + pass - async def send(message: Message) -> None: # pragma: no cover - pass + # Verify context is empty before middleware + assert auth_context_var.get() is None + assert get_access_token() is None - # Verify context is empty before middleware - assert auth_context_var.get() is None - assert get_access_token() is None + # Run the middleware + await middleware(scope, receive, send) - # Run the middleware - await middleware(scope, receive, send) + # Verify the app was called + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send - # Verify the app was called - assert app.called - assert app.scope == scope - assert app.receive == receive - assert app.send == send + # Verify the access token was available during the call + assert app.access_token_during_call == valid_access_token - # Verify the access token was available during the call - assert app.access_token_during_call == valid_access_token + # Verify context is reset after middleware + assert auth_context_var.get() is None + assert get_access_token() is None - # Verify context is reset after middleware - assert auth_context_var.get() is None - assert get_access_token() is None - async def test_with_no_user(self): - """Test middleware with no user in scope.""" - app = MockApp() - middleware = AuthContextMiddleware(app) +@pytest.mark.anyio +async def test_auth_context_middleware_with_no_user(): + """Test middleware with no user in scope.""" + app = MockApp() + middleware = AuthContextMiddleware(app) - scope: Scope = {"type": "http"} # No user + scope: Scope = {"type": "http"} # No user - # Create dummy async functions for receive and send - async def receive() -> Message: # pragma: no cover - return {"type": "http.request"} + # Create dummy async functions for receive and send + async def receive() -> Message: # pragma: no cover + return {"type": "http.request"} - async def send(message: Message) -> None: # pragma: no cover - pass + async def send(message: Message) -> None: # pragma: no cover + pass - # Verify context is empty before middleware - assert auth_context_var.get() is None - assert get_access_token() is None + # Verify context is empty before middleware + assert auth_context_var.get() is None + assert get_access_token() is None - # Run the middleware - await middleware(scope, receive, send) + # Run the middleware + await middleware(scope, receive, send) - # Verify the app was called - assert app.called - assert app.scope == scope - assert app.receive == receive - assert app.send == send + # Verify the app was called + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send - # Verify the access token was not available during the call - assert app.access_token_during_call is None + # Verify the access token was not available during the call + assert app.access_token_during_call is None - # Verify context is still empty after middleware - assert auth_context_var.get() is None - assert get_access_token() is None + # Verify context is still empty after middleware + assert auth_context_var.get() is None + assert get_access_token() is None diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index f8c799147..8eafbcdbb 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -83,176 +83,120 @@ async def registered_client(client: httpx.AsyncClient) -> dict[str, Any]: return client_info -class TestRegistrationErrorHandling: - @pytest.mark.anyio - async def test_registration_error_handling(self, client: httpx.AsyncClient, oauth_provider: MockOAuthProvider): - # Mock the register_client method to raise a registration error - with unittest.mock.patch.object( - oauth_provider, - "register_client", - side_effect=RegistrationError( - error="invalid_redirect_uri", - error_description="The redirect URI is invalid", - ), - ): - # Prepare a client registration request - client_data = { - "redirect_uris": ["https://client.example.com/callback"], - "token_endpoint_auth_method": "client_secret_post", - "grant_types": ["authorization_code", "refresh_token"], - "response_types": ["code"], - "client_name": "Test Client", - } - - # Send the registration request - response = await client.post( - "/register", - json=client_data, - ) - - # Verify the response - assert response.status_code == 400, response.content - data = response.json() - assert data["error"] == "invalid_redirect_uri" - assert data["error_description"] == "The redirect URI is invalid" - - -class TestAuthorizeErrorHandling: - @pytest.mark.anyio - async def test_authorize_error_handling( - self, - client: httpx.AsyncClient, - oauth_provider: MockOAuthProvider, - registered_client: dict[str, Any], - pkce_challenge: dict[str, str], - ): - # Mock the authorize method to raise an authorize error - with unittest.mock.patch.object( - oauth_provider, - "authorize", - side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"), - ): - # Register the client - client_id = registered_client["client_id"] - redirect_uri = registered_client["redirect_uris"][0] - - # Prepare an authorization request - params = { - "client_id": client_id, - "redirect_uri": redirect_uri, - "response_type": "code", - "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256", - "state": "test_state", - } - - # Send the authorization request - response = await client.get("/authorize", params=params) - - # Verify the response is a redirect with error parameters - assert response.status_code == 302 - redirect_url = response.headers["location"] - parsed_url = urlparse(redirect_url) - query_params = parse_qs(parsed_url.query) - - assert query_params["error"][0] == "access_denied" - assert "error_description" in query_params - assert query_params["state"][0] == "test_state" - - -class TestTokenErrorHandling: - @pytest.mark.anyio - async def test_token_error_handling_auth_code( - self, - client: httpx.AsyncClient, - oauth_provider: MockOAuthProvider, - registered_client: dict[str, Any], - pkce_challenge: dict[str, str], +@pytest.mark.anyio +async def test_registration_error_handling(client: httpx.AsyncClient, oauth_provider: MockOAuthProvider): + # Mock the register_client method to raise a registration error + with unittest.mock.patch.object( + oauth_provider, + "register_client", + side_effect=RegistrationError( + error="invalid_redirect_uri", + error_description="The redirect URI is invalid", + ), ): - # Register the client and get an auth code - client_id = registered_client["client_id"] - client_secret = registered_client["client_secret"] - redirect_uri = registered_client["redirect_uris"][0] - - # First get an authorization code - auth_response = await client.get( - "/authorize", - params={ - "client_id": client_id, - "redirect_uri": redirect_uri, - "response_type": "code", - "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256", - "state": "test_state", - }, + # Prepare a client registration request + client_data = { + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "client_name": "Test Client", + } + + # Send the registration request + response = await client.post( + "/register", + json=client_data, ) - redirect_url = auth_response.headers["location"] - parsed_url = urlparse(redirect_url) - query_params = parse_qs(parsed_url.query) - code = query_params["code"][0] - - # Mock the exchange_authorization_code method to raise a token error - with unittest.mock.patch.object( - oauth_provider, - "exchange_authorization_code", - side_effect=TokenError( - error="invalid_grant", - error_description="The authorization code is invalid", - ), - ): - # Try to exchange the code for tokens - token_response = await client.post( - "/token", - data={ - "grant_type": "authorization_code", - "code": code, - "redirect_uri": redirect_uri, - "client_id": client_id, - "client_secret": client_secret, - "code_verifier": pkce_challenge["code_verifier"], - }, - ) - - # Verify the response - assert token_response.status_code == 400 - data = token_response.json() - assert data["error"] == "invalid_grant" - assert data["error_description"] == "The authorization code is invalid" - - @pytest.mark.anyio - async def test_token_error_handling_refresh_token( - self, - client: httpx.AsyncClient, - oauth_provider: MockOAuthProvider, - registered_client: dict[str, Any], - pkce_challenge: dict[str, str], + # Verify the response + assert response.status_code == 400, response.content + data = response.json() + assert data["error"] == "invalid_redirect_uri" + assert data["error_description"] == "The redirect URI is invalid" + + +@pytest.mark.anyio +async def test_authorize_error_handling( + client: httpx.AsyncClient, + oauth_provider: MockOAuthProvider, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], +): + # Mock the authorize method to raise an authorize error + with unittest.mock.patch.object( + oauth_provider, + "authorize", + side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"), ): - # Register the client and get tokens + # Register the client client_id = registered_client["client_id"] - client_secret = registered_client["client_secret"] redirect_uri = registered_client["redirect_uris"][0] - # First get an authorization code - auth_response = await client.get( - "/authorize", - params={ - "client_id": client_id, - "redirect_uri": redirect_uri, - "response_type": "code", - "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256", - "state": "test_state", - }, - ) - assert auth_response.status_code == 302, auth_response.content - - redirect_url = auth_response.headers["location"] + # Prepare an authorization request + params = { + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + } + + # Send the authorization request + response = await client.get("/authorize", params=params) + + # Verify the response is a redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - code = query_params["code"][0] - # Exchange the code for tokens + assert query_params["error"][0] == "access_denied" + assert "error_description" in query_params + assert query_params["state"][0] == "test_state" + + +@pytest.mark.anyio +async def test_token_error_handling_auth_code( + client: httpx.AsyncClient, + oauth_provider: MockOAuthProvider, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], +): + # Register the client and get an auth code + client_id = registered_client["client_id"] + client_secret = registered_client["client_secret"] + redirect_uri = registered_client["redirect_uris"][0] + + # First get an authorization code + auth_response = await client.get( + "/authorize", + params={ + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + redirect_url = auth_response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + code = query_params["code"][0] + + # Mock the exchange_authorization_code method to raise a token error + with unittest.mock.patch.object( + oauth_provider, + "exchange_authorization_code", + side_effect=TokenError( + error="invalid_grant", + error_description="The authorization code is invalid", + ), + ): + # Try to exchange the code for tokens token_response = await client.post( "/token", data={ @@ -265,31 +209,82 @@ async def test_token_error_handling_refresh_token( }, ) - tokens = token_response.json() - refresh_token = tokens["refresh_token"] - - # Mock the exchange_refresh_token method to raise a token error - with unittest.mock.patch.object( - oauth_provider, - "exchange_refresh_token", - side_effect=TokenError( - error="invalid_scope", - error_description="The requested scope is invalid", - ), - ): - # Try to use the refresh token - refresh_response = await client.post( - "/token", - data={ - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": client_id, - "client_secret": client_secret, - }, - ) - - # Verify the response - assert refresh_response.status_code == 400 - data = refresh_response.json() - assert data["error"] == "invalid_scope" - assert data["error_description"] == "The requested scope is invalid" + # Verify the response + assert token_response.status_code == 400 + data = token_response.json() + assert data["error"] == "invalid_grant" + assert data["error_description"] == "The authorization code is invalid" + + +@pytest.mark.anyio +async def test_token_error_handling_refresh_token( + client: httpx.AsyncClient, + oauth_provider: MockOAuthProvider, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], +): + # Register the client and get tokens + client_id = registered_client["client_id"] + client_secret = registered_client["client_secret"] + redirect_uri = registered_client["redirect_uris"][0] + + # First get an authorization code + auth_response = await client.get( + "/authorize", + params={ + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + assert auth_response.status_code == 302, auth_response.content + + redirect_url = auth_response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + code = query_params["code"][0] + + # Exchange the code for tokens + token_response = await client.post( + "/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + "code_verifier": pkce_challenge["code_verifier"], + }, + ) + + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Mock the exchange_refresh_token method to raise a token error + with unittest.mock.patch.object( + oauth_provider, + "exchange_refresh_token", + side_effect=TokenError( + error="invalid_scope", + error_description="The requested scope is invalid", + ), + ): + # Try to use the refresh token + refresh_response = await client.post( + "/token", + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Verify the response + assert refresh_response.status_code == 400 + data = refresh_response.json() + assert data["error"] == "invalid_scope" + assert data["error_description"] == "The requested scope is invalid" diff --git a/tests/server/auth/test_protected_resource.py b/tests/server/auth/test_protected_resource.py index 594541420..413a80276 100644 --- a/tests/server/auth/test_protected_resource.py +++ b/tests/server/auth/test_protected_resource.py @@ -105,90 +105,94 @@ async def test_metadata_endpoint_without_path(root_resource_client: httpx.AsyncC ) -class TestMetadataUrlConstruction: - """Test URL construction utility function.""" - - def test_url_without_path(self): - """Test URL construction for resource without path component.""" - resource_url = AnyHttpUrl("https://example.com") - result = build_resource_metadata_url(resource_url) - assert str(result) == "https://example.com/.well-known/oauth-protected-resource" - - def test_url_with_path_component(self): - """Test URL construction for resource with path component.""" - resource_url = AnyHttpUrl("https://example.com/mcp") - result = build_resource_metadata_url(resource_url) - assert str(result) == "https://example.com/.well-known/oauth-protected-resource/mcp" - - def test_url_with_trailing_slash_only(self): - """Test URL construction for resource with trailing slash only.""" - resource_url = AnyHttpUrl("https://example.com/") - result = build_resource_metadata_url(resource_url) - # Trailing slash should be treated as empty path - assert str(result) == "https://example.com/.well-known/oauth-protected-resource" - - @pytest.mark.parametrize( - "resource_url,expected_url", - [ - ("https://example.com", "https://example.com/.well-known/oauth-protected-resource"), - ("https://example.com/", "https://example.com/.well-known/oauth-protected-resource"), - ("https://example.com/mcp", "https://example.com/.well-known/oauth-protected-resource/mcp"), - ("http://localhost:8001/mcp", "http://localhost:8001/.well-known/oauth-protected-resource/mcp"), - ], +# Tests for URL construction utility function + + +def test_metadata_url_construction_url_without_path(): + """Test URL construction for resource without path component.""" + resource_url = AnyHttpUrl("https://example.com") + result = build_resource_metadata_url(resource_url) + assert str(result) == "https://example.com/.well-known/oauth-protected-resource" + + +def test_metadata_url_construction_url_with_path_component(): + """Test URL construction for resource with path component.""" + resource_url = AnyHttpUrl("https://example.com/mcp") + result = build_resource_metadata_url(resource_url) + assert str(result) == "https://example.com/.well-known/oauth-protected-resource/mcp" + + +def test_metadata_url_construction_url_with_trailing_slash_only(): + """Test URL construction for resource with trailing slash only.""" + resource_url = AnyHttpUrl("https://example.com/") + result = build_resource_metadata_url(resource_url) + # Trailing slash should be treated as empty path + assert str(result) == "https://example.com/.well-known/oauth-protected-resource" + + +@pytest.mark.parametrize( + "resource_url,expected_url", + [ + ("https://example.com", "https://example.com/.well-known/oauth-protected-resource"), + ("https://example.com/", "https://example.com/.well-known/oauth-protected-resource"), + ("https://example.com/mcp", "https://example.com/.well-known/oauth-protected-resource/mcp"), + ("http://localhost:8001/mcp", "http://localhost:8001/.well-known/oauth-protected-resource/mcp"), + ], +) +def test_metadata_url_construction_various_resource_configurations(resource_url: str, expected_url: str): + """Test URL construction with various resource configurations.""" + result = build_resource_metadata_url(AnyHttpUrl(resource_url)) + assert str(result) == expected_url + + +# Tests for consistency between URL generation and route registration + + +def test_route_consistency_route_path_matches_metadata_url(): + """Test that route path matches the generated metadata URL.""" + resource_url = AnyHttpUrl("https://example.com/mcp") + + # Generate metadata URL + metadata_url = build_resource_metadata_url(resource_url) + + # Create routes + routes = create_protected_resource_routes( + resource_url=resource_url, + authorization_servers=[AnyHttpUrl("https://auth.example.com")], ) - def test_various_resource_configurations(self, resource_url: str, expected_url: str): - """Test URL construction with various resource configurations.""" - result = build_resource_metadata_url(AnyHttpUrl(resource_url)) - assert str(result) == expected_url - - -class TestRouteConsistency: - """Test consistency between URL generation and route registration.""" - - def test_route_path_matches_metadata_url(self): - """Test that route path matches the generated metadata URL.""" - resource_url = AnyHttpUrl("https://example.com/mcp") - - # Generate metadata URL - metadata_url = build_resource_metadata_url(resource_url) - - # Create routes - routes = create_protected_resource_routes( - resource_url=resource_url, - authorization_servers=[AnyHttpUrl("https://auth.example.com")], - ) - - # Extract path from metadata URL - metadata_path = urlparse(str(metadata_url)).path - - # Verify consistency - assert len(routes) == 1 - assert routes[0].path == metadata_path - - @pytest.mark.parametrize( - "resource_url,expected_path", - [ - ("https://example.com", "/.well-known/oauth-protected-resource"), - ("https://example.com/", "/.well-known/oauth-protected-resource"), - ("https://example.com/mcp", "/.well-known/oauth-protected-resource/mcp"), - ], + + # Extract path from metadata URL + metadata_path = urlparse(str(metadata_url)).path + + # Verify consistency + assert len(routes) == 1 + assert routes[0].path == metadata_path + + +@pytest.mark.parametrize( + "resource_url,expected_path", + [ + ("https://example.com", "/.well-known/oauth-protected-resource"), + ("https://example.com/", "/.well-known/oauth-protected-resource"), + ("https://example.com/mcp", "/.well-known/oauth-protected-resource/mcp"), + ], +) +def test_route_consistency_consistent_paths_for_various_resources(resource_url: str, expected_path: str): + """Test that URL generation and route creation are consistent.""" + resource_url_obj = AnyHttpUrl(resource_url) + + # Test URL generation + metadata_url = build_resource_metadata_url(resource_url_obj) + url_path = urlparse(str(metadata_url)).path + + # Test route creation + routes = create_protected_resource_routes( + resource_url=resource_url_obj, + authorization_servers=[AnyHttpUrl("https://auth.example.com")], ) - def test_consistent_paths_for_various_resources(self, resource_url: str, expected_path: str): - """Test that URL generation and route creation are consistent.""" - resource_url_obj = AnyHttpUrl(resource_url) - - # Test URL generation - metadata_url = build_resource_metadata_url(resource_url_obj) - url_path = urlparse(str(metadata_url)).path - - # Test route creation - routes = create_protected_resource_routes( - resource_url=resource_url_obj, - authorization_servers=[AnyHttpUrl("https://auth.example.com")], - ) - route_path = routes[0].path - - # Both should match expected path - assert url_path == expected_path - assert route_path == expected_path - assert url_path == route_path + route_path = routes[0].path + + # Both should match expected path + assert url_path == expected_path + assert route_path == expected_path + assert url_path == route_path diff --git a/tests/server/auth/test_provider.py b/tests/server/auth/test_provider.py index 89a7cbede..aaaeb413a 100644 --- a/tests/server/auth/test_provider.py +++ b/tests/server/auth/test_provider.py @@ -3,73 +3,77 @@ from mcp.server.auth.provider import construct_redirect_uri -class TestConstructRedirectUri: - """Tests for the construct_redirect_uri function.""" - - def test_construct_redirect_uri_no_existing_params(self): - """Test construct_redirect_uri with no existing query parameters.""" - base_uri = "http://localhost:8000/callback" - result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") - - assert "http://localhost:8000/callback?code=auth_code&state=test_state" == result - - def test_construct_redirect_uri_with_existing_params(self): - """Test construct_redirect_uri with existing query parameters (regression test for #1279).""" - base_uri = "http://localhost:8000/callback?session_id=1234" - result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") - - # Should preserve existing params and add new ones - assert "session_id=1234" in result - assert "code=auth_code" in result - assert "state=test_state" in result - assert result.startswith("http://localhost:8000/callback?") - - def test_construct_redirect_uri_multiple_existing_params(self): - """Test construct_redirect_uri with multiple existing query parameters.""" - base_uri = "http://localhost:8000/callback?session_id=1234&user=test" - result = construct_redirect_uri(base_uri, code="auth_code") - - assert "session_id=1234" in result - assert "user=test" in result - assert "code=auth_code" in result - - def test_construct_redirect_uri_with_none_values(self): - """Test construct_redirect_uri filters out None values.""" - base_uri = "http://localhost:8000/callback" - result = construct_redirect_uri(base_uri, code="auth_code", state=None) - - assert result == "http://localhost:8000/callback?code=auth_code" - assert "state" not in result - - def test_construct_redirect_uri_empty_params(self): - """Test construct_redirect_uri with no additional parameters.""" - base_uri = "http://localhost:8000/callback?existing=param" - result = construct_redirect_uri(base_uri) - - assert result == "http://localhost:8000/callback?existing=param" - - def test_construct_redirect_uri_duplicate_param_names(self): - """Test construct_redirect_uri when adding param that already exists.""" - base_uri = "http://localhost:8000/callback?code=existing" - result = construct_redirect_uri(base_uri, code="new_code") - - # Should contain both values (this is expected behavior of parse_qs/urlencode) - assert "code=existing" in result - assert "code=new_code" in result - - def test_construct_redirect_uri_multivalued_existing_params(self): - """Test construct_redirect_uri with existing multi-valued parameters.""" - base_uri = "http://localhost:8000/callback?scope=read&scope=write" - result = construct_redirect_uri(base_uri, code="auth_code") - - assert "scope=read" in result - assert "scope=write" in result - assert "code=auth_code" in result - - def test_construct_redirect_uri_encoded_values(self): - """Test construct_redirect_uri handles URL encoding properly.""" - base_uri = "http://localhost:8000/callback" - result = construct_redirect_uri(base_uri, state="test state with spaces") - - # urlencode uses + for spaces by default - assert "state=test+state+with+spaces" in result +def test_construct_redirect_uri_no_existing_params(): + """Test construct_redirect_uri with no existing query parameters.""" + base_uri = "http://localhost:8000/callback" + result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") + + assert "http://localhost:8000/callback?code=auth_code&state=test_state" == result + + +def test_construct_redirect_uri_with_existing_params(): + """Test construct_redirect_uri with existing query parameters (regression test for #1279).""" + base_uri = "http://localhost:8000/callback?session_id=1234" + result = construct_redirect_uri(base_uri, code="auth_code", state="test_state") + + # Should preserve existing params and add new ones + assert "session_id=1234" in result + assert "code=auth_code" in result + assert "state=test_state" in result + assert result.startswith("http://localhost:8000/callback?") + + +def test_construct_redirect_uri_multiple_existing_params(): + """Test construct_redirect_uri with multiple existing query parameters.""" + base_uri = "http://localhost:8000/callback?session_id=1234&user=test" + result = construct_redirect_uri(base_uri, code="auth_code") + + assert "session_id=1234" in result + assert "user=test" in result + assert "code=auth_code" in result + + +def test_construct_redirect_uri_with_none_values(): + """Test construct_redirect_uri filters out None values.""" + base_uri = "http://localhost:8000/callback" + result = construct_redirect_uri(base_uri, code="auth_code", state=None) + + assert result == "http://localhost:8000/callback?code=auth_code" + assert "state" not in result + + +def test_construct_redirect_uri_empty_params(): + """Test construct_redirect_uri with no additional parameters.""" + base_uri = "http://localhost:8000/callback?existing=param" + result = construct_redirect_uri(base_uri) + + assert result == "http://localhost:8000/callback?existing=param" + + +def test_construct_redirect_uri_duplicate_param_names(): + """Test construct_redirect_uri when adding param that already exists.""" + base_uri = "http://localhost:8000/callback?code=existing" + result = construct_redirect_uri(base_uri, code="new_code") + + # Should contain both values (this is expected behavior of parse_qs/urlencode) + assert "code=existing" in result + assert "code=new_code" in result + + +def test_construct_redirect_uri_multivalued_existing_params(): + """Test construct_redirect_uri with existing multi-valued parameters.""" + base_uri = "http://localhost:8000/callback?scope=read&scope=write" + result = construct_redirect_uri(base_uri, code="auth_code") + + assert "scope=read" in result + assert "scope=write" in result + assert "code=auth_code" in result + + +def test_construct_redirect_uri_encoded_values(): + """Test construct_redirect_uri handles URL encoding properly.""" + base_uri = "http://localhost:8000/callback" + result = construct_redirect_uri(base_uri, state="test state with spaces") + + # urlencode uses + for spaces by default + assert "state=test+state+with+spaces" in result diff --git a/tests/server/lowlevel/test_helper_types.py b/tests/server/lowlevel/test_helper_types.py index 27a8081b6..e29273d3f 100644 --- a/tests/server/lowlevel/test_helper_types.py +++ b/tests/server/lowlevel/test_helper_types.py @@ -10,51 +10,50 @@ from mcp.server.lowlevel.helper_types import ReadResourceContents -class TestReadResourceContentsMetadata: - """Test ReadResourceContents meta field. +def test_read_resource_contents_with_metadata(): + """Test that ReadResourceContents accepts meta parameter. ReadResourceContents is an internal helper type used by the low-level MCP server. When a resource is read, the server creates a ReadResourceContents instance that contains the content, mime type, and now metadata. The low-level server then extracts the meta field and includes it in the protocol response as _meta. """ + # Bridge between Resource.meta and MCP protocol _meta field (helper_types.py:11) + metadata = {"version": "1.0", "cached": True} - def test_read_resource_contents_with_metadata(self): - """Test that ReadResourceContents accepts meta parameter.""" - # Bridge between Resource.meta and MCP protocol _meta field (helper_types.py:11) - metadata = {"version": "1.0", "cached": True} - - contents = ReadResourceContents( - content="test content", - mime_type="text/plain", - meta=metadata, - ) - - assert contents.meta is not None - assert contents.meta == metadata - assert contents.meta["version"] == "1.0" - assert contents.meta["cached"] is True - - def test_read_resource_contents_without_metadata(self): - """Test that ReadResourceContents meta defaults to None.""" - # Ensures backward compatibility - meta defaults to None, _meta omitted from protocol (helper_types.py:11) - contents = ReadResourceContents( - content="test content", - mime_type="text/plain", - ) - - assert contents.meta is None - - def test_read_resource_contents_with_bytes(self): - """Test that ReadResourceContents works with bytes content and meta.""" - # Verifies meta works with both str and bytes content (binary resources like images, PDFs) - metadata = {"encoding": "utf-8"} - - contents = ReadResourceContents( - content=b"binary content", - mime_type="application/octet-stream", - meta=metadata, - ) - - assert contents.content == b"binary content" - assert contents.meta == metadata + contents = ReadResourceContents( + content="test content", + mime_type="text/plain", + meta=metadata, + ) + + assert contents.meta is not None + assert contents.meta == metadata + assert contents.meta["version"] == "1.0" + assert contents.meta["cached"] is True + + +def test_read_resource_contents_without_metadata(): + """Test that ReadResourceContents meta defaults to None.""" + # Ensures backward compatibility - meta defaults to None, _meta omitted from protocol (helper_types.py:11) + contents = ReadResourceContents( + content="test content", + mime_type="text/plain", + ) + + assert contents.meta is None + + +def test_read_resource_contents_with_bytes(): + """Test that ReadResourceContents works with bytes content and meta.""" + # Verifies meta works with both str and bytes content (binary resources like images, PDFs) + metadata = {"encoding": "utf-8"} + + contents = ReadResourceContents( + content=b"binary content", + mime_type="application/octet-stream", + meta=metadata, + ) + + assert contents.content == b"binary content" + assert contents.meta == metadata diff --git a/tests/server/test_validation.py b/tests/server/test_validation.py index 11c61d93b..4583e470c 100644 --- a/tests/server/test_validation.py +++ b/tests/server/test_validation.py @@ -20,122 +20,132 @@ ToolUseContent, ) +# Tests for check_sampling_tools_capability function -class TestCheckSamplingToolsCapability: - """Tests for check_sampling_tools_capability function.""" - - def test_returns_false_when_caps_none(self) -> None: - """Returns False when client_caps is None.""" - assert check_sampling_tools_capability(None) is False - - def test_returns_false_when_sampling_none(self) -> None: - """Returns False when client_caps.sampling is None.""" - caps = ClientCapabilities() - assert check_sampling_tools_capability(caps) is False - - def test_returns_false_when_tools_none(self) -> None: - """Returns False when client_caps.sampling.tools is None.""" - caps = ClientCapabilities(sampling=SamplingCapability()) - assert check_sampling_tools_capability(caps) is False - - def test_returns_true_when_tools_present(self) -> None: - """Returns True when sampling.tools is present.""" - caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) - assert check_sampling_tools_capability(caps) is True - - -class TestValidateSamplingTools: - """Tests for validate_sampling_tools function.""" - - def test_no_error_when_tools_none(self) -> None: - """No error when tools and tool_choice are None.""" - validate_sampling_tools(None, None, None) # Should not raise - - def test_raises_when_tools_provided_but_no_capability(self) -> None: - """Raises McpError when tools provided but client doesn't support.""" - tool = Tool(name="test", input_schema={"type": "object"}) - with pytest.raises(McpError) as exc_info: - validate_sampling_tools(None, [tool], None) - assert "sampling tools capability" in str(exc_info.value) - - def test_raises_when_tool_choice_provided_but_no_capability(self) -> None: - """Raises McpError when tool_choice provided but client doesn't support.""" - with pytest.raises(McpError) as exc_info: - validate_sampling_tools(None, None, ToolChoice(mode="auto")) - assert "sampling tools capability" in str(exc_info.value) - - def test_no_error_when_capability_present(self) -> None: - """No error when client has sampling.tools capability.""" - caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) - tool = Tool(name="test", input_schema={"type": "object"}) - validate_sampling_tools(caps, [tool], ToolChoice(mode="auto")) # Should not raise - - -class TestValidateToolUseResultMessages: - """Tests for validate_tool_use_result_messages function.""" - - def test_no_error_for_empty_messages(self) -> None: - """No error when messages list is empty.""" - validate_tool_use_result_messages([]) # Should not raise - - def test_no_error_for_simple_text_messages(self) -> None: - """No error for simple text messages.""" - messages = [ - SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), - SamplingMessage(role="assistant", content=TextContent(type="text", text="Hi")), - ] - validate_tool_use_result_messages(messages) # Should not raise - - def test_raises_when_tool_result_mixed_with_other_content(self) -> None: - """Raises when tool_result is mixed with other content types.""" - messages = [ - SamplingMessage( - role="user", - content=[ - ToolResultContent(type="tool_result", tool_use_id="123"), - TextContent(type="text", text="also this"), - ], - ), - ] - with pytest.raises(ValueError, match="only tool_result content"): - validate_tool_use_result_messages(messages) - - def test_raises_when_tool_result_without_previous_tool_use(self) -> None: - """Raises when tool_result appears without preceding tool_use.""" - messages = [ - SamplingMessage( - role="user", - content=ToolResultContent(type="tool_result", tool_use_id="123"), - ), - ] - with pytest.raises(ValueError, match="previous message containing tool_use"): - validate_tool_use_result_messages(messages) - - def test_raises_when_tool_result_ids_dont_match_tool_use(self) -> None: - """Raises when tool_result IDs don't match tool_use IDs.""" - messages = [ - SamplingMessage( - role="assistant", - content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), - ), - SamplingMessage( - role="user", - content=ToolResultContent(type="tool_result", tool_use_id="tool-2"), - ), - ] - with pytest.raises(ValueError, match="do not match"): - validate_tool_use_result_messages(messages) - - def test_no_error_when_tool_result_matches_tool_use(self) -> None: - """No error when tool_result IDs match tool_use IDs.""" - messages = [ - SamplingMessage( - role="assistant", - content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), - ), - SamplingMessage( - role="user", - content=ToolResultContent(type="tool_result", tool_use_id="tool-1"), - ), - ] - validate_tool_use_result_messages(messages) # Should not raise + +def test_check_sampling_tools_capability_returns_false_when_caps_none() -> None: + """Returns False when client_caps is None.""" + assert check_sampling_tools_capability(None) is False + + +def test_check_sampling_tools_capability_returns_false_when_sampling_none() -> None: + """Returns False when client_caps.sampling is None.""" + caps = ClientCapabilities() + assert check_sampling_tools_capability(caps) is False + + +def test_check_sampling_tools_capability_returns_false_when_tools_none() -> None: + """Returns False when client_caps.sampling.tools is None.""" + caps = ClientCapabilities(sampling=SamplingCapability()) + assert check_sampling_tools_capability(caps) is False + + +def test_check_sampling_tools_capability_returns_true_when_tools_present() -> None: + """Returns True when sampling.tools is present.""" + caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + assert check_sampling_tools_capability(caps) is True + + +# Tests for validate_sampling_tools function + + +def test_validate_sampling_tools_no_error_when_tools_none() -> None: + """No error when tools and tool_choice are None.""" + validate_sampling_tools(None, None, None) # Should not raise + + +def test_validate_sampling_tools_raises_when_tools_provided_but_no_capability() -> None: + """Raises McpError when tools provided but client doesn't support.""" + tool = Tool(name="test", input_schema={"type": "object"}) + with pytest.raises(McpError) as exc_info: + validate_sampling_tools(None, [tool], None) + assert "sampling tools capability" in str(exc_info.value) + + +def test_validate_sampling_tools_raises_when_tool_choice_provided_but_no_capability() -> None: + """Raises McpError when tool_choice provided but client doesn't support.""" + with pytest.raises(McpError) as exc_info: + validate_sampling_tools(None, None, ToolChoice(mode="auto")) + assert "sampling tools capability" in str(exc_info.value) + + +def test_validate_sampling_tools_no_error_when_capability_present() -> None: + """No error when client has sampling.tools capability.""" + caps = ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + tool = Tool(name="test", input_schema={"type": "object"}) + validate_sampling_tools(caps, [tool], ToolChoice(mode="auto")) # Should not raise + + +# Tests for validate_tool_use_result_messages function + + +def test_validate_tool_use_result_messages_no_error_for_empty_messages() -> None: + """No error when messages list is empty.""" + validate_tool_use_result_messages([]) # Should not raise + + +def test_validate_tool_use_result_messages_no_error_for_simple_text_messages() -> None: + """No error for simple text messages.""" + messages = [ + SamplingMessage(role="user", content=TextContent(type="text", text="Hello")), + SamplingMessage(role="assistant", content=TextContent(type="text", text="Hi")), + ] + validate_tool_use_result_messages(messages) # Should not raise + + +def test_validate_tool_use_result_messages_raises_when_tool_result_mixed_with_other_content() -> None: + """Raises when tool_result is mixed with other content types.""" + messages = [ + SamplingMessage( + role="user", + content=[ + ToolResultContent(type="tool_result", tool_use_id="123"), + TextContent(type="text", text="also this"), + ], + ), + ] + with pytest.raises(ValueError, match="only tool_result content"): + validate_tool_use_result_messages(messages) + + +def test_validate_tool_use_result_messages_raises_when_tool_result_without_previous_tool_use() -> None: + """Raises when tool_result appears without preceding tool_use.""" + messages = [ + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="123"), + ), + ] + with pytest.raises(ValueError, match="previous message containing tool_use"): + validate_tool_use_result_messages(messages) + + +def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_match_tool_use() -> None: + """Raises when tool_result IDs don't match tool_use IDs.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="tool-2"), + ), + ] + with pytest.raises(ValueError, match="do not match"): + validate_tool_use_result_messages(messages) + + +def test_validate_tool_use_result_messages_no_error_when_tool_result_matches_tool_use() -> None: + """No error when tool_result IDs match tool_use IDs.""" + messages = [ + SamplingMessage( + role="assistant", + content=ToolUseContent(type="tool_use", id="tool-1", name="test", input={}), + ), + SamplingMessage( + role="user", + content=ToolResultContent(type="tool_result", tool_use_id="tool-1"), + ), + ] + validate_tool_use_result_messages(messages) # Should not raise diff --git a/tests/shared/test_auth.py b/tests/shared/test_auth.py index bd9f5a934..cd3c35332 100644 --- a/tests/shared/test_auth.py +++ b/tests/shared/test_auth.py @@ -3,59 +3,58 @@ from mcp.shared.auth import OAuthMetadata -class TestOAuthMetadata: - """Tests for OAuthMetadata parsing.""" +def test_oauth(): + """Should not throw when parsing OAuth metadata.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "scopes_supported": ["read", "write"], + "response_types_supported": ["code", "token"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + } + ) - def test_oauth(self): - """Should not throw when parsing OAuth metadata.""" - OAuthMetadata.model_validate( - { - "issuer": "https://example.com", - "authorization_endpoint": "https://example.com/oauth2/authorize", - "token_endpoint": "https://example.com/oauth2/token", - "scopes_supported": ["read", "write"], - "response_types_supported": ["code", "token"], - "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], - } - ) - def test_oidc(self): - """Should not throw when parsing OIDC metadata.""" - OAuthMetadata.model_validate( - { - "issuer": "https://example.com", - "authorization_endpoint": "https://example.com/oauth2/authorize", - "token_endpoint": "https://example.com/oauth2/token", - "end_session_endpoint": "https://example.com/logout", - "id_token_signing_alg_values_supported": ["RS256"], - "jwks_uri": "https://example.com/.well-known/jwks.json", - "response_types_supported": ["code", "token"], - "revocation_endpoint": "https://example.com/oauth2/revoke", - "scopes_supported": ["openid", "read", "write"], - "subject_types_supported": ["public"], - "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], - "userinfo_endpoint": "https://example.com/oauth2/userInfo", - } - ) +def test_oidc(): + """Should not throw when parsing OIDC metadata.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "end_session_endpoint": "https://example.com/logout", + "id_token_signing_alg_values_supported": ["RS256"], + "jwks_uri": "https://example.com/.well-known/jwks.json", + "response_types_supported": ["code", "token"], + "revocation_endpoint": "https://example.com/oauth2/revoke", + "scopes_supported": ["openid", "read", "write"], + "subject_types_supported": ["public"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + "userinfo_endpoint": "https://example.com/oauth2/userInfo", + } + ) - def test_oauth_with_jarm(self): - """Should not throw when parsing OAuth metadata that includes JARM response modes.""" - OAuthMetadata.model_validate( - { - "issuer": "https://example.com", - "authorization_endpoint": "https://example.com/oauth2/authorize", - "token_endpoint": "https://example.com/oauth2/token", - "scopes_supported": ["read", "write"], - "response_types_supported": ["code", "token"], - "response_modes_supported": [ - "query", - "fragment", - "form_post", - "query.jwt", - "fragment.jwt", - "form_post.jwt", - "jwt", - ], - "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], - } - ) + +def test_oauth_with_jarm(): + """Should not throw when parsing OAuth metadata that includes JARM response modes.""" + OAuthMetadata.model_validate( + { + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/oauth2/authorize", + "token_endpoint": "https://example.com/oauth2/token", + "scopes_supported": ["read", "write"], + "response_types_supported": ["code", "token"], + "response_modes_supported": [ + "query", + "fragment", + "form_post", + "query.jwt", + "fragment.jwt", + "form_post.jwt", + "jwt", + ], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + } + ) diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index d658385cb..2c1c16dc3 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -4,109 +4,120 @@ from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url +# Tests for resource_url_from_server_url function -class TestResourceUrlFromServerUrl: - """Tests for resource_url_from_server_url function.""" - - def test_removes_fragment(self): - """Fragment should be removed per RFC 8707.""" - assert resource_url_from_server_url("https://example.com/path#fragment") == "https://example.com/path" - assert resource_url_from_server_url("https://example.com/#fragment") == "https://example.com/" - - def test_preserves_path(self): - """Path should be preserved.""" - assert ( - resource_url_from_server_url("https://example.com/path/to/resource") - == "https://example.com/path/to/resource" - ) - assert resource_url_from_server_url("https://example.com/") == "https://example.com/" - assert resource_url_from_server_url("https://example.com") == "https://example.com" - - def test_preserves_query(self): - """Query parameters should be preserved.""" - assert resource_url_from_server_url("https://example.com/path?foo=bar") == "https://example.com/path?foo=bar" - assert resource_url_from_server_url("https://example.com/?key=value") == "https://example.com/?key=value" - - def test_preserves_port(self): - """Non-default ports should be preserved.""" - assert resource_url_from_server_url("https://example.com:8443/path") == "https://example.com:8443/path" - assert resource_url_from_server_url("http://example.com:8080/") == "http://example.com:8080/" - - def test_lowercase_scheme_and_host(self): - """Scheme and host should be lowercase for canonical form.""" - assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" - assert resource_url_from_server_url("Http://Example.Com:8080/") == "http://example.com:8080/" - - def test_handles_pydantic_urls(self): - """Should handle Pydantic URL types.""" - url = HttpUrl("https://example.com/path") - assert resource_url_from_server_url(url) == "https://example.com/path" - - -class TestCheckResourceAllowed: - """Tests for check_resource_allowed function.""" - - def test_identical_urls(self): - """Identical URLs should match.""" - assert check_resource_allowed("https://example.com/path", "https://example.com/path") is True - assert check_resource_allowed("https://example.com/", "https://example.com/") is True - assert check_resource_allowed("https://example.com", "https://example.com") is True - - def test_different_schemes(self): - """Different schemes should not match.""" - assert check_resource_allowed("https://example.com/path", "http://example.com/path") is False - assert check_resource_allowed("http://example.com/", "https://example.com/") is False - - def test_different_domains(self): - """Different domains should not match.""" - assert check_resource_allowed("https://example.com/path", "https://example.org/path") is False - assert check_resource_allowed("https://sub.example.com/", "https://example.com/") is False - - def test_different_ports(self): - """Different ports should not match.""" - assert check_resource_allowed("https://example.com:8443/path", "https://example.com/path") is False - assert check_resource_allowed("https://example.com:8080/", "https://example.com:8443/") is False - - def test_hierarchical_matching(self): - """Child paths should match parent paths.""" - # Parent resource allows child resources - assert check_resource_allowed("https://example.com/api/v1/users", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/mcp/server", "https://example.com/mcp") is True - - # Exact match - assert check_resource_allowed("https://example.com/api", "https://example.com/api") is True - - # Parent cannot use child's token - assert check_resource_allowed("https://example.com/api", "https://example.com/api/v1") is False - assert check_resource_allowed("https://example.com/", "https://example.com/api") is False - - def test_path_boundary_matching(self): - """Path matching should respect boundaries.""" - # Should not match partial path segments - assert check_resource_allowed("https://example.com/apiextra", "https://example.com/api") is False - assert check_resource_allowed("https://example.com/api123", "https://example.com/api") is False - - # Should match with trailing slash - assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True - - def test_trailing_slash_handling(self): - """Trailing slashes should be handled correctly.""" - # With and without trailing slashes - assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is False - assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True - assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True - - def test_case_insensitive_origin(self): - """Origin comparison should be case-insensitive.""" - assert check_resource_allowed("https://EXAMPLE.COM/path", "https://example.com/path") is True - assert check_resource_allowed("HTTPS://example.com/path", "https://example.com/path") is True - assert check_resource_allowed("https://Example.Com:8080/api", "https://example.com:8080/api") is True - - def test_empty_paths(self): - """Empty paths should be handled correctly.""" - assert check_resource_allowed("https://example.com", "https://example.com") is True - assert check_resource_allowed("https://example.com/", "https://example.com") is True - assert check_resource_allowed("https://example.com/api", "https://example.com") is True + +def test_resource_url_from_server_url_removes_fragment(): + """Fragment should be removed per RFC 8707.""" + assert resource_url_from_server_url("https://example.com/path#fragment") == "https://example.com/path" + assert resource_url_from_server_url("https://example.com/#fragment") == "https://example.com/" + + +def test_resource_url_from_server_url_preserves_path(): + """Path should be preserved.""" + assert ( + resource_url_from_server_url("https://example.com/path/to/resource") == "https://example.com/path/to/resource" + ) + assert resource_url_from_server_url("https://example.com/") == "https://example.com/" + assert resource_url_from_server_url("https://example.com") == "https://example.com" + + +def test_resource_url_from_server_url_preserves_query(): + """Query parameters should be preserved.""" + assert resource_url_from_server_url("https://example.com/path?foo=bar") == "https://example.com/path?foo=bar" + assert resource_url_from_server_url("https://example.com/?key=value") == "https://example.com/?key=value" + + +def test_resource_url_from_server_url_preserves_port(): + """Non-default ports should be preserved.""" + assert resource_url_from_server_url("https://example.com:8443/path") == "https://example.com:8443/path" + assert resource_url_from_server_url("http://example.com:8080/") == "http://example.com:8080/" + + +def test_resource_url_from_server_url_lowercase_scheme_and_host(): + """Scheme and host should be lowercase for canonical form.""" + assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" + assert resource_url_from_server_url("Http://Example.Com:8080/") == "http://example.com:8080/" + + +def test_resource_url_from_server_url_handles_pydantic_urls(): + """Should handle Pydantic URL types.""" + url = HttpUrl("https://example.com/path") + assert resource_url_from_server_url(url) == "https://example.com/path" + + +# Tests for check_resource_allowed function + + +def test_check_resource_allowed_identical_urls(): + """Identical URLs should match.""" + assert check_resource_allowed("https://example.com/path", "https://example.com/path") is True + assert check_resource_allowed("https://example.com/", "https://example.com/") is True + assert check_resource_allowed("https://example.com", "https://example.com") is True + + +def test_check_resource_allowed_different_schemes(): + """Different schemes should not match.""" + assert check_resource_allowed("https://example.com/path", "http://example.com/path") is False + assert check_resource_allowed("http://example.com/", "https://example.com/") is False + + +def test_check_resource_allowed_different_domains(): + """Different domains should not match.""" + assert check_resource_allowed("https://example.com/path", "https://example.org/path") is False + assert check_resource_allowed("https://sub.example.com/", "https://example.com/") is False + + +def test_check_resource_allowed_different_ports(): + """Different ports should not match.""" + assert check_resource_allowed("https://example.com:8443/path", "https://example.com/path") is False + assert check_resource_allowed("https://example.com:8080/", "https://example.com:8443/") is False + + +def test_check_resource_allowed_hierarchical_matching(): + """Child paths should match parent paths.""" + # Parent resource allows child resources + assert check_resource_allowed("https://example.com/api/v1/users", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/mcp/server", "https://example.com/mcp") is True + + # Exact match + assert check_resource_allowed("https://example.com/api", "https://example.com/api") is True + + # Parent cannot use child's token + assert check_resource_allowed("https://example.com/api", "https://example.com/api/v1") is False + assert check_resource_allowed("https://example.com/", "https://example.com/api") is False + + +def test_check_resource_allowed_path_boundary_matching(): + """Path matching should respect boundaries.""" + # Should not match partial path segments + assert check_resource_allowed("https://example.com/apiextra", "https://example.com/api") is False + assert check_resource_allowed("https://example.com/api123", "https://example.com/api") is False + + # Should match with trailing slash + assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True + + +def test_check_resource_allowed_trailing_slash_handling(): + """Trailing slashes should be handled correctly.""" + # With and without trailing slashes + assert check_resource_allowed("https://example.com/api/", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api", "https://example.com/api/") is False + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api") is True + assert check_resource_allowed("https://example.com/api/v1", "https://example.com/api/") is True + + +def test_check_resource_allowed_case_insensitive_origin(): + """Origin comparison should be case-insensitive.""" + assert check_resource_allowed("https://EXAMPLE.COM/path", "https://example.com/path") is True + assert check_resource_allowed("HTTPS://example.com/path", "https://example.com/path") is True + assert check_resource_allowed("https://Example.Com:8080/api", "https://example.com:8080/api") is True + + +def test_check_resource_allowed_empty_paths(): + """Empty paths should be handled correctly.""" + assert check_resource_allowed("https://example.com", "https://example.com") is True + assert check_resource_allowed("https://example.com/", "https://example.com") is True + assert check_resource_allowed("https://example.com/api", "https://example.com") is True diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 1a42e7aef..70d14c9cd 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -6,154 +6,159 @@ from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData -class TestUrlElicitationRequiredError: - """Tests for UrlElicitationRequiredError exception class.""" - - def test_create_with_single_elicitation(self) -> None: - """Test creating error with a single elicitation.""" - elicitation = ElicitRequestURLParams( +def test_url_elicitation_required_error_create_with_single_elicitation() -> None: + """Test creating error with a single elicitation.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert error.error.code == URL_ELICITATION_REQUIRED + assert error.error.message == "URL elicitation required" + assert len(error.elicitations) == 1 + assert error.elicitations[0].elicitation_id == "test-123" + + +def test_url_elicitation_required_error_create_with_multiple_elicitations() -> None: + """Test creating error with multiple elicitations uses plural message.""" + elicitations = [ + ElicitRequestURLParams( mode="url", - message="Auth required", - url="https://example.com/auth", - elicitation_id="test-123", - ) - error = UrlElicitationRequiredError([elicitation]) - - assert error.error.code == URL_ELICITATION_REQUIRED - assert error.error.message == "URL elicitation required" - assert len(error.elicitations) == 1 - assert error.elicitations[0].elicitation_id == "test-123" - - def test_create_with_multiple_elicitations(self) -> None: - """Test creating error with multiple elicitations uses plural message.""" - elicitations = [ - ElicitRequestURLParams( - mode="url", - message="Auth 1", - url="https://example.com/auth1", - elicitation_id="test-1", - ), - ElicitRequestURLParams( - mode="url", - message="Auth 2", - url="https://example.com/auth2", - elicitation_id="test-2", - ), - ] - error = UrlElicitationRequiredError(elicitations) - - assert error.error.message == "URL elicitations required" # Plural - assert len(error.elicitations) == 2 - - def test_custom_message(self) -> None: - """Test creating error with a custom message.""" - elicitation = ElicitRequestURLParams( + message="Auth 1", + url="https://example.com/auth1", + elicitation_id="test-1", + ), + ElicitRequestURLParams( mode="url", - message="Auth required", - url="https://example.com/auth", - elicitation_id="test-123", - ) - error = UrlElicitationRequiredError([elicitation], message="Custom message") - - assert error.error.message == "Custom message" - - def test_from_error_data(self) -> None: - """Test reconstructing error from ErrorData.""" - error_data = ErrorData( - code=URL_ELICITATION_REQUIRED, - message="URL elicitation required", - data={ - "elicitations": [ - { - "mode": "url", - "message": "Auth required", - "url": "https://example.com/auth", - "elicitationId": "test-123", - } - ] - }, - ) - - error = UrlElicitationRequiredError.from_error(error_data) - - assert len(error.elicitations) == 1 - assert error.elicitations[0].elicitation_id == "test-123" - assert error.elicitations[0].url == "https://example.com/auth" - - def test_from_error_data_wrong_code(self) -> None: - """Test that from_error raises ValueError for wrong error code.""" - error_data = ErrorData( - code=-32600, # Wrong code - message="Some other error", - data={}, - ) - - with pytest.raises(ValueError, match="Expected error code"): - UrlElicitationRequiredError.from_error(error_data) - - def test_serialization_roundtrip(self) -> None: - """Test that error can be serialized and reconstructed.""" - original = UrlElicitationRequiredError( - [ - ElicitRequestURLParams( - mode="url", - message="Auth required", - url="https://example.com/auth", - elicitation_id="test-123", - ) + message="Auth 2", + url="https://example.com/auth2", + elicitation_id="test-2", + ), + ] + error = UrlElicitationRequiredError(elicitations) + + assert error.error.message == "URL elicitations required" # Plural + assert len(error.elicitations) == 2 + + +def test_url_elicitation_required_error_custom_message() -> None: + """Test creating error with a custom message.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + error = UrlElicitationRequiredError([elicitation], message="Custom message") + + assert error.error.message == "Custom message" + + +def test_url_elicitation_required_error_from_error_data() -> None: + """Test reconstructing error from ErrorData.""" + error_data = ErrorData( + code=URL_ELICITATION_REQUIRED, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Auth required", + "url": "https://example.com/auth", + "elicitationId": "test-123", + } ] - ) + }, + ) - # Simulate serialization over wire - error_data = original.error + error = UrlElicitationRequiredError.from_error(error_data) - # Reconstruct - reconstructed = UrlElicitationRequiredError.from_error(error_data) + assert len(error.elicitations) == 1 + assert error.elicitations[0].elicitation_id == "test-123" + assert error.elicitations[0].url == "https://example.com/auth" - assert reconstructed.elicitations[0].elicitation_id == original.elicitations[0].elicitation_id - assert reconstructed.elicitations[0].url == original.elicitations[0].url - assert reconstructed.elicitations[0].message == original.elicitations[0].message - def test_error_data_contains_elicitations(self) -> None: - """Test that error data contains properly serialized elicitations.""" - elicitation = ElicitRequestURLParams( - mode="url", - message="Please authenticate", - url="https://example.com/oauth", - elicitation_id="oauth-flow-1", - ) - error = UrlElicitationRequiredError([elicitation]) - - assert error.error.data is not None - assert "elicitations" in error.error.data - elicit_data = error.error.data["elicitations"][0] - assert elicit_data["mode"] == "url" - assert elicit_data["message"] == "Please authenticate" - assert elicit_data["url"] == "https://example.com/oauth" - assert elicit_data["elicitationId"] == "oauth-flow-1" - - def test_inherits_from_mcp_error(self) -> None: - """Test that UrlElicitationRequiredError inherits from McpError.""" - elicitation = ElicitRequestURLParams( - mode="url", - message="Auth required", - url="https://example.com/auth", - elicitation_id="test-123", - ) - error = UrlElicitationRequiredError([elicitation]) - - assert isinstance(error, McpError) - assert isinstance(error, Exception) - - def test_exception_message(self) -> None: - """Test that exception message is set correctly.""" - elicitation = ElicitRequestURLParams( - mode="url", - message="Auth required", - url="https://example.com/auth", - elicitation_id="test-123", - ) - error = UrlElicitationRequiredError([elicitation]) - - # The exception's string representation should match the message - assert str(error) == "URL elicitation required" +def test_url_elicitation_required_error_from_error_data_wrong_code() -> None: + """Test that from_error raises ValueError for wrong error code.""" + error_data = ErrorData( + code=-32600, # Wrong code + message="Some other error", + data={}, + ) + + with pytest.raises(ValueError, match="Expected error code"): + UrlElicitationRequiredError.from_error(error_data) + + +def test_url_elicitation_required_error_serialization_roundtrip() -> None: + """Test that error can be serialized and reconstructed.""" + original = UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + ] + ) + + # Simulate serialization over wire + error_data = original.error + + # Reconstruct + reconstructed = UrlElicitationRequiredError.from_error(error_data) + + assert reconstructed.elicitations[0].elicitation_id == original.elicitations[0].elicitation_id + assert reconstructed.elicitations[0].url == original.elicitations[0].url + assert reconstructed.elicitations[0].message == original.elicitations[0].message + + +def test_url_elicitation_required_error_data_contains_elicitations() -> None: + """Test that error data contains properly serialized elicitations.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Please authenticate", + url="https://example.com/oauth", + elicitation_id="oauth-flow-1", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert error.error.data is not None + assert "elicitations" in error.error.data + elicit_data = error.error.data["elicitations"][0] + assert elicit_data["mode"] == "url" + assert elicit_data["message"] == "Please authenticate" + assert elicit_data["url"] == "https://example.com/oauth" + assert elicit_data["elicitationId"] == "oauth-flow-1" + + +def test_url_elicitation_required_error_inherits_from_mcp_error() -> None: + """Test that UrlElicitationRequiredError inherits from McpError.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + assert isinstance(error, McpError) + assert isinstance(error, Exception) + + +def test_url_elicitation_required_error_exception_message() -> None: + """Test that exception message is set correctly.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + error = UrlElicitationRequiredError([elicitation]) + + # The exception's string representation should match the message + assert str(error) == "URL elicitation required" diff --git a/tests/shared/test_tool_name_validation.py b/tests/shared/test_tool_name_validation.py index 4746f3f9f..97b3dffcd 100644 --- a/tests/shared/test_tool_name_validation.py +++ b/tests/shared/test_tool_name_validation.py @@ -10,190 +10,199 @@ validate_tool_name, ) +# Tests for validate_tool_name function - valid names + + +@pytest.mark.parametrize( + "tool_name", + [ + "getUser", + "get_user_profile", + "user-profile-update", + "admin.tools.list", + "DATA_EXPORT_v2.1", + "a", + "a" * 128, + ], + ids=[ + "simple_alphanumeric", + "with_underscores", + "with_dashes", + "with_dots", + "mixed_characters", + "single_character", + "max_length_128", + ], +) +def test_validate_tool_name_accepts_valid_names(tool_name: str) -> None: + """Valid tool names should pass validation with no warnings.""" + result = validate_tool_name(tool_name) + assert result.is_valid is True + assert result.warnings == [] + + +# Tests for validate_tool_name function - invalid names + + +def test_validate_tool_name_rejects_empty_name() -> None: + """Empty names should be rejected.""" + result = validate_tool_name("") + assert result.is_valid is False + assert "Tool name cannot be empty" in result.warnings + + +def test_validate_tool_name_rejects_name_exceeding_max_length() -> None: + """Names exceeding 128 characters should be rejected.""" + result = validate_tool_name("a" * 129) + assert result.is_valid is False + assert any("exceeds maximum length of 128 characters (current: 129)" in w for w in result.warnings) + + +@pytest.mark.parametrize( + "tool_name,expected_char", + [ + ("get user profile", "' '"), + ("get,user,profile", "','"), + ("user/profile/update", "'/'"), + ("user@domain.com", "'@'"), + ], + ids=[ + "with_spaces", + "with_commas", + "with_slashes", + "with_at_symbol", + ], +) +def test_validate_tool_name_rejects_invalid_characters(tool_name: str, expected_char: str) -> None: + """Names with invalid characters should be rejected.""" + result = validate_tool_name(tool_name) + assert result.is_valid is False + assert any("invalid characters" in w and expected_char in w for w in result.warnings) + + +def test_validate_tool_name_rejects_multiple_invalid_chars() -> None: + """Names with multiple invalid chars should list all of them.""" + result = validate_tool_name("user name@domain,com") + assert result.is_valid is False + warning = next(w for w in result.warnings if "invalid characters" in w) + assert "' '" in warning + assert "'@'" in warning + assert "','" in warning + + +def test_validate_tool_name_rejects_unicode_characters() -> None: + """Names with unicode characters should be rejected.""" + result = validate_tool_name("user-\u00f1ame") # n with tilde + assert result.is_valid is False + + +# Tests for validate_tool_name function - warnings for problematic patterns + + +def test_validate_tool_name_warns_on_leading_dash() -> None: + """Names starting with dash should generate warning but be valid.""" + result = validate_tool_name("-get-user") + assert result.is_valid is True + assert any("starts or ends with a dash" in w for w in result.warnings) + + +def test_validate_tool_name_warns_on_trailing_dash() -> None: + """Names ending with dash should generate warning but be valid.""" + result = validate_tool_name("get-user-") + assert result.is_valid is True + assert any("starts or ends with a dash" in w for w in result.warnings) + + +def test_validate_tool_name_warns_on_leading_dot() -> None: + """Names starting with dot should generate warning but be valid.""" + result = validate_tool_name(".get.user") + assert result.is_valid is True + assert any("starts or ends with a dot" in w for w in result.warnings) + + +def test_validate_tool_name_warns_on_trailing_dot() -> None: + """Names ending with dot should generate warning but be valid.""" + result = validate_tool_name("get.user.") + assert result.is_valid is True + assert any("starts or ends with a dot" in w for w in result.warnings) + + +# Tests for issue_tool_name_warning function + + +def test_issue_tool_name_warning_logs_warnings(caplog: pytest.LogCaptureFixture) -> None: + """Warnings should be logged at WARNING level.""" + warnings = ["Warning 1", "Warning 2"] + with caplog.at_level(logging.WARNING): + issue_tool_name_warning("test-tool", warnings) -class TestValidateToolName: - """Tests for validate_tool_name function.""" - - class TestValidNames: - """Test cases for valid tool names.""" - - @pytest.mark.parametrize( - "tool_name", - [ - "getUser", - "get_user_profile", - "user-profile-update", - "admin.tools.list", - "DATA_EXPORT_v2.1", - "a", - "a" * 128, - ], - ids=[ - "simple_alphanumeric", - "with_underscores", - "with_dashes", - "with_dots", - "mixed_characters", - "single_character", - "max_length_128", - ], - ) - def test_accepts_valid_names(self, tool_name: str) -> None: - """Valid tool names should pass validation with no warnings.""" - result = validate_tool_name(tool_name) - assert result.is_valid is True - assert result.warnings == [] - - class TestInvalidNames: - """Test cases for invalid tool names.""" - - def test_rejects_empty_name(self) -> None: - """Empty names should be rejected.""" - result = validate_tool_name("") - assert result.is_valid is False - assert "Tool name cannot be empty" in result.warnings - - def test_rejects_name_exceeding_max_length(self) -> None: - """Names exceeding 128 characters should be rejected.""" - result = validate_tool_name("a" * 129) - assert result.is_valid is False - assert any("exceeds maximum length of 128 characters (current: 129)" in w for w in result.warnings) - - @pytest.mark.parametrize( - "tool_name,expected_char", - [ - ("get user profile", "' '"), - ("get,user,profile", "','"), - ("user/profile/update", "'/'"), - ("user@domain.com", "'@'"), - ], - ids=[ - "with_spaces", - "with_commas", - "with_slashes", - "with_at_symbol", - ], - ) - def test_rejects_invalid_characters(self, tool_name: str, expected_char: str) -> None: - """Names with invalid characters should be rejected.""" - result = validate_tool_name(tool_name) - assert result.is_valid is False - assert any("invalid characters" in w and expected_char in w for w in result.warnings) - - def test_rejects_multiple_invalid_chars(self) -> None: - """Names with multiple invalid chars should list all of them.""" - result = validate_tool_name("user name@domain,com") - assert result.is_valid is False - warning = next(w for w in result.warnings if "invalid characters" in w) - assert "' '" in warning - assert "'@'" in warning - assert "','" in warning - - def test_rejects_unicode_characters(self) -> None: - """Names with unicode characters should be rejected.""" - result = validate_tool_name("user-\u00f1ame") # n with tilde - assert result.is_valid is False - - class TestWarningsForProblematicPatterns: - """Test cases for valid names that generate warnings.""" - - def test_warns_on_leading_dash(self) -> None: - """Names starting with dash should generate warning but be valid.""" - result = validate_tool_name("-get-user") - assert result.is_valid is True - assert any("starts or ends with a dash" in w for w in result.warnings) - - def test_warns_on_trailing_dash(self) -> None: - """Names ending with dash should generate warning but be valid.""" - result = validate_tool_name("get-user-") - assert result.is_valid is True - assert any("starts or ends with a dash" in w for w in result.warnings) - - def test_warns_on_leading_dot(self) -> None: - """Names starting with dot should generate warning but be valid.""" - result = validate_tool_name(".get.user") - assert result.is_valid is True - assert any("starts or ends with a dot" in w for w in result.warnings) - - def test_warns_on_trailing_dot(self) -> None: - """Names ending with dot should generate warning but be valid.""" - result = validate_tool_name("get.user.") - assert result.is_valid is True - assert any("starts or ends with a dot" in w for w in result.warnings) - - -class TestIssueToolNameWarning: - """Tests for issue_tool_name_warning function.""" - - def test_logs_warnings(self, caplog: pytest.LogCaptureFixture) -> None: - """Warnings should be logged at WARNING level.""" - warnings = ["Warning 1", "Warning 2"] - with caplog.at_level(logging.WARNING): - issue_tool_name_warning("test-tool", warnings) - - assert 'Tool name validation warning for "test-tool"' in caplog.text - assert "- Warning 1" in caplog.text - assert "- Warning 2" in caplog.text - assert "Tool registration will proceed" in caplog.text - assert "SEP-986" in caplog.text - - def test_no_logging_for_empty_warnings(self, caplog: pytest.LogCaptureFixture) -> None: - """Empty warnings list should not produce any log output.""" - with caplog.at_level(logging.WARNING): - issue_tool_name_warning("test-tool", []) - - assert caplog.text == "" - - -class TestValidateAndWarnToolName: - """Tests for validate_and_warn_tool_name function.""" - - def test_returns_true_for_valid_name(self) -> None: - """Valid names should return True.""" - assert validate_and_warn_tool_name("valid-tool-name") is True - - def test_returns_false_for_invalid_name(self) -> None: - """Invalid names should return False.""" - assert validate_and_warn_tool_name("") is False - assert validate_and_warn_tool_name("a" * 129) is False - assert validate_and_warn_tool_name("invalid name") is False - - def test_logs_warnings_for_invalid_name(self, caplog: pytest.LogCaptureFixture) -> None: - """Invalid names should trigger warning logs.""" - with caplog.at_level(logging.WARNING): - validate_and_warn_tool_name("invalid name") - - assert "Tool name validation warning" in caplog.text - - def test_no_warnings_for_clean_valid_name(self, caplog: pytest.LogCaptureFixture) -> None: - """Clean valid names should not produce any log output.""" - with caplog.at_level(logging.WARNING): - result = validate_and_warn_tool_name("clean-tool-name") - - assert result is True - assert caplog.text == "" - - -class TestEdgeCases: - """Test edge cases and robustness.""" - - @pytest.mark.parametrize( - "tool_name,is_valid,expected_warning_fragment", - [ - ("...", True, "starts or ends with a dot"), - ("---", True, "starts or ends with a dash"), - ("///", False, "invalid characters"), - ("user@name123", False, "invalid characters"), - ], - ids=[ - "only_dots", - "only_dashes", - "only_slashes", - "mixed_valid_invalid", - ], - ) - def test_edge_cases(self, tool_name: str, is_valid: bool, expected_warning_fragment: str) -> None: - """Various edge cases should be handled correctly.""" - result = validate_tool_name(tool_name) - assert result.is_valid is is_valid - assert any(expected_warning_fragment in w for w in result.warnings) + assert 'Tool name validation warning for "test-tool"' in caplog.text + assert "- Warning 1" in caplog.text + assert "- Warning 2" in caplog.text + assert "Tool registration will proceed" in caplog.text + assert "SEP-986" in caplog.text + + +def test_issue_tool_name_warning_no_logging_for_empty_warnings(caplog: pytest.LogCaptureFixture) -> None: + """Empty warnings list should not produce any log output.""" + with caplog.at_level(logging.WARNING): + issue_tool_name_warning("test-tool", []) + + assert caplog.text == "" + + +# Tests for validate_and_warn_tool_name function + + +def test_validate_and_warn_tool_name_returns_true_for_valid_name() -> None: + """Valid names should return True.""" + assert validate_and_warn_tool_name("valid-tool-name") is True + + +def test_validate_and_warn_tool_name_returns_false_for_invalid_name() -> None: + """Invalid names should return False.""" + assert validate_and_warn_tool_name("") is False + assert validate_and_warn_tool_name("a" * 129) is False + assert validate_and_warn_tool_name("invalid name") is False + + +def test_validate_and_warn_tool_name_logs_warnings_for_invalid_name(caplog: pytest.LogCaptureFixture) -> None: + """Invalid names should trigger warning logs.""" + with caplog.at_level(logging.WARNING): + validate_and_warn_tool_name("invalid name") + + assert "Tool name validation warning" in caplog.text + + +def test_validate_and_warn_tool_name_no_warnings_for_clean_valid_name(caplog: pytest.LogCaptureFixture) -> None: + """Clean valid names should not produce any log output.""" + with caplog.at_level(logging.WARNING): + result = validate_and_warn_tool_name("clean-tool-name") + + assert result is True + assert caplog.text == "" + + +# Tests for edge cases + + +@pytest.mark.parametrize( + "tool_name,is_valid,expected_warning_fragment", + [ + ("...", True, "starts or ends with a dot"), + ("---", True, "starts or ends with a dash"), + ("///", False, "invalid characters"), + ("user@name123", False, "invalid characters"), + ], + ids=[ + "only_dots", + "only_dashes", + "only_slashes", + "mixed_valid_invalid", + ], +) +def test_edge_cases(tool_name: str, is_valid: bool, expected_warning_fragment: str) -> None: + """Various edge cases should be handled correctly.""" + result = validate_tool_name(tool_name) + assert result.is_valid is is_valid + assert any(expected_warning_fragment in w for w in result.warnings)