From 232fe54f14fbbd5e8987fe5f3af68787ddca2a74 Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 21 May 2026 09:06:24 +0200 Subject: [PATCH 1/6] Make model download non-blocking on serve --- src/semble/mcp.py | 108 +++++++++++++++++++++++++++++++++------------- tests/test_mcp.py | 64 +++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 30 deletions(-) diff --git a/src/semble/mcp.py b/src/semble/mcp.py index a9c533d..e41855b 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -113,28 +113,71 @@ async def find_related( async def serve(path: str | None = None, ref: str | None = None, include_text_files: bool = False) -> None: - """Start an MCP stdio server, optionally pre-indexing a default source.""" - model = await asyncio.to_thread(load_model) - cache = _IndexCache(model=model, include_text_files=include_text_files) - if path: - await cache.get(path, ref=ref) - if not _is_git_url(path): - await cache.start_watcher(path) + """Start an MCP stdio server, optionally pre-indexing a default source. + The MCP transport opens before the embedding model finishes loading, so the + ``initialize`` / ``tools/list`` handshake responds even on a cold HF cache. + Tool calls await the model implicitly via :class:`_IndexCache`. + """ + cache = _IndexCache(include_text_files=include_text_files) + + async def _load_and_prewarm() -> None: + try: + model = await asyncio.to_thread(load_model) + except Exception as exc: + logger.exception("Failed to load embedding model") + cache.set_model_error(exc) + return + cache.set_model(model) + if path: + try: + await cache.get(path, ref=ref) + except Exception: + logger.warning("Failed to pre-index %r at startup", path, exc_info=True) + if not _is_git_url(path): + await cache.start_watcher(path) + + init_task = asyncio.create_task(_load_and_prewarm()) server = create_server(cache, default_source=path) - await server.run_stdio_async() + try: + await server.run_stdio_async() + finally: + if not init_task.done(): + init_task.cancel() class _IndexCache: """Cache of indexed repos and local paths for the lifetime of the MCP server process.""" - def __init__(self, model: Encoder, include_text_files: bool = False) -> None: - """Initialise an empty cache with a shared embedding model.""" - self._model = model + def __init__(self, model: Encoder | None = None, include_text_files: bool = False) -> None: + """Initialise an empty cache; ``model`` may be supplied later via :meth:`set_model`.""" + self._model: Encoder | None = model + self._model_error: BaseException | None = None + self._model_ready = asyncio.Event() + if model is not None: + self._model_ready.set() self._include_text_files = include_text_files self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction self._watcher_task: asyncio.Task[None] | None = None + def set_model(self, model: Encoder) -> None: + """Install the embedding model and unblock any tool calls awaiting it.""" + self._model = model + self._model_ready.set() + + def set_model_error(self, exc: BaseException) -> None: + """Mark model loading as failed; awaiting tool calls will raise ``exc``.""" + self._model_error = exc + self._model_ready.set() + + async def _await_model(self) -> Encoder: + """Block until the model is installed; re-raise the load error if it failed.""" + await self._model_ready.wait() + if self._model_error is not None: + raise self._model_error + assert self._model is not None + return self._model + def _compute_cache_key(self, source: str, ref: str | None = None) -> str: """Compute the canonical cache key for a source.""" is_git = _is_git_url(source) @@ -163,27 +206,32 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex: """Return an index for the requested source, building and caching it on first access.""" cache_key = self._compute_cache_key(source, ref) - if cache_key in self._tasks: - self._tasks.move_to_end(cache_key) - else: - if len(self._tasks) >= _CACHE_MAX_SIZE: - self._tasks.popitem(last=False) - if _is_git_url(source): - self._tasks[cache_key] = asyncio.create_task( - asyncio.to_thread( - SembleIndex.from_git, - source, - ref=ref, - model=self._model, - include_text_files=self._include_text_files, + if cache_key not in self._tasks: + model = await self._await_model() + # Re-check after the await: another caller may have populated the entry. + if cache_key not in self._tasks: + if len(self._tasks) >= _CACHE_MAX_SIZE: + self._tasks.popitem(last=False) + if _is_git_url(source): + self._tasks[cache_key] = asyncio.create_task( + asyncio.to_thread( + SembleIndex.from_git, + source, + ref=ref, + model=model, + include_text_files=self._include_text_files, + ) ) - ) - else: - self._tasks[cache_key] = asyncio.create_task( - asyncio.to_thread( - SembleIndex.from_path, cache_key, model=self._model, include_text_files=self._include_text_files + else: + self._tasks[cache_key] = asyncio.create_task( + asyncio.to_thread( + SembleIndex.from_path, + cache_key, + model=model, + include_text_files=self._include_text_files, + ) ) - ) + self._tasks.move_to_end(cache_key) task = self._tasks[cache_key] try: return await asyncio.shield(task) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index ef4dcce..aa3d7d0 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from typing import Any, AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch @@ -254,6 +255,69 @@ async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: mock_run.assert_called_once() +@pytest.mark.anyio +async def test_serve_opens_stdio_before_model_loads() -> None: + """The MCP transport must open before load_model() completes. + + Regression test for slow-network startup: stdio is what carries the + initialize/tools/list handshake, so blocking on a cold HF model download + causes the MCP client to time out before tools are registered. + """ + import time + + stdio_opened = asyncio.Event() + + def blocking_load_model() -> Encoder: + # Spin until stdio has opened. If serve() blocks on us first, this never + # observes the event and the test fails fast. + deadline = time.monotonic() + 1.0 + while time.monotonic() < deadline: + if stdio_opened.is_set(): + return MagicMock(spec=Encoder) + time.sleep(0.005) + raise AssertionError("stdio did not open while load_model was in flight") + + async def fake_run_stdio() -> None: + stdio_opened.set() + # Yield briefly so the background init task can complete. + await asyncio.sleep(0.05) + + with ( + patch("semble.mcp.load_model", side_effect=blocking_load_model), + patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_run_stdio), + ): + await serve() + + assert stdio_opened.is_set() + + +@pytest.mark.anyio +async def test_index_cache_awaits_model(tmp_path: Path) -> None: + """get() blocks until set_model() is called, then proceeds.""" + cache = _IndexCache() # no model yet + fake_index = MagicMock() + with patch("semble.mcp.SembleIndex.from_path", return_value=fake_index): + get_task = asyncio.create_task(cache.get(str(tmp_path))) + # Yield so get_task runs up to the _await_model() point. + await asyncio.sleep(0.01) + assert not get_task.done(), "get() must block until the model is installed" + cache.set_model(MagicMock(spec=Encoder)) + result = await asyncio.wait_for(get_task, timeout=1.0) + assert result is fake_index + + +@pytest.mark.anyio +async def test_index_cache_propagates_model_error(tmp_path: Path) -> None: + """If model load fails, awaiting tool calls re-raise the original exception.""" + cache = _IndexCache() + get_task = asyncio.create_task(cache.get(str(tmp_path))) + await asyncio.sleep(0.01) + assert not get_task.done() + cache.set_model_error(RuntimeError("HF download failed")) + with pytest.raises(RuntimeError, match="HF download failed"): + await asyncio.wait_for(get_task, timeout=1.0) + + @pytest.mark.anyio @pytest.mark.parametrize( ("repo", "tool", "extra_args"), From 5cd3a2f0aba0561b0916e41db1860c8358a76ef9 Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 21 May 2026 09:10:47 +0200 Subject: [PATCH 2/6] Simplify code --- src/semble/mcp.py | 27 +++++++-------------------- tests/test_mcp.py | 34 ++++++++++------------------------ uv.lock | 2 +- 3 files changed, 18 insertions(+), 45 deletions(-) diff --git a/src/semble/mcp.py b/src/semble/mcp.py index e41855b..78f1089 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -113,22 +113,19 @@ async def find_related( async def serve(path: str | None = None, ref: str | None = None, include_text_files: bool = False) -> None: - """Start an MCP stdio server, optionally pre-indexing a default source. - - The MCP transport opens before the embedding model finishes loading, so the - ``initialize`` / ``tools/list`` handshake responds even on a cold HF cache. - Tool calls await the model implicitly via :class:`_IndexCache`. - """ + """Start an MCP stdio server, optionally pre-indexing a default source.""" cache = _IndexCache(include_text_files=include_text_files) + # Pre-load the model and optionally pre-index the default source in parallel with starting the server. async def _load_and_prewarm() -> None: try: - model = await asyncio.to_thread(load_model) + cache._model = await asyncio.to_thread(load_model) except Exception as exc: logger.exception("Failed to load embedding model") - cache.set_model_error(exc) + cache._model_error = exc return - cache.set_model(model) + finally: + cache._model_ready.set() if path: try: await cache.get(path, ref=ref) @@ -150,7 +147,7 @@ class _IndexCache: """Cache of indexed repos and local paths for the lifetime of the MCP server process.""" def __init__(self, model: Encoder | None = None, include_text_files: bool = False) -> None: - """Initialise an empty cache; ``model`` may be supplied later via :meth:`set_model`.""" + """Initialise an empty cache.""" self._model: Encoder | None = model self._model_error: BaseException | None = None self._model_ready = asyncio.Event() @@ -160,16 +157,6 @@ def __init__(self, model: Encoder | None = None, include_text_files: bool = Fals self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction self._watcher_task: asyncio.Task[None] | None = None - def set_model(self, model: Encoder) -> None: - """Install the embedding model and unblock any tool calls awaiting it.""" - self._model = model - self._model_ready.set() - - def set_model_error(self, exc: BaseException) -> None: - """Mark model loading as failed; awaiting tool calls will raise ``exc``.""" - self._model_error = exc - self._model_ready.set() - async def _await_model(self) -> Encoder: """Block until the model is installed; re-raise the load error if it failed.""" await self._model_ready.wait() diff --git a/tests/test_mcp.py b/tests/test_mcp.py index aa3d7d0..f5599fa 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,4 +1,5 @@ import asyncio +import threading from pathlib import Path from typing import Any, AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch @@ -257,29 +258,15 @@ async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: @pytest.mark.anyio async def test_serve_opens_stdio_before_model_loads() -> None: - """The MCP transport must open before load_model() completes. - - Regression test for slow-network startup: stdio is what carries the - initialize/tools/list handshake, so blocking on a cold HF model download - causes the MCP client to time out before tools are registered. - """ - import time - - stdio_opened = asyncio.Event() + """Stdio must open before load_model() finishes — regression for #133.""" + stdio_opened = threading.Event() def blocking_load_model() -> Encoder: - # Spin until stdio has opened. If serve() blocks on us first, this never - # observes the event and the test fails fast. - deadline = time.monotonic() + 1.0 - while time.monotonic() < deadline: - if stdio_opened.is_set(): - return MagicMock(spec=Encoder) - time.sleep(0.005) - raise AssertionError("stdio did not open while load_model was in flight") + assert stdio_opened.wait(timeout=1.0), "stdio did not open" + return MagicMock(spec=Encoder) async def fake_run_stdio() -> None: stdio_opened.set() - # Yield briefly so the background init task can complete. await asyncio.sleep(0.05) with ( @@ -288,20 +275,18 @@ async def fake_run_stdio() -> None: ): await serve() - assert stdio_opened.is_set() - @pytest.mark.anyio async def test_index_cache_awaits_model(tmp_path: Path) -> None: - """get() blocks until set_model() is called, then proceeds.""" + """get() blocks until the model is installed, then proceeds.""" cache = _IndexCache() # no model yet fake_index = MagicMock() with patch("semble.mcp.SembleIndex.from_path", return_value=fake_index): get_task = asyncio.create_task(cache.get(str(tmp_path))) - # Yield so get_task runs up to the _await_model() point. await asyncio.sleep(0.01) assert not get_task.done(), "get() must block until the model is installed" - cache.set_model(MagicMock(spec=Encoder)) + cache._model = MagicMock(spec=Encoder) + cache._model_ready.set() result = await asyncio.wait_for(get_task, timeout=1.0) assert result is fake_index @@ -313,7 +298,8 @@ async def test_index_cache_propagates_model_error(tmp_path: Path) -> None: get_task = asyncio.create_task(cache.get(str(tmp_path))) await asyncio.sleep(0.01) assert not get_task.done() - cache.set_model_error(RuntimeError("HF download failed")) + cache._model_error = RuntimeError("HF download failed") + cache._model_ready.set() with pytest.raises(RuntimeError, match="HF download failed"): await asyncio.wait_for(get_task, timeout=1.0) diff --git a/uv.lock b/uv.lock index 9ce6273..a45dbf4 100644 --- a/uv.lock +++ b/uv.lock @@ -3171,7 +3171,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.9.0" }, { name = "sentence-transformers", marker = "extra == 'benchmark'", specifier = ">=3.0" }, { name = "tiktoken", marker = "extra == 'benchmark'", specifier = ">=0.7" }, - { name = "tree-sitter", specifier = ">=0.25" }, + { name = "tree-sitter", specifier = ">=0.25,<0.26" }, { name = "tree-sitter-language-pack", specifier = ">=1.0,!=1.6.3,<1.8.0" }, { name = "vicinity", specifier = ">=0.4.4" }, { name = "watchfiles", marker = "extra == 'mcp'", specifier = ">=0.21" }, From fa8836cebb3f2092e580ea999546e8e96741cdcf Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 21 May 2026 09:19:06 +0200 Subject: [PATCH 3/6] Update docstring --- src/semble/mcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/semble/mcp.py b/src/semble/mcp.py index 78f1089..5993c6c 100644 --- a/src/semble/mcp.py +++ b/src/semble/mcp.py @@ -116,8 +116,8 @@ async def serve(path: str | None = None, ref: str | None = None, include_text_fi """Start an MCP stdio server, optionally pre-indexing a default source.""" cache = _IndexCache(include_text_files=include_text_files) - # Pre-load the model and optionally pre-index the default source in parallel with starting the server. async def _load_and_prewarm() -> None: + """Pre-load the model and optionally pre-index the default source in parallel with starting the server.""" try: cache._model = await asyncio.to_thread(load_model) except Exception as exc: From 91efdd8dce1768ab946ded861a395c01b10fbfe0 Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 21 May 2026 09:27:12 +0200 Subject: [PATCH 4/6] Add more tests --- tests/test_mcp.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index f5599fa..3f2f5c7 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -242,14 +242,37 @@ async def test_tool_output( @pytest.mark.anyio -@pytest.mark.parametrize("with_path", [True, False], ids=["pre_index", "no_path"]) -async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None: - """serve() loads the model, runs stdio, and optionally pre-indexes when a path is given.""" +@pytest.mark.parametrize( + ("with_path", "load_err", "from_path_err", "stdio_yields"), + [ + (True, None, None, True), + (False, None, None, True), + (False, RuntimeError("boom"), None, True), + (True, None, RuntimeError("boom"), True), + (False, None, None, False), + ], + ids=["pre_index", "no_path", "model_load_fails", "prewarm_fails", "cancel_pending_init"], +) +async def test_serve_runs_stdio( + tmp_path: Path, + with_path: bool, + load_err: Exception | None, + from_path_err: Exception | None, + stdio_yields: bool, +) -> None: + """serve() runs stdio and handles all background init outcomes without raising.""" + + async def fake_stdio() -> None: + if stdio_yields: + await asyncio.sleep(0.05) # let the background init task run + + load_kwargs = {"side_effect": load_err} if load_err else {"return_value": MagicMock(spec=Encoder)} + fp_kwargs = {"side_effect": from_path_err} if from_path_err else {"return_value": MagicMock()} with ( - patch("semble.mcp.load_model", return_value=MagicMock(spec=Encoder)), - patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()), + patch("semble.mcp.load_model", **load_kwargs), + patch("semble.mcp.SembleIndex.from_path", **fp_kwargs), patch.object(_IndexCache, "start_watcher", new_callable=AsyncMock), - patch("mcp.server.fastmcp.FastMCP.run_stdio_async", new_callable=AsyncMock) as mock_run, + patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_stdio) as mock_run, ): await (serve(str(tmp_path)) if with_path else serve()) From 9524101930499b12a7dff026dbf6c41ec6bd9213 Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 21 May 2026 09:28:04 +0200 Subject: [PATCH 5/6] Update docstring --- tests/test_mcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 3f2f5c7..57dd8b7 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -281,7 +281,7 @@ async def fake_stdio() -> None: @pytest.mark.anyio async def test_serve_opens_stdio_before_model_loads() -> None: - """Stdio must open before load_model() finishes — regression for #133.""" + """Stdio must open before load_model() finishes.""" stdio_opened = threading.Event() def blocking_load_model() -> Encoder: From 4069cf0f485bf050f2bc62c1c3585e3f0a398402 Mon Sep 17 00:00:00 2001 From: Pringled Date: Thu, 21 May 2026 09:52:45 +0200 Subject: [PATCH 6/6] Set force_download to False --- src/semble/index/dense.py | 2 +- tests/test_search.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/semble/index/dense.py b/src/semble/index/dense.py index 2427fee..fcec51a 100644 --- a/src/semble/index/dense.py +++ b/src/semble/index/dense.py @@ -20,7 +20,7 @@ def load_model(model_path: str | None = None) -> Encoder: # Disable HF progress bars since the model is loaded silently in the background during indexing. disable_progress_bars() try: - model = StaticModel.from_pretrained(model_path) + model = StaticModel.from_pretrained(model_path, force_download=False) finally: disable_progress_bars() return cast(Encoder, model) diff --git a/tests/test_search.py b/tests/test_search.py index 56bd2f1..2f40fa6 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -142,7 +142,7 @@ def test_load_model(model_path: str | None, expected_call_arg: str) -> None: fake_model = MagicMock(spec=Encoder) with patch("semble.index.dense.StaticModel.from_pretrained", return_value=fake_model) as mock_fp: result = load_model(model_path) - mock_fp.assert_called_once_with(expected_call_arg) + mock_fp.assert_called_once_with(expected_call_arg, force_download=False) assert result is fake_model