diff --git a/src/semble/index/dense.py b/src/semble/index/dense.py index 2427fee..fcec51a 100644 --- a/src/semble/index/dense.py +++ b/src/semble/index/dense.py @@ -20,7 +20,7 @@ def load_model(model_path: str | None = None) -> Encoder: # Disable HF progress bars since the model is loaded silently in the background during indexing. disable_progress_bars() try: - model = StaticModel.from_pretrained(model_path) + model = StaticModel.from_pretrained(model_path, force_download=False) finally: disable_progress_bars() return cast(Encoder, model) diff --git a/src/semble/mcp.py b/src/semble/mcp.py index a9c533d..5993c6c 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -114,27 +114,57 @@ async def find_related( async def serve(path: str | None = None, ref: str | None = None, include_text_files: bool = False) -> None: """Start an MCP stdio server, optionally pre-indexing a default source.""" - model = await asyncio.to_thread(load_model) - cache = _IndexCache(model=model, include_text_files=include_text_files) - if path: - await cache.get(path, ref=ref) - if not _is_git_url(path): - await cache.start_watcher(path) + cache = _IndexCache(include_text_files=include_text_files) + async def _load_and_prewarm() -> None: + """Pre-load the model and optionally pre-index the default source in parallel with starting the server.""" + try: + cache._model = await asyncio.to_thread(load_model) + except Exception as exc: + logger.exception("Failed to load embedding model") + cache._model_error = exc + return + finally: + cache._model_ready.set() + if path: + try: + await cache.get(path, ref=ref) + except Exception: + logger.warning("Failed to pre-index %r at startup", path, exc_info=True) + if not _is_git_url(path): + await cache.start_watcher(path) + + init_task = asyncio.create_task(_load_and_prewarm()) server = create_server(cache, default_source=path) - await server.run_stdio_async() + try: + await server.run_stdio_async() + finally: + if not init_task.done(): + init_task.cancel() class _IndexCache: """Cache of indexed repos and local paths for the lifetime of the MCP server process.""" - def __init__(self, model: Encoder, include_text_files: bool = False) -> None: - """Initialise an empty cache with a shared embedding model.""" - self._model = model + def __init__(self, model: Encoder | None = None, include_text_files: bool = False) -> None: + """Initialise an empty cache.""" + self._model: Encoder | None = model + self._model_error: BaseException | None = None + self._model_ready = asyncio.Event() + if model is not None: + self._model_ready.set() self._include_text_files = include_text_files self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction self._watcher_task: asyncio.Task[None] | None = None + async def _await_model(self) -> Encoder: + """Block until the model is installed; re-raise the load error if it failed.""" + await self._model_ready.wait() + if self._model_error is not None: + raise self._model_error + assert self._model is not None + return self._model + def _compute_cache_key(self, source: str, ref: str | None = None) -> str: """Compute the canonical cache key for a source.""" is_git = _is_git_url(source) @@ -163,27 +193,32 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex: """Return an index for the requested source, building and caching it on first access.""" cache_key = self._compute_cache_key(source, ref) - if cache_key in self._tasks: - self._tasks.move_to_end(cache_key) - else: - if len(self._tasks) >= _CACHE_MAX_SIZE: - self._tasks.popitem(last=False) - if _is_git_url(source): - self._tasks[cache_key] = asyncio.create_task( - asyncio.to_thread( - SembleIndex.from_git, - source, - ref=ref, - model=self._model, - include_text_files=self._include_text_files, + if cache_key not in self._tasks: + model = await self._await_model() + # Re-check after the await: another caller may have populated the entry. + if cache_key not in self._tasks: + if len(self._tasks) >= _CACHE_MAX_SIZE: + self._tasks.popitem(last=False) + if _is_git_url(source): + self._tasks[cache_key] = asyncio.create_task( + asyncio.to_thread( + SembleIndex.from_git, + source, + ref=ref, + model=model, + include_text_files=self._include_text_files, + ) ) - ) - else: - self._tasks[cache_key] = asyncio.create_task( - asyncio.to_thread( - SembleIndex.from_path, cache_key, model=self._model, include_text_files=self._include_text_files + else: + self._tasks[cache_key] = asyncio.create_task( + asyncio.to_thread( + SembleIndex.from_path, + cache_key, + model=model, + include_text_files=self._include_text_files, + ) ) - ) + self._tasks.move_to_end(cache_key) task = self._tasks[cache_key] try: return await asyncio.shield(task) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index ef4dcce..57dd8b7 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,3 +1,5 @@ +import asyncio +import threading from pathlib import Path from typing import Any, AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch @@ -240,20 +242,91 @@ async def test_tool_output( @pytest.mark.anyio -@pytest.mark.parametrize("with_path", [True, False], ids=["pre_index", "no_path"]) -async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: - """serve() loads the model, runs stdio, and optionally pre-indexes when a path is given.""" +@pytest.mark.parametrize( + ("with_path", "load_err", "from_path_err", "stdio_yields"), + [ + (True, None, None, True), + (False, None, None, True), + (False, RuntimeError("boom"), None, True), + (True, None, RuntimeError("boom"), True), + (False, None, None, False), + ], + ids=["pre_index", "no_path", "model_load_fails", "prewarm_fails", "cancel_pending_init"], +) +async def test_serve_runs_stdio( + tmp_path: Path, + with_path: bool, + load_err: Exception | None, + from_path_err: Exception | None, + stdio_yields: bool, +) -> None: + """serve() runs stdio and handles all background init outcomes without raising.""" + + async def fake_stdio() -> None: + if stdio_yields: + await asyncio.sleep(0.05) # let the background init task run + + load_kwargs = {"side_effect": load_err} if load_err else {"return_value": MagicMock(spec=Encoder)} + fp_kwargs = {"side_effect": from_path_err} if from_path_err else {"return_value": MagicMock()} with ( - patch("semble.mcp.load_model", return_value=MagicMock(spec=Encoder)), - patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()), + patch("semble.mcp.load_model", **load_kwargs), + patch("semble.mcp.SembleIndex.from_path", **fp_kwargs), patch.object(_IndexCache, "start_watcher", new_callable=AsyncMock), - patch("mcp.server.fastmcp.FastMCP.run_stdio_async", new_callable=AsyncMock) as mock_run, + patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_stdio) as mock_run, ): await (serve(str(tmp_path)) if with_path else serve()) mock_run.assert_called_once() +@pytest.mark.anyio +async def test_serve_opens_stdio_before_model_loads() -> None: + """Stdio must open before load_model() finishes.""" + stdio_opened = threading.Event() + + def blocking_load_model() -> Encoder: + assert stdio_opened.wait(timeout=1.0), "stdio did not open" + return MagicMock(spec=Encoder) + + async def fake_run_stdio() -> None: + stdio_opened.set() + await asyncio.sleep(0.05) + + with ( + patch("semble.mcp.load_model", side_effect=blocking_load_model), + patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_run_stdio), + ): + await serve() + + +@pytest.mark.anyio +async def test_index_cache_awaits_model(tmp_path: Path) -> None: + """get() blocks until the model is installed, then proceeds.""" + cache = _IndexCache() # no model yet + fake_index = MagicMock() + with patch("semble.mcp.SembleIndex.from_path", return_value=fake_index): + get_task = asyncio.create_task(cache.get(str(tmp_path))) + await asyncio.sleep(0.01) + assert not get_task.done(), "get() must block until the model is installed" + cache._model = MagicMock(spec=Encoder) + cache._model_ready.set() + result = await asyncio.wait_for(get_task, timeout=1.0) + assert result is fake_index + + +@pytest.mark.anyio +async def test_index_cache_propagates_model_error(tmp_path: Path) -> None: + """If model load fails, awaiting tool calls re-raise the original exception.""" + cache = _IndexCache() + get_task = asyncio.create_task(cache.get(str(tmp_path))) + await asyncio.sleep(0.01) + assert not get_task.done() + cache._model_error = RuntimeError("HF download failed") + cache._model_ready.set() + with pytest.raises(RuntimeError, match="HF download failed"): + await asyncio.wait_for(get_task, timeout=1.0) + + @pytest.mark.anyio @pytest.mark.parametrize( ("repo", "tool", "extra_args"), diff --git a/tests/test_search.py b/tests/test_search.py index 56bd2f1..2f40fa6 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -142,7 +142,7 @@ def test_load_model(model_path: str | None, expected_call_arg: str) -> None: fake_model = MagicMock(spec=Encoder) with patch("semble.index.dense.StaticModel.from_pretrained", return_value=fake_model) as mock_fp: result = load_model(model_path) - mock_fp.assert_called_once_with(expected_call_arg) + mock_fp.assert_called_once_with(expected_call_arg, force_download=False) assert result is fake_model diff --git a/uv.lock b/uv.lock index 9ce6273..a45dbf4 100644 --- a/uv.lock +++ b/uv.lock @@ -3171,7 +3171,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.9.0" }, { name = "sentence-transformers", marker = "extra == 'benchmark'", specifier = ">=3.0" }, { name = "tiktoken", marker = "extra == 'benchmark'", specifier = ">=0.7" }, - { name = "tree-sitter", specifier = ">=0.25" }, + { name = "tree-sitter", specifier = ">=0.25,<0.26" }, { name = "tree-sitter-language-pack", specifier = ">=1.0,!=1.6.3,<1.8.0" }, { name = "vicinity", specifier = ">=0.4.4" }, { name = "watchfiles", marker = "extra == 'mcp'", specifier = ">=0.21" },