Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import copy
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from typing import Any, Generic

from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -107,6 +108,55 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"


def _normalize_mcp_input_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Normalize common non-standard MCP JSON Schema variants.

Some MCP servers incorrectly mark required properties with a boolean
`required: true` on the property schema itself. Draft 2020-12 requires the
parent object to declare `required` as an array of property names instead.
We lift those booleans to the parent object so the schema remains usable
without disabling validation entirely.
"""

def _normalize(node: Any) -> Any:
if isinstance(node, list):
return [_normalize(item) for item in node]

if not isinstance(node, dict):
return node

normalized = {key: _normalize(value) for key, value in node.items()}

properties = normalized.get("properties")
if isinstance(properties, dict):
required = normalized.get("required")
required_list = required[:] if isinstance(required, list) else []

for prop_name, prop_schema in properties.items():
if not isinstance(prop_schema, dict):
continue

prop_required = prop_schema.get("required")
if isinstance(prop_required, bool):
prop_schema.pop("required", None)
if prop_required:
required_list.append(prop_name)

if required_list:
seen: set[str] = set()
normalized["required"] = [
name
for name in required_list
if not (name in seen or seen.add(name))
Comment on lines +146 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: 在推导式中使用 seen.add 做去重的小技巧会影响可读性。

这种习惯用法对很多读者来说不太好理解,而 required_list 的规模很小且使用频率不高,因此这类微优化没有必要。建议采用更清晰的方式——例如显式循环,或者使用 normalized["required"] = list(dict.fromkeys(required_list))——在保持顺序的同时完成去重,对性能影响可以忽略不计。

Original comment in English

nitpick: The deduplication trick using seen.add in a comprehension impacts readability.

This idiom is hard to parse for many readers, and the small size / infrequent use of required_list means the micro-optimization isn’t needed. Consider a clearer approach—e.g. an explicit loop or normalized["required"] = list(dict.fromkeys(required_list))—to deduplicate while preserving order with negligible performance impact.

]
Comment on lines +146 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method used for de-duplicating required_list while preserving order is a bit verbose. For Python 3.7+, you can use dict.fromkeys() which is a more modern, concise, and idiomatic way to achieve the same result.

Suggested change
seen: set[str] = set()
normalized["required"] = [
name
for name in required_list
if not (name in seen or seen.add(name))
]
normalized["required"] = list(dict.fromkeys(required_list))

elif isinstance(required, list):
normalized.pop("required", None)

return normalized

return _normalize(copy.deepcopy(schema))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _normalize function is pure and does not mutate its input; it returns a new, modified copy of the schema. Therefore, the copy.deepcopy() call here is redundant. Removing it will improve performance by avoiding an unnecessary full copy of the schema object, which can be significant for large schemas.

Suggested change
return _normalize(copy.deepcopy(schema))
return _normalize(schema)



class MCPClient:
def __init__(self) -> None:
# Initialize session and client objects
Expand Down Expand Up @@ -382,7 +432,7 @@ def __init__(
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
parameters=_normalize_mcp_input_schema(mcp_tool.inputSchema),
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_mcp_client_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

from astrbot.core.agent.mcp_client import MCPTool, _normalize_mcp_input_schema


class TestNormalizeMcpInputSchema:
def test_lifts_property_level_required_booleans_to_parent_required_array(self):
schema = {
"type": "object",
"properties": {
"stock_code": {"type": "string", "required": True},
"market": {"type": "string", "required": False},
},
}

normalized = _normalize_mcp_input_schema(schema)

assert normalized["required"] == ["stock_code"]
assert "required" not in normalized["properties"]["stock_code"]
assert "required" not in normalized["properties"]["market"]
assert schema["properties"]["stock_code"]["required"] is True

def test_preserves_existing_required_arrays_while_fixing_nested_objects(self):
schema = {
"type": "object",
"required": ["server"],
"properties": {
"server": {
"type": "object",
"required": ["transport"],
"properties": {
"transport": {"type": "string"},
"stock_code": {"type": "string", "required": True},
"market": {"type": "string", "required": False},
},
}
},
}

normalized = _normalize_mcp_input_schema(schema)

assert normalized["required"] == ["server"]
assert normalized["properties"]["server"]["required"] == [
"transport",
"stock_code",
]
assert "required" not in normalized["properties"]["server"]["properties"]["stock_code"]
assert "required" not in normalized["properties"]["server"]["properties"]["market"]


class TestMCPToolSchemaNormalization:
def test_mcp_tool_accepts_property_level_required_booleans(self):
mcp_tool = SimpleNamespace(
name="quote_lookup",
description="Lookup a quote",
inputSchema={
"type": "object",
"properties": {
"stock_code": {"type": "string", "required": True},
"market": {"type": "string", "required": False},
},
},
)

tool = MCPTool(mcp_tool, MagicMock(), "gf-securities")

assert tool.parameters["required"] == ["stock_code"]
assert "required" not in tool.parameters["properties"]["stock_code"]
assert "required" not in tool.parameters["properties"]["market"]