Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/semble/index/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def load_model(model_path: str | None = None) -> Encoder:
# Disable HF progress bars since the model is loaded silently in the background during indexing.
disable_progress_bars()
try:
model = StaticModel.from_pretrained(model_path)
model = StaticModel.from_pretrained(model_path, force_download=False)
finally:
disable_progress_bars()
return cast(Encoder, model)
Expand Down
93 changes: 64 additions & 29 deletions src/semble/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,57 @@ async def find_related(

async def serve(path: str | None = None, ref: str | None = None, include_text_files: bool = False) -> None:
"""Start an MCP stdio server, optionally pre-indexing a default source."""
model = await asyncio.to_thread(load_model)
cache = _IndexCache(model=model, include_text_files=include_text_files)
if path:
await cache.get(path, ref=ref)
if not _is_git_url(path):
await cache.start_watcher(path)
cache = _IndexCache(include_text_files=include_text_files)
Comment thread
Pringled marked this conversation as resolved.

async def _load_and_prewarm() -> None:
"""Pre-load the model and optionally pre-index the default source in parallel with starting the server."""
try:
cache._model = await asyncio.to_thread(load_model)
except Exception as exc:
logger.exception("Failed to load embedding model")
cache._model_error = exc
return
finally:
cache._model_ready.set()
if path:
try:
await cache.get(path, ref=ref)
except Exception:
logger.warning("Failed to pre-index %r at startup", path, exc_info=True)
if not _is_git_url(path):
await cache.start_watcher(path)

init_task = asyncio.create_task(_load_and_prewarm())
server = create_server(cache, default_source=path)
await server.run_stdio_async()
try:
await server.run_stdio_async()
finally:
if not init_task.done():
init_task.cancel()


class _IndexCache:
"""Cache of indexed repos and local paths for the lifetime of the MCP server process."""

def __init__(self, model: Encoder, include_text_files: bool = False) -> None:
"""Initialise an empty cache with a shared embedding model."""
self._model = model
def __init__(self, model: Encoder | None = None, include_text_files: bool = False) -> None:
"""Initialise an empty cache."""
self._model: Encoder | None = model
self._model_error: BaseException | None = None
self._model_ready = asyncio.Event()
if model is not None:
self._model_ready.set()
Comment thread
Pringled marked this conversation as resolved.
self._include_text_files = include_text_files
self._tasks: OrderedDict[str, asyncio.Task[SembleIndex]] = OrderedDict() # ordered for LRU eviction
self._watcher_task: asyncio.Task[None] | None = None

async def _await_model(self) -> Encoder:
"""Block until the model is installed; re-raise the load error if it failed."""
await self._model_ready.wait()
if self._model_error is not None:
raise self._model_error
assert self._model is not None
return self._model

def _compute_cache_key(self, source: str, ref: str | None = None) -> str:
"""Compute the canonical cache key for a source."""
is_git = _is_git_url(source)
Expand Down Expand Up @@ -163,27 +193,32 @@ async def get(self, source: str, ref: str | None = None) -> SembleIndex:
"""Return an index for the requested source, building and caching it on first access."""
cache_key = self._compute_cache_key(source, ref)

if cache_key in self._tasks:
self._tasks.move_to_end(cache_key)
else:
if len(self._tasks) >= _CACHE_MAX_SIZE:
self._tasks.popitem(last=False)
if _is_git_url(source):
self._tasks[cache_key] = asyncio.create_task(
asyncio.to_thread(
SembleIndex.from_git,
source,
ref=ref,
model=self._model,
include_text_files=self._include_text_files,
if cache_key not in self._tasks:
model = await self._await_model()
# Re-check after the await: another caller may have populated the entry.
if cache_key not in self._tasks:
if len(self._tasks) >= _CACHE_MAX_SIZE:
self._tasks.popitem(last=False)
if _is_git_url(source):
self._tasks[cache_key] = asyncio.create_task(
asyncio.to_thread(
SembleIndex.from_git,
source,
ref=ref,
model=model,
include_text_files=self._include_text_files,
)
)
)
else:
self._tasks[cache_key] = asyncio.create_task(
asyncio.to_thread(
SembleIndex.from_path, cache_key, model=self._model, include_text_files=self._include_text_files
else:
self._tasks[cache_key] = asyncio.create_task(
asyncio.to_thread(
SembleIndex.from_path,
cache_key,
model=model,
include_text_files=self._include_text_files,
)
)
)
self._tasks.move_to_end(cache_key)
task = self._tasks[cache_key]
try:
return await asyncio.shield(task)
Expand Down
85 changes: 79 additions & 6 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import threading
from pathlib import Path
from typing import Any, AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -240,20 +242,91 @@ async def test_tool_output(


@pytest.mark.anyio
@pytest.mark.parametrize("with_path", [True, False], ids=["pre_index", "no_path"])
async def test_serve_runs_stdio(tmp_path: Path, with_path: bool) -> None:
"""serve() loads the model, runs stdio, and optionally pre-indexes when a path is given."""
@pytest.mark.parametrize(
("with_path", "load_err", "from_path_err", "stdio_yields"),
[
(True, None, None, True),
(False, None, None, True),
(False, RuntimeError("boom"), None, True),
(True, None, RuntimeError("boom"), True),
(False, None, None, False),
],
ids=["pre_index", "no_path", "model_load_fails", "prewarm_fails", "cancel_pending_init"],
)
async def test_serve_runs_stdio(
tmp_path: Path,
with_path: bool,
load_err: Exception | None,
from_path_err: Exception | None,
stdio_yields: bool,
) -> None:
"""serve() runs stdio and handles all background init outcomes without raising."""

async def fake_stdio() -> None:
if stdio_yields:
await asyncio.sleep(0.05) # let the background init task run

load_kwargs = {"side_effect": load_err} if load_err else {"return_value": MagicMock(spec=Encoder)}
fp_kwargs = {"side_effect": from_path_err} if from_path_err else {"return_value": MagicMock()}
with (
patch("semble.mcp.load_model", return_value=MagicMock(spec=Encoder)),
patch("semble.mcp.SembleIndex.from_path", return_value=MagicMock()),
patch("semble.mcp.load_model", **load_kwargs),
patch("semble.mcp.SembleIndex.from_path", **fp_kwargs),
patch.object(_IndexCache, "start_watcher", new_callable=AsyncMock),
patch("mcp.server.fastmcp.FastMCP.run_stdio_async", new_callable=AsyncMock) as mock_run,
patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_stdio) as mock_run,
):
await (serve(str(tmp_path)) if with_path else serve())

mock_run.assert_called_once()


@pytest.mark.anyio
async def test_serve_opens_stdio_before_model_loads() -> None:
"""Stdio must open before load_model() finishes."""
stdio_opened = threading.Event()

def blocking_load_model() -> Encoder:
assert stdio_opened.wait(timeout=1.0), "stdio did not open"
return MagicMock(spec=Encoder)

async def fake_run_stdio() -> None:
stdio_opened.set()
await asyncio.sleep(0.05)

with (
patch("semble.mcp.load_model", side_effect=blocking_load_model),
patch("mcp.server.fastmcp.FastMCP.run_stdio_async", side_effect=fake_run_stdio),
):
await serve()


@pytest.mark.anyio
async def test_index_cache_awaits_model(tmp_path: Path) -> None:
"""get() blocks until the model is installed, then proceeds."""
cache = _IndexCache() # no model yet
fake_index = MagicMock()
with patch("semble.mcp.SembleIndex.from_path", return_value=fake_index):
get_task = asyncio.create_task(cache.get(str(tmp_path)))
await asyncio.sleep(0.01)
assert not get_task.done(), "get() must block until the model is installed"
cache._model = MagicMock(spec=Encoder)
cache._model_ready.set()
result = await asyncio.wait_for(get_task, timeout=1.0)
assert result is fake_index


@pytest.mark.anyio
async def test_index_cache_propagates_model_error(tmp_path: Path) -> None:
"""If model load fails, awaiting tool calls re-raise the original exception."""
cache = _IndexCache()
get_task = asyncio.create_task(cache.get(str(tmp_path)))
await asyncio.sleep(0.01)
assert not get_task.done()
cache._model_error = RuntimeError("HF download failed")
cache._model_ready.set()
with pytest.raises(RuntimeError, match="HF download failed"):
await asyncio.wait_for(get_task, timeout=1.0)


@pytest.mark.anyio
@pytest.mark.parametrize(
("repo", "tool", "extra_args"),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_load_model(model_path: str | None, expected_call_arg: str) -> None:
fake_model = MagicMock(spec=Encoder)
with patch("semble.index.dense.StaticModel.from_pretrained", return_value=fake_model) as mock_fp:
result = load_model(model_path)
mock_fp.assert_called_once_with(expected_call_arg)
mock_fp.assert_called_once_with(expected_call_arg, force_download=False)
assert result is fake_model


Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading