diff --git a/examples/llmagent_with_streaming_progress_tool/.env b/examples/llmagent_with_streaming_progress_tool/.env new file mode 100644 index 00000000..0a57a178 --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/.env @@ -0,0 +1,5 @@ +# Set TRPC_AGENT_API_KEY、TRPC_AGENT_BASE_URL、TRPC_AGENT_MODEL_NAME +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name + diff --git a/examples/llmagent_with_streaming_progress_tool/README.md b/examples/llmagent_with_streaming_progress_tool/README.md new file mode 100644 index 00000000..14e36bd2 --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/README.md @@ -0,0 +1,67 @@ +# Streaming Progress Tool + +This example shows how to expose a **long-running tool that streams progress +events to the user in real time**, using `StreamingProgressTool`. + +The wrapped function is an `async def` generator (`async def fn(...): yield ...`). +Every `yield` is surfaced to the runner as a `partial=True` Event tagged with +`custom_metadata={"tool_progress": True, ...}`. The **last** yielded value is +*also* the final `function_response` returned to the LLM. + +```text +yield progress_1 --> partial Event (live) +yield progress_2 --> partial Event (live) +yield progress_3 --> partial Event (live) AND final function_response +``` + +This is different from the other two streaming-ish tools shipped with the SDK: + +| Class | What gets streamed | +| --------------------------- | --------------------------------------------------- | +| `StreamingFunctionTool` | The *arguments* the LLM is generating for the call. | +| `LongRunningFunctionTool` | Nothing intermediate; just marks the call as slow. | +| **`StreamingProgressTool`** | The tool's *own* execution progress. | + +## Run + +```bash +cd examples/llmagent_with_streaming_progress_tool +cp ../mcp_tools/.env .env # or write your own +# edit .env to set TRPC_AGENT_API_KEY / BASE_URL / MODEL_NAME +python run_agent.py +``` + +Expected output (abridged): + +``` +User: Please crawl https://example.com and fetch the first 5 pages. +[crawl_site] ⏳ {'status': 'started', 'url': 'https://example.com', 'max_pages': 5} +[crawl_site] ⏳ {'status': 'fetched', 'page': 1, 'total': 5, ...} +[crawl_site] ⏳ {'status': 'fetched', 'page': 2, 'total': 5, ...} +... +[tool-result] crawl_site → {'status': 'done', 'url': '...', 'pages_fetched': 5, ...} +Assistant: I crawled example.com and fetched 5 pages. ... +``` + +## How to consume progress events on the client side + +Filter on `event.partial` + `custom_metadata.tool_progress` to detect a +progress chunk. The raw value the tool yielded is available in +`custom_metadata['payload']` (for `dict`/`BaseModel` yields) and as a JSON +string in `event.content.parts[0].text` for plain-text consumers. + +```python +async for event in runner.run_async(...): + meta = event.custom_metadata or {} + if event.partial and meta.get("tool_progress"): + print(meta["tool_name"], meta.get("payload") or event.get_text()) + continue + # ...handle final events as usual +``` + +Notes: +- Progress events are NOT persisted into session history (they are partial). +- The LLM only ever sees the **last** yielded value as the tool response. +- If a batch contains a progress-streaming tool, the framework forces + sequential tool execution to keep interim events in deterministic order, + even if the agent has `parallel_tool_calls=True`. diff --git a/examples/llmagent_with_streaming_progress_tool/agent/__init__.py b/examples/llmagent_with_streaming_progress_tool/agent/__init__.py new file mode 100644 index 00000000..bc6e483f --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/agent/__init__.py @@ -0,0 +1,5 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. diff --git a/examples/llmagent_with_streaming_progress_tool/agent/agent.py b/examples/llmagent_with_streaming_progress_tool/agent/agent.py new file mode 100644 index 00000000..5870eb49 --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/agent/agent.py @@ -0,0 +1,51 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent that uses StreamingProgressTool for a long-running task.""" + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import LLMModel +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import StreamingProgressTool +from trpc_agent_sdk.types import GenerateContentConfig + +from .config import get_model_config +from .prompts import INSTRUCTION +from .tools import crawl_site + + +def _create_model() -> LLMModel: + api_key, url, model_name = get_model_config() + return OpenAIModel(model_name=model_name, api_key=api_key, base_url=url) + + +def create_agent() -> LlmAgent: + """Build the agent. ``crawl_site`` is wrapped in + ``StreamingProgressTool(skip_summarization=True)`` so that: + + 1. Every ``yield`` becomes a partial Event the caller renders live. + 2. The last ``yield`` is also the final ``function_response`` event – + persisted to the session as the canonical record of this turn. + 3. ``skip_summarization=True`` makes :class:`LlmAgent` exit the + conversation loop immediately after the tool returns, so the LLM + is **not** asked to re-summarise the streamed output (which the + user has already seen). + """ + crawl_tool = StreamingProgressTool(crawl_site, skip_summarization=True) + + return LlmAgent( + name="streaming_crawler", + description="Crawls a site step-by-step and streams progress to the user.", + model=_create_model(), + instruction=INSTRUCTION, + tools=[crawl_tool], + generate_content_config=GenerateContentConfig( + temperature=0.3, + max_output_tokens=1000, + ), + ) + + +root_agent = create_agent() diff --git a/examples/llmagent_with_streaming_progress_tool/agent/config.py b/examples/llmagent_with_streaming_progress_tool/agent/config.py new file mode 100644 index 00000000..694938b3 --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/agent/config.py @@ -0,0 +1,19 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent config module.""" + +import os + + +def get_model_config() -> tuple[str, str, str]: + """Get model config from environment variables.""" + api_key = os.getenv('TRPC_AGENT_API_KEY', '') + url = os.getenv('TRPC_AGENT_BASE_URL', '') + model_name = os.getenv('TRPC_AGENT_MODEL_NAME', '') + if not api_key or not url or not model_name: + raise ValueError("TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, and TRPC_AGENT_MODEL_NAME " + "must be set in environment variables (e.g. via a .env file).") + return api_key, url, model_name diff --git a/examples/llmagent_with_streaming_progress_tool/agent/prompts.py b/examples/llmagent_with_streaming_progress_tool/agent/prompts.py new file mode 100644 index 00000000..45563bf7 --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/agent/prompts.py @@ -0,0 +1,12 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompts for the streaming-progress tool demo.""" + +INSTRUCTION = ( + "You are a helpful crawling assistant. When the user asks you to crawl, fetch, " + "or inspect a website, ALWAYS call the `crawl_site` tool. Pass a sensible " + "`max_pages` (default to 5 if unspecified). After the tool finishes, " + "summarise what was fetched in 1-2 sentences in the user's language.") diff --git a/examples/llmagent_with_streaming_progress_tool/agent/tools.py b/examples/llmagent_with_streaming_progress_tool/agent/tools.py new file mode 100644 index 00000000..12f786ef --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/agent/tools.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""A long-running tool that streams progress events to the user. + +The tool simulates a multi-step site crawl. Each step yields a structured +progress payload that the framework surfaces as a partial Event in real time. +The **last** yielded value is also the final tool response fed back to the LLM. +""" + +from __future__ import annotations + +import asyncio +import random +from typing import AsyncIterator + + +async def crawl_site(url: str, max_pages: int = 5) -> AsyncIterator[dict]: + """Crawl ``url`` and stream progress for every page fetched. + + Use this for long-running fetches where the user benefits from seeing + incremental progress instead of staring at a spinner. + + Args: + url: The site URL to crawl (any string for demo purposes). + max_pages: How many pages to simulate fetching. Defaults to 5. + + Yields: + dict: One progress payload per step. The final payload is also the + return value the LLM sees. + """ + yield {"status": "started", "url": url, "max_pages": max_pages} + + fetched_titles: list[str] = [] + for page_index in range(1, max_pages + 1): + # Simulate variable per-page latency so the streaming is observable. + await asyncio.sleep(random.uniform(0.4, 1.2)) + title = f"{url} - page {page_index}" + fetched_titles.append(title) + yield { + "status": "fetched", + "page": page_index, + "total": max_pages, + "title": title, + "progress": round(page_index / max_pages, 2), + } + + yield { + "status": "done", + "url": url, + "pages_fetched": len(fetched_titles), + "titles": fetched_titles, + } diff --git a/examples/llmagent_with_streaming_progress_tool/run_agent.py b/examples/llmagent_with_streaming_progress_tool/run_agent.py new file mode 100644 index 00000000..94378b46 --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/run_agent.py @@ -0,0 +1,105 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Demo: streaming progress events from a long-running tool. + +Run with:: + + cd examples/llmagent_with_streaming_progress_tool + python run_agent.py + +Make sure ``TRPC_AGENT_API_KEY``, ``TRPC_AGENT_BASE_URL`` and +``TRPC_AGENT_MODEL_NAME`` are set in your environment or .env file. +""" + +import asyncio +import sys +import uuid +from pathlib import Path + +from dotenv import load_dotenv +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + +sys.path.append(str(Path(__file__).parent)) + + +async def run_streaming_progress_demo() -> None: + """Issue one query and pretty-print every event surfaced by the runner. + + Three event kinds matter here: + 1. ``event.partial`` with ``custom_metadata.tool_progress`` → a *progress* + chunk from the streaming tool. Print it live; do NOT treat it as a + tool response. + 2. ``event.partial`` text from the LLM (no ``tool_progress`` marker) → + streaming model output. + 3. ``partial=False`` events with a ``function_response`` part → the final + tool result; with a ``text`` part → the model's final reply. + """ + + app_name = "streaming_progress_demo" + from agent.agent import root_agent + + session_service = InMemorySessionService() + runner = Runner(app_name=app_name, agent=root_agent, session_service=session_service) + + user_id = "demo_user" + session_id = str(uuid.uuid4()) + await session_service.create_session(app_name=app_name, user_id=user_id, session_id=session_id) + + query = "Please crawl https://example.com and fetch the first 5 pages." + print("=" * 60) + print(f"User: {query}") + print("=" * 60) + + user_content = Content(parts=[Part.from_text(text=query)]) + + async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=user_content): + meta = event.custom_metadata or {} + + # --- 1. Tool progress (partial, comes from StreamingProgressTool) --- + if event.partial and meta.get("tool_progress"): + payload = meta.get("payload") + tool_name = meta.get("tool_name", "?") + print(f"[{tool_name}] ⏳ {payload if payload is not None else event.get_text()}") + continue + + if not event.content or not event.content.parts: + continue + + # --- 2. Streaming LLM text --- + if event.partial: + for part in event.content.parts: + if part.text: + print(part.text, end="", flush=True) + continue + + # --- 3. Final events --- + for part in event.content.parts: + if part.function_call: + print(f"\n[tool-call] {part.function_call.name}({part.function_call.args})") + elif part.function_response: + print(f"\n[tool-result] {part.function_response.name} → " + f"{part.function_response.response}") + elif part.text: + print(f"\nAssistant: {part.text}") + + print("\n" + "-" * 60) + + +if __name__ == "__main__": + print(""" ++--------------------------------------------------------------+ +| StreamingProgressTool Demo (long-running tool) | +| | +| Watch the tool yield progress events live, then the LLM | +| summarises the final result. | ++--------------------------------------------------------------+ +""") + asyncio.run(run_streaming_progress_demo()) diff --git a/examples/llmagent_with_streaming_progress_tool/verify.py b/examples/llmagent_with_streaming_progress_tool/verify.py new file mode 100644 index 00000000..9585a45d --- /dev/null +++ b/examples/llmagent_with_streaming_progress_tool/verify.py @@ -0,0 +1,307 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""End-to-end verification of StreamingProgressTool(skip_summarization=True). + +What this script proves, in one run: + +1. **Live streaming reaches the caller** + The wrapped async generator's yields show up as partial events with + ``custom_metadata.tool_progress=True`` *while* the tool is still running. + +2. **The LLM is not asked to re-summarise the streamed output** + Because the tool is constructed with ``skip_summarization=True``, the + final ``function_response`` event has ``actions.skip_summarization=True``. + :class:`LlmAgent` exits the conversation loop immediately after the + tool returns – the caller will NOT see any "Assistant: ..." text after + the tool result. + +3. **The session keeps the final tool result, not the partials** + Partial progress events have ``partial=True`` so session services + skip them. The final ``function_response`` event is non-partial and + IS persisted. We assert this against the in-memory session at the + end of turn 1. + +4. **Next turn can use the persisted data** + In turn 2 we ask "Summarise what you fetched earlier". The LLM has + access to the full tool response from turn 1 via session history and + answers based on it. The accumulated ``titles`` list we put in the + final yield is the source of truth. + +Run:: + + cd examples/llmagent_with_streaming_progress_tool + # Ensure .env defines TRPC_AGENT_API_KEY / BASE_URL / MODEL_NAME + python verify.py +""" + +from __future__ import annotations + +import asyncio +import sys +import uuid +from pathlib import Path + +from dotenv import load_dotenv +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + +sys.path.append(str(Path(__file__).parent)) + +APP_NAME = "streaming_progress_verify" +USER_ID = "verify_user" + + +def _section(title: str) -> None: + line = "=" * 70 + print(f"\n{line}\n {title}\n{line}") + + +def _bullet(ok: bool, msg: str) -> None: + mark = "[PASS]" if ok else "[FAIL]" + print(f" {mark} {msg}") + + +async def _run_turn( + runner: Runner, + session_id: str, + query: str, + label: str, +) -> tuple[list[dict], list, str, str]: + """Run one user turn. + + Returns a 4-tuple ``(live_progress, all_events, post_tool_text, all_llm_text)``: + + - ``live_progress`` : every progress chunk that arrived live, proving + the streaming pipe fired. + - ``all_events`` : everything yielded by the runner, in order. + - ``post_tool_text``: LLM-authored text that arrived *after* the + function_response event. We use this to detect whether the LLM + tried to "re-summarise" the tool output. + - ``all_llm_text`` : every LLM-authored text chunk across the whole + turn (used when the turn has no tool call at all). + """ + _section(label) + print(f"User: {query}") + + live_progress: list[dict] = [] + all_events: list = [] + tool_done = False + post_tool_text_parts: list[str] = [] + all_llm_text_parts: list[str] = [] + + user_content = Content(parts=[Part.from_text(text=query)]) + async for event in runner.run_async(user_id=USER_ID, session_id=session_id, new_message=user_content): + all_events.append(event) + meta = event.custom_metadata or {} + + # (1) Live progress + if event.partial and meta.get("tool_progress"): + payload = meta.get("payload") + live_progress.append(payload if isinstance(payload, dict) else {"raw": payload}) + print(f" [live] {meta.get('tool_name')} -> {payload}") + continue + + if not event.content or not event.content.parts: + continue + + # Streaming LLM text (partial assistant message) + if event.partial: + for part in event.content.parts: + if part.text: + all_llm_text_parts.append(part.text) + if tool_done: + post_tool_text_parts.append(part.text) + print(part.text, end="", flush=True) + continue + + # Final non-partial events + for part in event.content.parts: + if part.function_call: + print(f"\n [tool-call] {part.function_call.name}({part.function_call.args})") + elif part.function_response: + tool_done = True + print(f"\n [tool-result] {part.function_response.name} -> " + f"keys={list((part.function_response.response or {}).keys())}") + elif part.text: + all_llm_text_parts.append(part.text) + if tool_done: + post_tool_text_parts.append(part.text) + print(f"\n Assistant: {part.text}") + + print() # newline after streaming + return ( + live_progress, + all_events, + "".join(post_tool_text_parts).strip(), + "".join(all_llm_text_parts).strip(), + ) + + +def _verify_turn1( + live_progress: list[dict], + all_events: list, + post_tool_text: str, + persisted_events: list, +) -> bool: + """Run all assertions for turn 1 and report PASS/FAIL per check. + + ``persisted_events`` is the list of events as stored in session – the + caller MUST refetch the session via ``session_service.get_session`` + after the turn finishes (the InMemorySessionService returns a deep + copy on creation, so any reference we held from ``create_session`` + is frozen at zero events). + """ + _section("Turn-1 verification") + + ok = True + + has_progress = len(live_progress) > 0 + _bullet(has_progress, f"Streaming pipe fired: {len(live_progress)} live progress event(s) received.") + ok &= has_progress + + no_followup = post_tool_text == "" + _bullet( + no_followup, + "skip_summarization stopped the LLM from re-summarising " + f"(post-tool LLM text length = {len(post_tool_text)} chars).", + ) + if not no_followup: + print(f" Unexpected follow-up text: {post_tool_text!r}") + ok &= no_followup + + persisted_partials = [e for e in persisted_events if e.partial] + function_response_events = [ + e for e in persisted_events if e.content and any(p.function_response for p in e.content.parts) + ] + + no_partials_persisted = len(persisted_partials) == 0 + _bullet( + no_partials_persisted, + f"No partial progress events leaked into session storage " + f"({len(persisted_partials)} found, expected 0).", + ) + ok &= no_partials_persisted + + has_function_response = len(function_response_events) == 1 + _bullet( + has_function_response, + f"Exactly one final function_response event persisted " + f"({len(function_response_events)} found).", + ) + ok &= has_function_response + + if function_response_events: + fr_part = next(p for p in function_response_events[0].content.parts if p.function_response) + response = fr_part.function_response.response or {} + has_titles = isinstance(response.get("titles"), list) and len(response["titles"]) > 0 + _bullet( + has_titles, + f"Final tool response contains the accumulated titles list " + f"(len={len(response.get('titles', []))}).", + ) + ok &= has_titles + print(f" Persisted response payload: {response}") + + fe = function_response_events[0] + has_skip = bool(fe.actions and fe.actions.skip_summarization) + _bullet( + has_skip, + "Final tool event carries actions.skip_summarization=True.", + ) + ok &= has_skip + + print() + print(f" Total events captured from runner: {len(all_events)}") + print(f" Total events persisted in session: {len(persisted_events)}") + return ok + + +def _verify_turn2(all_llm_text: str, persisted_events: list) -> bool: + """Turn 2 asks the LLM to summarise; verify it had context to do so.""" + _section("Turn-2 verification") + ok = True + + has_assistant_reply = len(all_llm_text) > 0 + _bullet( + has_assistant_reply, + f"LLM produced an answer in turn 2 ({len(all_llm_text)} chars).", + ) + ok &= has_assistant_reply + + references_data = "page" in all_llm_text.lower() or "example.com" in all_llm_text.lower() + _bullet( + references_data, + "Turn-2 answer references turn-1 tool data " + "(mentions 'page' or 'example.com').", + ) + ok &= references_data + + # Bonus visibility: count how many events the session now holds (user, + # tool_call, tool_response from turn 1 plus user + assistant from turn 2). + print(f" Total events now persisted in session: {len(persisted_events)}") + return ok + + +async def _refetch_events(session_service: InMemorySessionService, session_id: str) -> list: + """Pull the *live* event list from the session service. + + ``InMemorySessionService.create_session`` returns a ``copy.deepcopy`` + of the stored session, so the Session object the caller holds at + creation time is frozen at zero events. Subsequent appends happen on + a different object kept inside the service. We refetch through the + public ``get_session`` API (which also returns a fresh deepcopy that + DOES reflect the latest events). + """ + s = await session_service.get_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=session_id, + ) + return list(s.events) if s else [] + + +async def main() -> None: + from agent.agent import root_agent + + session_service = InMemorySessionService() + runner = Runner(app_name=APP_NAME, agent=root_agent, session_service=session_service) + + session_id = str(uuid.uuid4()) + await session_service.create_session(app_name=APP_NAME, user_id=USER_ID, session_id=session_id) + + # ----- Turn 1: streaming tool runs ----- + live, evs, post_tool_text, _all_llm_text1 = await _run_turn( + runner, + session_id, + "Please crawl https://example.com and fetch the first 5 pages.", + "Turn 1 -- streaming crawl", + ) + persisted_after_turn1 = await _refetch_events(session_service, session_id) + turn1_ok = _verify_turn1(live, evs, post_tool_text, persisted_after_turn1) + + # ----- Turn 2: ask the LLM to summarise; it must use turn-1 data ----- + _live2, _evs2, _post_tool_text2, all_llm_text2 = await _run_turn( + runner, + session_id, + "Based ONLY on the previous crawl results, list every page title you fetched.", + "Turn 2 -- LLM reads persisted tool data", + ) + persisted_after_turn2 = await _refetch_events(session_service, session_id) + turn2_ok = _verify_turn2(all_llm_text2, persisted_after_turn2) + + _section("Result") + print(f" Turn 1: {'PASS' if turn1_ok else 'FAIL'}") + print(f" Turn 2: {'PASS' if turn2_ok else 'FAIL'}") + if not (turn1_ok and turn2_ok): + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mempalace_mcp/.env b/examples/mempalace_mcp/.env new file mode 100644 index 00000000..4705bd4a --- /dev/null +++ b/examples/mempalace_mcp/.env @@ -0,0 +1,8 @@ +# Set TRPC_AGENT_API_KEY、TRPC_AGENT_BASE_URL、TRPC_AGENT_MODEL_NAME +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name + +# Optional: point MemPalace MCP server at a custom palace directory. +# If unset, MemPalace falls back to ~/.mempalace/palace. +# MEMPALACE_PALACE_PATH=/absolute/path/to/palace diff --git a/examples/mempalace_mcp/README.md b/examples/mempalace_mcp/README.md new file mode 100644 index 00000000..56654185 --- /dev/null +++ b/examples/mempalace_mcp/README.md @@ -0,0 +1,277 @@ +# MemPalace MCP 接入示例 + +本示例演示如何让 **tRPC-Agent** 通过 **MCP (Model Context Protocol)** 直接调用 +[MemPalace](https://github.com/MemPalace/mempalace) 官方提供的 MCP 服务端,使用 +其 30 个原生工具能力(palace 读写、知识图谱、agent diary、跨 wing 导航等)。 + +跟 [`examples/memory_service_with_mempalace`](../memory_service_with_mempalace) +的区别: + +| 接入方式 | 谁触发记忆操作 | 适合场景 | +| --- | --- | --- | +| `MempalaceMemoryService`(隐式) | Runner 自动 store / 模型通过 `load_memory_tool` 检索 | 跨 session 长期记忆,对话自动归档 | +| **本示例:MCP toolset(显式)** | 模型按需调用 MemPalace MCP 工具 | 需要细粒度控制:搜索、归档、知识图谱、agent 日记 | + +两种方式可以叠加使用,互不冲突。 + +--- + +## 工作原理 + +``` +LlmAgent --stdio--> mempalace mcp (子进程, MCP server) + └─ ChromaDB (本地 palace) + └─ SQLite (knowledge graph) +``` + +- `MempalaceMCPToolset` 继承自 `trpc_agent_sdk.tools.MCPToolset`,用 + `StdioConnectionParams` 在启动时把 `mempalace mcp` 作为子进程拉起。 +- 服务端通过 stdio 暴露 ~30 个 MCP 工具,trpc-agent 自动转换为 `LlmAgent` 的工具 + 声明,模型即可用函数调用语法触发。 +- 数据落在本地 MemPalace(默认 `~/.mempalace/palace`),全程零云端调用。 + +工具详细参数参考 MemPalace 官方文档:[MCP Tools Reference](https://mempalaceofficial.com/reference/mcp-tools)。 + +--- + +## 准备工作 + +### 1. 安装依赖 + +在仓库根目录: + +```bash +pip install -e ".[mempalace]" +``` + +`mempalace` 包会带上 `mempalace` CLI 命令到当前 Python 环境的 PATH。 + +### 2. 初始化 palace(首次使用) + +```bash +mempalace init +``` + +如需自定义存储路径,可以指定一个目录: + +```bash +export MEMPALACE_PALACE_PATH=/absolute/path/to/palace +mempalace --palace "$MEMPALACE_PALACE_PATH" init +``` + +### 3. 配置模型 key + +复制并填写 `.env`: + +```env +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name +# MEMPALACE_PALACE_PATH=/absolute/path/to/palace # 可选 +``` + +--- + +## 启动 MemPalace MCP Server + +> ⚠️ **重要**:`mempalace mcp`(带空格)**不是** MCP server,它只是打印设置帮助。 +> 真正的 server 入口是 `mempalace-mcp`(带连字符)或 `python -m mempalace.mcp_server`。 + +MemPalace MCP server 有 **3 种启动方式**,本示例使用第 1 种,**完全无需手动操作**: + +### 方式 1:自动 stdio 子进程(本示例采用,推荐) + +`MempalaceMCPToolset` 在 `LlmAgent` 启动时**自动**把 server 作为子进程拉起,通过 +stdin/stdout 与之通信;`Runner` 关闭时子进程也跟着退出。**你什么都不用做,跑 `python3 run_agent.py` 即可。** + +[`agent/tools.py`](agent/tools.py) 优先用模块直跑,避免依赖 CLI shim 的命名: + +```python +# 等价于在 shell 里执行: +# python -m mempalace.mcp_server [--palace /path/to/palace] +McpStdioServerParameters( + command=sys.executable, + args=["-m", "mempalace.mcp_server", *(["--palace", palace_path] if palace_path else [])], + env=env, +) +``` + +如果你环境里只有 `mempalace-mcp` 这个二进制(没有装 Python 模块),上面的代码会自动回退到: + +```python +McpStdioServerParameters(command="mempalace-mcp", args=[...], env=env) +``` + +### 方式 2:手动启动 stdio server(用于调试) + +要确认 MemPalace MCP server 本身可用,先在终端单独跑一下: + +```bash +# 推荐写法:直接跑模块 +python -m mempalace.mcp_server + +# 自定义 palace 路径: +python -m mempalace.mcp_server --palace /absolute/path/to/palace + +# 如果你的环境装了 CLI shim 也可以: +mempalace-mcp +mempalace-mcp --palace /absolute/path/to/palace +``` + +server 启动后会在 stdout 上**安静地等待** JSON-RPC 消息——看不到任何 banner 才是对的, +stdio 协议要求 stdout 纯净,否则 MCP 客户端会无法解析。 +用 `Ctrl+C` 结束即可。 + +**如何区分**: + +| 命令 | 行为 | +|---|---| +| `mempalace mcp`(带空格) | ❌ 只打印帮助文本,**不是** server | +| `mempalace-mcp`(带连字符) | ✅ 真正启动 stdio server | +| `python -m mempalace.mcp_server` | ✅ 真正启动 stdio server(最稳) | + +### 方式 3:作为常驻 HTTP server(多 agent 共享同一 palace) + +如果你希望多个 agent 共享同一个 MemPalace,可以让 MCP server 跑成 HTTP 服务(具体 +CLI 选项请参考 MemPalace 官方文档当前版本:[mempalace mcp](https://mempalaceofficial.com/reference/cli))。然后把 +`MempalaceMCPToolset` 改为使用 `StreamableHTTPConnectionParams` 连接已存在的 server: + +```python +from trpc_agent_sdk.tools import StreamableHTTPConnectionParams + +self._connection_params = StreamableHTTPConnectionParams( + url="http://localhost:8000/mcp", + timeout=5, + sse_read_timeout=60 * 5, + terminate_on_close=False, # 不关闭外部 server +) +``` + +参考 [`examples/mcp_tools/agent/tools.py`](../mcp_tools/agent/tools.py) 里 +`SseMCPToolset` / `StreamableHttpMCPToolset` 的写法。 + +--- + +## 运行示例 + +```bash +cd examples/mempalace_mcp +python3 run_agent.py +``` + +示例会跑 7 轮独立 session,逐步触发以下 MCP 工具: + +| 轮次 | 用户提问 | 触发的 MCP 工具(典型) | +| --- | --- | --- | +| 1 | 让 agent 给出 palace 总览 | `mempalace_status` | +| 2 | 让 agent 记住一条偏好 | `mempalace_add_drawer` | +| 3 | 问 agent 自己的工作习惯 | `mempalace_search` | +| 4 | 写入一条三元组事实 | `mempalace_kg_add` | +| 5 | 查询 Alice 的相关关系 | `mempalace_kg_query` | +| 6 | 让 agent 写日记 | `mempalace_diary_write` | +| 7 | 读回最近的日记 | `mempalace_diary_read` | + +> 工具的实际调用顺序由模型决定,提示语只是引导。 + +--- + +## 关键代码 + +`agent/tools.py` —— 把 MemPalace MCP server 包装成 trpc-agent 的 `MCPToolset`: + +```python +class MempalaceMCPToolset(MCPToolset): + def __init__(self, palace_path=None, tool_filter=_DEFAULT_TOOL_FILTER): + super().__init__() + env = os.environ.copy() + if palace_path: + env["MEMPALACE_PALACE_PATH"] = palace_path + self._connection_params = StdioConnectionParams( + server_params=McpStdioServerParameters( + command="mempalace", + args=["mcp"], + env=env, + ), + timeout=10, + ) + if tool_filter is not None: + self._tool_filter = tool_filter +``` + +`agent/agent.py` —— 把 toolset 挂到 `LlmAgent`: + +```python +def create_agent() -> LlmAgent: + palace_path = os.getenv("MEMPALACE_PALACE_PATH") or None + return LlmAgent( + name="mempalace_assistant", + model=_create_model(), + instruction=INSTRUCTION, + tools=[MempalaceMCPToolset(palace_path=palace_path)], + ) +``` + +--- + +## 自定义 + +### 暴露全部 30 个工具 + +默认只暴露 9 个高频工具以节省模型上下文。要解锁全部: + +```python +MempalaceMCPToolset(tool_filter=None) +``` + +### 只暴露知识图谱相关工具 + +```python +MempalaceMCPToolset( + tool_filter=[ + "mempalace_kg_add", + "mempalace_kg_query", + "mempalace_kg_invalidate", + "mempalace_kg_timeline", + "mempalace_kg_stats", + ], +) +``` + +### 改用其他传输方式 + +如果你希望 MCP server 作为独立 HTTP 服务运行(而不是子进程),可以参考 +[`examples/mcp_tools/agent/tools.py`](../mcp_tools/agent/tools.py) 里的 +`SseConnectionParams` / `StreamableHTTPConnectionParams` 模式自行替换。 + +--- + +## 故障排查 + +- **`mempalace: command not found`**:未安装或装到了不同 Python 环境。 + 解决:`pip install -e ".[mempalace]"`,或确保运行 `python3 run_agent.py` 用的是 + 同一个解释器。 + +- **想先确认 MCP server 本身能起来**:在终端单独跑一下 + ```bash + python -m mempalace.mcp_server + ``` + 正常情况下进程会**挂起且不输出任何内容**(stdio 协议要求 stdout 纯净), + `Ctrl+C` 退出即可。如果立刻报错或退出,说明 MemPalace 自身环境有问题(缺少 + 模型文件、palace 未初始化等)。 + 注意:**不要**用 `mempalace mcp`(带空格)做预检,那个命令只是打印帮助文本。 + +- **MCP 启动超时**:palace 第一次初始化、加载 embedding 模型会较慢。 + 解决:先在终端跑一次 `mempalace status` 让 embedding 模型预热,再启动 demo;或在 + `agent/tools.py` 里把 `StdioConnectionParams(timeout=10)` 调大。 + +- **工具被模型忽略 / 不调用**:模型可能更倾向于直接回答。 + 解决:在 `.env` 切换到更强的模型,或者在 prompt 里加更明确的工具触发暗示。 + +--- + +## 相关链接 + +- MemPalace 官方文档:[mempalaceofficial.com](https://mempalaceofficial.com/) +- MCP 工具完整列表:[MCP Tools Reference](https://mempalaceofficial.com/reference/mcp-tools) +- MemPalace 集成介绍(含隐式 Memory Service 路径):[`docs/mkdocs/zh/mempalace.md`](../../docs/mkdocs/zh/mempalace.md) +- tRPC-Agent 通用 MCP 示例:[`examples/mcp_tools/`](../mcp_tools/) diff --git a/examples/mempalace_mcp/agent/__init__.py b/examples/mempalace_mcp/agent/__init__.py new file mode 100644 index 00000000..bc6e483f --- /dev/null +++ b/examples/mempalace_mcp/agent/__init__.py @@ -0,0 +1,5 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. diff --git a/examples/mempalace_mcp/agent/agent.py b/examples/mempalace_mcp/agent/agent.py new file mode 100644 index 00000000..a13633e5 --- /dev/null +++ b/examples/mempalace_mcp/agent/agent.py @@ -0,0 +1,39 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent module: LlmAgent wired to the MemPalace MCP toolset.""" + +import os + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.models import LLMModel +from trpc_agent_sdk.models import OpenAIModel + +from .config import get_model_config +from .prompts import INSTRUCTION +from .tools import MempalaceMCPToolset + + +def _create_model() -> LLMModel: + """Create a model from .env config.""" + api_key, url, model_name = get_model_config() + return OpenAIModel(model_name=model_name, api_key=api_key, base_url=url) + + +def create_agent() -> LlmAgent: + """Create an LlmAgent backed by the MemPalace MCP server.""" + palace_path = os.getenv("MEMPALACE_PALACE_PATH") or None + mempalace_toolset = MempalaceMCPToolset(palace_path=palace_path) + + return LlmAgent( + name="mempalace_assistant", + description="A personal memory assistant powered by the MemPalace MCP server.", + model=_create_model(), + instruction=INSTRUCTION, + tools=[mempalace_toolset], + ) + + +root_agent = create_agent() diff --git a/examples/mempalace_mcp/agent/config.py b/examples/mempalace_mcp/agent/config.py new file mode 100644 index 00000000..d4006350 --- /dev/null +++ b/examples/mempalace_mcp/agent/config.py @@ -0,0 +1,19 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent config module.""" + +import os + + +def get_model_config() -> tuple[str, str, str]: + """Get model config from environment variables.""" + api_key = os.getenv("TRPC_AGENT_API_KEY", "") + url = os.getenv("TRPC_AGENT_BASE_URL", "") + model_name = os.getenv("TRPC_AGENT_MODEL_NAME", "") + if not api_key or not url or not model_name: + raise ValueError("TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, " + "and TRPC_AGENT_MODEL_NAME must be set in environment variables") + return api_key, url, model_name diff --git a/examples/mempalace_mcp/agent/prompts.py b/examples/mempalace_mcp/agent/prompts.py new file mode 100644 index 00000000..c2661959 --- /dev/null +++ b/examples/mempalace_mcp/agent/prompts.py @@ -0,0 +1,25 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompts for the MemPalace MCP demo agent.""" + +INSTRUCTION = """ +You are a personal memory assistant powered by the MemPalace MCP server. + +MemPalace organizes memory as: + Palace -> Wing (person/project) -> Room (topic) -> Drawer (verbatim content). + +Whenever the user asks you to remember, file, or store something, prefer: + - `mempalace_add_drawer` for verbatim facts (with explicit `wing` and `room`). + - `mempalace_kg_add` for relational facts shaped as (subject, predicate, object). + +Whenever the user asks you to recall, retrieve, or check what you know, prefer: + - `mempalace_search` for free-form recall (scope with wing/room if obvious). + - `mempalace_kg_query` when the user asks about an entity's relationships. + - `mempalace_status` / `mempalace_list_wings` for overview questions. + +If a tool returns empty results, say so clearly. Do not invent memories. +Keep replies short and cite which MemPalace tool you used. +""".strip() diff --git a/examples/mempalace_mcp/agent/tools.py b/examples/mempalace_mcp/agent/tools.py new file mode 100644 index 00000000..c56f05c8 --- /dev/null +++ b/examples/mempalace_mcp/agent/tools.py @@ -0,0 +1,93 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""MCP toolset that connects to the official MemPalace MCP server over stdio. + +MemPalace ships an MCP server with ~30 tools (palace read/write, knowledge +graph, agent diary, navigation). The server is launched on demand by the +`MCPToolset` as a child process. + +Why `python -m mempalace.mcp_server` is the default +--------------------------------------------------- +In current MemPalace releases the CLI entry point that actually runs the +server is `mempalace-mcp` (with a hyphen). `mempalace mcp` (with a space) only +prints setup help — it does NOT start the server. To avoid relying on a +specific CLI name we launch the module directly with the same Python +interpreter, which works for every recent MemPalace version. +""" + +import os +import shutil +import sys + +from trpc_agent_sdk.tools import MCPToolset +from trpc_agent_sdk.tools import McpStdioServerParameters +from trpc_agent_sdk.tools import StdioConnectionParams + + +# A small, curated default. Set to `None` to expose every tool the MemPalace +# server advertises. Trim the list to reduce model token usage and to keep the +# demo focused on a few representative tools. +_DEFAULT_TOOL_FILTER = [ + "mempalace_status", + "mempalace_list_wings", + "mempalace_search", + "mempalace_add_drawer", + "mempalace_kg_add", + "mempalace_kg_query", + "mempalace_kg_timeline", + "mempalace_diary_write", + "mempalace_diary_read", +] + + +def _resolve_server_command(palace_path: str | None) -> tuple[str, list[str]]: + """Pick the best available command to launch the MemPalace MCP server. + + Priority: + 1. `python -m mempalace.mcp_server [--palace PATH]` (always works if the + `mempalace` package is importable from the current interpreter). + 2. `mempalace-mcp` CLI shim (fallback if it is on PATH). + """ + extra_args = ["--palace", palace_path] if palace_path else [] + + # Sanity check that the `mempalace` package is importable from the current + # interpreter. If not, we still try the CLI shim but warn early. + try: + import importlib.util # noqa: WPS433 + if importlib.util.find_spec("mempalace") is not None: + return sys.executable, ["-m", "mempalace.mcp_server", *extra_args] + except Exception: # pragma: no cover - defensive + pass + + if shutil.which("mempalace-mcp"): + return "mempalace-mcp", extra_args + + raise RuntimeError("Cannot find the MemPalace MCP server. Install MemPalace into the same Python environment " + "(`pip install -e \".[mempalace]\"`) so that `python -m mempalace.mcp_server` works.") + + +class MempalaceMCPToolset(MCPToolset): + """Stdio-based MCP toolset bound to the MemPalace MCP server.""" + + def __init__(self, palace_path: str | None = None, tool_filter: list[str] | None = _DEFAULT_TOOL_FILTER) -> None: + super().__init__() + + env = os.environ.copy() + if palace_path: + env["MEMPALACE_PALACE_PATH"] = palace_path + + command, args = _resolve_server_command(palace_path) + stdio_server_params = McpStdioServerParameters( + command=command, + args=args, + env=env, + ) + self._connection_params = StdioConnectionParams( + server_params=stdio_server_params, + timeout=30, # palace warm-up + embedding model load can take a few seconds + ) + if tool_filter is not None: + self._tool_filter = tool_filter diff --git a/examples/mempalace_mcp/run_agent.py b/examples/mempalace_mcp/run_agent.py new file mode 100644 index 00000000..320fcdca --- /dev/null +++ b/examples/mempalace_mcp/run_agent.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Run the MemPalace MCP demo. + +This example talks to the official MemPalace MCP server (`mempalace mcp`) +through tRPC-Agent's stdio MCP toolset. The demo exercises a handful of +representative tools: status, drawer write, semantic search, KG add/query, and +agent diary. +""" + +import asyncio +import sys +import uuid +from pathlib import Path + +from dotenv import load_dotenv +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + +sys.path.append(str(Path(__file__).parent)) + + +def _truncate(text: str, max_length: int = 320) -> str: + text = str(text) + return text if len(text) <= max_length else text[:max_length] + "..." + + +async def run_mempalace_mcp_agent() -> None: + """Run the MemPalace MCP demo.""" + + from agent.agent import root_agent + + app_name = "mempalace_mcp_demo" + session_service = InMemorySessionService() + runner = Runner(app_name=app_name, agent=root_agent, session_service=session_service) + + user_id = "demo_user" + + # Each query is run in its own session so we can also confirm cross-session + # recall works through the MemPalace MCP server. + demo_queries = [ + # Overview + "Give me a one-line overview of my MemPalace.", + # Write a verbatim drawer + "Please remember this for me: I prefer working in the morning. " + "Store it under wing 'preferences', room 'work_habits'.", + # Search the drawer back + "What do you know about my work habits?", + # Knowledge graph: add a fact + "Record a fact: Alice prefers blue.", + # Knowledge graph: query the fact + "What facts do you know about Alice?", + # Agent diary: write + "Write a diary entry as agent 'mempalace_assistant': " + "Today I helped the user file two memories about their preferences.", + # Agent diary: read + "Read back the latest diary entries for agent 'mempalace_assistant'.", + ] + + try: + for query in demo_queries: + session_id = str(uuid.uuid4()) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + print(f"\n🆔 Session: {session_id[:8]}") + print(f"📝 User: {query}") + print("🤖 Assistant: ", end="", flush=True) + + user_content = Content(parts=[Part.from_text(text=query)]) + async for event in runner.run_async(user_id=user_id, session_id=session_id, new_message=user_content): + if not event.content or not event.content.parts: + continue + + if event.partial: + for part in event.content.parts: + if part.text: + print(part.text, end="", flush=True) + continue + + for part in event.content.parts: + if part.thought: + continue + if part.function_call: + print(f"\n🔧 [MCP tool: {part.function_call.name}({part.function_call.args})]") + elif part.function_response: + print(f"📊 [Result: {_truncate(part.function_response.response)}]") + + print("\n" + "-" * 60) + finally: + await runner.close() + + +if __name__ == "__main__": + asyncio.run(run_mempalace_mcp_agent()) diff --git a/tests/agents/core/test_tools_processor.py b/tests/agents/core/test_tools_processor.py index 5ebe5122..8dd9255f 100644 --- a/tests/agents/core/test_tools_processor.py +++ b/tests/agents/core/test_tools_processor.py @@ -28,7 +28,7 @@ def _compat_get_skill_processor_parameters(agent_context): from trpc_agent_sdk.events import Event, EventActions from trpc_agent_sdk.models import LLMModel, LlmRequest, LlmResponse, ModelRegistry from trpc_agent_sdk.sessions import InMemorySessionService -from trpc_agent_sdk.tools import BaseTool, FunctionTool +from trpc_agent_sdk.tools import BaseTool, FunctionTool, StreamingProgressTool from trpc_agent_sdk.types import Content, FunctionCall, Part @@ -274,6 +274,157 @@ def test_error_event_without_tool_info(self, invocation_context): # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# execute_tools_async - progress-streaming tool path +# --------------------------------------------------------------------------- + + +async def _crawl_stream(url: str): + """Async generator that yields N+1 progress events.""" + yield {"status": "started", "url": url} + yield {"status": "step", "i": 1} + yield {"status": "step", "i": 2} + yield {"status": "done", "url": url, "steps": 2} + + +class TestExecuteToolsStreamingProgress: + """execute_tools_async should surface partial progress events plus a final + function_response event for tools that set is_progress_streaming=True.""" + + def test_streaming_tool_yields_partials_then_final(self, invocation_context): + tool = StreamingProgressTool(_crawl_stream) + proc = ToolsProcessor([tool]) + fc = FunctionCall(id="call-1", name="_crawl_stream", args={"url": "https://x"}) + + async def run(): + events = [] + async for event in proc.execute_tools_async([fc], invocation_context): + events.append(event) + return events + + events = asyncio.run(run()) + + # 4 yields → 3 partials (the last yield is reserved for the final event) + # + 1 final function_response event = 4 events total. + assert len(events) == 4, f"expected 4 events, got {len(events)}" + + partials = events[:-1] + final = events[-1] + + for ev in partials: + assert ev.partial is True + meta = ev.custom_metadata or {} + assert meta.get("tool_progress") is True + assert meta.get("tool_name") == "_crawl_stream" + assert meta.get("tool_call_id") == "call-1" + assert ev.content is not None + # Text part contains JSON serialization of the payload. + assert ev.content.parts[0].text + + # First partial carries the 'started' payload, second carries i=1, ... + first_payload = (partials[0].custom_metadata or {}).get("payload") + assert first_payload == {"status": "started", "url": "https://x"} + + # Final event is a non-partial function_response carrying the LAST yield. + assert final.partial is not True + assert final.content is not None + fr = final.content.parts[0].function_response + assert fr is not None + assert fr.name == "_crawl_stream" + assert fr.id == "call-1" + assert fr.response == {"status": "done", "url": "https://x", "steps": 2} + + def test_streaming_tool_error_yields_error_event(self, invocation_context): + async def boom(query: str): + yield {"status": "started"} + raise RuntimeError("kaboom") + yield {"unreachable": True} # pragma: no cover # noqa: E501 + + tool = StreamingProgressTool(boom) + proc = ToolsProcessor([tool]) + fc = FunctionCall(id="call-err", name="boom", args={"query": "x"}) + + async def run(): + events = [] + async for event in proc.execute_tools_async([fc], invocation_context): + events.append(event) + return events + + events = asyncio.run(run()) + # We expect: the 'started' partial WAS NOT yielded yet (it's still + # the buffered value when boom() raises), but a tool_execution_error + # event SHOULD be produced. + assert any(ev.error_code == "tool_execution_error" for ev in events) + + def test_streaming_tool_runs_outside_parallel_batch(self, invocation_context): + # When the agent has parallel_tool_calls=True and the batch mixes a + # streaming and a non-streaming tool, the legacy parallel path must + # process the non-streaming tool exactly as before; the streaming + # tool is handled by a separate, sequential phase that surfaces + # partial progress events as well as a final function_response. + streaming = StreamingProgressTool(_crawl_stream) + non_streaming = FunctionTool(sample_tool) + proc = ToolsProcessor([streaming, non_streaming]) + + # _StubAgent is a pydantic model that doesn't declare + # parallel_tool_calls; bypass validation to flip the runtime flag. + object.__setattr__(invocation_context.agent, "parallel_tool_calls", True) + try: + fc_stream = FunctionCall(id="c-stream", name="_crawl_stream", args={"url": "https://a"}) + fc_normal = FunctionCall(id="c-normal", name="sample_tool", args={"name": "a", "value": "b"}) + + async def run(): + events = [] + async for event in proc.execute_tools_async([fc_stream, fc_normal], invocation_context): + events.append(event) + return events + + events = asyncio.run(run()) + + # The non-streaming call is the only thing in the parallel batch, + # so it surfaces as a single function_response event without + # interleaving partials. + non_streaming_finals = [ + ev for ev in events if ev.partial is not True and ev.content and any( + p.function_response and p.function_response.id == "c-normal" for p in ev.content.parts) + ] + assert len(non_streaming_finals) == 1 + + # The streaming call yields partials AND its own final event. + stream_partials = [ + ev for ev in events + if ev.partial and (ev.custom_metadata or {}).get("tool_call_id") == "c-stream" + ] + stream_finals = [ + ev for ev in events if ev.partial is not True and ev.content and any( + p.function_response and p.function_response.id == "c-stream" for p in ev.content.parts) + ] + assert stream_partials, "expected at least one partial progress event" + assert len(stream_finals) == 1 + finally: + object.__setattr__(invocation_context.agent, "parallel_tool_calls", False) + + def test_streaming_tool_not_found_yields_error_event(self, invocation_context): + # When the LLM names a streaming tool that doesn't exist, the + # streaming phase must still surface a tool_not_found error event + # rather than silently dropping the call. We can't easily fake the + # "name is a streaming tool but no resolution" case, so we just + # verify the standard not-found error still works when no streaming + # tools are registered. + proc = ToolsProcessor([]) + fc = FunctionCall(id="missing", name="ghost_streaming_tool", args={}) + + async def run(): + events = [] + async for ev in proc.execute_tools_async([fc], invocation_context): + events.append(ev) + return events + + events = asyncio.run(run()) + assert len(events) == 1 + assert events[0].error_code == "tool_not_found" + + class TestUpdateStreamingToolNames: def test_no_streaming_tools(self): proc = ToolsProcessor([]) diff --git a/tests/sessions/test_session_summarizer.py b/tests/sessions/test_session_summarizer.py index e8fc460d..97462cad 100644 --- a/tests/sessions/test_session_summarizer.py +++ b/tests/sessions/test_session_summarizer.py @@ -183,16 +183,23 @@ def test_basic_extraction(self): assert "user: What is AI?" in text assert "agent: AI is artificial intelligence." in text - def test_skip_summarization_events(self): + def test_skip_summarization_events_are_still_included(self): + # ``skip_summarization=True`` means "the agent loop should not call + # the LLM again to summarize this tool response" (a control-flow + # concern). It must NOT cause the *session* summarizer to drop the + # event from the summary input, because these events usually carry + # the actual user-visible final answer (e.g. AgentTool / + # StreamingProgressTool outputs). Dropping them would strip the + # most informative content from the resulting session summary. model = _make_model_mock() summarizer = SessionSummarizer(model=model) events = [ _make_event(author="user", text="Question"), - _make_event(author="agent", text="Skipped", skip_summarization=True), + _make_event(author="agent", text="FinalAnswerFromSubAgent", skip_summarization=True), _make_event(author="agent", text="Included"), ] text = summarizer._extract_conversation_text(events) - assert "Skipped" not in text + assert "FinalAnswerFromSubAgent" in text assert "Included" in text def test_empty_events(self): diff --git a/tests/tools/test_streaming_progress_tool.py b/tests/tools/test_streaming_progress_tool.py new file mode 100644 index 00000000..3b961dfb --- /dev/null +++ b/tests/tools/test_streaming_progress_tool.py @@ -0,0 +1,138 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. + +from __future__ import annotations + +from typing import AsyncIterator +from unittest.mock import MagicMock + +import pytest + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.tools._function_tool import FunctionTool +from trpc_agent_sdk.tools._streaming_progress_tool import StreamingProgressTool + + +async def streaming_func(query: str) -> AsyncIterator[dict]: + """Stream a few progress updates.""" + yield {"status": "started", "query": query} + yield {"status": "step", "step": 1} + yield {"status": "step", "step": 2} + yield {"status": "done", "query": query, "steps": 2} + + +async def streaming_func_no_yield() -> AsyncIterator[dict]: + """Generator that never yields.""" + if False: # pragma: no cover + yield {} + + +async def streaming_func_str() -> AsyncIterator[str]: + """Stream a few string progress updates.""" + yield "starting" + yield "halfway" + yield "completed" + + +def regular_sync_func(x: int) -> int: + """Not a generator.""" + return x + + +async def regular_async_func(x: int) -> int: + """Not a generator.""" + return x + + +class TestStreamingProgressToolInit: + + def test_init_with_async_generator(self): + tool = StreamingProgressTool(streaming_func) + assert tool.name == "streaming_func" + assert tool.is_progress_streaming is True + assert tool.func is streaming_func + + def test_init_rejects_sync_function(self): + with pytest.raises(TypeError, match="async def.*generator"): + StreamingProgressTool(regular_sync_func) # type: ignore[arg-type] + + def test_init_rejects_plain_async_function(self): + with pytest.raises(TypeError, match="async def.*generator"): + StreamingProgressTool(regular_async_func) # type: ignore[arg-type] + + def test_is_progress_streaming_property(self): + tool = StreamingProgressTool(streaming_func) + assert tool.is_progress_streaming is True + + def test_base_function_tool_is_not_progress_streaming(self): + # Plain FunctionTool must not silently inherit the streaming flag. + ft = FunctionTool(regular_sync_func) + assert getattr(ft, "is_progress_streaming", False) is False + + +class TestStreamingProgressToolExecution: + + @pytest.mark.asyncio + async def test_run_streaming_yields_all_values(self): + tool = StreamingProgressTool(streaming_func) + ctx = MagicMock(spec=InvocationContext) + ctx.agent = MagicMock() + + out = [] + async for value in tool.run_streaming(tool_context=ctx, args={"query": "hi"}): + out.append(value) + assert out == [ + {"status": "started", "query": "hi"}, + {"status": "step", "step": 1}, + {"status": "step", "step": 2}, + {"status": "done", "query": "hi", "steps": 2}, + ] + + @pytest.mark.asyncio + async def test_run_streaming_with_string_payloads(self): + tool = StreamingProgressTool(streaming_func_str) + ctx = MagicMock(spec=InvocationContext) + ctx.agent = MagicMock() + + out = [] + async for value in tool.run_streaming(tool_context=ctx, args={}): + out.append(value) + assert out == ["starting", "halfway", "completed"] + + @pytest.mark.asyncio + async def test_run_streaming_missing_mandatory_arg(self): + tool = StreamingProgressTool(streaming_func) + ctx = MagicMock(spec=InvocationContext) + ctx.agent = MagicMock() + + out = [] + async for value in tool.run_streaming(tool_context=ctx, args={}): + out.append(value) + # Exactly one error payload, no exception bubbled up. + assert len(out) == 1 + assert "error" in out[0] + assert "missing" in out[0]["error"].lower() + + @pytest.mark.asyncio + async def test_run_async_impl_refuses_direct_invocation(self): + # Single-responsibility: streaming tools must NOT be drainable via + # the synchronous tool path. The only entry point is run_streaming(), + # which ToolsProcessor.execute_tools_async calls. + tool = StreamingProgressTool(streaming_func) + ctx = MagicMock(spec=InvocationContext) + ctx.agent = MagicMock() + + with pytest.raises(RuntimeError, match="does not support direct"): + await tool._run_async_impl(tool_context=ctx, args={"query": "hi"}) + + +class TestStreamingProgressToolDeclaration: + + def test_get_declaration_includes_function_name(self): + tool = StreamingProgressTool(streaming_func) + decl = tool._get_declaration() + assert decl is not None + assert decl.name == "streaming_func" diff --git a/trpc_agent_sdk/agents/_llm_agent.py b/trpc_agent_sdk/agents/_llm_agent.py index e5713d4c..3546c1d6 100644 --- a/trpc_agent_sdk/agents/_llm_agent.py +++ b/trpc_agent_sdk/agents/_llm_agent.py @@ -570,8 +570,11 @@ def accumulate_content(event: Event) -> None: # Execute tools and yield results (Runner will store them automatically) last_tool_event = None + any_skip_summarization = False async for tool_event in extended_tools_processor.execute_tools_async(collected_tool_calls, ctx): last_tool_event = tool_event + if tool_event.actions and tool_event.actions.skip_summarization: + any_skip_summarization = True # Check if this event contains responses from long-running tools if tool_event.content and tool_event.content.parts: @@ -637,6 +640,17 @@ def accumulate_content(event: Event) -> None: logger.debug("disable_react_tool set, exiting after tool execution for external control") return + # Honor skip_summarization on tool responses: when any tool + # in this batch declares "the tool output IS the final + # answer, do not ask the LLM to summarize it" (e.g. + # `AgentTool(skip_summarization=True)` or + # `StreamingProgressTool(skip_summarization=True)`), + # end this agent without another LLM follow-up call. + # See `EventActions.skip_summarization` docstring. + if any_skip_summarization: + logger.debug("Tool returned skip_summarization=True, exiting without LLM follow-up") + return + # Continue the multi-turn loop for next LLM call with tool results in history logger.debug("Tool execution completed, continuing conversation") continue diff --git a/trpc_agent_sdk/agents/core/_tools_processor.py b/trpc_agent_sdk/agents/core/_tools_processor.py index 29307ef2..348df782 100644 --- a/trpc_agent_sdk/agents/core/_tools_processor.py +++ b/trpc_agent_sdk/agents/core/_tools_processor.py @@ -16,6 +16,7 @@ import asyncio import json import time +from typing import Any from typing import AsyncGenerator from typing import List from typing import Optional @@ -193,54 +194,100 @@ async def execute_tools_async( logger.debug("Starting execution of %s tool calls", len(tool_calls)) + # Split the batch by execution model. Progress-streaming tools are + # **never** mixed into the legacy parallel/sequential path: they have + # a different control flow (one tool call -> many events) that does + # not compose with the "1 call -> 1 event, then merge" parallel + # design. The non-streaming bucket is fed through the legacy path + # *verbatim* so that we do not regress any existing behavior. + streaming_calls, non_streaming_calls = self._split_calls_by_streaming(tool_calls, resolved_tools) + # Capture state before tool execution state_begin = dict(context.session.state) - parallel_tool_calls: bool = getattr(context.agent, "parallel_tool_calls", False) + # ---- Phase 1: legacy path for non-streaming tools (unchanged) ---- + if non_streaming_calls: + parallel_tool_calls: bool = getattr(context.agent, "parallel_tool_calls", False) + if parallel_tool_calls: + # Parallel execution: collect all events and merge them + function_response_events: list[Event] = [] + async with asyncio.TaskGroup() as tg: + for tool_call in non_streaming_calls: + tg.create_task(self.__invoke_tools(context, resolved_tools, tool_call, + function_response_events)) + + # Handle merging and tracing based on number of events + if function_response_events: + if len(function_response_events) == 1: + yield function_response_events[0] + else: + merged_event = self._merge_parallel_function_response_events(function_response_events) + state_end = dict(context.session.state) + if merged_event.actions and merged_event.actions.state_delta: + state_end.update(merged_event.actions.state_delta) + with tracer.start_as_current_span( + "execute_tool (merged)", + attributes={"gen_ai.operation.name": "execute_tool"}, + ): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + state_begin=state_begin, + state_end=state_end, + ) + yield merged_event + else: + # Sequential execution: yield each event immediately after execution + for tool_call in non_streaming_calls: + function_response_events: list[Event] = [] + result_event = await self.__invoke_tools(context, resolved_tools, tool_call, + function_response_events) + if result_event: + yield result_event + + # ---- Phase 2: uniform streaming path for progress-streaming tools ---- + # Streaming tools are always executed **sequentially among themselves**. + # Interleaving their partials would force the consumer to demux events + # by tool_call_id; we deliberately keep ordering deterministic instead. + # See StreamingProgressTool docstring for the per-tool contract. + for tool_call in streaming_calls: + tool = await self._find_tool(tool_call, resolved_tools) + if tool is None: + yield self._create_error_event( + context, + "tool_not_found", + f"Tool '{tool_call.name}' not found", + tool_call.id, + tool_call.name, + ) + continue + async for ev in self._execute_progress_streaming_tool(tool_call, tool, context): + yield ev - if parallel_tool_calls: - # Parallel execution: collect all events and merge them - function_response_events: list[Event] = [] - async with asyncio.TaskGroup() as tg: - for tool_call in tool_calls: - tg.create_task(self.__invoke_tools(context, resolved_tools, tool_call, function_response_events)) + @staticmethod + def _split_calls_by_streaming( + tool_calls: List[FunctionCall], + resolved_tools: List[BaseTool], + ) -> tuple[List[FunctionCall], List[FunctionCall]]: + """Partition ``tool_calls`` into ``(streaming, non_streaming)`` lists. - # Handle merging and tracing based on number of events - if not function_response_events: - return + Calls whose target tool cannot be resolved (e.g. typo from the LLM) + are placed in the **non_streaming** bucket so that the legacy path + keeps producing the canonical ``tool_not_found`` error event. - if len(function_response_events) == 1: - # Single tool call - yield the event directly - yield function_response_events[0] + Relative order within each list is preserved so downstream tracing + stays predictable. + """ + by_name = {t.name: t for t in resolved_tools if isinstance(t, BaseTool)} + streaming: List[FunctionCall] = [] + non_streaming: List[FunctionCall] = [] + for tc in tool_calls: + tool = by_name.get(tc.name) + if tool is not None and tool.is_progress_streaming: + streaming.append(tc) else: - # Multiple tool calls - merge them and add merged tracing - merged_event = self._merge_parallel_function_response_events(function_response_events) - - # Compute state after merged tool execution - state_end = dict(context.session.state) - if merged_event.actions and merged_event.actions.state_delta: - state_end.update(merged_event.actions.state_delta) - - # Add merged tool call tracing - with tracer.start_as_current_span( - "execute_tool (merged)", - attributes={"gen_ai.operation.name": "execute_tool"}, - ): - trace_merged_tool_calls( - response_event_id=merged_event.id, - function_response_event=merged_event, - state_begin=state_begin, - state_end=state_end, - ) - - yield merged_event - else: - # Sequential execution: yield each event immediately after execution - for tool_call in tool_calls: - function_response_events: list[Event] = [] - result_event = await self.__invoke_tools(context, resolved_tools, tool_call, function_response_events) - if result_event: - yield result_event + non_streaming.append(tc) + return streaming, non_streaming async def find_tool(self, context: InvocationContext, tool_call: FunctionCall) -> Optional[BaseTool]: """Find the appropriate tool for a tool call. @@ -403,6 +450,202 @@ async def _execute_tool(self, tool_call: FunctionCall, tool: BaseTool, context: return error_event + async def _execute_progress_streaming_tool( + self, + tool_call: FunctionCall, + tool: BaseTool, + context: InvocationContext, + ) -> AsyncGenerator[Event, None]: + """Execute a progress-streaming tool, surfacing every yield as a partial event. + + Contract with :class:`StreamingProgressTool`: + + - Every value yielded by the tool's async generator becomes a + ``partial=True`` Event with ``custom_metadata.tool_progress=True``. + These events are *not* persisted into session history and are *not* + fed back to the LLM as tool responses. + - The **last** yielded value is additionally used to build the final + function_response event (``partial=False``, with a real + ``function_response`` Part) that closes this tool call. + + Args: + tool_call: The LLM-issued FunctionCall to execute. + tool: The resolved StreamingProgressTool instance. + context: The invocation context. + + Yields: + Event: zero or more partial progress events, followed by exactly + one final function_response event (or an error event). + """ + with tracer.start_as_current_span( + f"execute_tool {tool.name} (streaming)", + attributes={ + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.name": tool.name, + "gen_ai.tool.description": tool.description or "", + }, + ): + state_begin = dict(context.session.state) + + if isinstance(tool_call.args, str): + arguments = json.loads(tool_call.args) + else: + arguments = tool_call.args or {} + + context.function_call_id = tool_call.id + start_time = time.monotonic() + + run_streaming = getattr(tool, "run_streaming", None) + if run_streaming is None: + # Defensive: a tool advertising is_progress_streaming=True + # without run_streaming() is broken. Fall back to non-streaming. + logger.warning( + "Tool %s sets is_progress_streaming=True but exposes no run_streaming(); " + "falling back to non-streaming execution.", + tool.name, + ) + final_event = await self._execute_tool(tool_call, tool, context) + yield final_event + return + + last_value: Any = None + progress_count = 0 + skip_summarization = bool(getattr(tool, "skip_summarization", False)) + + try: + # Drain the streaming generator. Buffer the previously-seen + # value and emit it as a partial event only when a *next* + # value arrives, so that the last value is reserved for the + # final function_response event. + async for value in run_streaming(tool_context=context, args=arguments): + if last_value is not None: + yield self._build_progress_event(context, tool_call, tool, last_value) + progress_count += 1 + last_value = value + + execution_time = time.monotonic() - start_time + report_execute_tool( + context, + tool, + duration_s=execution_time, + error_type=None, + ) + + final_result = last_value if last_value is not None else {} + if not isinstance(final_result, dict): + final_result = {"result": final_result} + + part_function_response = Part.from_function_response(name=tool_call.name, response=final_result) + part_function_response.function_response.id = tool_call.id + + final_event = Event( + invocation_id=context.invocation_id, + author=context.agent.name, + content=Content(role="user", parts=[part_function_response]), + custom_metadata={ + "execution_time": execution_time, + "progress_events": progress_count, + }, + branch=context.branch, + ) + + if context.state.has_delta(): + final_event.actions.state_delta.update(context.state._delta) # pylint: disable=protected-access + if context.event_actions.skip_summarization or skip_summarization: + # Either the tool declared the streamed output as the final + # answer at construction time, or it asked for it via the + # event_actions context bag during execution. + final_event.actions.skip_summarization = True + if context.event_actions.transfer_to_agent: + final_event.actions.transfer_to_agent = context.event_actions.transfer_to_agent + if context.event_actions.artifact_delta: + final_event.actions.artifact_delta.update(context.event_actions.artifact_delta) + + state_end = dict(context.session.state) + if final_event.actions and final_event.actions.state_delta: + state_end.update(final_event.actions.state_delta) + + trace_tool_call( + tool=tool, + args=arguments, + function_response_event=final_event, + state_begin=state_begin, + state_end=state_end, + ) + + yield final_event + + except Exception as ex: # pylint: disable=broad-except + report_execute_tool( + context, + tool, + duration_s=time.monotonic() - start_time, + error_type=type(ex).__name__, + ) + error_event = self._create_error_event( + context, + "tool_execution_error", + str(ex), + tool_call.id, + tool_call.name, + ) + state_end = dict(context.session.state) + trace_tool_call( + tool=tool, + args=arguments, + function_response_event=error_event, + state_begin=state_begin, + state_end=state_end, + ) + logger.error("Error executing streaming tool %s: %s", tool_call.name, ex, exc_info=True) + yield error_event + + @staticmethod + def _build_progress_event( + context: InvocationContext, + tool_call: FunctionCall, + tool: BaseTool, + value: Any, + ) -> Event: + """Wrap a single value yielded by a streaming tool into a partial Event. + + Rules: + - ``str`` → rendered as a text Part directly. + - ``dict`` / anything else → rendered as JSON text Part; the raw + value is also attached under ``custom_metadata['payload']`` so + structured consumers can read it without re-parsing. + - The event is marked ``partial=True`` so session services skip + persisting it and the LLM never sees it as a tool response. + - ``custom_metadata`` carries ``tool_progress=True``, ``tool_name``, + ``tool_call_id`` to make filtering on the consumer side trivial. + """ + if isinstance(value, str): + text = value + payload: Optional[Any] = None + else: + try: + text = json.dumps(value, ensure_ascii=False, default=str) + except (TypeError, ValueError): + text = str(value) + payload = value + + custom_metadata = { + "tool_progress": True, + "tool_name": tool.name, + "tool_call_id": tool_call.id, + } + if payload is not None: + custom_metadata["payload"] = payload + + return Event( + invocation_id=context.invocation_id, + author=context.agent.name, + content=Content(role="model", parts=[Part(text=text)]), + partial=True, + branch=context.branch, + custom_metadata=custom_metadata, + ) + def _merge_parallel_function_response_events(self, function_response_events: List[Event]) -> Event: """Merge multiple function response events into a single event. diff --git a/trpc_agent_sdk/sessions/_session_summarizer.py b/trpc_agent_sdk/sessions/_session_summarizer.py index 5f1dacfd..6251beec 100644 --- a/trpc_agent_sdk/sessions/_session_summarizer.py +++ b/trpc_agent_sdk/sessions/_session_summarizer.py @@ -253,10 +253,6 @@ def _extract_conversation_text(self, events: List[Event]) -> str: if not event.content or not event.content.parts: continue - # Skip events that should not be included in summary - if event.actions and event.actions.skip_summarization: - continue - # Extract text、tool_call、tool_response from event parts event_text = "" for part in event.content.parts: diff --git a/trpc_agent_sdk/tools/__init__.py b/trpc_agent_sdk/tools/__init__.py index 4412110d..1ced02e8 100644 --- a/trpc_agent_sdk/tools/__init__.py +++ b/trpc_agent_sdk/tools/__init__.py @@ -33,6 +33,7 @@ from ._registry import register_tool_set from ._set_model_response_tool import SetModelResponseTool from ._streaming_function_tool import StreamingFunctionTool +from ._streaming_progress_tool import StreamingProgressTool from ._tool_adapter import convert_toolunion_to_tool_list from ._tool_adapter import create_tool from ._tool_adapter import create_toolset @@ -91,6 +92,7 @@ "register_tool_set", "SetModelResponseTool", "StreamingFunctionTool", + "StreamingProgressTool", "convert_toolunion_to_tool_list", "create_tool", "create_toolset", diff --git a/trpc_agent_sdk/tools/_base_tool.py b/trpc_agent_sdk/tools/_base_tool.py index 311922cb..6fa91a40 100644 --- a/trpc_agent_sdk/tools/_base_tool.py +++ b/trpc_agent_sdk/tools/_base_tool.py @@ -112,6 +112,27 @@ def is_streaming(self) -> bool: """ return False + @property + def is_progress_streaming(self) -> bool: + """Whether this tool streams **its own execution progress** as events. + + When True, the framework drives this tool via a dedicated streaming + path that emits one ``partial=True`` Event per ``yield`` from the + tool, plus a final ``function_response`` event built from the last + yielded value. Such tools are also pulled out of the normal + sequential / parallel execution batches and handled uniformly. + + Subclasses that produce progress events (e.g. + :class:`StreamingProgressTool`) override this to return True. + Filter / batch logic should use this flag to skip or specially + handle progress-streaming tools. + + Returns: + bool: True if the tool emits intermediate progress events, + False otherwise. + """ + return False + @property def api_variant(self) -> str: """Get API variant.""" diff --git a/trpc_agent_sdk/tools/_streaming_progress_tool.py b/trpc_agent_sdk/tools/_streaming_progress_tool.py new file mode 100644 index 00000000..b0439be6 --- /dev/null +++ b/trpc_agent_sdk/tools/_streaming_progress_tool.py @@ -0,0 +1,213 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Streaming-progress function tool. + +A `StreamingProgressTool` lets a long-running tool push intermediate progress +events to the user **while** it is still executing, in the same way that the +LLM streams text. This is **different** from the two existing streaming-ish +tools shipped in this package: + +| Class | What gets streamed | +| --------------------------- | --------------------------------------------------- | +| ``StreamingFunctionTool`` | The *arguments* the LLM is generating for the call. | +| ``LongRunningFunctionTool`` | Just marks the call as long-running; one final | +| | result, no intermediate events. | +| ``StreamingProgressTool`` | The tool's *own* execution progress (this file). | + +Usage +----- + +The wrapped function must be an ``async def`` generator (i.e. uses ``yield``). +Each yielded value becomes a partial ``Event`` surfaced to the caller in real +time. The **last** yielded value is *also* used as the final +``function_response`` returned to the LLM: + +.. code-block:: python + + import asyncio + from typing import AsyncIterator + + from trpc_agent_sdk.tools import StreamingProgressTool + + + async def crawl_site(url: str) -> AsyncIterator[dict]: + '''Crawl a website and report progress.''' + yield {"status": "started", "url": url} + total = 5 + for i in range(total): + await asyncio.sleep(1) + yield {"status": "fetching", "page": i + 1, "total": total} + # Last yield = both the final progress event AND the function_response + # that is fed back to the LLM. + yield {"status": "done", "url": url, "pages": total} + + + tool = StreamingProgressTool(crawl_site) + agent = LlmAgent(name="crawler", model=model, tools=[tool]) + +Consuming the partial progress events from the runner side looks like: + +.. code-block:: python + + async for event in runner.run_async(...): + if event.partial and event.custom_metadata.get("tool_progress"): + print("[progress]", event.custom_metadata["tool_name"], + event.get_text() or event.custom_metadata.get("payload")) + +The yielded value can be: + +- ``dict``: surfaced verbatim under ``event.custom_metadata['payload']`` and + also serialised as JSON text on a ``Part`` so plain text consumers see it. +- ``str``: surfaced as a regular text ``Part``. +- ``BaseModel``: ``.model_dump()`` is used to coerce to ``dict``. + +Constraints +----------- + +- ``parallel_tool_calls=True`` is *not* recommended together with progress + streaming. The framework will fall back to sequential execution when at + least one progress-streaming tool is invoked in a batch, otherwise + intermediate events from concurrent tools would interleave unpredictably. +- The wrapped function MUST be an async generator. A regular ``async def`` + that returns a value will raise ``TypeError`` at construction time, with a + hint to use ``LongRunningFunctionTool`` or ``FunctionTool`` instead. +""" + +from __future__ import annotations + +import inspect +from typing import Any +from typing import AsyncIterator +from typing import Callable +from typing import Dict +from typing import Optional +from typing_extensions import override + +from pydantic import BaseModel + +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.filter import BaseFilter + +from ._constants import TOOL_CONTEXT +from ._function_tool import FunctionTool +from .utils import convert_pydantic_args +from .utils import get_mandatory_args + + +class StreamingProgressTool(FunctionTool): + """A function tool that yields intermediate progress events. + + See module docstring for usage and rationale. + """ + + def __init__( + self, + func: Callable[..., AsyncIterator[Any]], + filters_name: Optional[list[str]] = None, + filters: Optional[list[BaseFilter]] = None, + *, + skip_summarization: bool = False, + ): + """Wrap ``func`` (an async generator) into a streaming-progress tool. + + Args: + func: The wrapped ``async def`` generator function. Each ``yield`` + becomes one streaming event; the last yield is also used as + the final ``function_response``. + filters_name: Optional filter names (forwarded to FunctionTool). + filters: Optional filter instances (forwarded to FunctionTool). + skip_summarization: If True, the framework treats the tool's + streamed output as the **final** user-facing answer and + stops the agent loop after this tool finishes; no LLM + follow-up call is made. Use this when the user has already + consumed the streaming output and an LLM summary would just + be redundant. Implemented by setting + ``event.actions.skip_summarization=True`` on the final + ``function_response`` event, which + :meth:`LlmAgent._run_async_impl` checks to terminate the + conversation loop early. + """ + if not inspect.isasyncgenfunction(func): + raise TypeError("StreamingProgressTool requires an `async def` *generator* function " + f"(one that uses `yield`). Got: {type(func).__name__}. " + "If your tool only returns a single result, use `FunctionTool` (fast) " + "or `LongRunningFunctionTool` (for long but non-streaming work) instead.") + super().__init__(func, filters_name=filters_name, filters=filters) + self._skip_summarization = bool(skip_summarization) + + @property + def is_progress_streaming(self) -> bool: + """Marks this tool as one that yields progress events during execution.""" + return True + + @property + def skip_summarization(self) -> bool: + """Whether the final tool event should set ``skip_summarization=True``. + + When True the LlmAgent loop exits after this tool returns, without + calling the LLM to summarize the streamed output. + """ + return self._skip_summarization + + @override + async def _run_async_impl(self, *, tool_context: InvocationContext, args: Dict[str, Any]) -> Any: + """Refuse direct invocation. + + Progress-streaming tools must be driven through + ``ToolsProcessor`` (which calls :meth:`run_streaming`). Allowing + ``_run_async_impl`` to silently drain the generator would violate + single-responsibility: this class would have two ways to be executed + with subtly different semantics (no partial events surfaced, no + ``function_response`` event built, callers thinking they got a + "normal" tool result). + + If you need a single-shot tool, wrap the function with + :class:`FunctionTool` or :class:`LongRunningFunctionTool` instead. + """ + raise RuntimeError(f"{type(self).__name__} (`{self.name}`) does not support direct " + "invocation via `run_async` / `_run_async_impl`. It must be " + "executed through `ToolsProcessor.execute_tools_async`, which " + "drives it via `run_streaming` and surfaces interim progress " + "events. If you only need a one-shot result, use `FunctionTool` " + "or `LongRunningFunctionTool` instead.") + + async def run_streaming( + self, + *, + tool_context: InvocationContext, + args: Dict[str, Any], + ) -> AsyncIterator[Any]: + """Yield progress values produced by the wrapped async generator. + + The framework wraps each yielded value into a partial ``Event`` and + surfaces it through ``Runner.run_async``. The *last* yielded value is + additionally used as the final ``function_response`` part fed back to + the LLM. + + Mandatory-argument validation, ``tool_context`` injection and + pydantic-arg coercion mirror :class:`FunctionTool`. + """ + args_to_call = args.copy() + signature = inspect.signature(self.func) + + if TOOL_CONTEXT in signature.parameters: + args_to_call[TOOL_CONTEXT] = tool_context + + args_to_call = convert_pydantic_args(args_to_call, signature) + + mandatory_args = get_mandatory_args(self.func) + missing = [arg for arg in mandatory_args if arg not in args_to_call] + if missing: + yield { + "error": (f"Invoking `{self.name}()` failed: missing mandatory input parameters: " + f"{', '.join(missing)}. Please call again with all required arguments.") + } + return + + async for value in self.func(**args_to_call): + if isinstance(value, BaseModel): + value = value.model_dump() + yield value