Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/basic_memory/mcp/client_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""MCP client identity helpers for client-specific compatibility tools."""

from __future__ import annotations

from collections.abc import Mapping
from typing import Any, cast

import mcp.types as mt
from fastmcp import Context
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext

MCP_CLIENT_INFO_STATE_KEY = "mcp.client_info"
OPENAI_MCP_CLIENT_NAME = "openai-mcp"
ClientInfoState = dict[str, str | None]


class MCPClientInfoMiddleware(Middleware):
"""Persist sanitized initialize clientInfo in FastMCP session state."""

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, Any],
) -> Any:
result = await call_next(context)
client_info = client_info_from_initialize(context.message)
if client_info is not None and context.fastmcp_context is not None:
await context.fastmcp_context.set_state(
MCP_CLIENT_INFO_STATE_KEY,
client_info,
)
return result


async def is_openai_mcp_client(context: Context | None) -> bool:
"""Return whether the current MCP session identified itself as OpenAI's MCP client."""
if context is None:
return False

return client_info_is_openai_mcp(await context.get_state(MCP_CLIENT_INFO_STATE_KEY))


def client_info_from_initialize(message: mt.InitializeRequest) -> ClientInfoState | None:
"""Extract the normalized clientInfo payload from an initialize request."""
client_info = message.params.clientInfo
return _client_info_from_mapping(
{
"name": client_info.name,
"title": client_info.title,
"version": client_info.version,
}
)


def client_info_is_openai_mcp(value: object | None) -> bool:
"""Check the reported clientInfo name/title against the OpenAI MCP client label."""
client_info = _client_info_from_mapping(value)
if client_info is None:
return False

for key in ("name", "title"):
client_value = _normalize_client_value(client_info.get(key))
if client_value == OPENAI_MCP_CLIENT_NAME:
return True
if client_value is not None and client_value.startswith(f"{OPENAI_MCP_CLIENT_NAME}/"):
return True
return False


def _client_info_from_mapping(value: object | None) -> ClientInfoState | None:
if not isinstance(value, Mapping):
return None
raw = cast(Mapping[str, object | None], value)
normalized: ClientInfoState = {
"name": _clean_optional_string(raw.get("name")),
"title": _clean_optional_string(raw.get("title")),
"version": _clean_optional_string(raw.get("version")),
}
if not normalized["name"] and not normalized["title"] and not normalized["version"]:
return None
return normalized


def _normalize_client_value(value: object | None) -> str | None:
text = _clean_optional_string(value)
return text.lower() if text else None


def _clean_optional_string(value: object | None) -> str | None:
if value is None:
return None
text = str(value).strip()
return text or None
2 changes: 2 additions & 0 deletions src/basic_memory/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from basic_memory import db
from basic_memory.cli.auth import CLIAuth
from basic_memory.db import scoped_session
from basic_memory.mcp.client_info import MCPClientInfoMiddleware
from basic_memory.mcp.container import McpContainer, set_container
from basic_memory.services.initialization import initialize_app
import logfire
Expand Down Expand Up @@ -145,3 +146,4 @@ async def lifespan(app: FastMCP):
name="Basic Memory",
lifespan=lifespan,
)
mcp.add_middleware(MCPClientInfoMiddleware())
54 changes: 49 additions & 5 deletions src/basic_memory/mcp/tools/chatgpt_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,48 @@
from fastmcp import Context
from loguru import logger

from basic_memory.mcp.client_info import is_openai_mcp_client
from basic_memory.mcp.server import mcp
from basic_memory.mcp.tools.read_note import read_note
from basic_memory.mcp.tools.search import search_notes
from basic_memory.schemas.search import SearchResponse, SearchResult


_UNSUPPORTED_CLIENT_ERROR = "Unsupported MCP client"


def _text_content(payload: dict[str, Any]) -> List[Dict[str, Any]]:
return [{"type": "text", "text": json.dumps(payload, ensure_ascii=False)}]


