Skip to content

Commit cd635e6

Browse files
feat(search-optimization): cache tool catalog and parallelize per-account MCP fetches (#173)
* Add caching to the fetch tools to boost search performance * Address Co-Pilot Comments * Add benchmark search example in the repo * Fix issues spotted by ruff
1 parent 17fc35b commit cd635e6

5 files changed

Lines changed: 406 additions & 13 deletions

File tree

examples/benchmark_search.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Benchmark: measure SDK search latency with caching.
2+
3+
Runs fetch_tools, local (BM25+TF-IDF) search, and semantic search N times,
4+
reports cold vs warm average latency and the speedup from caching.
5+
6+
Prerequisites:
7+
- STACKONE_API_KEY environment variable
8+
- STACKONE_ACCOUNT_ID environment variable
9+
10+
Run with:
11+
uv run python examples/benchmark_search.py # default 100 iterations
12+
uv run python examples/benchmark_search.py -n 50 # fewer for a quick check
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import argparse
18+
import os
19+
import sys
20+
import time
21+
22+
try:
23+
from dotenv import load_dotenv
24+
25+
load_dotenv()
26+
except ModuleNotFoundError:
27+
pass
28+
29+
from stackone_ai import StackOneToolSet
30+
31+
QUERIES = [
32+
"list events",
33+
"cancel a meeting",
34+
"send a message",
35+
"get current user",
36+
"list employees",
37+
]
38+
39+
40+
def bench(fn, n: int) -> tuple[float, float, list[float]]:
41+
"""Run fn() n times. Return (cold, warm_avg, all_times)."""
42+
times: list[float] = []
43+
for _ in range(n):
44+
t = time.perf_counter()
45+
fn()
46+
times.append(time.perf_counter() - t)
47+
48+
cold = times[0]
49+
warm_times = times[1:]
50+
warm_avg = sum(warm_times) / len(warm_times) if warm_times else cold
51+
return cold, warm_avg, times
52+
53+
54+
def fmt_ms(seconds: float) -> str:
55+
return f"{seconds * 1000:8.1f}ms"
56+
57+
58+
def main() -> int:
59+
parser = argparse.ArgumentParser(description="Benchmark SDK search latency")
60+
parser.add_argument(
61+
"--iterations", "-n", type=int, default=100, help="iterations per benchmark (default 100)"
62+
)
63+
args = parser.parse_args()
64+
n = args.iterations
65+
66+
api_key = os.getenv("STACKONE_API_KEY")
67+
account_id = os.getenv("STACKONE_ACCOUNT_ID")
68+
69+
if not api_key:
70+
print("Set STACKONE_API_KEY to run this benchmark.")
71+
return 1
72+
if not account_id:
73+
print("Set STACKONE_ACCOUNT_ID to run this benchmark.")
74+
return 1
75+
76+
print(f"Benchmarking with account {account_id[:8]}..., {n} iterations each\n")
77+
78+
ts = StackOneToolSet(
79+
api_key=api_key,
80+
account_id=account_id,
81+
search={"method": "auto", "top_k": 5},
82+
)
83+
84+
results: list[tuple[str, float, float, float]] = []
85+
query_idx = 0
86+
87+
def next_query() -> str:
88+
nonlocal query_idx
89+
q = QUERIES[query_idx % len(QUERIES)]
90+
query_idx += 1
91+
return q
92+
93+
# --- 1. fetch_tools ---
94+
print(f"[1/3] fetch_tools x{n} ...")
95+
ts.clear_catalog_cache()
96+
cold, warm_avg, _ = bench(lambda: ts.fetch_tools(), n)
97+
speedup = cold / warm_avg if warm_avg > 0 else float("inf")
98+
results.append(("fetch_tools", cold, warm_avg, speedup))
99+
print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x")
100+
101+
# --- 2. local search (BM25 + TF-IDF) ---
102+
print(f"[2/3] search_tools (local) x{n} ...")
103+
ts.clear_catalog_cache()
104+
query_idx = 0
105+
cold, warm_avg, _ = bench(lambda: ts.search_tools(next_query(), search="local"), n)
106+
speedup = cold / warm_avg if warm_avg > 0 else float("inf")
107+
results.append(("search (local/BM25)", cold, warm_avg, speedup))
108+
print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x")
109+
110+
# --- 3. semantic search (auto) ---
111+
print(f"[3/3] search_tools (semantic/auto) x{n} ...")
112+
ts.clear_catalog_cache()
113+
query_idx = 0
114+
cold, warm_avg, _ = bench(lambda: ts.search_tools(next_query(), search="auto"), n)
115+
speedup = cold / warm_avg if warm_avg > 0 else float("inf")
116+
results.append(("search (semantic)", cold, warm_avg, speedup))
117+
print(f" cold={fmt_ms(cold)} warm_avg={fmt_ms(warm_avg)} speedup={speedup:.0f}x")
118+
119+
# --- Summary ---
120+
print("\n" + "=" * 65)
121+
print(f"{'Benchmark':<22} {'Cold':>10} {'Warm (avg)':>10} {'Speedup':>10}")
122+
print("-" * 65)
123+
for name, c, w, s in results:
124+
print(f"{name:<22} {fmt_ms(c):>10} {fmt_ms(w):>10} {s:>9.0f}x")
125+
print("=" * 65)
126+
127+
print(f"\nWarm = average of {n - 1} calls after the first (cold) call.")
128+
print("Speedup = cold / warm_avg — shows the benefit of caching.\n")
129+
130+
return 0
131+
132+
133+
if __name__ == "__main__":
134+
sys.exit(main())

examples/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_example_files() -> list[str]:
3434
"semantic_search_example.py": ["mcp"],
3535
"mcp_server.py": ["mcp"],
3636
"workday_integration.py": ["openai", "mcp"],
37+
"benchmark_search.py": ["mcp"],
3738
}
3839

