From b7ac59f5c104e8a52d026822c99a5769918637cf Mon Sep 17 00:00:00 2001 From: phernandez Date: Fri, 3 Jul 2026 18:02:01 -0500 Subject: [PATCH] fix(mcp): gate ChatGPT tools by clientInfo Signed-off-by: phernandez --- src/basic_memory/mcp/client_info.py | 93 ++++++++++++ src/basic_memory/mcp/server.py | 2 + src/basic_memory/mcp/tools/chatgpt_tools.py | 54 ++++++- .../mcp/test_chatgpt_tools_integration.py | 58 +++++-- tests/mcp/test_client_info.py | 114 ++++++++++++++ tests/mcp/tools/test_chatgpt_tools.py | 142 +++++++++++++++--- 6 files changed, 429 insertions(+), 34 deletions(-) create mode 100644 src/basic_memory/mcp/client_info.py create mode 100644 tests/mcp/test_client_info.py diff --git a/src/basic_memory/mcp/client_info.py b/src/basic_memory/mcp/client_info.py new file mode 100644 index 000000000..dd05acfe5 --- /dev/null +++ b/src/basic_memory/mcp/client_info.py @@ -0,0 +1,93 @@ +"""MCP client identity helpers for client-specific compatibility tools.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, cast + +import mcp.types as mt +from fastmcp import Context +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext + +MCP_CLIENT_INFO_STATE_KEY = "mcp.client_info" +OPENAI_MCP_CLIENT_NAME = "openai-mcp" +ClientInfoState = dict[str, str | None] + + +class MCPClientInfoMiddleware(Middleware): + """Persist sanitized initialize clientInfo in FastMCP session state.""" + + async def on_initialize( + self, + context: MiddlewareContext[mt.InitializeRequest], + call_next: CallNext[mt.InitializeRequest, Any], + ) -> Any: + result = await call_next(context) + client_info = client_info_from_initialize(context.message) + if client_info is not None and context.fastmcp_context is not None: + await context.fastmcp_context.set_state( + MCP_CLIENT_INFO_STATE_KEY, + client_info, + ) + return result + + +async def is_openai_mcp_client(context: Context | None) -> bool: + """Return whether the current MCP session identified itself as OpenAI's MCP client.""" + if context is None: + return False + + return client_info_is_openai_mcp(await context.get_state(MCP_CLIENT_INFO_STATE_KEY)) + + +def client_info_from_initialize(message: mt.InitializeRequest) -> ClientInfoState | None: + """Extract the normalized clientInfo payload from an initialize request.""" + client_info = message.params.clientInfo + return _client_info_from_mapping( + { + "name": client_info.name, + "title": client_info.title, + "version": client_info.version, + } + ) + + +def client_info_is_openai_mcp(value: object | None) -> bool: + """Check the reported clientInfo name/title against the OpenAI MCP client label.""" + client_info = _client_info_from_mapping(value) + if client_info is None: + return False + + for key in ("name", "title"): + client_value = _normalize_client_value(client_info.get(key)) + if client_value == OPENAI_MCP_CLIENT_NAME: + return True + if client_value is not None and client_value.startswith(f"{OPENAI_MCP_CLIENT_NAME}/"): + return True + return False + + +def _client_info_from_mapping(value: object | None) -> ClientInfoState | None: + if not isinstance(value, Mapping): + return None + raw = cast(Mapping[str, object | None], value) + normalized: ClientInfoState = { + "name": _clean_optional_string(raw.get("name")), + "title": _clean_optional_string(raw.get("title")), + "version": _clean_optional_string(raw.get("version")), + } + if not normalized["name"] and not normalized["title"] and not normalized["version"]: + return None + return normalized + + +def _normalize_client_value(value: object | None) -> str | None: + text = _clean_optional_string(value) + return text.lower() if text else None + + +def _clean_optional_string(value: object | None) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None diff --git a/src/basic_memory/mcp/server.py b/src/basic_memory/mcp/server.py index aaee21af6..8fc8a73cb 100644 --- a/src/basic_memory/mcp/server.py +++ b/src/basic_memory/mcp/server.py @@ -13,6 +13,7 @@ from basic_memory import db from basic_memory.cli.auth import CLIAuth from basic_memory.db import scoped_session +from basic_memory.mcp.client_info import MCPClientInfoMiddleware from basic_memory.mcp.container import McpContainer, set_container from basic_memory.services.initialization import initialize_app import logfire @@ -145,3 +146,4 @@ async def lifespan(app: FastMCP): name="Basic Memory", lifespan=lifespan, ) +mcp.add_middleware(MCPClientInfoMiddleware()) diff --git a/src/basic_memory/mcp/tools/chatgpt_tools.py b/src/basic_memory/mcp/tools/chatgpt_tools.py index 1b0193b60..9ce77bd92 100644 --- a/src/basic_memory/mcp/tools/chatgpt_tools.py +++ b/src/basic_memory/mcp/tools/chatgpt_tools.py @@ -11,12 +11,48 @@ from fastmcp import Context from loguru import logger +from basic_memory.mcp.client_info import is_openai_mcp_client from basic_memory.mcp.server import mcp from basic_memory.mcp.tools.read_note import read_note from basic_memory.mcp.tools.search import search_notes from basic_memory.schemas.search import SearchResponse, SearchResult +_UNSUPPORTED_CLIENT_ERROR = "Unsupported MCP client" + + +def _text_content(payload: dict[str, Any]) -> List[Dict[str, Any]]: + return [{"type": "text", "text": json.dumps(payload, ensure_ascii=False)}] + + +def _unsupported_search_client_response() -> List[Dict[str, Any]]: + return _text_content( + { + "results": [], + "error": _UNSUPPORTED_CLIENT_ERROR, + "error_message": ( + "The search compatibility tool is only available to OpenAI MCP clients. " + "Use search_notes instead." + ), + } + ) + + +def _unsupported_fetch_client_response(identifier: str) -> List[Dict[str, Any]]: + return _text_content( + { + "id": identifier, + "title": "Unsupported MCP Client", + "text": ( + "The fetch compatibility tool is only available to OpenAI MCP clients. " + "Use read_note instead." + ), + "url": identifier, + "metadata": {"error": _UNSUPPORTED_CLIENT_ERROR}, + } + ) + + def _identifier_for_read_note(identifier: str) -> str: """Convert ChatGPT result ids into routable Basic Memory identifiers.""" stripped = identifier.strip() @@ -124,6 +160,10 @@ async def search( List with one dict: `{ "type": "text", "text": "{...JSON...}" }` where the JSON body contains `results`, `total_count`, and echo of `query`. """ + if not await is_openai_mcp_client(context): + logger.warning("Rejected ChatGPT search request from non-OpenAI MCP client") + return _unsupported_search_client_response() + logger.info(f"ChatGPT search request: query='{query}'") try: @@ -143,7 +183,7 @@ async def search( "error": "Search failed", "error_details": results[:500], # Truncate long error messages } - return [{"type": "text", "text": json.dumps(search_results, ensure_ascii=False)}] + return _text_content(search_results) raw_results = results.get("results", []) if isinstance(results, dict) else [] @@ -156,7 +196,7 @@ async def search( logger.info(f"Search completed: {len(formatted_results)} results returned") # Return in MCP content array format as required by OpenAI - return [{"type": "text", "text": json.dumps(search_results, ensure_ascii=False)}] + return _text_content(search_results) except Exception as e: logger.error(f"ChatGPT search failed for query '{query}': {e}") @@ -165,7 +205,7 @@ async def search( "error": "Internal search error", "error_message": str(e)[:200], } - return [{"type": "text", "text": json.dumps(error_results, ensure_ascii=False)}] + return _text_content(error_results) @mcp.tool( @@ -188,6 +228,10 @@ async def fetch( List with one dict: `{ "type": "text", "text": "{...JSON...}" }` where the JSON body includes `id`, `title`, `text`, `url`, and metadata. """ + if not await is_openai_mcp_client(context): + logger.warning("Rejected ChatGPT fetch request from non-OpenAI MCP client") + return _unsupported_fetch_client_response(id) + logger.info(f"ChatGPT fetch request: id='{id}'") try: @@ -206,7 +250,7 @@ async def fetch( logger.info(f"Fetch completed: id='{id}', content_length={len(document.get('text', ''))}") # Return in MCP content array format as required by OpenAI - return [{"type": "text", "text": json.dumps(document, ensure_ascii=False)}] + return _text_content(document) except Exception as e: logger.error(f"ChatGPT fetch failed for id '{id}': {e}") @@ -217,4 +261,4 @@ async def fetch( "url": id, "metadata": {"error": "Fetch failed"}, } - return [{"type": "text", "text": json.dumps(error_document, ensure_ascii=False)}] + return _text_content(error_document) diff --git a/test-int/mcp/test_chatgpt_tools_integration.py b/test-int/mcp/test_chatgpt_tools_integration.py index 164aad025..007b7adc5 100644 --- a/test-int/mcp/test_chatgpt_tools_integration.py +++ b/test-int/mcp/test_chatgpt_tools_integration.py @@ -9,6 +9,14 @@ import json import pytest from fastmcp import Client +from mcp import types as mt + + +def openai_mcp_client(mcp_server): + return Client( + mcp_server, + client_info=mt.Implementation(name="openai-mcp", version="1.0.0"), + ) def extract_mcp_json_content(mcp_result): @@ -29,7 +37,7 @@ def extract_mcp_json_content(mcp_result): async def test_chatgpt_search_basic(mcp_server, app, test_project): """Test basic ChatGPT search functionality with MCP content array format.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Create test notes for searching await client.call_tool( "write_note", @@ -99,7 +107,7 @@ async def test_chatgpt_search_basic(mcp_server, app, test_project): async def test_chatgpt_search_empty_results(mcp_server, app, test_project): """Test ChatGPT search with no matching results.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Search for non-existent content search_result = await client.call_tool( "search", @@ -119,7 +127,7 @@ async def test_chatgpt_search_empty_results(mcp_server, app, test_project): async def test_chatgpt_search_with_boolean_operators(mcp_server, app, test_project): """Test ChatGPT search with boolean operators.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Create test notes await client.call_tool( "write_note", @@ -164,7 +172,7 @@ async def test_chatgpt_search_with_boolean_operators(mcp_server, app, test_proje async def test_chatgpt_fetch_document(mcp_server, app, test_project): """Test ChatGPT fetch tool for retrieving full document content.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Create a test note note_content = """# Advanced Python Techniques @@ -224,7 +232,7 @@ def wrapper(*args, **kwargs): async def test_chatgpt_fetch_by_permalink(mcp_server, app, test_project): """Test ChatGPT fetch using permalink identifier.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Create a note with known content await client.call_tool( "write_note", @@ -268,7 +276,7 @@ async def test_chatgpt_fetch_by_permalink(mcp_server, app, test_project): async def test_chatgpt_fetch_nonexistent_document(mcp_server, app, test_project): """Test ChatGPT fetch with non-existent document ID.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Try to fetch a non-existent document fetch_result = await client.call_tool( "fetch", @@ -294,7 +302,7 @@ async def test_chatgpt_fetch_nonexistent_document(mcp_server, app, test_project) async def test_chatgpt_fetch_with_empty_title(mcp_server, app, test_project): """Test ChatGPT fetch handles documents with empty or missing titles.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Create a note without a title in the content await client.call_tool( "write_note", @@ -329,7 +337,7 @@ async def test_chatgpt_fetch_with_empty_title(mcp_server, app, test_project): async def test_chatgpt_search_pagination_default(mcp_server, app, test_project): """Test that ChatGPT search uses reasonable pagination defaults.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Create more than 10 notes to test pagination for i in range(15): await client.call_tool( @@ -362,7 +370,7 @@ async def test_chatgpt_search_pagination_default(mcp_server, app, test_project): async def test_chatgpt_tools_error_handling(mcp_server, app, test_project): """Test error handling in ChatGPT tools returns proper MCP format.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Test search with invalid query (if validation exists) # Using empty query to potentially trigger an error search_result = await client.call_tool( @@ -384,11 +392,41 @@ async def test_chatgpt_tools_error_handling(mcp_server, app, test_project): assert "results" in results_json # Should have results key even if empty +@pytest.mark.asyncio +async def test_chatgpt_tools_reject_default_client(mcp_server, app, test_project): + """Default FastMCP clients cannot use ChatGPT compatibility tools.""" + + async with Client(mcp_server) as client: + search_result = await client.call_tool( + "search", + { + "query": "Machine Learning", + }, + ) + + search_json = extract_mcp_json_content(search_result) + assert search_json["results"] == [] + assert search_json["error"] == "Unsupported MCP client" + assert "search_notes" in search_json["error_message"] + + fetch_result = await client.call_tool( + "fetch", + { + "id": "Machine Learning Fundamentals", + }, + ) + + document_json = extract_mcp_json_content(fetch_result) + assert document_json["id"] == "Machine Learning Fundamentals" + assert document_json["metadata"]["error"] == "Unsupported MCP client" + assert "read_note" in document_json["text"] + + @pytest.mark.asyncio async def test_chatgpt_integration_workflow(mcp_server, app, test_project): """Test complete workflow: search then fetch, as ChatGPT would use it.""" - async with Client(mcp_server) as client: + async with openai_mcp_client(mcp_server) as client: # Step 1: Create multiple documents docs = [ { diff --git a/tests/mcp/test_client_info.py b/tests/mcp/test_client_info.py new file mode 100644 index 000000000..e6ffd944c --- /dev/null +++ b/tests/mcp/test_client_info.py @@ -0,0 +1,114 @@ +"""Tests for MCP clientInfo capture and classification.""" + +from __future__ import annotations + +from typing import Any, cast + +import mcp.types as mt +import pytest +from fastmcp.server.middleware import CallNext, MiddlewareContext + +from basic_memory.mcp.client_info import ( + MCP_CLIENT_INFO_STATE_KEY, + MCPClientInfoMiddleware, + client_info_is_openai_mcp, + is_openai_mcp_client, +) + + +class FakeMCPContext: + """Small state-only context for middleware tests.""" + + def __init__(self) -> None: + self.state: dict[str, object] = {} + + async def set_state(self, key: str, value: object) -> None: + self.state[key] = value + + async def get_state(self, key: str) -> object | None: + return self.state.get(key) + + +def _initialize_context( + *, + name: str, + title: str | None, + version: str, + fastmcp_context: object | None, +) -> MiddlewareContext[mt.InitializeRequest]: + return MiddlewareContext( + message=mt.InitializeRequest( + params=mt.InitializeRequestParams( + protocolVersion="2025-06-18", + capabilities=mt.ClientCapabilities(), + clientInfo=mt.Implementation(name=name, title=title, version=version), + ) + ), + fastmcp_context=cast(Any, fastmcp_context), + source="client", + type="request", + method="initialize", + ) + + +def _as_initialize_next( + callback, +) -> CallNext[mt.InitializeRequest, Any]: + return cast(CallNext[mt.InitializeRequest, Any], callback) + + +@pytest.mark.asyncio +async def test_client_info_middleware_stores_initialize_state() -> None: + """The initialize clientInfo is available to later tool calls.""" + context = FakeMCPContext() + middleware = MCPClientInfoMiddleware() + + async def call_next(inner_context: MiddlewareContext[mt.InitializeRequest]) -> str: + return "ok" + + result = await middleware.on_initialize( + _initialize_context( + name="openai-mcp", + title=None, + version="1.0.0", + fastmcp_context=context, + ), + _as_initialize_next(call_next), + ) + + assert result == "ok" + assert context.state[MCP_CLIENT_INFO_STATE_KEY] == { + "name": "openai-mcp", + "title": None, + "version": "1.0.0", + } + + +@pytest.mark.asyncio +async def test_is_openai_mcp_client_reads_session_state() -> None: + """The gate accepts OpenAI's versioned clientInfo label.""" + context = FakeMCPContext() + await context.set_state( + MCP_CLIENT_INFO_STATE_KEY, + {"name": "openai-mcp/1.0.0", "title": None, "version": None}, + ) + + assert await is_openai_mcp_client(cast(Any, context)) is True + + +@pytest.mark.asyncio +async def test_is_openai_mcp_client_rejects_missing_context() -> None: + """No context means no authenticated MCP clientInfo to trust.""" + assert await is_openai_mcp_client(None) is False + + +def test_client_info_is_openai_mcp_accepts_title() -> None: + """Some clients place the useful label in title rather than name.""" + assert client_info_is_openai_mcp({"name": "mcp", "title": "openai-mcp", "version": "1.0.0"}) + + +def test_client_info_is_openai_mcp_rejects_other_clients() -> None: + """Claude/Codex-style labels do not pass the ChatGPT compatibility gate.""" + assert not client_info_is_openai_mcp({"name": "codex", "title": "Codex", "version": "5.0.0"}) + assert not client_info_is_openai_mcp({"name": None, "title": None, "version": None}) + assert not client_info_is_openai_mcp(None) diff --git a/tests/mcp/tools/test_chatgpt_tools.py b/tests/mcp/tools/test_chatgpt_tools.py index 836a8b260..394a252ef 100644 --- a/tests/mcp/tools/test_chatgpt_tools.py +++ b/tests/mcp/tools/test_chatgpt_tools.py @@ -1,14 +1,25 @@ """Tests for ChatGPT-compatible MCP tools.""" import json +from typing import Any, cast + import pytest +from basic_memory.mcp.client_info import MCP_CLIENT_INFO_STATE_KEY from basic_memory.mcp.tools import write_note from basic_memory.schemas.search import SearchResponse, SearchResult, SearchItemType +async def _openai_mcp_context(context_state) -> Any: + await context_state.set_state( + MCP_CLIENT_INFO_STATE_KEY, + {"name": "openai-mcp", "title": None, "version": "1.0.0"}, + ) + return cast(Any, context_state) + + @pytest.mark.asyncio -async def test_search_successful_results(client, test_project): +async def test_search_successful_results(client, test_project, context_state): """Test search with successful results returns proper MCP content array format.""" await write_note( project=test_project.name, @@ -25,7 +36,8 @@ async def test_search_successful_results(client, test_project): from basic_memory.mcp.tools.chatgpt_tools import search - result = await search("test content") + context = await _openai_mcp_context(context_state) + result = await search("test content", context=context) # Verify MCP content array format assert isinstance(result, list) @@ -43,7 +55,7 @@ async def test_search_successful_results(client, test_project): @pytest.mark.asyncio -async def test_search_with_error_response(monkeypatch, client, test_project): +async def test_search_with_error_response(monkeypatch, client, test_project, context_state): """Test search when underlying search_notes returns an error string.""" import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools @@ -54,7 +66,8 @@ async def fake_search_notes_fn(*args, **kwargs): monkeypatch.setattr(chatgpt_tools, "search_notes", fake_search_notes_fn) - result = await chatgpt_tools.search("invalid query") + context = await _openai_mcp_context(context_state) + result = await chatgpt_tools.search("invalid query", context=context) assert isinstance(result, list) assert len(result) == 1 @@ -67,7 +80,9 @@ async def fake_search_notes_fn(*args, **kwargs): @pytest.mark.asyncio -async def test_search_uses_dynamic_default_search_type(monkeypatch, client, test_project): +async def test_search_uses_dynamic_default_search_type( + monkeypatch, client, test_project, context_state +): """ChatGPT adapter should not hardcode search_type so search_notes can pick defaults.""" import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools @@ -79,7 +94,8 @@ async def fake_search_notes_fn(*args, **kwargs): monkeypatch.setattr(chatgpt_tools, "search_notes", fake_search_notes_fn) - result = await chatgpt_tools.search("default search mode query") + context = await _openai_mcp_context(context_state) + result = await chatgpt_tools.search("default search mode query", context=context) assert isinstance(result, list) assert "search_type" not in captured_kwargs @@ -87,7 +103,7 @@ async def fake_search_notes_fn(*args, **kwargs): @pytest.mark.asyncio async def test_search_delegates_to_search_notes_without_project_iteration( - monkeypatch, client, test_project + monkeypatch, client, test_project, context_state ): """ChatGPT search is only a compatibility wrapper around search_notes.""" import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools @@ -100,7 +116,8 @@ async def fake_search_notes_fn(*args, **kwargs): monkeypatch.setattr(chatgpt_tools, "search_notes", fake_search_notes_fn) - result = await chatgpt_tools.search("MCP Test Note") + context = await _openai_mcp_context(context_state) + result = await chatgpt_tools.search("MCP Test Note", context=context) content = json.loads(result[0]["text"]) assert content["results"] == [] @@ -110,12 +127,12 @@ async def fake_search_notes_fn(*args, **kwargs): "page": 1, "page_size": 10, "output_format": "json", - "context": None, + "context": context, } @pytest.mark.asyncio -async def test_fetch_successful_document(client, test_project): +async def test_fetch_successful_document(client, test_project, context_state): """Test fetch with successful document retrieval.""" await write_note( project=test_project.name, @@ -126,7 +143,8 @@ async def test_fetch_successful_document(client, test_project): from basic_memory.mcp.tools.chatgpt_tools import fetch - result = await fetch("docs/test-document") + context = await _openai_mcp_context(context_state) + result = await fetch("docs/test-document", context=context) assert isinstance(result, list) assert len(result) == 1 @@ -141,11 +159,12 @@ async def test_fetch_successful_document(client, test_project): @pytest.mark.asyncio -async def test_fetch_document_not_found(client, test_project): +async def test_fetch_document_not_found(client, test_project, context_state): """Test fetch when document is not found.""" from basic_memory.mcp.tools.chatgpt_tools import fetch - result = await fetch("nonexistent-doc") + context = await _openai_mcp_context(context_state) + result = await fetch("nonexistent-doc", context=context) assert isinstance(result, list) assert len(result) == 1 @@ -157,7 +176,9 @@ async def test_fetch_document_not_found(client, test_project): @pytest.mark.asyncio -async def test_fetch_routes_path_ids_as_memory_urls(monkeypatch, client, test_project): +async def test_fetch_routes_path_ids_as_memory_urls( + monkeypatch, client, test_project, context_state +): """Workspace-qualified search ids need memory URL routing during fetch.""" import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools @@ -169,7 +190,8 @@ async def fake_read_note(*, identifier: str, context=None): monkeypatch.setattr(chatgpt_tools, "read_note", fake_read_note) - result = await chatgpt_tools.fetch("team-paul/main/tests/mcp-test-note") + context = await _openai_mcp_context(context_state) + result = await chatgpt_tools.fetch("team-paul/main/tests/mcp-test-note", context=context) content = json.loads(result[0]["text"]) assert captured["identifier"] == "memory://team-paul/main/tests/mcp-test-note" @@ -215,6 +237,32 @@ def test_format_search_results_for_chatgpt(): assert formatted[1]["title"] == "Untitled" +def test_format_search_results_for_chatgpt_handles_dict_payloads(): + """Dict payloads must still normalize to the ChatGPT result array shape.""" + from basic_memory.mcp.tools.chatgpt_tools import _format_search_results_for_chatgpt + + formatted = _format_search_results_for_chatgpt( + {"results": [{"title": "Document One", "permalink": "docs/doc-one"}]} + ) + + assert formatted == [ + { + "id": "docs/doc-one", + "title": "Document One", + "url": "docs/doc-one", + } + ] + assert _format_search_results_for_chatgpt({"results": "not a list"}) == [] + + +def test_format_search_results_for_chatgpt_rejects_unknown_rows(): + """Unexpected row shapes should fail loudly instead of producing bad IDs.""" + from basic_memory.mcp.tools.chatgpt_tools import _format_search_results_for_chatgpt + + with pytest.raises(TypeError, match="Unexpected result type: object"): + _format_search_results_for_chatgpt(cast(Any, [object()])) + + def test_format_document_for_chatgpt(): """Test document formatting.""" from basic_memory.mcp.tools.chatgpt_tools import _format_document_for_chatgpt @@ -251,7 +299,9 @@ def test_format_document_untitled_fallback_for_empty_identifier(): @pytest.mark.asyncio -async def test_search_internal_exception_returns_error_payload(monkeypatch, client, test_project): +async def test_search_internal_exception_returns_error_payload( + monkeypatch, client, test_project, context_state +): """search() should return a structured error payload if an unexpected exception occurs.""" import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools @@ -260,7 +310,8 @@ async def boom(*args, **kwargs): monkeypatch.setattr(chatgpt_tools, "search_notes", boom) - result = await chatgpt_tools.search("anything") + context = await _openai_mcp_context(context_state) + result = await chatgpt_tools.search("anything", context=context) assert isinstance(result, list) content = json.loads(result[0]["text"]) assert content["error"] == "Internal search error" @@ -268,7 +319,9 @@ async def boom(*args, **kwargs): @pytest.mark.asyncio -async def test_fetch_internal_exception_returns_error_payload(monkeypatch, client, test_project): +async def test_fetch_internal_exception_returns_error_payload( + monkeypatch, client, test_project, context_state +): """fetch() should return a structured error payload if an unexpected exception occurs.""" import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools @@ -277,8 +330,59 @@ async def boom(*args, **kwargs): monkeypatch.setattr(chatgpt_tools, "read_note", boom) - result = await chatgpt_tools.fetch("docs/test") + context = await _openai_mcp_context(context_state) + result = await chatgpt_tools.fetch("docs/test", context=context) assert isinstance(result, list) content = json.loads(result[0]["text"]) assert content["id"] == "docs/test" assert content["metadata"]["error"] == "Fetch failed" + + +@pytest.mark.asyncio +async def test_search_rejects_non_openai_mcp_client(monkeypatch, context_state): + """Only OpenAI's MCP client can use the ChatGPT compatibility search tool.""" + import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools + + called = False + + async def fake_search_notes_fn(*args, **kwargs): + nonlocal called + called = True + return {"results": []} + + await context_state.set_state( + MCP_CLIENT_INFO_STATE_KEY, + {"name": "claude-code", "title": "Claude Code", "version": "1.2.3"}, + ) + monkeypatch.setattr(chatgpt_tools, "search_notes", fake_search_notes_fn) + + result = await chatgpt_tools.search("anything", context=cast(Any, context_state)) + + content = json.loads(result[0]["text"]) + assert called is False + assert content["results"] == [] + assert content["error"] == "Unsupported MCP client" + assert "search_notes" in content["error_message"] + + +@pytest.mark.asyncio +async def test_fetch_rejects_missing_client_info(monkeypatch, context_state): + """A missing clientInfo state is not enough to use the ChatGPT fetch shim.""" + import basic_memory.mcp.tools.chatgpt_tools as chatgpt_tools + + called = False + + async def fake_read_note(*args, **kwargs): + nonlocal called + called = True + return "# Should Not Be Read" + + monkeypatch.setattr(chatgpt_tools, "read_note", fake_read_note) + + result = await chatgpt_tools.fetch("docs/test", context=cast(Any, context_state)) + + content = json.loads(result[0]["text"]) + assert called is False + assert content["id"] == "docs/test" + assert content["metadata"]["error"] == "Unsupported MCP client" + assert "read_note" in content["text"]