def _unsupported_search_client_response() -> List[Dict[str, Any]]:
return _text_content(
{
"results": [],
"error": _UNSUPPORTED_CLIENT_ERROR,
"error_message": (
"The search compatibility tool is only available to OpenAI MCP clients. "
"Use search_notes instead."
),
}
)


def _unsupported_fetch_client_response(identifier: str) -> List[Dict[str, Any]]:
return _text_content(
{
"id": identifier,
"title": "Unsupported MCP Client",
"text": (
"The fetch compatibility tool is only available to OpenAI MCP clients. "
"Use read_note instead."
),
"url": identifier,
"metadata": {"error": _UNSUPPORTED_CLIENT_ERROR},
}
)


def _identifier_for_read_note(identifier: str) -> str:
"""Convert ChatGPT result ids into routable Basic Memory identifiers."""
stripped = identifier.strip()
Expand Down Expand Up @@ -124,6 +160,10 @@ async def search(
List with one dict: `{ "type": "text", "text": "{...JSON...}" }`
where the JSON body contains `results`, `total_count`, and echo of `query`.
"""
if not await is_openai_mcp_client(context):
logger.warning("Rejected ChatGPT search request from non-OpenAI MCP client")
return _unsupported_search_client_response()

logger.info(f"ChatGPT search request: query='{query}'")

try:
Expand All @@ -143,7 +183,7 @@ async def search(
"error": "Search failed",
"error_details": results[:500], # Truncate long error messages
}
return [{"type": "text", "text": json.dumps(search_results, ensure_ascii=False)}]
return _text_content(search_results)

raw_results = results.get("results", []) if isinstance(results, dict) else []

Expand All @@ -156,7 +196,7 @@ async def search(
logger.info(f"Search completed: {len(formatted_results)} results returned")

# Return in MCP content array format as required by OpenAI
return [{"type": "text", "text": json.dumps(search_results, ensure_ascii=False)}]
return _text_content(search_results)

except Exception as e:
logger.error(f"ChatGPT search failed for query '{query}': {e}")
Expand All @@ -165,7 +205,7 @@ async def search(
"error": "Internal search error",
"error_message": str(e)[:200],
}
return [{"type": "text", "text": json.dumps(error_results, ensure_ascii=False)}]
return _text_content(error_results)


@mcp.tool(
Expand All @@ -188,6 +228,10 @@ async def fetch(
List with one dict: `{ "type": "text", "text": "{...JSON...}" }`
where the JSON body includes `id`, `title`, `text`, `url`, and metadata.
"""
if not await is_openai_mcp_client(context):
logger.warning("Rejected ChatGPT fetch request from non-OpenAI MCP client")
return _unsupported_fetch_client_response(id)

logger.info(f"ChatGPT fetch request: id='{id}'")

try:
Expand All @@ -206,7 +250,7 @@ async def fetch(
logger.info(f"Fetch completed: id='{id}', content_length={len(document.get('text', ''))}")

# Return in MCP content array format as required by OpenAI
return [{"type": "text", "text": json.dumps(document, ensure_ascii=False)}]
return _text_content(document)

except Exception as e:
logger.error(f"ChatGPT fetch failed for id '{id}': {e}")
Expand All @@ -217,4 +261,4 @@ async def fetch(
"url": id,
"metadata": {"error": "Fetch failed"},
}
return [{"type": "text", "text": json.dumps(error_document, ensure_ascii=False)}]
return _text_content(error_document)
58 changes: 48 additions & 10 deletions test-int/mcp/test_chatgpt_tools_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
import json
import pytest
from fastmcp import Client
from mcp import types as mt


def openai_mcp_client(mcp_server):
return Client(
mcp_server,
client_info=mt.Implementation(name="openai-mcp", version="1.0.0"),
)


def extract_mcp_json_content(mcp_result):
Expand All @@ -29,7 +37,7 @@ def extract_mcp_json_content(mcp_result):
async def test_chatgpt_search_basic(mcp_server, app, test_project):
"""Test basic ChatGPT search functionality with MCP content array format."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Create test notes for searching
await client.call_tool(
"write_note",
Expand Down Expand Up @@ -99,7 +107,7 @@ async def test_chatgpt_search_basic(mcp_server, app, test_project):
async def test_chatgpt_search_empty_results(mcp_server, app, test_project):
"""Test ChatGPT search with no matching results."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Search for non-existent content
search_result = await client.call_tool(
"search",
Expand All @@ -119,7 +127,7 @@ async def test_chatgpt_search_empty_results(mcp_server, app, test_project):
async def test_chatgpt_search_with_boolean_operators(mcp_server, app, test_project):
"""Test ChatGPT search with boolean operators."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Create test notes
await client.call_tool(
"write_note",
Expand Down Expand Up @@ -164,7 +172,7 @@ async def test_chatgpt_search_with_boolean_operators(mcp_server, app, test_proje
async def test_chatgpt_fetch_document(mcp_server, app, test_project):
"""Test ChatGPT fetch tool for retrieving full document content."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Create a test note
note_content = """# Advanced Python Techniques

Expand Down Expand Up @@ -224,7 +232,7 @@ def wrapper(*args, **kwargs):
async def test_chatgpt_fetch_by_permalink(mcp_server, app, test_project):
"""Test ChatGPT fetch using permalink identifier."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Create a note with known content
await client.call_tool(
"write_note",
Expand Down Expand Up @@ -268,7 +276,7 @@ async def test_chatgpt_fetch_by_permalink(mcp_server, app, test_project):
async def test_chatgpt_fetch_nonexistent_document(mcp_server, app, test_project):
"""Test ChatGPT fetch with non-existent document ID."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Try to fetch a non-existent document
fetch_result = await client.call_tool(
"fetch",
Expand All @@ -294,7 +302,7 @@ async def test_chatgpt_fetch_nonexistent_document(mcp_server, app, test_project)
async def test_chatgpt_fetch_with_empty_title(mcp_server, app, test_project):
"""Test ChatGPT fetch handles documents with empty or missing titles."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Create a note without a title in the content
await client.call_tool(
"write_note",
Expand Down Expand Up @@ -329,7 +337,7 @@ async def test_chatgpt_fetch_with_empty_title(mcp_server, app, test_project):
async def test_chatgpt_search_pagination_default(mcp_server, app, test_project):
"""Test that ChatGPT search uses reasonable pagination defaults."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Create more than 10 notes to test pagination
for i in range(15):
await client.call_tool(
Expand Down Expand Up @@ -362,7 +370,7 @@ async def test_chatgpt_search_pagination_default(mcp_server, app, test_project):
async def test_chatgpt_tools_error_handling(mcp_server, app, test_project):
"""Test error handling in ChatGPT tools returns proper MCP format."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Test search with invalid query (if validation exists)
# Using empty query to potentially trigger an error
search_result = await client.call_tool(
Expand All @@ -384,11 +392,41 @@ async def test_chatgpt_tools_error_handling(mcp_server, app, test_project):
assert "results" in results_json # Should have results key even if empty


@pytest.mark.asyncio
async def test_chatgpt_tools_reject_default_client(mcp_server, app, test_project):
"""Default FastMCP clients cannot use ChatGPT compatibility tools."""

async with Client(mcp_server) as client:
search_result = await client.call_tool(
"search",
{
"query": "Machine Learning",
},
)

search_json = extract_mcp_json_content(search_result)
assert search_json["results"] == []
assert search_json["error"] == "Unsupported MCP client"
assert "search_notes" in search_json["error_message"]

fetch_result = await client.call_tool(
"fetch",
{
"id": "Machine Learning Fundamentals",
},
)

document_json = extract_mcp_json_content(fetch_result)
assert document_json["id"] == "Machine Learning Fundamentals"
assert document_json["metadata"]["error"] == "Unsupported MCP client"
assert "read_note" in document_json["text"]


@pytest.mark.asyncio
async def test_chatgpt_integration_workflow(mcp_server, app, test_project):
"""Test complete workflow: search then fetch, as ChatGPT would use it."""

async with Client(mcp_server) as client:
async with openai_mcp_client(mcp_server) as client:
# Step 1: Create multiple documents
docs = [
{
Expand Down
Loading
Loading