3940

stackone_ai/toolset.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ class _ExecuteTool(StackOneTool):
170170
"""LLM-callable tool that executes a StackOne tool by name."""
171171

172172
_toolset: Any = PrivateAttr(default=None)
173-
_cached_tools: Any = PrivateAttr(default=None)
174173

175174
def execute(
176175
self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None
@@ -185,10 +184,8 @@ def execute(
185184
parsed = _ExecuteInput(**raw_params)
186185
tool_name = parsed.tool_name
187186

188-
if self._cached_tools is None:
189-
self._cached_tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids)
190-
191-
target = self._cached_tools.get_tool(parsed.tool_name)
187+
tools = self._toolset.fetch_tools(account_ids=self._toolset._account_ids)
188+
target = tools.get_tool(parsed.tool_name)
192189

193190
if target is None:
194191
return {
@@ -602,6 +599,8 @@ def __init__(
602599
execute_timeout = execute.get("timeout") if execute else None
603600
self._timeout: float = timeout if timeout is not None else (execute_timeout or 60.0)
604601
self._tools_cache: Tools | None = None
602+
self._catalog_cache: dict[tuple[Any, ...], Tools] = {}
603+
self._tool_index_cache: tuple[int, Any] | None = None
605604

606605
def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
607606
"""Set account IDs for filtering tools
@@ -613,8 +612,18 @@ def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
613612
This toolset instance for chaining
614613
"""
615614
self._account_ids = account_ids
615+
self.clear_catalog_cache()
616616
return self
617617

618+
def clear_catalog_cache(self) -> None:
619+
"""Invalidate cached tool catalog and local search index.
620+
621+
Call when linked accounts change outside of ``set_accounts`` or when
622+
you need to force a fresh fetch from the StackOne MCP endpoint.
623+
"""
624+
self._catalog_cache.clear()
625+
self._tool_index_cache = None
626+
618627
def get_search_tool(self, *, search: SearchMode | None = None) -> SearchTool:
619628
"""Get a callable search tool that returns Tools collections.
620629
@@ -802,7 +811,10 @@ def _local_search(
802811
if not available_connectors:
803812
return Tools([])
804813

805-
index = ToolIndex(list(all_tools))
814+
cache_key = id(all_tools)
815+
if self._tool_index_cache is None or self._tool_index_cache[0] != cache_key:
816+
self._tool_index_cache = (cache_key, ToolIndex(list(all_tools)))
817+
index = self._tool_index_cache[1]
806818
results = index.search(
807819
query,
808820
limit=top_k if top_k is not None else 5,
@@ -1171,22 +1183,41 @@ def fetch_tools(
11711183
else:
11721184
account_scope = [None]
11731185

1186+
cache_key = (
1187+
tuple(sorted(account_scope, key=lambda a: (a is None, a))),
1188+
tuple(sorted(p.lower() for p in providers)) if providers else None,
1189+
tuple(sorted(actions)) if actions else None,
1190+
)
1191+
cached = self._catalog_cache.get(cache_key)
1192+
if cached is not None:
1193+
return cached
1194+
11741195
endpoint = f"{self.base_url.rstrip('/')}/mcp"
1175-
all_tools: list[StackOneTool] = []
11761196

1177-
for account in account_scope:
1197+
def _fetch_for_account(account: str | None) -> list[StackOneTool]:
11781198
headers = self._build_mcp_headers(account)
11791199
catalog = _fetch_mcp_tools(endpoint, headers)
1180-
for tool_def in catalog:
1181-
all_tools.append(self._create_rpc_tool(tool_def, account))
1200+
return [self._create_rpc_tool(tool_def, account) for tool_def in catalog]
1201+
1202+
all_tools: list[StackOneTool] = []
1203+
if len(account_scope) == 1:
1204+
all_tools.extend(_fetch_for_account(account_scope[0]))
1205+
else:
1206+
max_workers = min(len(account_scope), 10)
1207+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
1208+
futures = [pool.submit(_fetch_for_account, acc) for acc in account_scope]
1209+
for future in futures:
1210+
all_tools.extend(future.result())
11821211

11831212
if providers:
11841213
all_tools = [tool for tool in all_tools if self._filter_by_provider(tool.name, providers)]
11851214

11861215
if actions:
11871216
all_tools = [tool for tool in all_tools if self._filter_by_action(tool.name, actions)]
11881217

1189-
return Tools(all_tools)
1218+
result = Tools(all_tools)
1219+
self._catalog_cache[cache_key] = result
1220+
return result
11901221

11911222
except ToolsetError:
11921223
raise

tests/test_agent_tools.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def test_invalid_json_returns_error_dict(self):
268268

269269
assert "error" in result
270270

271-
def test_caches_fetched_tools(self):
271+
def test_delegates_catalog_lookup_to_toolset(self):
272+
# _ExecuteTool no longer holds a local cache; the toolset's catalog
273+
# cache (see StackOneToolSet._catalog_cache) is the single source of
274+
# truth. Verify execute always defers to the toolset so it benefits
275+
# from that shared cache.
272276
toolset = MagicMock()
273277
toolset.api_key = "test-key"
274278
toolset._account_ids = []
@@ -286,7 +290,8 @@ def test_caches_fetched_tools(self):
286290
execute.execute({"tool_name": "test_tool"})
287291
execute.execute({"tool_name": "test_tool"})
288292

289-
toolset.fetch_tools.assert_called_once()
293+
assert toolset.fetch_tools.call_count == 2
294+
toolset.fetch_tools.assert_called_with(account_ids=[])
290295

291296
def test_passes_account_ids_from_toolset(self):
292297
toolset = MagicMock()

0 commit comments

Comments
 (0)