diff --git a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py index 43b0ebc..0c02325 100644 --- a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py +++ b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/agent.py @@ -1,249 +1,178 @@ """Knowledge-grounded QA agent using Google ADK with Google Search. -This module provides a proper ReAct agent that explicitly calls -Google Search and shows the reasoning process through observable tool calls. +This module provides a ReAct agent with built-in planning via Gemini's thinking +mode that explicitly calls tools and shows the reasoning process through observable +tool calls. """ import asyncio import logging +import time import uuid +import warnings from typing import Any from aieng.agent_evals.configs import Configs from aieng.agent_evals.tools import ( - GroundedResponse, GroundingChunk, + create_fetch_file_tool, create_google_search_tool, + create_grep_file_tool, + create_read_file_tool, + create_web_fetch_tool, ) from google.adk.agents import Agent +from google.adk.agents.context_cache_config import ContextCacheConfig +from google.adk.apps.app import App, EventsCompactionConfig +from google.adk.apps.llm_event_summarizer import LlmEventSummarizer +from google.adk.models import Gemini +from google.adk.planners import PlanReActPlanner from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService from google.genai import types +from google.genai.errors import ClientError +from pydantic import BaseModel, Field +from tenacity import ( + RetryError, + before_sleep_log, + retry, + retry_if_exception, + stop_after_attempt, + wait_exponential_jitter, +) - -logger = logging.getLogger(__name__) - - -def _extract_tool_calls(event: Any) -> list[dict[str, Any]]: - """Extract tool calls from event function calls. - - Parameters - ---------- - event : Any - An event from the ADK runner. - - Returns - ------- - list[dict[str, Any]] - List of tool call dictionaries with 'name' and 'args' keys. - """ - if not hasattr(event, "get_function_calls"): - return [] - function_calls = event.get_function_calls() - if not function_calls: - return [] - - tool_calls = [] - for fc in function_calls: - tool_call_info = { - "name": getattr(fc, "name", "unknown"), - "args": getattr(fc, "args", {}), - } - tool_calls.append(tool_call_info) - logger.info(f"Tool call: {tool_call_info['name']}({tool_call_info['args']})") - return tool_calls +from .event_extraction import ( + extract_event_text, + extract_final_response, + extract_grounding_queries, + extract_grounding_sources, + extract_search_queries_from_tool_calls, + extract_sources_from_responses, + extract_thoughts_from_event, + extract_tool_calls, + resolve_source_urls, +) +from .plan_parsing import ( + PLANNING_TAG, + REPLANNING_TAG, + ResearchPlan, + StepStatus, + extract_final_answer_text, + extract_plan_text, + extract_reasoning_text, + parse_plan_steps_from_text, +) +from .retry import ( + API_RETRY_INITIAL_WAIT, + API_RETRY_JITTER, + API_RETRY_MAX_ATTEMPTS, + API_RETRY_MAX_WAIT, + MAX_EMPTY_RESPONSE_RETRIES, + is_context_overflow_error, + is_retryable_api_error, +) +from .system_instructions import build_system_instructions +from .token_tracker import TokenTracker -def _extract_search_queries_from_tool_calls(tool_calls: list[dict[str, Any]]) -> list[str]: - """Extract search queries from tool calls. +# Suppress experimental warnings from ADK +warnings.filterwarnings("ignore", message=r".*EXPERIMENTAL.*ContextCacheConfig.*") +warnings.filterwarnings("ignore", message=r".*EXPERIMENTAL.*EventsCompactionConfig.*") - Parameters - ---------- - tool_calls : list[dict[str, Any]] - List of tool call dictionaries. +logger = logging.getLogger(__name__) - Returns - ------- - list[str] - Search queries found in the tool calls. - """ - queries = [] - for tool_call in tool_calls: - tool_name = str(tool_call.get("name", "")) - tool_args = tool_call.get("args", {}) - if "search" in tool_name.lower() and isinstance(tool_args, dict): - query = tool_args.get("query", "") - if query: - queries.append(query) - return queries +class StepExecution(BaseModel): + """Record of executing a single research step. -def _extract_sources_from_responses(event: Any) -> list[GroundingChunk]: - """Extract sources from event function responses. + This model captures the execution trace for evaluation purposes. - Parameters + Attributes ---------- - event : Any - An event from the ADK runner. - - Returns - ------- - list[GroundingChunk] - Sources extracted from the function responses. + step_id : int + The step ID that was executed. + tool_used : str + The actual tool that was used. + input_query : str + The query or input provided to the tool. + output_summary : str + Summary of what the step produced. + sources_found : int + Number of sources discovered in this step. + duration_ms : int + Execution time in milliseconds. + raw_output : str + Raw output from the tool for debugging. """ - if not hasattr(event, "get_function_responses"): - return [] - function_responses = event.get_function_responses() - if not function_responses: - return [] - - sources = [] - for fr in function_responses: - response_data = getattr(fr, "response", {}) - if not isinstance(response_data, dict): - continue - # Extract sources from search tool response - for src in response_data.get("sources", []): - if isinstance(src, dict): - sources.append( - GroundingChunk( - title=src.get("title", ""), - uri=src.get("uri") or src.get("url") or "", - ) - ) - # Extract grounding_chunks if present - for chunk in response_data.get("grounding_chunks", []): - if isinstance(chunk, dict) and "web" in chunk: - sources.append( - GroundingChunk( - title=chunk["web"].get("title", ""), - uri=chunk["web"].get("uri", ""), - ) - ) - return sources + step_id: int + tool_used: str + input_query: str + output_summary: str = "" + sources_found: int = 0 + duration_ms: int = 0 + raw_output: str = "" -def _extract_grounding_sources(event: Any) -> list[GroundingChunk]: - """Extract sources from grounding metadata. - - Parameters - ---------- - event : Any - An event from the ADK runner. - - Returns - ------- - list[GroundingChunk] - Sources extracted from the grounding metadata. - """ - gm = getattr(event, "grounding_metadata", None) - if not gm and hasattr(event, "content") and event.content: - gm = getattr(event.content, "grounding_metadata", None) - if not gm: - return [] - - sources = [] - if hasattr(gm, "grounding_chunks") and gm.grounding_chunks: - for chunk in gm.grounding_chunks: - if hasattr(chunk, "web") and chunk.web: - sources.append( - GroundingChunk( - title=getattr(chunk.web, "title", "") or "", - uri=getattr(chunk.web, "uri", "") or "", - ) - ) - return sources +class AgentResponse(BaseModel): + """Response from the knowledge agent with execution trace. -def _extract_grounding_queries(event: Any) -> list[str]: - """Extract search queries from grounding metadata. + Contains the answer text along with metadata about how the agent + arrived at the answer: the research plan, tool calls, sources, and reasoning. - Parameters + Attributes ---------- - event : Any - An event from the ADK runner. - - Returns - ------- - list[str] - Search queries from the grounding metadata. + text : str + The generated response text. + plan : ResearchPlan + The research plan created for the question. + execution_trace : list[StepExecution] + Record of each step's execution. + sources : list[GroundingChunk] + Web sources used in the response. + search_queries : list[str] + Search queries executed. + reasoning_chain : list[str] + Step-by-step reasoning trace. + tool_calls : list[dict] + Raw tool calls made during execution. + total_duration_ms : int + Total execution time in milliseconds. """ - gm = getattr(event, "grounding_metadata", None) - if not gm and hasattr(event, "content") and event.content: - gm = getattr(event.content, "grounding_metadata", None) - if not gm: - return [] - - queries = [] - if hasattr(gm, "web_search_queries") and gm.web_search_queries: - for q in gm.web_search_queries: - if q: - queries.append(q) - return queries - - -def _extract_final_response(event: Any) -> str | None: - """Extract final response text from event if it's a final response.""" - if not hasattr(event, "is_final_response") or not event.is_final_response(): - return None - if not hasattr(event, "content") or not event.content: - return None - if not hasattr(event.content, "parts") or not event.content.parts: - return None - return event.content.parts[0].text or "" - -SYSTEM_INSTRUCTIONS = """\ -You are a knowledge-grounded research assistant. Your role is to answer -questions accurately by searching the web for relevant information. - -## How to Answer Questions - -1. **Search First**: Always search the web before answering factual questions - that require current information. Do not rely solely on your training data. - -2. **Be Thorough**: For complex questions, search multiple times to gather - all relevant facts before synthesizing your answer. - -3. **Cite Sources**: Always mention which sources you used to answer the question. - -4. **Be Honest**: If you cannot find relevant information, say so clearly. - -5. **Synthesize Information**: When answering complex questions, synthesize - findings from multiple sources into a coherent response. - -## Response Format - -When answering questions: -- Provide a clear, direct answer first -- Include relevant context and details from your sources -- List the sources used at the end of your response -""" + text: str + plan: ResearchPlan + execution_trace: list[StepExecution] = Field(default_factory=list) + sources: list[GroundingChunk] = Field(default_factory=list) + search_queries: list[str] = Field(default_factory=list) + reasoning_chain: list[str] = Field(default_factory=list) + tool_calls: list[dict[str, Any]] = Field(default_factory=list) + total_duration_ms: int = 0 class KnowledgeGroundedAgent: - """A ReAct agent for knowledge-grounded QA using Google Search. + """A ReAct agent with built-in planning via Gemini's thinking mode. - This agent uses Google ADK with explicit Google Search tool calls, - making the reasoning process observable and traceable. + This agent uses Google ADK with BuiltInPlanner to enable Gemini's native + thinking capabilities, which plan and execute research in a unified loop. Parameters ---------- config : Configs, optional Configuration settings. If not provided, creates default config. model : str, optional - The model to use. If not provided, uses config.default_worker_model. - - Attributes - ---------- - config : Configs - The configuration settings. + The model to use for answering. If not provided, uses + config.default_worker_model. + enable_planning : bool, default True + Whether to enable the built-in planner (Gemini thinking mode). + thinking_budget : int, default 8192 + Token budget for the model's thinking/planning phase. Examples -------- >>> from aieng.agent_evals.knowledge_qa import KnowledgeGroundedAgent >>> agent = KnowledgeGroundedAgent() - >>> response = agent.answer("Who won the 2024 Nobel Prize in Physics?") + >>> response = agent.answer("What are the Basel III capital requirements?") >>> print(response.text) """ @@ -251,8 +180,13 @@ def __init__( self, config: Configs | None = None, model: str | None = None, + enable_planning: bool = True, + enable_caching: bool = True, + enable_compaction: bool = True, + compaction_interval: int = 10, + thinking_budget: int = 8192, ) -> None: - """Initialize the knowledge-grounded agent. + """Initialize the knowledge-grounded agent with built-in planning. Parameters ---------- @@ -260,55 +194,191 @@ def __init__( Configuration settings. If not provided, creates default config. model : str, optional The model to use. If not provided, uses config.default_worker_model. + enable_planning : bool, default True + Whether to enable the built-in planner (Gemini thinking mode). + enable_caching : bool, default True + Whether to enable context caching for reduced latency and cost. + enable_compaction : bool, default True + Whether to enable context compaction. When enabled, ADK automatically + summarizes older events to prevent running out of context. + compaction_interval : int, default 3 + Number of invocations before triggering context compaction. + thinking_budget : int, default 8192 + Token budget for the model's thinking/planning phase. """ + self._enable_compaction = enable_compaction + self._compaction_interval = compaction_interval if config is None: config = Configs() # type: ignore[call-arg] self.config = config self.model = model or config.default_worker_model + self.temperature = config.default_temperature + self.enable_planning = enable_planning + self._thinking_budget = thinking_budget + + # Create tools - use function tool for search so agent sees actual URLs + self._search_tool = create_google_search_tool(config=config) + self._web_fetch_tool = create_web_fetch_tool() + self._fetch_file_tool = create_fetch_file_tool() + self._grep_file_tool = create_grep_file_tool() + self._read_file_tool = create_read_file_tool() + + # Create planner if enabled + planner = None + if enable_planning: + planner = PlanReActPlanner() + + # Create ADK agent with built-in planner + # Configure thinking for models that support it (gemini-2.5-*, gemini-3-*) + thinking_config = None + if thinking_budget > 0 and self._supports_thinking(self.model): + thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget) - # Create the Google Search tool - self._search_tool = create_google_search_tool() - - # Create ADK agent with Google Search tool self._agent = Agent( - name="knowledge_qa_agent", + name="knowledge_qa", model=self.model, - instruction=SYSTEM_INSTRUCTIONS, - tools=[self._search_tool], + instruction=build_system_instructions(), + tools=[ + self._search_tool, + self._web_fetch_tool, + self._fetch_file_tool, + self._grep_file_tool, + self._read_file_tool, + ], + planner=planner, + generate_content_config=types.GenerateContentConfig( + temperature=self.temperature, + thinking_config=thinking_config, + ), ) + # Current research plan (populated from model's thinking for CLI display) + self._current_plan: ResearchPlan | None = None + + # Token tracking for context usage display + self._token_tracker = TokenTracker(model=self.model) + # Session service for conversation history self._session_service = InMemorySessionService() - # Runner orchestrates the ReAct loop - self._runner = Runner( - app_name="knowledge_qa", - agent=self._agent, + # Create App and Runner based on enabled features + self._app: App | None + if enable_caching or enable_compaction: + self._app, self._runner = self._create_app_and_runner(config, enable_caching, enable_compaction) + else: + self._app = None + self._runner = Runner( + app_name="knowledge_qa", + agent=self._agent, + session_service=self._session_service, + ) + + # Track active sessions + self._sessions: dict[str, str] = {} + + def _create_app_and_runner( + self, + config: Configs, + enable_caching: bool, + enable_compaction: bool, + ) -> tuple[App, Runner]: + """Create App and Runner with caching/compaction config.""" + app_kwargs: dict[str, Any] = { + "name": "knowledge_qa", + "root_agent": self._agent, + } + + if enable_caching: + app_kwargs["context_cache_config"] = ContextCacheConfig( + min_tokens=2048, + ttl_seconds=3600, # 1 hour (increased from 10 min) + cache_intervals=50, # 50 reuses (increased from 10) + ) + + if enable_compaction: + summarizer = LlmEventSummarizer(llm=Gemini(model=config.default_worker_model)) + app_kwargs["events_compaction_config"] = EventsCompactionConfig( + compaction_interval=self._compaction_interval, + overlap_size=1, + summarizer=summarizer, + ) + + app = App(**app_kwargs) + runner = Runner( + app=app, session_service=self._session_service, ) + return app, runner - # Track active sessions - self._sessions: dict[str, str] = {} # Maps external session_id to ADK session_id + @staticmethod + def _supports_thinking(model: str) -> bool: + """Check if a model supports thinking configuration. - async def _get_or_create_session_async(self, session_id: str | None = None) -> str: - """Get or create an ADK session for the given session ID. + Thinking is supported by gemini-2.5-* and gemini-3-* models. + """ + model_lower = model.lower() + return "gemini-2.5" in model_lower or "gemini-3" in model_lower - Parameters - ---------- - session_id : str, optional - External session ID. If not provided, generates a new one. + def reset(self) -> None: + """Reset agent state for a new question. - Returns - ------- - str - The ADK session ID. + Clears session history and plan state to ensure clean execution + for each new question. Call this between evaluation examples. + """ + self._sessions.clear() + self._session_service = InMemorySessionService() + self._current_plan = None + self._token_tracker = TokenTracker(model=self.model) + + # Recreate runner with fresh session service + if self._app is not None: + self._runner = Runner( + app=self._app, + session_service=self._session_service, + ) + else: + self._runner = Runner( + app_name="knowledge_qa", + agent=self._agent, + session_service=self._session_service, + ) + logger.debug("Agent state reset for new question") + + @property + def current_plan(self) -> ResearchPlan | None: + """Get the current research plan if one exists.""" + return self._current_plan + + @property + def token_tracker(self) -> TokenTracker: + """Get the token tracker for context usage monitoring.""" + return self._token_tracker + + async def create_plan_async(self, question: str) -> ResearchPlan | None: + """Initialize plan tracking for CLI display. + + Creates an empty plan that will be populated from the model's + PLANNING output during answer_async(). The actual plan steps + are extracted from the model's first response. """ + if not self.enable_planning: + return None + + # Create empty plan - will be populated from model's PLANNING output + self._current_plan = ResearchPlan( + original_question=question, + steps=[], + reasoning="", + ) + return self._current_plan + + async def _get_or_create_session_async(self, session_id: str | None = None) -> str: + """Get or create an ADK session for the given session ID.""" if session_id is None: session_id = str(uuid.uuid4()) if session_id not in self._sessions: - # Create a new ADK session through the session service session = await self._session_service.create_session( app_name="knowledge_qa", user_id="user", @@ -318,12 +388,250 @@ async def _get_or_create_session_async(self, session_id: str | None = None) -> s return self._sessions[session_id] + def _update_plan_from_text(self, text: str, question: str, is_replan: bool = False) -> bool: + """Update the current plan from PLANNING or REPLANNING tagged text.""" + plan_text = extract_plan_text(text) + if not plan_text: + return False + + steps = parse_plan_steps_from_text(plan_text) + if not steps: + return False + + if is_replan and self._current_plan: + # Replanning: preserve completed steps, update remaining + completed_steps = [s for s in self._current_plan.steps if s.status == StepStatus.COMPLETED] + # Renumber new steps starting after completed ones + next_id = len(completed_steps) + 1 + for i, step in enumerate(steps): + step.step_id = next_id + i + self._current_plan.steps = completed_steps + steps + self._current_plan.reasoning = f"Replanned: {plan_text[:300]}" + logger.info(f"Replanned with {len(steps)} new steps (keeping {len(completed_steps)} completed)") + else: + # New plan + self._current_plan = ResearchPlan( + original_question=question, + steps=steps, + reasoning=plan_text[:500], + ) + logger.info(f"Extracted plan with {len(steps)} steps") + + # Mark first pending step as in progress + for step in self._current_plan.steps: + if step.status == StepStatus.PENDING: + step.status = StepStatus.IN_PROGRESS + break + + return True + + def _process_event_text_for_plan(self, text: str, question: str) -> None: + """Process event text to extract and update plan.""" + if not text or not self._current_plan: + return + + # Check for replanning first + if REPLANNING_TAG in text: + self._update_plan_from_text(text, question, is_replan=True) + # Check for initial planning (only if plan is empty) + elif PLANNING_TAG in text and len(self._current_plan.steps) == 0: + self._update_plan_from_text(text, question, is_replan=False) + + def _update_plan_step_from_tool_call(self, tool_name: str) -> None: + """Record tool call against current plan step.""" + if not self._current_plan or not self._current_plan.steps: + return + + # Find current in-progress step and record the tool used + for step in self._current_plan.steps: + if step.status == StepStatus.IN_PROGRESS: + # Append tool to actual_output + if step.actual_output: + step.actual_output += f", {tool_name}" + else: + step.actual_output = f"Used: {tool_name}" + break + + def _advance_plan_step_on_reasoning(self) -> None: + """Advance to next plan step when reasoning is detected.""" + if not self._current_plan or not self._current_plan.steps: + return + + # Find current in-progress step + for i, step in enumerate(self._current_plan.steps): + if step.status == StepStatus.IN_PROGRESS: + # Mark current step as completed + step.status = StepStatus.COMPLETED + # Mark next step as in progress (if exists) + if i + 1 < len(self._current_plan.steps): + self._current_plan.steps[i + 1].status = StepStatus.IN_PROGRESS + break + + def _create_execution_trace( + self, + tool_calls: list[dict[str, Any]], + total_duration_ms: int, + ) -> list[StepExecution]: + """Create execution trace from tool calls.""" + return [ + StepExecution( + step_id=i + 1, + tool_used=str(tc.get("name", "unknown")), + input_query=str(tc.get("args", {})), + output_summary=f"Tool call {i + 1}", + sources_found=0, + duration_ms=total_duration_ms // max(len(tool_calls), 1), + ) + for i, tc in enumerate(tool_calls) + ] + + def _process_event( + self, + event: Any, + question: str, + results: dict[str, Any], + ) -> None: + """Process a single event from the agent run loop. + + Updates results dict in place with extracted information. + """ + self._token_tracker.add_from_event(event) + event_text = extract_event_text(event) + + # Extract thoughts for reasoning chain + thoughts = extract_thoughts_from_event(event) + if thoughts: + results["reasoning_chain"].append(thoughts[:300]) + + # Process plan tags and reasoning + if event_text: + self._process_event_text_for_plan(event_text, question) + reasoning_text = extract_reasoning_text(event_text) + if reasoning_text: + results["reasoning_chain"].append(reasoning_text) + self._advance_plan_step_on_reasoning() + + # Extract tool calls + new_tool_calls = extract_tool_calls(event) + results["tool_calls"].extend(new_tool_calls) + results["search_queries"].extend(extract_search_queries_from_tool_calls(new_tool_calls)) + for tc in new_tool_calls: + self._update_plan_step_from_tool_call(tc.get("name", "")) + + # Extract sources + results["sources"].extend(extract_sources_from_responses(event)) + results["sources"].extend(extract_grounding_sources(event)) + for q in extract_grounding_queries(event): + if q not in results["search_queries"]: + results["search_queries"].append(q) + + # Extract final response - prefer tagged answer, fall back to final response + # Only overwrite with non-empty content to avoid losing valid responses + final_answer = extract_final_answer_text(event_text) if event_text else None + if final_answer: + results["final_response"] = final_answer + else: + text = extract_final_response(event) + if text: # Only set if non-empty (don't overwrite valid response with empty) + results["final_response"] = text + + async def _run_agent_once_inner( + self, + question: str, + adk_session_id: str, + ) -> dict[str, Any]: + """Run the agent once and collect results (inner implementation).""" + content = types.Content(role="user", parts=[types.Part(text=question)]) + + # Collect results in a mutable dict for _process_event + results: dict[str, Any] = { + "tool_calls": [], + "sources": [], + "search_queries": [], + "reasoning_chain": [], + "final_response": "", + } + + event_count = 0 + async for event in self._runner.run_async( + user_id="user", + session_id=adk_session_id, + new_message=content, + ): + event_count += 1 + self._process_event(event, question, results) + + logger.debug(f"Processed {event_count} events. Final response length: {len(results.get('final_response', ''))}") + return results + + async def _run_agent_once( + self, + question: str, + adk_session_id: str, + ) -> dict[str, Any]: + """Run the agent once with retry logic for rate limits and context overflow. + + Wraps _run_agent_once_inner with exponential backoff retry for + 429/RESOURCE_EXHAUSTED errors from the Gemini API. If a context overflow + error occurs, resets the session and retries once with fresh context. + """ + + # Create a retry-wrapped version of the inner method + @retry( + retry=retry_if_exception(is_retryable_api_error), + wait=wait_exponential_jitter( + initial=API_RETRY_INITIAL_WAIT, + max=API_RETRY_MAX_WAIT, + jitter=API_RETRY_JITTER, + ), + stop=stop_after_attempt(API_RETRY_MAX_ATTEMPTS), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=False, # Don't reraise to avoid noisy stack traces + ) + async def _run_with_retry() -> dict[str, Any]: + return await self._run_agent_once_inner(question, adk_session_id) + + try: + return await _run_with_retry() + except RetryError as e: + # Clean error message when retries are exhausted + original_error = e.last_attempt.exception() + logger.error(f"API retry failed after {API_RETRY_MAX_ATTEMPTS} attempts: {original_error}") + raise RuntimeError( + f"API request failed after {API_RETRY_MAX_ATTEMPTS} retry attempts. Last error: {original_error}" + ) from original_error + except ClientError as e: + # Handle context overflow by resetting session + if is_context_overflow_error(e): + logger.warning(f"Context overflow detected: {e}") + logger.warning("Resetting session and retrying with fresh context...") + + # Create fresh session to clear accumulated history + self._session_service = InMemorySessionService() + new_session_id = await self._get_or_create_session_async() + + # Retry once with fresh session + try: + return await self._run_agent_once_inner(question, new_session_id) + except Exception as retry_error: + logger.error(f"Retry with fresh session failed: {retry_error}") + raise RuntimeError( + f"Context overflow error. Original error: {e}. " + f"Retry with fresh session also failed: {retry_error}" + ) from e + + # Re-raise non-context-overflow errors + raise + async def answer_async( self, question: str, session_id: str | None = None, - ) -> GroundedResponse: - """Answer a question using the ReAct loop asynchronously. + ) -> AgentResponse: + """Answer a question using built-in planning and tools. + + The agent uses PlanReAct planning to create and execute research steps. + The plan is captured and updated in real-time for CLI display. Parameters ---------- @@ -334,63 +642,91 @@ async def answer_async( Returns ------- - GroundedResponse - The response with text, tool calls, and sources. + AgentResponse + The response with plan, execution trace, and sources. """ - logger.info(f"Answering question (async): {question[:100]}...") + start_time = time.time() + logger.info(f"Answering question: {question[:100]}...") adk_session_id = await self._get_or_create_session_async(session_id) - # Create the user message - content = types.Content( - role="user", - parts=[types.Part(text=question)], - ) + if self._current_plan is None and self.enable_planning: + await self.create_plan_async(question) - # Collect events from the ReAct loop - tool_calls: list[dict[str, Any]] = [] - sources: list[GroundingChunk] = [] - search_queries: list[str] = [] - final_response = "" + # Run agent with retry logic for empty responses + results: dict[str, Any] = {} + current_session_id = adk_session_id - async for event in self._runner.run_async( - user_id="user", - session_id=adk_session_id, - new_message=content, - ): - logger.debug(f"Event: {event}") - - # Extract tool calls and search queries from function calls - new_tool_calls = _extract_tool_calls(event) - tool_calls.extend(new_tool_calls) - search_queries.extend(_extract_search_queries_from_tool_calls(new_tool_calls)) - - # Extract sources from function responses - sources.extend(_extract_sources_from_responses(event)) - - # Extract sources and queries from grounding metadata - sources.extend(_extract_grounding_sources(event)) - for q in _extract_grounding_queries(event): - if q not in search_queries: - search_queries.append(q) - - text = _extract_final_response(event) - if text is not None: - final_response = text - - return GroundedResponse( - text=final_response, - search_queries=search_queries, - sources=sources, - tool_calls=tool_calls, + for attempt in range(MAX_EMPTY_RESPONSE_RETRIES + 1): + results = await self._run_agent_once(question, current_session_id) + + # Check if we got a non-empty response + if results.get("final_response", "").strip(): + break + + # Empty response - log and retry if we have attempts left + if attempt < MAX_EMPTY_RESPONSE_RETRIES: + logger.warning( + f"Empty model response (attempt {attempt + 1}/{MAX_EMPTY_RESPONSE_RETRIES + 1}), " + "creating fresh session and retrying..." + ) + # Create fresh session for retry to avoid polluted history + fresh_session = await self._session_service.create_session( + app_name="knowledge_qa", + user_id="user", + state={}, + ) + current_session_id = fresh_session.id + # Reset plan for retry + if self.enable_planning: + await self.create_plan_async(question) + else: + # All retries exhausted - log detailed diagnostics + tool_call_count = len(results.get("tool_calls", [])) + reasoning_count = len(results.get("reasoning_chain", [])) + source_count = len(results.get("sources", [])) + logger.error( + f"Empty model response after {MAX_EMPTY_RESPONSE_RETRIES + 1} attempts. " + f"Tool calls: {tool_call_count}, Reasoning steps: {reasoning_count}, " + f"Sources: {source_count}. The model may have only produced thinking tokens." + ) + # Log reasoning chain to debug why no final answer was generated + if results.get("reasoning_chain"): + logger.error("Last reasoning steps (for debugging):") + for i, reasoning in enumerate(results["reasoning_chain"][-3:], 1): + logger.error(f" Reasoning {i}: {reasoning[:500]}...") # Log first 500 chars + + total_duration_ms = int((time.time() - start_time) * 1000) + + # Mark remaining steps as skipped (not completed - they weren't executed) + if self._current_plan: + for step in self._current_plan.steps: + if step.status in (StepStatus.PENDING, StepStatus.IN_PROGRESS): + step.status = StepStatus.SKIPPED + + # Resolve redirect URLs and build response + resolved_sources = await resolve_source_urls(results.get("sources", [])) + plan = self._current_plan or ResearchPlan(original_question=question, steps=[], reasoning="No planning enabled") + execution_trace = self._create_execution_trace(results.get("tool_calls", []), total_duration_ms) + self._current_plan = None + + return AgentResponse( + text=results.get("final_response", ""), + plan=plan, + execution_trace=execution_trace, + sources=resolved_sources, + search_queries=results.get("search_queries", []), + reasoning_chain=results.get("reasoning_chain", []), + tool_calls=results.get("tool_calls", []), + total_duration_ms=total_duration_ms, ) def answer( self, question: str, session_id: str | None = None, - ) -> GroundedResponse: - """Answer a question using the ReAct loop. + ) -> AgentResponse: + """Answer a question using built-in planning and tools (sync). Parameters ---------- @@ -401,15 +737,10 @@ def answer( Returns ------- - GroundedResponse - The response with text, tool calls, and sources. - - Notes - ----- - This is a synchronous wrapper around answer_async(). For Jupyter notebooks, - use `await agent.answer_async(question)` directly instead. + AgentResponse + The response with plan, execution trace, and sources. """ - logger.info(f"Answering question: {question[:100]}...") + logger.info(f"Answering question (sync): {question[:100]}...") return asyncio.run(self.answer_async(question, session_id)) @@ -417,9 +748,7 @@ class KnowledgeAgentManager: """Manages KnowledgeGroundedAgent lifecycle with lazy initialization. This class provides convenient lifecycle management for the knowledge agent, - with lazy initialization and state tracking. Unlike the general-purpose - AsyncClientManager (for infrastructure clients), this is specific to the - knowledge agent and is not a singleton. + with lazy initialization and state tracking. Parameters ---------- @@ -435,42 +764,50 @@ class KnowledgeAgentManager: >>> manager.close() """ - def __init__(self, config: Configs | None = None) -> None: + def __init__( + self, + config: Configs | None = None, + enable_caching: bool = True, + enable_planning: bool = True, + enable_compaction: bool = True, + ) -> None: """Initialize the client manager. Parameters ---------- config : Configs, optional Configuration object. If not provided, creates default config. + enable_caching : bool, default True + Whether to enable context caching. + enable_planning : bool, default True + Whether to enable built-in planning (Gemini thinking mode). + enable_compaction : bool, default True + Whether to enable context compaction. """ self._config = config + self._enable_caching = enable_caching + self._enable_planning = enable_planning + self._enable_compaction = enable_compaction self._agent: KnowledgeGroundedAgent | None = None self._initialized = False @property def config(self) -> Configs: - """Get or create the config instance. - - Returns - ------- - Configs - The configuration settings. - """ + """Get or create the config instance.""" if self._config is None: self._config = Configs() # type: ignore[call-arg] return self._config @property def agent(self) -> KnowledgeGroundedAgent: - """Get or create the knowledge-grounded agent. - - Returns - ------- - KnowledgeGroundedAgent - The knowledge-grounded QA agent. - """ + """Get or create the knowledge-grounded agent.""" if self._agent is None: - self._agent = KnowledgeGroundedAgent(config=self.config) + self._agent = KnowledgeGroundedAgent( + config=self.config, + enable_caching=self._enable_caching, + enable_planning=self._enable_planning, + enable_compaction=self._enable_compaction, + ) self._initialized = True return self._agent @@ -480,11 +817,5 @@ def close(self) -> None: self._initialized = False def is_initialized(self) -> bool: - """Check if any clients have been initialized. - - Returns - ------- - bool - True if any clients have been initialized. - """ + """Check if any clients have been initialized.""" return self._initialized diff --git a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/cli.py b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/cli.py new file mode 100644 index 0000000..c8f9ed0 --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/cli.py @@ -0,0 +1,1559 @@ +#!/usr/bin/env python3 +"""Knowledge Agent CLI. + +Command-line interface for running and evaluating the Knowledge-Grounded QA Agent. + +Usage:: + + knowledge-qa ask "What is..." + knowledge-qa eval --samples 3 + knowledge-qa eval --ids 123 456 789 + knowledge-qa sample --ids 123 + knowledge-qa sample --category "Finance & Economics" --count 5 +""" + +import argparse +import asyncio +import io +import logging +import re +import sys +from importlib.metadata import version +from pathlib import Path + +from aieng.agent_evals.configs import Configs +from aieng.agent_evals.evaluation.trace import flush_traces +from aieng.agent_evals.knowledge_qa.deepsearchqa_grader import ( + EvaluationOutcome, + evaluate_deepsearchqa_async, +) +from aieng.agent_evals.langfuse import init_tracing +from dotenv import load_dotenv +from rich import box +from rich.console import Console, Group +from rich.live import Live +from rich.panel import Panel +from rich.status import Status +from rich.table import Table +from rich.text import Text + +from .agent import KnowledgeGroundedAgent +from .data import DeepSearchQADataset +from .plan_parsing import StepStatus + + +# Load .env file from current directory or parent directories +def _load_env() -> None: + """Load environment variables from .env file.""" + # Try current directory first, then walk up + for parent in [Path.cwd(), *Path.cwd().parents]: + env_file = parent / ".env" + if env_file.exists(): + load_dotenv(env_file) + return + # Fallback to default dotenv behavior + load_dotenv() + + +_load_env() + +console = Console() + +# Vector Institute cyan color +VECTOR_CYAN = "#00B4D8" + + +def get_version() -> str: + """Get the installed version of the package.""" + try: + return version("aieng-eval-agents") + except Exception: + return "dev" + + +def _get_model_config() -> tuple[str, str]: + """Get model names from config. + + Returns + ------- + tuple[str, str] + The worker model and evaluator model names from config. + """ + try: + config = Configs() # type: ignore[call-arg] + return config.default_worker_model, config.default_evaluator_model + except Exception: + return "gemini-2.5-flash", "gemini-2.5-pro" + + +def display_banner() -> None: + """Display the CLI banner with version and model info.""" + ver = get_version() + worker_model, evaluator_model = _get_model_config() + + # Robot face with magnifying glass + line0 = Text() + line0.append(" ◯─◯ ", style=f"{VECTOR_CYAN} bold") + line0.append("knowledge-qa ", style="white bold") + line0.append(f"v{ver}", style="bright_black") + + line1 = Text() + line1.append(" ╱ 🔍 ╲ ", style=f"{VECTOR_CYAN} bold") + line1.append("Agent: ", style="dim") + line1.append(worker_model, style="cyan") + + line2 = Text() + line2.append(" │ │ ", style=f"{VECTOR_CYAN} bold") + line2.append("Evaluator: ", style="dim") + line2.append(evaluator_model, style="yellow") + + line3 = Text() + line3.append(" ╲__╱ ", style=f"{VECTOR_CYAN} bold") + line3.append("Vector Institute AI Engineering", style="bright_black") + + console.print() + console.print(line0) + console.print(line1) + console.print(line2) + console.print(line3) + console.print() + + +def display_tools_info() -> None: + """Display information about available tools.""" + console.print("[bold]Available Tools:[/bold]") + console.print() + + tools = [ + ("google_search", "blue", "Search the web for current information and sources"), + ("fetch_url", "green", "Fetch webpage content and save locally for searching"), + ("grep_file", "cyan", "Search within fetched files for matching patterns"), + ("read_file", "cyan", "Read sections of fetched files"), + ("read_pdf", "green", "Read and extract text from PDF documents"), + ] + + for name, color, desc in tools: + console.print(f" [{color}]{name:<16}[/{color}] {desc}") + + console.print() + + +def _parse_structured_answer(text: str) -> dict[str, str] | None: + """Parse structured answer format (ANSWER/SOURCES/REASONING). + + Parameters + ---------- + text : str + The raw response text. + + Returns + ------- + dict[str, str] | None + Parsed sections or None if parsing fails. + """ + if not text: + return None + + # Check if text contains our structured format + text_upper = text.upper() + if "ANSWER:" not in text_upper: + return None + + result = {"answer": "", "sources": "", "reasoning": ""} + + # Find positions of each section (case-insensitive) + # Match ANSWER:, SOURCES:, REASONING: with flexible spacing + answer_match = re.search(r"ANSWER:\s*", text, re.IGNORECASE) + sources_match = re.search(r"SOURCES:\s*", text, re.IGNORECASE) + reasoning_match = re.search(r"REASONING:\s*", text, re.IGNORECASE) + + if answer_match: + start = answer_match.end() + # Find end - next section or end of text + end = len(text) + if sources_match and sources_match.start() > start: + end = min(end, sources_match.start()) + if reasoning_match and reasoning_match.start() > start: + end = min(end, reasoning_match.start()) + result["answer"] = text[start:end].strip() + + if sources_match: + start = sources_match.end() + end = len(text) + if reasoning_match and reasoning_match.start() > start: + end = min(end, reasoning_match.start()) + if answer_match and answer_match.start() > start: + end = min(end, answer_match.start()) + result["sources"] = text[start:end].strip() + + if reasoning_match: + start = reasoning_match.end() + end = len(text) + if sources_match and sources_match.start() > start: + end = min(end, sources_match.start()) + if answer_match and answer_match.start() > start: + end = min(end, answer_match.start()) + result["reasoning"] = text[start:end].strip() + + # Return None if we didn't extract any meaningful content + if not result["answer"]: + return None + + return result + + +class ToolCallHandler(logging.Handler): + """Custom logging handler that captures tool calls for rich display.""" + + def __init__(self): + super().__init__() + self.tool_calls: list[dict] = [] + + def emit(self, record): + """Process a log record, capturing tool calls for display.""" + msg = record.getMessage() + if "Tool call:" in msg: + try: + parts = msg.split("Tool call: ", 1)[1] + paren_idx = parts.find("(") + if paren_idx > 0: + tool_name = parts[:paren_idx] + args_str = parts[paren_idx + 1 : -1] + if len(args_str) > 80: + args_str = args_str[:77] + "..." + self.tool_calls.append( + { + "name": tool_name, + "args": args_str, + "completed": False, + "failed": False, + "error": None, + } + ) + except Exception: + pass + elif "Tool error:" in msg: + # Mark the most recent incomplete tool call as failed + try: + parts = msg.split("Tool error: ", 1)[1] + # Format: "tool_name failed - error message" + tool_part, error_msg = ( + parts.split(" failed - ", 1) if " failed - " in parts else (parts, "Unknown error") + ) + tool_name = tool_part.strip() + # Find the most recent matching incomplete tool call + for tc in reversed(self.tool_calls): + if tc["name"] == tool_name and not tc["completed"] and not tc["failed"]: + tc["failed"] = True + tc["error"] = error_msg[:60] + "..." if len(error_msg) > 60 else error_msg + break + except Exception: + pass + elif "Tool response:" in msg: + # Mark the most recent incomplete tool call as completed + try: + parts = msg.split("Tool response: ", 1)[1] + tool_name = parts.split(" ")[0] + # Find the most recent matching incomplete tool call + for tc in reversed(self.tool_calls): + if tc["name"] == tool_name and not tc["completed"] and not tc["failed"]: + tc["completed"] = True + break + except Exception: + pass + + def clear(self): + """Reset captured tool calls.""" + self.tool_calls = [] + + +def _parse_markdown_bold(text: str, base_style: str) -> Text: + """Parse markdown-style **bold** markers and return styled Rich Text. + + Parameters + ---------- + text : str + Text that may contain **bold** markers. + base_style : str + The base style to apply to non-bold text. + + Returns + ------- + Text + Rich Text object with bold sections properly styled. + """ + result = Text() + # Match **text** patterns + pattern = r"\*\*([^*]+)\*\*" + last_end = 0 + + for match in re.finditer(pattern, text): + # Add text before the match with base style + if match.start() > last_end: + result.append(text[last_end : match.start()], style=base_style) + # Add the bold text (combine bold with base style) + bold_style = f"bold {base_style}" if base_style else "bold" + result.append(match.group(1), style=bold_style) + last_end = match.end() + + # Add remaining text after last match + if last_end < len(text): + result.append(text[last_end:], style=base_style) + + return result + + +def _create_plan_display(plan) -> Panel: + """Create a rich panel showing the research plan checklist. + + Parameters + ---------- + plan : ResearchPlan + The research plan to display. Step statuses are read directly from the plan. + + Returns + ------- + Panel + A rich panel with the plan checklist. + """ + lines = [] + + for step in plan.steps: + # Use the step's actual status from the plan (updated by the agent in real-time) + if step.status == StepStatus.COMPLETED: + icon, icon_style = "✓", "green" + desc_style = "dim" + elif step.status == StepStatus.FAILED: + icon, icon_style = "✗", "red" + desc_style = "red" + elif step.status == StepStatus.IN_PROGRESS: + icon, icon_style = "→", "bold yellow" + desc_style = "yellow" + elif step.status == StepStatus.SKIPPED: + icon, icon_style = "○", "dim" + desc_style = "strike dim" + else: + # PENDING - not yet started + icon, icon_style = "○", "dim" + desc_style = "dim" + + line = Text() + line.append(" ") + line.append(icon, style=icon_style) + line.append(f" {step.step_id}. ", style="bold") + # Parse markdown bold markers in description + styled_desc = _parse_markdown_bold(step.description, desc_style) + line.append_text(styled_desc) + lines.append(line) + + content = Group(*lines) if lines else Text("No plan steps", style="dim") + + return Panel( + content, + title="[bold magenta]📋 Research Plan[/bold magenta]", + subtitle=f"[dim]{len(plan.steps)} steps[/dim]", + border_style="magenta", + padding=(0, 1), + ) + + +def _get_tool_display_info(name: str) -> tuple[str, str, str]: + """Get display name, icon, and style for a tool. + + Returns (display_name, icon, style). + """ + # Normalize tool name for display + display_name = "google_search" if name == "google_search_agent" else name + + # Tool icon and style lookup + tool_styles = { + "fetch_url": ("🌐", "green"), + "read_pdf": ("📄", "green"), + "grep_file": ("📑", "cyan"), + "read_file": ("📖", "cyan"), + "google_search": ("🔍", "blue"), + "google_search_agent": ("🔍", "blue"), + } + icon, style = tool_styles.get(name, ("🔧", "white")) + return display_name, icon, style + + +def _create_compact_question_panel( + question: str, example_id: int | None = None, answer_type: str | None = None +) -> Panel: + """Create a compact question panel for the live display. + + Parameters + ---------- + question : str + The question text. + example_id : int, optional + The example ID if in eval mode. + answer_type : str, optional + The answer type if in eval mode. + + Returns + ------- + Panel + A compact question panel. + """ + title = "[bold blue]📋 Question[/bold blue]" + if example_id is not None: + title = f"[bold blue]📋 Question (ID: {example_id})[/bold blue]" + + subtitle = f"[dim]Answer Type: {answer_type}[/dim]" if answer_type else None + + return Panel( + question, + title=title, + subtitle=subtitle, + border_style="blue", + padding=(0, 1), + ) + + +def _create_compact_ground_truth_panel(ground_truth: str) -> Panel: + """Create a compact ground truth panel for the live display. + + Parameters + ---------- + ground_truth : str + The ground truth answer. + + Returns + ------- + Panel + A compact ground truth panel. + """ + # Truncate long ground truth for display + display_gt = ground_truth if len(ground_truth) <= 150 else ground_truth[:147] + "..." + + return Panel( + f"[yellow]{display_gt}[/yellow]", + title="[bold yellow]🎯 Ground Truth[/bold yellow]", + border_style="yellow", + padding=(0, 1), + ) + + +def create_tool_display( + tool_calls: list[dict], + plan=None, + context_percent: float | None = None, + question: str | None = None, + ground_truth: str | None = None, + example_id: int | None = None, + answer_type: str | None = None, +) -> Group | Panel: + """Create a rich display showing tool calls and optionally the plan. + + Parameters + ---------- + tool_calls : list[dict] + List of tool calls made so far. + plan : ResearchPlan, optional + If provided, shows the plan checklist above tool calls. + context_percent : float, optional + Percentage of context window remaining. + question : str, optional + The question being answered (for eval mode display). + ground_truth : str, optional + The ground truth answer (for eval mode display). + example_id : int, optional + The example ID (for eval mode display). + answer_type : str, optional + The answer type (for eval mode display). + + Returns + ------- + Group or Panel + The display content. + """ + tool_content = _build_tool_calls_content(tool_calls, plan is not None) + + # Build subtitle with tool calls and context usage + subtitle_parts = [f"{len(tool_calls)} tool calls"] + if context_percent is not None: + # Color code based on remaining context + if context_percent > 50: + color = "green" + elif context_percent > 20: + color = "yellow" + else: + color = "red" + subtitle_parts.append(f"[{color}]{context_percent:.0f}% context left[/{color}]") + + tool_panel = Panel( + tool_content, + title="[bold cyan]🔧 Agent Working[/bold cyan]", + subtitle=f"[dim]{' | '.join(subtitle_parts)}[/dim]", + border_style="cyan", + padding=(0, 1), + ) + + # Build the display components + components: list[Panel | Text] = [] + + # Add question and ground truth panels if in eval mode + if question is not None: + components.append(_create_compact_question_panel(question, example_id, answer_type)) + components.append(Text("")) + + if ground_truth is not None: + components.append(_create_compact_ground_truth_panel(ground_truth)) + components.append(Text("")) + + # Add plan if available + if plan and plan.steps: + components.append(_create_plan_display(plan)) + components.append(Text("")) + + # Always add the tool panel + components.append(tool_panel) + + # Return as group if we have multiple components, otherwise just the tool panel + if len(components) > 1: + return Group(*components) + return tool_panel + + +def _build_tool_calls_content(tool_calls: list[dict], has_plan: bool) -> Group | Text: + """Build the content for tool calls display.""" + if not tool_calls: + return Text("Waiting for tool calls...", style="dim") + + lines = [] + display_calls = tool_calls[-6:] if has_plan else tool_calls[-8:] + if len(tool_calls) > len(display_calls): + lines.append(Text(f" ... ({len(tool_calls) - len(display_calls)} earlier calls)", style="dim")) + + for tc in display_calls: + is_completed = tc.get("completed", False) + is_failed = tc.get("failed", False) + display_name, icon, style = _get_tool_display_info(tc["name"]) + + line = Text() + if is_failed: + line.append(" ✗ ", style="bold red") + line.append(f"{icon} ", style="red") + line.append(display_name, style="bold red") + line.append(f" {tc['args']}", style="dim red") + if tc.get("error"): + line.append(f" [{tc['error']}]", style="red") + elif is_completed: + line.append(" ✓ ", style="dim green") + line.append(f"{icon} ", style=style) + line.append(display_name, style=f"bold {style}") + line.append(f" {tc['args']}", style="dim") + else: + line.append(" → ", style="bold yellow") + line.append(f"{icon} ", style=style) + line.append(display_name, style=f"bold {style}") + line.append(f" {tc['args']}", style="dim") + lines.append(line) + + return Group(*lines) if lines else Text("No tool calls yet", style="dim") + + +def display_tool_usage(tool_calls: list[dict]) -> dict[str, int]: + """Display and return tool usage statistics.""" + tool_counts: dict[str, int] = {} + for tc in tool_calls: + name = tc.get("name", "unknown") + # Normalize google_search_agent to google_search for cleaner display + if name == "google_search_agent": + name = "google_search" + tool_counts[name] = tool_counts.get(name, 0) + 1 + + if tool_counts: + table = Table(title="🔧 Tool Usage", show_header=True, header_style="bold magenta", box=None) + table.add_column("Tool", style="cyan", no_wrap=True) + table.add_column("Calls", justify="right", style="bold") + + for tool, count in sorted(tool_counts.items()): + if tool in ("fetch_url", "read_pdf"): + table.add_row(f"[bold green]✓ {tool}[/bold green]", f"[green]{count}[/green]") + elif tool == "grep_file": + table.add_row(f"[bold cyan]✓ {tool}[/bold cyan]", f"[cyan]{count}[/cyan]") + elif "search" in tool.lower(): + table.add_row(f"[blue]{tool}[/blue]", str(count)) + else: + table.add_row(tool, str(count)) + + console.print(table) + + return tool_counts + + +def setup_logging() -> ToolCallHandler: + """Configure logging to capture tool calls without verbose output.""" + logging.basicConfig(level=logging.ERROR, format="%(message)s", force=True) + + # Suppress verbose logging from external libraries and tools + for logger_name in [ + "google.adk", + "google.genai", + "google.generativeai", + "httpx", + "httpcore", + "aieng.agent_evals.tools", + "aieng.agent_evals.knowledge_qa.web_tools", + ]: + _logger = logging.getLogger(logger_name) + _logger.setLevel(logging.CRITICAL) + _logger.propagate = False + # Clear any existing handlers + _logger.handlers.clear() + + # Set up custom handler for tool call capture + tool_handler = ToolCallHandler() + tool_handler.setLevel(logging.INFO) + + # Configure agent logger to only capture tool calls, suppress other messages + agent_logger = logging.getLogger("aieng.agent_evals.knowledge_qa.agent") + agent_logger.handlers.clear() + agent_logger.addHandler(tool_handler) + agent_logger.setLevel(logging.INFO) + agent_logger.propagate = False + + # Add a filter to suppress non-tool-call messages + class ToolCallOnlyFilter(logging.Filter): + def filter(self, record): + msg = record.getMessage() + return "Tool call:" in msg or "Tool response:" in msg or "Tool error:" in msg + + tool_handler.addFilter(ToolCallOnlyFilter()) + + return tool_handler + + +async def run_agent_with_display( + agent, + question: str, + tool_handler: ToolCallHandler, + show_plan: bool = False, + ground_truth: str | None = None, + example_id: int | None = None, + answer_type: str | None = None, + example_num: int | None = None, + total_examples: int | None = None, +): + """Run the agent with live tool call display. + + Parameters + ---------- + agent : KnowledgeGroundedAgent + The agent to run. + question : str + The question to answer. + tool_handler : ToolCallHandler + Handler for capturing tool calls. + show_plan : bool + If True, display the research plan checklist during execution. + ground_truth : str, optional + The ground truth answer (for eval mode - shown in live display). + example_id : int, optional + The example ID (for eval mode - shown in live display). + answer_type : str, optional + The answer type (for eval mode - shown in live display). + example_num : int, optional + Current example number (for eval mode spinner display). + total_examples : int, optional + Total number of examples (for eval mode spinner display). + """ + live_console = Console(file=sys.stdout, force_terminal=True) + + # Show spinner while preparing (planning if enabled) + if example_num is not None and total_examples is not None: + spinner_text = f"[bold cyan]Example {example_num}/{total_examples}[/bold cyan]" + with Status(spinner_text, console=console, spinner="dots"): + if show_plan and hasattr(agent, "create_plan_async"): + await agent.create_plan_async(question) + elif show_plan and hasattr(agent, "create_plan_async"): + await agent.create_plan_async(question) + + # Capture stdout/stderr before Live to suppress agent output + original_stdout = sys.stdout + original_stderr = sys.stderr + sys.stdout = io.StringIO() + sys.stderr = io.StringIO() + + try: + with Live( + create_tool_display( + [], + plan=agent.current_plan if show_plan else None, + question=question, + ground_truth=ground_truth, + example_id=example_id, + answer_type=answer_type, + ), + console=live_console, + screen=True, + refresh_per_second=10, + ) as live: + task = asyncio.create_task(agent.answer_async(question)) + + while not task.done(): + current_plan = agent.current_plan if show_plan else None + # Get context percentage from token tracker if available + context_pct = None + if hasattr(agent, "token_tracker"): + context_pct = agent.token_tracker.usage.context_remaining_percent + live.update( + create_tool_display( + tool_handler.tool_calls, + plan=current_plan, + context_percent=context_pct, + question=question, + ground_truth=ground_truth, + example_id=example_id, + answer_type=answer_type, + ) + ) + await asyncio.sleep(0.1) + + return await task + finally: + sys.stdout = original_stdout + sys.stderr = original_stderr + + +def _setup_tracing(log_trace: bool) -> bool: + """Initialize Langfuse tracing if requested. + + Parameters + ---------- + log_trace : bool + Whether to enable tracing. + + Returns + ------- + bool + True if tracing was successfully enabled, False otherwise. + """ + if not log_trace: + return False + + enabled = init_tracing() + if enabled: + console.print("[green]✓ Langfuse tracing enabled[/green]\n") + else: + console.print("[yellow]⚠ Could not initialize Langfuse tracing[/yellow]\n") + return enabled + + +def _flush_tracing(tracing_enabled: bool) -> None: + """Flush traces to Langfuse if tracing was enabled. + + Parameters + ---------- + tracing_enabled : bool + Whether tracing is enabled. + """ + if not tracing_enabled: + return + + flush_traces() + console.print("\n[dim]Traces flushed to Langfuse[/dim]") + + +async def cmd_ask(question: str, show_plan: bool = False, log_trace: bool = False) -> int: + """Ask the agent a question. + + Parameters + ---------- + question : str + The question to ask. + show_plan : bool + Display the research plan checklist during execution. + log_trace : bool + Enable Langfuse tracing for this run. + """ + display_banner() + tracing_enabled = _setup_tracing(log_trace) + + console.print( + Panel( + question, + title="[bold blue]📋 Question[/bold blue]", + border_style="blue", + padding=(1, 2), + ) + ) + console.print() + + tool_handler = setup_logging() + + agent = KnowledgeGroundedAgent(enable_planning=True) + + tool_handler.clear() + response = await run_agent_with_display(agent, question, tool_handler, show_plan=show_plan) + + # Display results + console.print() + display_tool_usage(response.tool_calls) + + console.print() + + # Parse structured answer format (ANSWER/SOURCES/REASONING) + answer_text = response.text + parsed_answer = _parse_structured_answer(answer_text) + + if parsed_answer: + # Display formatted answer with sections + answer_content = Text() + if parsed_answer.get("answer"): + # Parse markdown bold markers in the answer + answer_content = _parse_markdown_bold(parsed_answer["answer"], "white") + + console.print( + Panel( + answer_content, + title="[bold cyan]🤖 Answer[/bold cyan]", + subtitle=f"[dim]Duration: {response.total_duration_ms / 1000:.1f}s[/dim]", + border_style="cyan", + padding=(1, 2), + ) + ) + + if parsed_answer.get("reasoning"): + console.print( + Panel( + parsed_answer["reasoning"], + title="[bold dim]💭 Reasoning[/bold dim]", + border_style="dim", + padding=(0, 1), + ) + ) + else: + # Fallback to raw display if parsing fails + console.print( + Panel( + answer_text, + title="[bold cyan]🤖 Answer[/bold cyan]", + subtitle=f"[dim]Duration: {response.total_duration_ms / 1000:.1f}s[/dim]", + border_style="cyan", + padding=(1, 2), + ) + ) + + if response.sources: + console.print("\n[bold]Sources:[/bold]") + for src in response.sources[:5]: + if src.uri: + console.print(f" • [blue]{src.title or 'Source'}[/blue]: {src.uri}") + + _flush_tracing(tracing_enabled) + console.print("\n[bold green]✓ Complete[/bold green]") + return 0 + + +# Outcome display configuration +OUTCOME_COLORS = { + EvaluationOutcome.FULLY_CORRECT.value: "green", + EvaluationOutcome.CORRECT_WITH_EXTRANEOUS.value: "yellow", + EvaluationOutcome.PARTIALLY_CORRECT.value: "orange1", + EvaluationOutcome.FULLY_INCORRECT.value: "red", +} +OUTCOME_ICONS = { + EvaluationOutcome.FULLY_CORRECT.value: "✅", + EvaluationOutcome.CORRECT_WITH_EXTRANEOUS.value: "🟡", + EvaluationOutcome.PARTIALLY_CORRECT.value: "🔶", + EvaluationOutcome.FULLY_INCORRECT.value: "❌", +} + + +def _display_example_result(example, response, idx: int, total: int) -> dict[str, int]: + """Display the full results for an evaluated example. + + Parameters + ---------- + example : DSQAExample + The example that was evaluated. + response : AgentResponse + The agent's response. + idx : int + Current index (1-based). + total : int + Total number of examples. + + Returns + ------- + dict[str, int] + Tool usage counts. + """ + console.print(f"\n[bold cyan]━━━ Example {idx}/{total} - Results ━━━[/bold cyan]\n") + console.print( + Panel( + example.problem, + title=f"[bold blue]📋 Question (ID: {example.example_id})[/bold blue]", + subtitle=f"[dim]Answer Type: {example.answer_type}[/dim]", + border_style="blue", + padding=(1, 2), + ) + ) + console.print() + console.print( + Panel( + f"[yellow]{example.answer}[/yellow]", + title="[bold yellow]🎯 Ground Truth[/bold yellow]", + border_style="yellow", + padding=(1, 2), + ) + ) + console.print() + tool_counts = display_tool_usage(response.tool_calls) + console.print() + + # Parse structured answer format (ANSWER/SOURCES/REASONING) + parsed_answer = _parse_structured_answer(response.text) + + if parsed_answer and parsed_answer.get("answer"): + # Display formatted answer (parse markdown bold markers) + answer_content = _parse_markdown_bold(parsed_answer["answer"], "white") + console.print( + Panel( + answer_content, + title="[bold cyan]🤖 Answer[/bold cyan]", + subtitle=f"[dim]Duration: {response.total_duration_ms / 1000:.1f}s[/dim]", + border_style="cyan", + padding=(1, 2), + ) + ) + + # Display reasoning if present + if parsed_answer.get("reasoning"): + console.print() + console.print( + Panel( + parsed_answer["reasoning"], + title="[bold dim]💭 Reasoning[/bold dim]", + border_style="dim", + padding=(0, 1), + ) + ) + else: + # Fallback to raw display if parsing fails + console.print( + Panel( + response.text, + title="[bold cyan]🤖 Agent Response[/bold cyan]", + subtitle=f"[dim]Duration: {response.total_duration_ms / 1000:.1f}s[/dim]", + border_style="cyan", + padding=(1, 2), + ) + ) + + return tool_counts + + +def _display_eval_result(result) -> None: + """Display evaluation metrics for a result.""" + color = OUTCOME_COLORS.get(result.outcome.value, "white") + icon = OUTCOME_ICONS.get(result.outcome.value, "•") + + # Main metrics table + metrics_table = Table(show_header=False, box=None, padding=(0, 2)) + metrics_table.add_column("Metric", style="bold") + metrics_table.add_column("Value", justify="right") + metrics_table.add_row("Outcome", f"[{color}]{icon} {result.outcome.value}[/{color}]") + metrics_table.add_row("Precision", f"[bold]{result.precision:.2f}[/bold]") + metrics_table.add_row("Recall", f"[bold]{result.recall:.2f}[/bold]") + metrics_table.add_row("F1 Score", f"[bold]{result.f1_score:.2f}[/bold]") + + console.print(Panel(metrics_table, title="[bold magenta]📊 Evaluation[/bold magenta]", border_style="magenta")) + + # Display judge explanation if available + if result.explanation: + console.print() + console.print( + Panel( + result.explanation, + title="[bold blue]💭 Judge Explanation[/bold blue]", + border_style="blue", + padding=(1, 2), + ) + ) + + # Display correctness details if available + if result.correctness_details: + console.print() + details_table = Table( + title="🎯 Correctness Details", + show_header=True, + header_style="bold cyan", + box=None, + ) + details_table.add_column("Ground Truth Item", style="white") + details_table.add_column("Found", justify="center", width=8) + + for item, found in result.correctness_details.items(): + found_icon = "[green]✓[/green]" if found else "[red]✗[/red]" + details_table.add_row(item, found_icon) + + console.print(details_table) + + # Display extraneous items if any + if result.extraneous_items: + console.print() + extra_text = "\n".join(f" • {item}" for item in result.extraneous_items) + console.print( + Panel( + f"[yellow]{extra_text}[/yellow]", + title="[bold yellow]⚠️ Extraneous Items[/bold yellow]", + subtitle=f"[dim]{len(result.extraneous_items)} item(s) not in ground truth[/dim]", + border_style="yellow", + padding=(0, 2), + ) + ) + + +def _display_eval_summary(results: list) -> None: + """Display comprehensive summary table for multiple evaluation results. + + Shows per-sample results, outcome distribution, and aggregate metrics. + """ + console.print() + + # Per-sample results table + sample_table = Table( + title="[bold cyan]📋 Per-Sample Results[/bold cyan]", + show_header=True, + header_style="bold", + box=box.ROUNDED, + title_justify="left", + ) + sample_table.add_column("ID", style="dim", width=8) + sample_table.add_column("Outcome", width=26) + sample_table.add_column("Precision", justify="right", width=10) + sample_table.add_column("Recall", justify="right", width=10) + sample_table.add_column("F1", justify="right", width=10) + + for example_id, result, _ in results: + color = OUTCOME_COLORS.get(result.outcome.value, "white") + icon = OUTCOME_ICONS.get(result.outcome.value, "•") + sample_table.add_row( + str(example_id), + f"[{color}]{icon} {result.outcome.value}[/{color}]", + f"{result.precision:.2f}", + f"{result.recall:.2f}", + f"{result.f1_score:.2f}", + ) + + console.print(sample_table) + console.print() + + # Count outcomes + outcome_counts = { + EvaluationOutcome.FULLY_CORRECT.value: 0, + EvaluationOutcome.CORRECT_WITH_EXTRANEOUS.value: 0, + EvaluationOutcome.PARTIALLY_CORRECT.value: 0, + EvaluationOutcome.FULLY_INCORRECT.value: 0, + } + for _, result, _ in results: + if result.outcome.value in outcome_counts: + outcome_counts[result.outcome.value] += 1 + + total = len(results) + + # Outcome distribution table + outcome_table = Table( + title="[bold magenta]📊 Outcome Distribution[/bold magenta]", + show_header=True, + header_style="bold", + box=box.ROUNDED, + title_justify="left", + ) + outcome_table.add_column("Outcome", width=30) + outcome_table.add_column("Count", justify="right", width=8) + outcome_table.add_column("Percentage", justify="right", width=12) + + outcome_display = [ + (EvaluationOutcome.FULLY_CORRECT.value, "Fully Correct", "green"), + (EvaluationOutcome.CORRECT_WITH_EXTRANEOUS.value, "Correct with Extraneous", "yellow"), + (EvaluationOutcome.PARTIALLY_CORRECT.value, "Partially Correct", "orange1"), + (EvaluationOutcome.FULLY_INCORRECT.value, "Fully Incorrect", "red"), + ] + + for key, label, color in outcome_display: + count = outcome_counts[key] + pct = (count / total * 100) if total > 0 else 0 + icon = OUTCOME_ICONS.get(key, "•") + outcome_table.add_row( + f"[{color}]{icon} {label}[/{color}]", + f"[{color}]{count}[/{color}]", + f"[{color}]{pct:.1f}%[/{color}]", + ) + + console.print(outcome_table) + console.print() + + # Calculate aggregate metrics + avg_precision = sum(r.precision for _, r, _ in results) / total if total > 0 else 0 + avg_recall = sum(r.recall for _, r, _ in results) / total if total > 0 else 0 + avg_f1 = sum(r.f1_score for _, r, _ in results) / total if total > 0 else 0 + + # Aggregate metrics table + metrics_table = Table( + title="[bold green]📈 Aggregate Metrics[/bold green]", + show_header=True, + header_style="bold", + box=box.ROUNDED, + title_justify="left", + ) + metrics_table.add_column("Metric", width=20) + metrics_table.add_column("Value", justify="right", width=12) + + # Color code F1 based on performance + if avg_f1 >= 0.8: + f1_color = "green" + elif avg_f1 >= 0.5: + f1_color = "yellow" + else: + f1_color = "red" + + metrics_table.add_row("Samples Evaluated", f"[bold]{total}[/bold]") + metrics_table.add_row("Avg Precision", f"[bold]{avg_precision:.3f}[/bold]") + metrics_table.add_row("Avg Recall", f"[bold]{avg_recall:.3f}[/bold]") + metrics_table.add_row("Avg F1 Score", f"[bold {f1_color}]{avg_f1:.3f}[/bold {f1_color}]") + + console.print(metrics_table) + + +async def cmd_eval( + samples: int = 1, + category: str = "Finance & Economics", + ids: list[int] | None = None, + show_plan: bool = False, + log_trace: bool = False, +) -> int: + """Run evaluation on DeepSearchQA samples. + + Parameters + ---------- + samples : int + Number of samples to evaluate (used when ids not specified). + category : str + Dataset category to filter by (used when ids not specified). + ids : list[int], optional + Specific example IDs to evaluate. If provided, samples and category are ignored. + show_plan : bool + Display the research plan checklist during execution. + log_trace : bool + Enable Langfuse tracing for this run. + """ + display_banner() + tracing_enabled = _setup_tracing(log_trace) + + # Build info text based on selection mode + if ids: + info_text = f"[bold]Evaluation Mode[/bold]\n\nExample IDs: [cyan]{', '.join(map(str, ids))}[/cyan]" + else: + info_text = ( + f"[bold]Evaluation Mode[/bold]\n\nCategory: [cyan]{category}[/cyan]\nSamples: [cyan]{samples}[/cyan]" + ) + + if show_plan: + info_text += "\nPlan Display: [green]enabled[/green]" + + console.print( + Panel( + info_text, + title="📊 DeepSearchQA Evaluation", + border_style="blue", + ) + ) + console.print() + + console.print("[bold blue]Loading dataset...[/bold blue]") + dataset = DeepSearchQADataset() + + # Get examples by ID or by category + if ids: + examples = dataset.get_by_ids(ids) + if len(examples) != len(ids): + found_ids = {ex.example_id for ex in examples} + missing_ids = [eid for eid in ids if eid not in found_ids] + console.print(f"[yellow]Warning: IDs not found: {missing_ids}[/yellow]") + else: + examples = dataset.get_by_category(category)[:samples] + + if not examples: + console.print("[bold red]Error: No examples found matching the criteria.[/bold red]") + return 1 + + console.print(f"[green]✓ Loaded {len(examples)} example(s)[/green]\n") + + console.print("[bold blue]Initializing agent...[/bold blue]") + agent = KnowledgeGroundedAgent(enable_planning=True) + console.print("[green]✓ Ready[/green]\n") + + tool_handler = setup_logging() + results = [] + + for i, example in enumerate(examples, 1): + tool_handler.clear() + agent.reset() # Clear session state between examples + + try: + response = await run_agent_with_display( + agent, + example.problem, + tool_handler, + show_plan=show_plan, + ground_truth=example.answer, + example_id=example.example_id, + answer_type=example.answer_type, + example_num=i, + total_examples=len(examples), + ) + + # Display full results after Live display ends + tool_counts = _display_example_result(example, response, i, len(examples)) + console.print("\n[bold blue]⏳ Evaluating...[/bold blue]\n") + result = await evaluate_deepsearchqa_async( + question=example.problem, + answer=response.text, + ground_truth=example.answer, + answer_type=example.answer_type, + ) + _display_eval_result(result) + results.append((example.example_id, result, tool_counts)) + + except Exception as e: + console.print(f"[bold red]Error: {e}[/bold red]") + + if results: + _display_eval_summary(results) + + _flush_tracing(tracing_enabled) + console.print("\n[bold green]✓ Evaluation complete[/bold green]") + return 0 + + +def _display_sample_detailed(example, idx: int | None = None, total: int | None = None) -> None: + """Display a single sample with full details. + + Parameters + ---------- + example : DSQAExample + The example to display. + idx : int, optional + Current index (1-based) for display in a list. + total : int, optional + Total number of examples being displayed. + """ + # Header with index if provided + if idx is not None and total is not None: + console.print(f"\n[bold cyan]━━━ Sample {idx}/{total} ━━━[/bold cyan]\n") + else: + console.print() + + # Metadata table + meta_table = Table(show_header=False, box=None, padding=(0, 2)) + meta_table.add_column("Field", style="bold dim") + meta_table.add_column("Value") + meta_table.add_row("ID", f"[cyan]{example.example_id}[/cyan]") + meta_table.add_row("Category", f"[magenta]{example.problem_category}[/magenta]") + meta_table.add_row("Answer Type", f"[blue]{example.answer_type}[/blue]") + + console.print(Panel(meta_table, title="[bold]📋 Metadata[/bold]", border_style="dim")) + + # Question + console.print() + console.print( + Panel( + example.problem, + title="[bold blue]❓ Question[/bold blue]", + border_style="blue", + padding=(1, 2), + ) + ) + + # Ground truth answer + console.print() + console.print( + Panel( + f"[yellow]{example.answer}[/yellow]", + title="[bold yellow]🎯 Ground Truth Answer[/bold yellow]", + border_style="yellow", + padding=(1, 2), + ) + ) + + +def cmd_sample( + ids: list[int] | None = None, + category: str | None = None, + count: int = 5, + random: bool = False, + list_categories: bool = False, +) -> int: + """View samples from the DeepSearchQA dataset. + + Parameters + ---------- + ids : list[int], optional + Specific example IDs to view. + category : str, optional + Filter by category. + count : int + Number of samples to show (default 5). + random : bool + If True, select random samples instead of first N. + list_categories : bool + If True, list all available categories and exit. + + Returns + ------- + int + Exit code (0 for success). + """ + display_banner() + + console.print("[bold blue]Loading dataset...[/bold blue]") + dataset = DeepSearchQADataset() + console.print(f"[green]✓ Loaded {len(dataset)} total examples[/green]\n") + + # List categories mode + if list_categories: + categories = dataset.get_categories() + table = Table(title="📂 Available Categories", show_header=True, header_style="bold cyan") + table.add_column("#", style="dim", width=4) + table.add_column("Category", style="bold") + table.add_column("Count", justify="right") + + for i, cat in enumerate(sorted(categories), 1): + cat_count = len(dataset.get_by_category(cat)) + table.add_row(str(i), cat, str(cat_count)) + + console.print(table) + console.print(f"\n[dim]Total: {len(categories)} categories[/dim]") + return 0 + + # Get examples based on selection criteria + if ids: + examples = dataset.get_by_ids(ids) + if len(examples) != len(ids): + found_ids = {ex.example_id for ex in examples} + missing_ids = [eid for eid in ids if eid not in found_ids] + console.print(f"[yellow]Warning: IDs not found: {missing_ids}[/yellow]\n") + selection_desc = f"IDs: {', '.join(map(str, ids))}" + elif category: + all_in_category = dataset.get_by_category(category) + if not all_in_category: + console.print(f"[bold red]Error: Category '{category}' not found.[/bold red]") + console.print("[dim]Use --list-categories to see available categories.[/dim]") + return 1 + if random: + import random as rand_module # noqa: PLC0415 + + examples = rand_module.sample(all_in_category, min(count, len(all_in_category))) + else: + examples = all_in_category[:count] + selection_desc = f"Category: {category} ({len(all_in_category)} total)" + elif random: + examples = dataset.sample(n=count) + selection_desc = f"Random {count} samples" + else: + examples = dataset.examples[:count] + selection_desc = f"First {count} samples" + + if not examples: + console.print("[bold red]No examples found matching the criteria.[/bold red]") + return 1 + + # Display selection info + console.print( + Panel( + f"[bold]Selection:[/bold] {selection_desc}\n[bold]Showing:[/bold] {len(examples)} example(s)", + title="📊 Dataset View", + border_style="blue", + ) + ) + + # Display each example + for i, example in enumerate(examples, 1): + _display_sample_detailed(example, idx=i, total=len(examples)) + + console.print("\n[bold green]✓ Done[/bold green]") + return 0 + + +def _display_help() -> None: + """Display colorful help message using Rich.""" + console.print() + + # Commands table + commands_table = Table( + show_header=True, + header_style="bold cyan", + box=None, + padding=(0, 2), + ) + commands_table.add_column("Command", style="bold green", width=12) + commands_table.add_column("Description") + + commands_table.add_row("ask", "Ask the agent a question") + commands_table.add_row("eval", "Run evaluation on DeepSearchQA") + commands_table.add_row("sample", "View samples from the DeepSearchQA dataset") + commands_table.add_row("tools", "Display available tools") + + console.print("[bold]Commands:[/bold]") + console.print(commands_table) + console.print() + + # Options + console.print("[bold]Options:[/bold]") + console.print(" [cyan]-h, --help[/cyan] Show this help message") + console.print(" [cyan]--version[/cyan] Show version number") + console.print() + + # Usage examples + console.print("[bold]Examples:[/bold]") + console.print(' [dim]$[/dim] knowledge-qa [green]ask[/green] [yellow]"What is quantum computing?"[/yellow]') + console.print( + ' [dim]$[/dim] knowledge-qa [green]ask[/green] [yellow]"What is AI?"[/yellow] [cyan]--log-trace[/cyan]' + ) + console.print(" [dim]$[/dim] knowledge-qa [green]eval[/green] [cyan]--samples[/cyan] 3") + console.print(" [dim]$[/dim] knowledge-qa [green]eval[/green] [cyan]--ids[/cyan] 123 456 [cyan]--show-plan[/cyan]") + console.print(" [dim]$[/dim] knowledge-qa [green]eval[/green] [cyan]--samples[/cyan] 5 [cyan]--log-trace[/cyan]") + console.print( + ' [dim]$[/dim] knowledge-qa [green]sample[/green] [cyan]--category[/cyan] [yellow]"Finance & Economics"[/yellow]' + ) + console.print() + + +def main() -> int: + """Run the Knowledge Agent CLI.""" + parser = argparse.ArgumentParser( + prog="knowledge-qa", + description="Knowledge-Grounded QA Agent CLI", + formatter_class=argparse.RawDescriptionHelpFormatter, + add_help=False, # We'll handle help ourselves + ) + parser.add_argument( + "-h", + "--help", + action="store_true", + help="Show this help message and exit", + ) + parser.add_argument( + "--version", + action="store_true", + help="Show version number and exit", + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Ask command + ask_parser = subparsers.add_parser("ask", help="Ask the agent a question") + ask_parser.add_argument("question", type=str, help="The question to ask") + ask_parser.add_argument( + "--show-plan", + action="store_true", + help="Display the research plan checklist during execution", + ) + ask_parser.add_argument( + "--log-trace", + action="store_true", + help="Enable Langfuse tracing for this run", + ) + + # Eval command + eval_parser = subparsers.add_parser("eval", help="Run evaluation on DeepSearchQA") + eval_parser.add_argument( + "--samples", + type=int, + default=1, + help="Number of samples to evaluate (default: 1, ignored if --ids is used)", + ) + eval_parser.add_argument( + "--category", + type=str, + default="Finance & Economics", + help="Dataset category (default: Finance & Economics, ignored if --ids is used)", + ) + eval_parser.add_argument( + "--ids", + type=int, + nargs="+", + metavar="ID", + help="Specific example ID(s) to evaluate (overrides --samples and --category)", + ) + eval_parser.add_argument( + "--show-plan", + action="store_true", + help="Display the research plan checklist during execution", + ) + eval_parser.add_argument( + "--log-trace", + action="store_true", + help="Enable Langfuse tracing for this run", + ) + + # Sample command + sample_parser = subparsers.add_parser("sample", help="View samples from the DeepSearchQA dataset") + sample_parser.add_argument( + "--ids", + type=int, + nargs="+", + metavar="ID", + help="Specific example ID(s) to view", + ) + sample_parser.add_argument( + "--category", + type=str, + help="Filter samples by category", + ) + sample_parser.add_argument( + "--count", + type=int, + default=5, + help="Number of samples to show (default: 5)", + ) + sample_parser.add_argument( + "--random", + action="store_true", + help="Select random samples instead of first N", + ) + sample_parser.add_argument( + "--list-categories", + action="store_true", + help="List all available categories and exit", + ) + + # Tools command + subparsers.add_parser("tools", help="Display available tools") + + args = parser.parse_args() + + if args.command == "ask": + return asyncio.run(cmd_ask(args.question, args.show_plan, args.log_trace)) + if args.command == "eval": + return asyncio.run(cmd_eval(args.samples, args.category, args.ids, args.show_plan, args.log_trace)) + if args.command == "sample": + return cmd_sample( + ids=args.ids, + category=args.category, + count=args.count, + random=args.random, + list_categories=args.list_categories, + ) + if args.command == "tools": + display_banner() + display_tools_info() + return 0 + + # Show help for no command or explicit --help + display_banner() + if args.version: + console.print(f"[bold]knowledge-qa[/bold] v{get_version()}") + return 0 + _display_help() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/retry.py b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/retry.py new file mode 100644 index 0000000..c787e36 --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/retry.py @@ -0,0 +1,69 @@ +"""Retry configuration and error handling for API calls. + +This module provides retry logic for handling rate limits, quota exhaustion, +and context overflow errors when interacting with the Gemini API. +""" + +from google.genai.errors import ClientError + + +# Max retries for empty model responses +MAX_EMPTY_RESPONSE_RETRIES = 2 + +# API retry configuration for rate limit and quota exhaustion +API_RETRY_MAX_ATTEMPTS = 5 +API_RETRY_INITIAL_WAIT = 1 # seconds +API_RETRY_MAX_WAIT = 60 # seconds +API_RETRY_JITTER = 5 # seconds + + +def is_retryable_api_error(exception: BaseException) -> bool: + """Check if an exception is a retryable API error (rate limit/quota exhaustion). + + Does NOT retry context overflow or cache expiration - those need session + reset instead. + + Parameters + ---------- + exception : BaseException + The exception to check. + + Returns + ------- + bool + True if the exception should trigger a retry (429/RESOURCE_EXHAUSTED errors). + """ + if isinstance(exception, ClientError): + error_str = str(exception).lower() + + # Don't retry context overflow - needs session reset, not retry + if "token count exceeds" in error_str or ("invalid_argument" in error_str and "token" in error_str): + return False + + # Don't retry cache expiration - needs session reset, not retry + if "cache" in error_str and "expired" in error_str: + return False + + # Check for rate limit indicators + if "429" in error_str or "resource_exhausted" in error_str or "quota" in error_str: + return True + return False + + +def is_context_overflow_error(exception: BaseException) -> bool: + """Check if an exception is a context overflow error. + + Parameters + ---------- + exception : BaseException + The exception to check. + + Returns + ------- + bool + True if the exception is due to context window overflow. + """ + if isinstance(exception, ClientError): + error_str = str(exception).lower() + return "token count exceeds" in error_str or ("invalid_argument" in error_str and "token" in error_str) + return False diff --git a/aieng-eval-agents/aieng/agent_evals/knowledge_qa/token_tracker.py b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/token_tracker.py new file mode 100644 index 0000000..4cc3a8b --- /dev/null +++ b/aieng-eval-agents/aieng/agent_evals/knowledge_qa/token_tracker.py @@ -0,0 +1,175 @@ +"""Token usage tracking for Gemini models. + +This module provides utilities for tracking token usage and context +window consumption during agent execution. +""" + +import logging +import os +from typing import Any + +from google.genai import Client +from pydantic import BaseModel + + +logger = logging.getLogger(__name__) +DEFAULT_MODEL = os.environ.get("DEFAULT_WORKER_MODEL", "gemini-2.5-flash") + +# Known context limits for Gemini models (as of 2025) +# Used as fallback if API fetch fails +KNOWN_MODEL_LIMITS: dict[str, int] = { + "gemini-2.5-pro": 1_048_576, + "gemini-2.5-flash": 1_048_576, +} + + +class TokenUsage(BaseModel): + """Token usage statistics. + + Attributes + ---------- + latest_prompt_tokens : int + Prompt tokens from the most recent API call. This represents the + actual current context size since each call includes full history. + latest_cached_tokens : int + Cached tokens from the most recent API call. + total_prompt_tokens : int + Cumulative prompt tokens across all calls (for cost tracking). + total_completion_tokens : int + Cumulative completion tokens across all calls. + total_tokens : int + Cumulative total tokens across all calls. + context_limit : int + Maximum context window size for the model. + """ + + latest_prompt_tokens: int = 0 + latest_cached_tokens: int = 0 + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + total_tokens: int = 0 + context_limit: int = 1_000_000 # Default for Gemini 2.5 Flash + + @property + def context_used_percent(self) -> float: + """Calculate percentage of context window currently used. + + Uses the latest prompt tokens (total, including cached) since cached + tokens still occupy space in the context window. Caching only affects + processing speed and billing, not the context window limit. + """ + if self.context_limit == 0: + return 0.0 + return (self.latest_prompt_tokens / self.context_limit) * 100 + + @property + def context_remaining_percent(self) -> float: + """Calculate percentage of context window remaining.""" + return max(0.0, 100.0 - self.context_used_percent) + + +class TokenTracker: + """Tracks token usage across agent interactions. + + Parameters + ---------- + model : str + The model name to track tokens for. + + Examples + -------- + >>> tracker = TokenTracker() # Uses DEFAULT_WORKER_MODEL from .env + >>> tracker.add_from_event(event) + >>> print(f"Context remaining: {tracker.usage.context_remaining_percent:.1f}%") + """ + + def __init__(self, model: str | None = None) -> None: + """Initialize the token tracker. + + Parameters + ---------- + model : str, optional + The model name to fetch context limits for. + Defaults to DEFAULT_WORKER_MODEL from environment. + """ + self._model = model or DEFAULT_MODEL + self._usage = TokenUsage() + self._fetch_model_limits() + + def _fetch_model_limits(self) -> None: + """Fetch model context limits from the API, with known fallbacks.""" + client = None + try: + client = Client() + model_info = client.models.get(model=self._model) + if model_info.input_token_limit: + self._usage.context_limit = model_info.input_token_limit + logger.debug(f"Model {self._model} context limit: {self._usage.context_limit}") + return + except Exception as e: + logger.warning(f"Failed to fetch model limits from API: {e}") + finally: + # Properly close the client to avoid aiohttp session leaks + if client is not None: + client.close() + + # Use known fallback if available + if self._model in KNOWN_MODEL_LIMITS: + self._usage.context_limit = KNOWN_MODEL_LIMITS[self._model] + logger.info(f"Using known limit for {self._model}: {self._usage.context_limit}") + else: + logger.warning(f"Unknown model {self._model}, using default limit: {self._usage.context_limit}") + + @property + def usage(self) -> TokenUsage: + """Get current token usage statistics.""" + return self._usage + + def add_from_event(self, event: Any) -> None: + """Add token usage from an ADK event. + + Updates both the latest token counts (for context tracking) and + cumulative totals (for cost tracking). + + Parameters + ---------- + event : Any + An event from the ADK runner that may contain usage_metadata. + """ + if not hasattr(event, "usage_metadata") or event.usage_metadata is None: + return + + metadata = event.usage_metadata + + # Extract token counts from this API call + prompt = getattr(metadata, "prompt_token_count", 0) or 0 + cached = getattr(metadata, "cached_content_token_count", 0) or 0 + completion = getattr(metadata, "candidates_token_count", 0) or 0 + total = getattr(metadata, "total_token_count", 0) or 0 + + # Update LATEST tokens - this reflects current context size + # Each API call includes full conversation history, so the latest + # prompt_token_count is the actual current context usage + self._usage.latest_prompt_tokens = prompt + self._usage.latest_cached_tokens = cached + + # Accumulate totals for cost/usage tracking + self._usage.total_prompt_tokens += prompt + self._usage.total_completion_tokens += completion + self._usage.total_tokens += total + + logger.debug( + f"Token update: prompt={prompt} (cached={cached}), context: {self._usage.context_used_percent:.1f}% used" + ) + + def reset(self) -> None: + """Reset all token counts (keeps context limit).""" + context_limit = self._usage.context_limit + self._usage = TokenUsage( + latest_prompt_tokens=0, + latest_cached_tokens=0, + total_prompt_tokens=0, + total_completion_tokens=0, + total_tokens=0, + context_limit=context_limit, + ) diff --git a/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py index 638134b..ba33840 100644 --- a/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py +++ b/aieng-eval-agents/tests/aieng/agent_evals/knowledge_qa/test_agent.py @@ -1,14 +1,347 @@ -"""Tests for the Knowledge-Grounded QA Agent.""" +"""Tests for the Knowledge-Grounded QA Agent and data models.""" -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from aieng.agent_evals.knowledge_qa.agent import ( - SYSTEM_INSTRUCTIONS, + AgentResponse, KnowledgeAgentManager, KnowledgeGroundedAgent, + StepExecution, ) -from aieng.agent_evals.tools import GroundedResponse +from aieng.agent_evals.knowledge_qa.plan_parsing import ( + ResearchPlan, + ResearchStep, + StepStatus, +) +from aieng.agent_evals.tools import GroundingChunk + + +# ============================================================================= +# Data Model Tests +# ============================================================================= + + +class TestStepStatus: + """Tests for the StepStatus constants.""" + + def test_status_constants(self): + """Test that status constants are defined.""" + assert StepStatus.PENDING == "pending" + assert StepStatus.IN_PROGRESS == "in_progress" + assert StepStatus.COMPLETED == "completed" + assert StepStatus.FAILED == "failed" + assert StepStatus.SKIPPED == "skipped" + + +class TestResearchStep: + """Tests for the ResearchStep model.""" + + def test_research_step_creation(self): + """Test creating a research step.""" + step = ResearchStep( + step_id=1, + description="Search for financial regulations", + step_type="research", + depends_on=[], + expected_output="List of relevant regulations", + ) + assert step.step_id == 1 + assert step.description == "Search for financial regulations" + assert step.step_type == "research" + assert step.depends_on == [] + assert step.expected_output == "List of relevant regulations" + + def test_research_step_with_dependencies(self): + """Test creating a step with dependencies.""" + step = ResearchStep( + step_id=3, + description="Synthesize findings", + step_type="synthesis", + depends_on=[1, 2], + expected_output="Comprehensive answer", + ) + assert step.depends_on == [1, 2] + + def test_research_step_defaults(self): + """Test default values for research step.""" + step = ResearchStep( + step_id=1, + description="Test step", + step_type="research", + ) + assert step.depends_on == [] + assert step.expected_output == "" + assert step.status == StepStatus.PENDING + assert step.actual_output == "" + assert step.attempts == 0 + assert step.failure_reason == "" + + def test_research_step_with_tracking_fields(self): + """Test creating a step with tracking fields.""" + step = ResearchStep( + step_id=1, + description="Test step", + step_type="research", + status=StepStatus.COMPLETED, + actual_output="Found 5 results", + attempts=2, + failure_reason="", + ) + assert step.status == StepStatus.COMPLETED + assert step.actual_output == "Found 5 results" + assert step.attempts == 2 + + def test_research_step_failed_status(self): + """Test creating a step with failed status.""" + step = ResearchStep( + step_id=1, + description="Fetch document", + step_type="research", + status=StepStatus.FAILED, + attempts=3, + failure_reason="404 Not Found", + ) + assert step.status == StepStatus.FAILED + assert step.attempts == 3 + assert step.failure_reason == "404 Not Found" + + +class TestResearchPlan: + """Tests for the ResearchPlan model.""" + + def test_research_plan_creation(self): + """Test creating a research plan.""" + plan = ResearchPlan( + original_question="What caused the 2008 financial crisis?", + steps=[ + ResearchStep( + step_id=1, + description="Research subprime mortgages", + step_type="research", + ), + ResearchStep( + step_id=2, + description="Look up Dodd-Frank regulations", + step_type="research", + ), + ], + reasoning="Complex question requiring multiple sources", + ) + assert plan.original_question == "What caused the 2008 financial crisis?" + assert len(plan.steps) == 2 + assert plan.reasoning != "" + + def test_research_plan_defaults(self): + """Test default values for research plan.""" + plan = ResearchPlan( + original_question="Simple question", + ) + assert plan.steps == [] + assert plan.reasoning == "" + + def test_get_step_found(self): + """Test getting an existing step by ID.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research"), + ResearchStep(step_id=2, description="Step 2", step_type="research"), + ], + ) + step = plan.get_step(2) + assert step is not None + assert step.description == "Step 2" + + def test_get_step_not_found(self): + """Test getting a non-existent step by ID.""" + plan = ResearchPlan( + original_question="Test", + steps=[ResearchStep(step_id=1, description="Step 1", step_type="research")], + ) + step = plan.get_step(99) + assert step is None + + def test_update_step_status(self): + """Test updating a step's status.""" + plan = ResearchPlan( + original_question="Test", + steps=[ResearchStep(step_id=1, description="Step 1", step_type="research")], + ) + result = plan.update_step(1, status=StepStatus.COMPLETED) + assert result is True + assert plan.steps[0].status == StepStatus.COMPLETED + + def test_update_step_all_fields(self): + """Test updating all tracking fields of a step.""" + plan = ResearchPlan( + original_question="Test", + steps=[ResearchStep(step_id=1, description="Step 1", step_type="research")], + ) + result = plan.update_step( + 1, + status=StepStatus.FAILED, + actual_output="Found some results", + failure_reason="Timeout error", + increment_attempts=True, + ) + assert result is True + assert plan.steps[0].status == StepStatus.FAILED + assert plan.steps[0].actual_output == "Found some results" + assert plan.steps[0].failure_reason == "Timeout error" + assert plan.steps[0].attempts == 1 + + def test_update_step_not_found(self): + """Test updating a non-existent step.""" + plan = ResearchPlan( + original_question="Test", + steps=[], + ) + result = plan.update_step(99, status=StepStatus.COMPLETED) + assert result is False + + def test_update_step_description(self): + """Test updating a step's description.""" + plan = ResearchPlan( + original_question="Test", + steps=[ResearchStep(step_id=1, description="Original", step_type="research")], + ) + result = plan.update_step(1, description="Updated description") + assert result is True + assert plan.steps[0].description == "Updated description" + + def test_get_pending_steps_no_dependencies(self): + """Test getting pending steps when none have dependencies.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research"), + ResearchStep(step_id=2, description="Step 2", step_type="research"), + ], + ) + pending = plan.get_pending_steps() + assert len(pending) == 2 + + def test_get_pending_steps_with_dependencies(self): + """Test getting pending steps with dependency filtering.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research"), + ResearchStep(step_id=2, description="Step 2", step_type="research", depends_on=[1]), + ResearchStep(step_id=3, description="Step 3", step_type="synthesis", depends_on=[1, 2]), + ], + ) + pending = plan.get_pending_steps() + assert len(pending) == 1 + assert pending[0].step_id == 1 + + def test_get_pending_steps_after_completion(self): + """Test getting pending steps after some complete.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research", status=StepStatus.COMPLETED), + ResearchStep(step_id=2, description="Step 2", step_type="research", depends_on=[1]), + ResearchStep(step_id=3, description="Step 3", step_type="synthesis", depends_on=[1, 2]), + ], + ) + pending = plan.get_pending_steps() + assert len(pending) == 1 + assert pending[0].step_id == 2 + + def test_get_steps_by_status(self): + """Test getting steps by status.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research", status=StepStatus.COMPLETED), + ResearchStep(step_id=2, description="Step 2", step_type="research", status=StepStatus.FAILED), + ResearchStep(step_id=3, description="Step 3", step_type="synthesis", status=StepStatus.PENDING), + ], + ) + completed = plan.get_steps_by_status(StepStatus.COMPLETED) + failed = plan.get_steps_by_status(StepStatus.FAILED) + pending = plan.get_steps_by_status(StepStatus.PENDING) + + assert len(completed) == 1 + assert completed[0].step_id == 1 + assert len(failed) == 1 + assert failed[0].step_id == 2 + assert len(pending) == 1 + assert pending[0].step_id == 3 + + def test_is_complete_all_done(self): + """Test is_complete when all steps are in terminal states.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research", status=StepStatus.COMPLETED), + ResearchStep(step_id=2, description="Step 2", step_type="research", status=StepStatus.FAILED), + ResearchStep(step_id=3, description="Step 3", step_type="synthesis", status=StepStatus.SKIPPED), + ], + ) + assert plan.is_complete() is True + + def test_is_complete_with_pending(self): + """Test is_complete when some steps are pending.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research", status=StepStatus.COMPLETED), + ResearchStep(step_id=2, description="Step 2", step_type="research", status=StepStatus.PENDING), + ], + ) + assert plan.is_complete() is False + + def test_is_complete_with_in_progress(self): + """Test is_complete when some steps are in progress.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", step_type="research", status=StepStatus.IN_PROGRESS), + ], + ) + assert plan.is_complete() is False + + +class TestStepExecution: + """Tests for the StepExecution model.""" + + def test_step_execution_creation(self): + """Test creating a step execution record.""" + execution = StepExecution( + step_id=1, + tool_used="web_search", + input_query="2008 financial crisis causes", + output_summary="Found 5 relevant articles", + sources_found=5, + duration_ms=1500, + raw_output="Raw search results...", + ) + assert execution.step_id == 1 + assert execution.tool_used == "web_search" + assert execution.input_query == "2008 financial crisis causes" + assert execution.output_summary == "Found 5 relevant articles" + assert execution.sources_found == 5 + assert execution.duration_ms == 1500 + + def test_step_execution_defaults(self): + """Test default values for step execution.""" + execution = StepExecution( + step_id=1, + tool_used="finance_knowledge", + input_query="Basel III", + ) + assert execution.output_summary == "" + assert execution.sources_found == 0 + assert execution.duration_ms == 0 + assert execution.raw_output == "" + + +# ============================================================================= +# Agent Tests +# ============================================================================= class TestKnowledgeGroundedAgent: @@ -20,620 +353,145 @@ def mock_config(self): config = MagicMock() config.gemini_api_key = "test-api-key" config.default_worker_model = "gemini-2.5-flash" + config.default_temperature = 0.0 return config + @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") @patch("aieng.agent_evals.knowledge_qa.agent.Runner") @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") @patch("aieng.agent_evals.knowledge_qa.agent.Agent") + @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") def test_agent_initialization( self, - mock_create_tool, + mock_create_search_tool, + mock_create_web_fetch_tool, + mock_create_fetch_file_tool, + mock_create_grep_file_tool, + mock_create_read_file_tool, mock_agent_class, - mock_session_service, - mock_runner_class, + _mock_session_service, + _mock_runner_class, + mock_planner, mock_config, ): - """Test initializing the agent.""" - mock_tool = MagicMock() - mock_create_tool.return_value = mock_tool + """Test initializing the agent with all tools.""" + mock_search_tool = MagicMock() + mock_web_fetch_tool = MagicMock() + mock_create_search_tool.return_value = mock_search_tool + mock_create_web_fetch_tool.return_value = mock_web_fetch_tool - KnowledgeGroundedAgent(config=mock_config) + agent = KnowledgeGroundedAgent(config=mock_config, enable_caching=False, enable_compaction=False) - # Verify tool was created - mock_create_tool.assert_called_once() + # Verify all tools were created + mock_create_search_tool.assert_called_once() + mock_create_web_fetch_tool.assert_called_once() + mock_create_fetch_file_tool.assert_called_once() + mock_create_grep_file_tool.assert_called_once() + mock_create_read_file_tool.assert_called_once() # Verify ADK Agent was created with correct params mock_agent_class.assert_called_once() call_kwargs = mock_agent_class.call_args[1] - assert call_kwargs["name"] == "knowledge_qa_agent" - assert call_kwargs["model"] == "gemini-2.5-flash" - assert call_kwargs["instruction"] == SYSTEM_INSTRUCTIONS - assert mock_tool in call_kwargs["tools"] - - # Verify session service and runner were created - mock_session_service.assert_called_once() - mock_runner_class.assert_called_once() - - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - def test_agent_with_custom_model( - self, - mock_create_tool, - mock_agent_class, - mock_session_service, - mock_runner_class, - mock_config, - ): - """Test initializing with a custom model.""" - KnowledgeGroundedAgent(config=mock_config, model="gemini-2.5-pro") - - call_kwargs = mock_agent_class.call_args[1] - assert call_kwargs["model"] == "gemini-2.5-pro" - - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_get_or_create_session( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test session creation and retrieval.""" - # Mock the session service's create_session method - mock_session = MagicMock() - mock_session.id = "mock-session-id-1" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - agent = KnowledgeGroundedAgent(config=mock_config) - - # Create a new session - session1 = await agent._get_or_create_session_async("test-session-1") - assert session1 is not None - - # Same session ID should return same ADK session (cached) - session2 = await agent._get_or_create_session_async("test-session-1") - assert session1 == session2 - - # Different session ID should create new session - mock_session.id = "mock-session-id-2" - session3 = await agent._get_or_create_session_async("test-session-2") - assert session3 != session1 - - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_get_or_create_session_generates_id( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test that session ID is generated if not provided.""" - # Mock the session service's create_session method - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - agent = KnowledgeGroundedAgent(config=mock_config) - - session = await agent._get_or_create_session_async(None) - assert session is not None - - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test async answer method.""" - # Mock the session service's create_session method - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock event with final response - mock_event = MagicMock() - mock_event.is_final_response.return_value = True - mock_event.content.parts = [MagicMock(text="Paris is the capital of France.")] - - # Make runner.run_async return an async generator - async def mock_run_async(*args, **kwargs): - yield mock_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("What is the capital of France?") - - assert isinstance(response, GroundedResponse) - assert response.text == "Paris is the capital of France." - - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_extracts_function_calls( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test that function calls are extracted from events.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock function call - mock_function_call = MagicMock() - mock_function_call.name = "google_search" - mock_function_call.args = {"query": "capital of France"} - - # Create mock event with function call - mock_tool_event = MagicMock() - mock_tool_event.is_final_response.return_value = False - mock_tool_event.get_function_calls.return_value = [mock_function_call] - mock_tool_event.get_function_responses.return_value = None - mock_tool_event.grounding_metadata = None - mock_tool_event.content = None - - # Create mock final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Paris is the capital.")] - - async def mock_run_async(*args, **kwargs): - yield mock_tool_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("What is the capital of France?") + assert call_kwargs["name"] == "knowledge_qa" + assert mock_search_tool in call_kwargs["tools"] + assert mock_web_fetch_tool in call_kwargs["tools"] - assert len(response.tool_calls) == 1 - assert response.tool_calls[0]["name"] == "google_search" - assert response.tool_calls[0]["args"] == {"query": "capital of France"} - assert "capital of France" in response.search_queries + # Verify BuiltInPlanner was created (planning enabled by default) + mock_planner.assert_called_once() + assert agent.enable_planning is True - @pytest.mark.asyncio + @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") @patch("aieng.agent_evals.knowledge_qa.agent.Runner") @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") @patch("aieng.agent_evals.knowledge_qa.agent.Agent") + @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_extracts_sources_from_function_responses( + def test_agent_without_planning( self, - mock_create_tool, + _mock_create_search_tool, + _mock_create_web_fetch_tool, + _mock_create_fetch_file_tool, + _mock_create_grep_file_tool, + _mock_create_read_file_tool, mock_agent_class, - mock_session_service_class, - mock_runner_class, + _mock_session_service, + _mock_runner_class, + mock_planner, mock_config, ): - """Test that sources are extracted from function responses.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock function response with sources - mock_function_response = MagicMock() - mock_function_response.response = { - "sources": [ - {"title": "Wikipedia - Paris", "uri": "https://en.wikipedia.org/wiki/Paris"}, - {"title": "Travel Guide", "url": "https://example.com/paris"}, - ] - } - - # Create mock event with function response - mock_response_event = MagicMock() - mock_response_event.is_final_response.return_value = False - mock_response_event.get_function_calls.return_value = None - mock_response_event.get_function_responses.return_value = [mock_function_response] - mock_response_event.grounding_metadata = None - mock_response_event.content = None - - # Create mock final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Paris is the capital.")] - - async def mock_run_async(*args, **kwargs): - yield mock_response_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("What is the capital of France?") - - assert len(response.sources) == 2 - assert response.sources[0].title == "Wikipedia - Paris" - assert response.sources[0].uri == "https://en.wikipedia.org/wiki/Paris" - assert response.sources[1].title == "Travel Guide" - assert response.sources[1].uri == "https://example.com/paris" - - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_extracts_grounding_chunks_from_responses( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test that grounding_chunks are extracted from function responses.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock function response with grounding_chunks - mock_function_response = MagicMock() - mock_function_response.response = { - "grounding_chunks": [ - {"web": {"title": "Official Site", "uri": "https://official.com"}}, - {"web": {"title": "News Article", "uri": "https://news.com/article"}}, - ] - } - - # Create mock event with function response - mock_response_event = MagicMock() - mock_response_event.is_final_response.return_value = False - mock_response_event.get_function_calls.return_value = None - mock_response_event.get_function_responses.return_value = [mock_function_response] - mock_response_event.grounding_metadata = None - mock_response_event.content = None - - # Create mock final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Answer.")] - - async def mock_run_async(*args, **kwargs): - yield mock_response_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("Test question") - - assert len(response.sources) == 2 - assert response.sources[0].title == "Official Site" - assert response.sources[0].uri == "https://official.com" - assert response.sources[1].title == "News Article" - assert response.sources[1].uri == "https://news.com/article" - - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_extracts_grounding_metadata( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test that grounding metadata is extracted from events.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock grounding chunk - mock_web_chunk = MagicMock() - mock_web_chunk.title = "Grounded Source" - mock_web_chunk.uri = "https://grounded.com" - - mock_grounding_chunk = MagicMock() - mock_grounding_chunk.web = mock_web_chunk - - # Create mock grounding metadata - mock_grounding_metadata = MagicMock() - mock_grounding_metadata.grounding_chunks = [mock_grounding_chunk] - mock_grounding_metadata.web_search_queries = ["grounded query"] - - # Create mock event with grounding metadata - mock_grounding_event = MagicMock() - mock_grounding_event.is_final_response.return_value = False - mock_grounding_event.get_function_calls.return_value = None - mock_grounding_event.get_function_responses.return_value = None - mock_grounding_event.grounding_metadata = mock_grounding_metadata - mock_grounding_event.content = None - - # Create mock final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Final answer.")] - - async def mock_run_async(*args, **kwargs): - yield mock_grounding_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("Test question") - - assert len(response.sources) == 1 - assert response.sources[0].title == "Grounded Source" - assert response.sources[0].uri == "https://grounded.com" - assert "grounded query" in response.search_queries + """Test initializing the agent without planning.""" + agent = KnowledgeGroundedAgent( + config=mock_config, enable_planning=False, enable_caching=False, enable_compaction=False + ) - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_extracts_grounding_metadata_from_content( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test grounding metadata extraction from event.content.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock grounding chunk on content - mock_web_chunk = MagicMock() - mock_web_chunk.title = "Content Source" - mock_web_chunk.uri = "https://content-source.com" - - mock_grounding_chunk = MagicMock() - mock_grounding_chunk.web = mock_web_chunk - - mock_grounding_metadata = MagicMock() - mock_grounding_metadata.grounding_chunks = [mock_grounding_chunk] - mock_grounding_metadata.web_search_queries = ["content query"] - - # Create mock event with grounding metadata on content (not event directly) - mock_content = MagicMock() - mock_content.grounding_metadata = mock_grounding_metadata - - mock_event = MagicMock() - mock_event.is_final_response.return_value = False - mock_event.get_function_calls.return_value = None - mock_event.get_function_responses.return_value = None - mock_event.grounding_metadata = None # Not on event directly - mock_event.content = mock_content - - # Create mock final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Answer.")] - mock_final_event.content.grounding_metadata = None - - async def mock_run_async(*args, **kwargs): - yield mock_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("Test question") + # BuiltInPlanner should not be created when planning disabled + mock_planner.assert_not_called() + assert agent.enable_planning is False - assert len(response.sources) == 1 - assert response.sources[0].title == "Content Source" - assert response.sources[0].uri == "https://content-source.com" - assert "content query" in response.search_queries + # ADK Agent should be created with planner=None + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs["planner"] is None - @pytest.mark.asyncio + @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") @patch("aieng.agent_evals.knowledge_qa.agent.Runner") @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") @patch("aieng.agent_evals.knowledge_qa.agent.Agent") + @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_handles_multiple_search_tool_names( + def test_agent_with_custom_model( self, - mock_create_tool, + _mock_create_search_tool, + _mock_create_web_fetch_tool, + _mock_create_fetch_file_tool, + _mock_create_grep_file_tool, + _mock_create_read_file_tool, mock_agent_class, - mock_session_service_class, - mock_runner_class, + _mock_session_service, + _mock_runner_class, + _mock_planner, mock_config, ): - """Test that search queries are extracted from various search tool names.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create mock function calls with different search tool names - mock_fc1 = MagicMock() - mock_fc1.name = "google_search" - mock_fc1.args = {"query": "query one"} - - mock_fc2 = MagicMock() - mock_fc2.name = "web_search" - mock_fc2.args = {"query": "query two"} - - mock_fc3 = MagicMock() - mock_fc3.name = "SearchTool" - mock_fc3.args = {"query": "query three"} - - # Event with all function calls - mock_tool_event = MagicMock() - mock_tool_event.is_final_response.return_value = False - mock_tool_event.get_function_calls.return_value = [mock_fc1, mock_fc2, mock_fc3] - mock_tool_event.get_function_responses.return_value = None - mock_tool_event.grounding_metadata = None - mock_tool_event.content = None - - # Final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Done.")] - - async def mock_run_async(*args, **kwargs): - yield mock_tool_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("Test") - - assert len(response.tool_calls) == 3 - assert "query one" in response.search_queries - assert "query two" in response.search_queries - assert "query three" in response.search_queries + """Test initializing with a custom model.""" + agent = KnowledgeGroundedAgent( + config=mock_config, model="gemini-2.5-pro", enable_caching=False, enable_compaction=False + ) - @pytest.mark.asyncio - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - async def test_answer_async_handles_empty_events( - self, - mock_create_tool, - mock_agent_class, - mock_session_service_class, - mock_runner_class, - mock_config, - ): - """Test that empty events are handled gracefully.""" - # Mock session service - mock_session = MagicMock() - mock_session.id = "mock-session-id" - mock_session_service = MagicMock() - mock_session_service.create_session = AsyncMock(return_value=mock_session) - mock_session_service_class.return_value = mock_session_service - - # Create events with no data - mock_empty_event = MagicMock() - mock_empty_event.is_final_response.return_value = False - mock_empty_event.get_function_calls.return_value = [] - mock_empty_event.get_function_responses.return_value = [] - mock_empty_event.grounding_metadata = None - mock_empty_event.content = None - - # Final event - mock_final_event = MagicMock() - mock_final_event.is_final_response.return_value = True - mock_final_event.get_function_calls.return_value = None - mock_final_event.get_function_responses.return_value = None - mock_final_event.grounding_metadata = None - mock_final_event.content.parts = [MagicMock(text="Final.")] - - async def mock_run_async(*args, **kwargs): - yield mock_empty_event - yield mock_final_event - - mock_runner = MagicMock() - mock_runner.run_async = mock_run_async - mock_runner_class.return_value = mock_runner - - agent = KnowledgeGroundedAgent(config=mock_config) - response = await agent.answer_async("Test") - - assert isinstance(response, GroundedResponse) - assert response.text == "Final." - assert response.tool_calls == [] - assert response.search_queries == [] - assert response.sources == [] + call_kwargs = mock_agent_class.call_args[1] + assert call_kwargs["model"] == "gemini-2.5-pro" + assert agent.model == "gemini-2.5-pro" class TestKnowledgeAgentManager: """Tests for the KnowledgeAgentManager class.""" + @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") @patch("aieng.agent_evals.knowledge_qa.agent.Runner") @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") @patch("aieng.agent_evals.knowledge_qa.agent.Agent") + @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - def test_lazy_initialization( - self, - mock_create_tool, - mock_agent_class, - mock_session_service, - mock_runner_class, - ): - """Test that clients are lazily initialized.""" + def test_lazy_initialization(self, *_mocks): + """Test that agent is lazily initialized.""" with patch("aieng.agent_evals.knowledge_qa.agent.Configs") as mock_config_class: - mock_config_class.return_value = MagicMock() + mock_config = MagicMock() + mock_config.default_worker_model = "gemini-2.5-flash" + mock_config.default_temperature = 0.0 + mock_config_class.return_value = mock_config - manager = KnowledgeAgentManager() + manager = KnowledgeAgentManager(enable_caching=False, enable_compaction=False) # Should not be initialized yet assert not manager.is_initialized() @@ -644,49 +502,140 @@ def test_lazy_initialization( # Now should be initialized assert manager.is_initialized() + @patch("aieng.agent_evals.knowledge_qa.agent.PlanReActPlanner") @patch("aieng.agent_evals.knowledge_qa.agent.Runner") @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") @patch("aieng.agent_evals.knowledge_qa.agent.Agent") + @patch("aieng.agent_evals.knowledge_qa.agent.create_read_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_grep_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_fetch_file_tool") + @patch("aieng.agent_evals.knowledge_qa.agent.create_web_fetch_tool") @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - def test_close( - self, - mock_create_tool, - mock_agent_class, - mock_session_service, - mock_runner_class, - ): + def test_close(self, *_mocks): """Test closing the client manager.""" with patch("aieng.agent_evals.knowledge_qa.agent.Configs") as mock_config_class: - mock_config_class.return_value = MagicMock() + mock_config = MagicMock() + mock_config.default_worker_model = "gemini-2.5-flash" + mock_config.default_temperature = 0.0 + mock_config_class.return_value = mock_config - manager = KnowledgeAgentManager() + manager = KnowledgeAgentManager(enable_caching=False, enable_compaction=False) _ = manager.agent assert manager.is_initialized() manager.close() assert not manager.is_initialized() - @patch("aieng.agent_evals.knowledge_qa.agent.Runner") - @patch("aieng.agent_evals.knowledge_qa.agent.InMemorySessionService") - @patch("aieng.agent_evals.knowledge_qa.agent.Agent") - @patch("aieng.agent_evals.knowledge_qa.agent.create_google_search_tool") - def test_agent_reuse( - self, - mock_create_tool, - mock_agent_class, - mock_session_service, - mock_runner_class, - ): - """Test that agent is reused on multiple accesses.""" - with patch("aieng.agent_evals.knowledge_qa.agent.Configs") as mock_config_class: - mock_config_class.return_value = MagicMock() - manager = KnowledgeAgentManager() +class TestAgentResponse: + """Tests for the AgentResponse model.""" + + def test_response_creation(self): + """Test creating an enhanced response.""" + plan = ResearchPlan( + original_question="Test question", + steps=[], + reasoning="Test reasoning", + ) + + response = AgentResponse( + text="Test answer.", + plan=plan, + sources=[GroundingChunk(title="Source", uri="https://example.com")], + search_queries=["test query"], + reasoning_chain=["Step 1"], + tool_calls=[{"name": "google_search", "args": {"query": "test"}}], + total_duration_ms=1000, + ) + + assert response.text == "Test answer." + assert response.plan.original_question == "Test question" + assert len(response.sources) == 1 + assert response.sources[0].uri == "https://example.com" + assert response.search_queries == ["test query"] + assert response.total_duration_ms == 1000 + + +class TestPlanStepStatusOnEarlyTermination: + """Tests for plan steps marked correctly when agent terminates early.""" + + def test_remaining_steps_marked_as_skipped(self): + """Test remaining steps are marked SKIPPED on early termination. + + When the agent finds the answer early and terminates before completing + all planned steps, the remaining steps should be marked as SKIPPED + to accurately reflect that they were not executed. + """ + # Create a plan with multiple steps + plan = ResearchPlan( + original_question="Test question", + steps=[ + ResearchStep( + step_id=1, + description="Search for initial info", + status=StepStatus.COMPLETED, + ), + ResearchStep( + step_id=2, + description="Verify the information", + status=StepStatus.PENDING, + ), + ResearchStep( + step_id=3, + description="Cross-check with another source", + status=StepStatus.PENDING, + ), + ResearchStep( + step_id=4, + description="Synthesize findings", + status=StepStatus.IN_PROGRESS, + ), + ], + reasoning="Multi-step research plan", + ) - agent1 = manager.agent - agent2 = manager.agent + # Simulate the agent's early termination logic + # (this is what happens in agent.py lines 629-633) + for step in plan.steps: + if step.status in (StepStatus.PENDING, StepStatus.IN_PROGRESS): + step.status = StepStatus.SKIPPED + + # Verify step 1 is still completed (it was executed) + assert plan.steps[0].status == StepStatus.COMPLETED + + # Verify remaining steps are marked as SKIPPED, not COMPLETED + assert plan.steps[1].status == StepStatus.SKIPPED + assert plan.steps[2].status == StepStatus.SKIPPED + assert plan.steps[3].status == StepStatus.SKIPPED + + def test_plan_is_complete_with_skipped_steps(self): + """Test that a plan with SKIPPED steps is considered complete.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", status=StepStatus.COMPLETED), + ResearchStep(step_id=2, description="Step 2", status=StepStatus.SKIPPED), + ResearchStep(step_id=3, description="Step 3", status=StepStatus.SKIPPED), + ], + ) - assert agent1 is agent2 + # SKIPPED is a terminal status, so the plan should be complete + assert plan.is_complete() + + def test_get_steps_by_status_skipped(self): + """Test getting steps by SKIPPED status.""" + plan = ResearchPlan( + original_question="Test", + steps=[ + ResearchStep(step_id=1, description="Step 1", status=StepStatus.COMPLETED), + ResearchStep(step_id=2, description="Step 2", status=StepStatus.SKIPPED), + ResearchStep(step_id=3, description="Step 3", status=StepStatus.SKIPPED), + ], + ) + + skipped_steps = plan.get_steps_by_status(StepStatus.SKIPPED) + assert len(skipped_steps) == 2 + assert all(s.status == StepStatus.SKIPPED for s in skipped_steps) @pytest.mark.integration_test @@ -698,23 +647,17 @@ class TestKnowledgeGroundedAgentIntegration: def test_agent_creation_real(self): """Test creating a real agent instance.""" - from aieng.agent_evals.knowledge_qa import ( # noqa: PLC0415 - KnowledgeGroundedAgent, - ) - agent = KnowledgeGroundedAgent() assert agent is not None assert agent.model == "gemini-2.5-flash" + assert agent.enable_planning is True @pytest.mark.asyncio async def test_answer_real_question(self): """Test answering a real question.""" - from aieng.agent_evals.knowledge_qa import ( # noqa: PLC0415 - KnowledgeGroundedAgent, - ) - agent = KnowledgeGroundedAgent() response = await agent.answer_async("What is the capital of France?") assert response.text assert "Paris" in response.text + assert isinstance(response, AgentResponse)