diff --git a/predicate/__init__.py b/predicate/__init__.py index 2b479b2..3ca2149 100644 --- a/predicate/__init__.py +++ b/predicate/__init__.py @@ -118,7 +118,17 @@ # Ordinal support (Phase 3) from .ordinal import OrdinalIntent, boost_ordinal_elements, detect_ordinal_intent, select_by_ordinal from .overlay import clear_overlay, show_overlay +from .overlay_dismissal import OverlayDismissResult, dismiss_overlays, dismiss_overlays_before_agent from .permissions import PermissionPolicy +from .pruning import ( + CategoryDetectionResult, + PrunedSnapshotContext, + PruningTaskCategory, + SkeletonDomNode, + classify_task_category, + prune_snapshot_for_task, + serialize_pruned_snapshot, +) from .query import find, query from .read import extract, extract_async, read, read_best_effort from .recorder import Recorder, Trace, TraceStep, record @@ -250,6 +260,10 @@ "screenshot", "show_overlay", "clear_overlay", + # Overlay dismissal (proactive popup/banner removal) + "OverlayDismissResult", + "dismiss_overlays", + "dismiss_overlays_before_agent", # Text Search "find_text_rect", "TextRectSearchResult", @@ -313,6 +327,13 @@ "save_storage_state", # Formatting (v0.12.0+) "format_snapshot_for_llm", + "CategoryDetectionResult", + "PrunedSnapshotContext", + "PruningTaskCategory", + "SkeletonDomNode", + "classify_task_category", + "prune_snapshot_for_task", + "serialize_pruned_snapshot", # Agent Config (v0.12.0+) "AgentConfig", # Enums diff --git a/predicate/agent_runtime.py b/predicate/agent_runtime.py index 8f82362..0414151 100644 --- a/predicate/agent_runtime.py +++ b/predicate/agent_runtime.py @@ -326,6 +326,40 @@ async def get_url(self) -> str: self._cached_url = url return url + async def read_markdown(self, max_chars: int = 8000) -> str | None: + """ + Read page content as markdown for semantic understanding. + + This extracts the page HTML and converts it to markdown format, + which is useful for LLM planning to understand page context + (e.g., product listings, form fields, navigation structure). + + Args: + max_chars: Maximum characters to return (default 8000). + Truncates from the end if content exceeds this limit. + + Returns: + Markdown string if successful, None if extraction fails. + """ + try: + page = getattr(self.backend, "page", None) + if page is None: + return None + + # Import here to avoid circular dependency + from .read import _fallback_read_from_page_async + + result = await _fallback_read_from_page_async(page, output_format="markdown") + if result is None or result.status != "success": + return None + + content = result.content + if len(content) > max_chars: + content = content[:max_chars] + return content + except Exception: + return None + async def get_viewport_height(self) -> int: """ Get current viewport height in pixels. @@ -398,19 +432,23 @@ async def click(self, element_id: int) -> None: await self.record_action(f"CLICK({element_id})") - async def type(self, element_id: int, text: str) -> None: + async def type(self, element_id: int, text: str, *, delay_ms: float | None = None) -> None: """ Type text into an element. Args: element_id: Element ID from snapshot text: Text to type + delay_ms: Optional delay between keystrokes in milliseconds """ # First click to focus await self.click(element_id) # Then type - await self.backend.type_text(text) + if delay_ms is None: + await self.backend.type_text(text) + else: + await self.backend.type_text(text, delay_ms=delay_ms) await self.record_action(f"TYPE({element_id}, '{text[:20]}...')" if len(text) > 20 else f"TYPE({element_id}, '{text}')") async def press(self, key: str) -> None: diff --git a/predicate/agents/__init__.py b/predicate/agents/__init__.py index fa895ad..40a16f3 100644 --- a/predicate/agents/__init__.py +++ b/predicate/agents/__init__.py @@ -6,8 +6,9 @@ - RuntimeAgent (execution loop and bounded vision fallback) Agent types: -- PredicateBrowserAgent: Single-executor agent with manual step definitions +- PredicateAgent: Branded alias for PlannerExecutorAgent (recommended for external use) - PlannerExecutorAgent: Two-tier agent with LLM-generated plans +- PredicateBrowserAgent: Single-executor agent with manual step definitions Task abstractions: - AutomationTask: Generic task model for browser automation @@ -67,6 +68,9 @@ get_config_preset, ) +# Branded alias for PlannerExecutorAgent +PredicateAgent = PlannerExecutorAgent + __all__ = [ # Automation Task "AutomationTask", @@ -95,6 +99,7 @@ "PlanStep", "PlannerExecutorAgent", "PlannerExecutorConfig", + "PredicateAgent", # Branded alias for PlannerExecutorAgent "PredicateSpec", "RecoveryNavigationConfig", "RetryConfig", diff --git a/predicate/agents/automation_task.py b/predicate/agents/automation_task.py index a0d0497..2a8b522 100644 --- a/predicate/agents/automation_task.py +++ b/predicate/agents/automation_task.py @@ -179,6 +179,32 @@ class AutomationTask: # Domain hints for heuristics (e.g., ["ecommerce", "amazon"]) domain_hints: tuple[str, ...] = field(default_factory=tuple) + # Force a specific pruning category (overrides auto-detection) + force_pruning_category: str | None = None + + def pruning_category_hint(self): + """ + Return the pruning-oriented category for this task. + + If force_pruning_category is set, returns that category directly. + Otherwise, uses rule-based normalization from task text and hints. + """ + from ..pruning import PruningTaskCategory, classify_task_category + + # If a category is forced, return it directly + if self.force_pruning_category: + try: + return PruningTaskCategory(self.force_pruning_category) + except ValueError: + pass # Invalid category, fall through to auto-detection + + return classify_task_category( + task_text=self.task, + current_url=self.starting_url, + domain_hints=self.domain_hints, + task_category=self.category, + ).category + @classmethod def from_webbench_task(cls, task: Any) -> "AutomationTask": """ diff --git a/predicate/agents/planner_executor_agent.py b/predicate/agents/planner_executor_agent.py index a3ad280..a5f7f09 100644 --- a/predicate/agents/planner_executor_agent.py +++ b/predicate/agents/planner_executor_agent.py @@ -17,8 +17,10 @@ import asyncio import base64 import hashlib +import inspect import json import re +import sys import time import uuid from abc import ABC, abstractmethod @@ -26,13 +28,21 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum +from types import SimpleNamespace from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field +from ..actions import clear_async, type_text_async from ..agent_runtime import AgentRuntime from ..llm_provider import LLMProvider, LLMResponse from ..models import Snapshot, SnapshotOptions, StepHookContext +from ..pruning import ( + PrunedSnapshotContext, + PruningTaskCategory, + classify_task_category, + prune_with_recovery, +) from ..trace_event_builder import TraceEventBuilder from ..tracing import Tracer from ..verification import ( @@ -667,9 +677,16 @@ class PlannerExecutorConfig: planner_max_tokens: int = 2048 planner_temperature: float = 0.0 + # Page context for planning: when enabled, extracts page content as markdown + # during initial planning to help the planner understand page type and structure. + # This adds token cost but improves plan quality for complex pages. + use_page_context: bool = False + page_context_max_chars: int = 8000 # Max chars of markdown to include + # Executor LLM settings executor_max_tokens: int = 96 executor_temperature: float = 0.0 + type_delay_ms: float | None = 17.0 # Stabilization (wait for DOM to settle after actions) stabilize_enabled: bool = True @@ -770,6 +787,8 @@ class SnapshotContext: snapshot_success: bool = True requires_vision: bool = False vision_reason: str | None = None + pruning_category: str | None = None + pruned_node_count: int = 0 def is_stale(self, max_age_seconds: float = 5.0) -> bool: """Check if snapshot is too old to reuse.""" @@ -875,6 +894,16 @@ class StepOutcome: duration_ms: int = 0 url_before: str | None = None url_after: str | None = None + extracted_data: Any | None = None + + +@dataclass +class SearchSubmitTelemetry: + """Tracks search-submit behavior for debugging and diagnostics.""" + + first_submit_method: Literal["click", "enter"] | None = None + retry_submit_method: Literal["click", "enter"] | None = None + observed_search_results_dom: bool = False @dataclass @@ -936,6 +965,112 @@ def build_predicate(spec: PredicateSpec | dict[str, Any]) -> Predicate: raise ValueError(f"Unsupported predicate: {name}") +# --------------------------------------------------------------------------- +# Extraction Keywords for Markdown-based Text Extraction +# --------------------------------------------------------------------------- + +# Keywords that indicate a simple text extraction task suitable for read_markdown() +# These tasks don't need LLM-based extraction - just return the page content as markdown +TEXT_EXTRACTION_KEYWORDS = frozenset([ + # Direct extraction verbs + "extract", + "read", + "parse", + "scrape", + "get", + "fetch", + "retrieve", + "capture", + "grab", + "copy", + "pull", + # Question words that indicate reading content + "what is", + "what are", + "what's", + "show me", + "tell me", + "find", + "list", + "display", + # Content-specific patterns + "title", + "headline", + "heading", + "text", + "content", + "body", + "paragraph", + "article", + "post", + "message", + "description", + "summary", + "excerpt", + # Data extraction patterns + "price", + "cost", + "amount", + "name", + "label", + "value", + "number", + "date", + "time", + "address", + "email", + "phone", + "rating", + "review", + "comment", + "author", + "username", + # Table/list extraction + "table", + "row", + "column", + "item", + "entry", + "record", +]) + + +def _is_text_extraction_task(task: str) -> bool: + """ + Determine if a task is a simple text extraction that can use read_markdown(). + + Returns True if the task contains keywords indicating text extraction, + where returning the page markdown is sufficient without LLM-based extraction. + + Args: + task: The task description to analyze + + Returns: + True if this is a text extraction task suitable for read_markdown() + """ + if not task: + return False + + task_lower = task.lower() + + # Check for extraction keyword patterns using word boundary matching + # to avoid false positives (e.g., "time" in "sentiment") + for keyword in TEXT_EXTRACTION_KEYWORDS: + # Multi-word keywords (like "what is") use substring matching + if " " in keyword: + if keyword in task_lower: + return True + else: + # Single-word keywords use word boundary matching via regex + # Match keyword at word boundaries, allowing for plurals (optional 's' or 'es') + # e.g., "title" matches "title", "titles", "title's" + pattern = rf"\b{re.escape(keyword)}(s|es)?\b" + if re.search(pattern, task_lower): + return True + + return False + + # --------------------------------------------------------------------------- # Plan Normalization and Validation # --------------------------------------------------------------------------- @@ -1073,6 +1208,7 @@ def normalize_plan(plan_dict: dict[str, Any]) -> dict[str, Any]: "INPUT": "TYPE_AND_SUBMIT", "TYPE_TEXT": "TYPE_AND_SUBMIT", "ENTER_TEXT": "TYPE_AND_SUBMIT", + "EXTRACT_TEXT": "EXTRACT", "GOTO": "NAVIGATE", "GO_TO": "NAVIGATE", "OPEN": "NAVIGATE", @@ -1204,10 +1340,20 @@ def build_planner_prompt( auth_state: str = "unknown", strict: bool = False, schema_errors: str | None = None, + page_context: str | None = None, ) -> tuple[str, str]: """ Build system and user prompts for the Planner LLM. + Args: + task: Task description + start_url: Starting URL + site_type: Type of site (general, e-commerce, etc.) + auth_state: Authentication state + strict: If True, emphasize JSON-only output + schema_errors: Errors from previous parsing attempt + page_context: Optional markdown content of the current page for context + Returns: (system_prompt, user_prompt) """ @@ -1234,23 +1380,59 @@ def build_planner_prompt( ========================================= For shopping/purchase tasks, include ALL necessary steps in order: 1. NAVIGATE to the site (if not already there) -2. TYPE_AND_SUBMIT search query in search box +2. Find the product - choose ONE of these approaches IN PRIORITY ORDER: + a) DIRECT MATCH (BEST): Scan page for text closely matching the goal. CLICK any product/category with matching text. + b) CATEGORY BROWSE: If no exact match, click a category link that relates to the goal (e.g., "Tablecloths" for "vinyl tablecloth") + c) SEARCH: ONLY if you see an input EXPLICITLY labeled "Search" with placeholder="Search..." or aria-label="Search" + d) SEARCH ICON: Only if you see a magnifying glass icon linked to search 3. CLICK on specific product from results (not filters or categories) 4. CLICK "Add to Cart" button on product page 5. CLICK "Proceed to Checkout" or cart icon 6. Handle login/signup if required (may need CLICK + TYPE_AND_SUBMIT) 7. CLICK through checkout process +CRITICAL - CATEGORY NAVIGATION (MOST RELIABLE FOR HOMEPAGES): +- On homepage/landing pages, browse via CATEGORY LINKS - this is the MOST RELIABLE method +- Look for category links like "Rally Home Goods", "Tablecloths", "Kitchen", "Catalog", etc. +- Category links are usually in the main navigation, sidebar, or footer and are always clickable +- Example: Goal "vinyl tablecloth" → Click "Rally Home Goods" or "Catalog" category first +- After clicking a category, THEN look for the specific product + +SECONDARY - Direct Product Click (ONLY on collection/category pages): +- If a product appears on a CATEGORY/COLLECTION page (not homepage), click it directly +- WARNING: Products in "Hot Products", carousels, or grid sections on HOME PAGES are often NOT clickable +- The snapshot may not capture product titles in carousels - use category navigation instead + +AVOID - Searching on sites without visible search box: +- Many e-commerce sites hide search or don't have search at all +- If you don't see a clear "Search" textbox in the page markdown, DO NOT try to search +- Prefer category navigation over searching - it's more reliable + +CRITICAL - Search Box Identification (ONLY WHEN NO MATCHING TEXT): +- Only use TYPE_AND_SUBMIT if you see an input EXPLICITLY labeled for SEARCH +- Valid search indicators: placeholder="Search...", aria-label="Search", text "Search products" +- DO NOT type into fields with these labels (they are NOT search boxes): + * "Your email address", "Email", "Newsletter", "Subscribe" + * "Zip code", "Location", "Enter your email" + * Any field asking for personal information +- If unsure whether a field is a search box, DO NOT use it - click products/categories instead + Common mistakes to AVOID: - Do NOT skip "Add to Cart" step - clicking a product link is NOT adding to cart - Do NOT combine multiple distinct actions into one step - Do NOT confuse filter/category clicks with product selection +- Do NOT assume a search box exists - if none is clearly visible, click products/categories directly +- Do NOT hallucinate search boxes - if page content doesn't show an obvious search input, use direct navigation +- Do NOT type into email/newsletter/subscription fields - they are NOT search boxes +- Do NOT use search when matching text is visible - click directly instead - Each distinct user action should be its own step -Intent hints are critical - use clear hints like: +Intent hints are critical - ALWAYS include the specific product/element name: +- intent: "Click product Vinyl Tablecloth" (GOOD - includes product name) +- intent: "Click on product title or image" (BAD - too generic, will click wrong product) +- intent: "Click category link Tablecloths" (GOOD - includes category name) - intent: "Click Add to Cart button" - intent: "Click Proceed to Checkout" -- intent: "Click on product title or image" - intent: "Click sign in button" """ elif is_search_task: @@ -1263,6 +1445,46 @@ def build_planner_prompt( 2. TYPE_AND_SUBMIT the search query 3. Wait for/verify search results 4. If selecting a result: CLICK on specific result item +""" + + # Check for extraction tasks + is_extraction_task = any(keyword in task_lower for keyword in [ + "extract", "get the", "what is", "read the", "find the text", "scrape", + "title of", "price of", "name of", "content of", + ]) + + if is_extraction_task: + domain_guidance = """ + +IMPORTANT: Extraction Task Planning Rules +========================================= +For extraction tasks where data is already visible on the page: + +1. If the data you need is VISIBLE in the page context markdown above: + - Use EXTRACT directly as the ONLY step - no clicking needed + - The EXTRACT action will read the visible text from the page + +2. If you need to navigate to see the data: + - First CLICK or NAVIGATE to the right page + - Then use EXTRACT + +CRITICAL: Do NOT click on links to external sites when extracting. +- Hacker News post titles link to EXTERNAL sites, not to HN pages +- To extract a title that's visible, use EXTRACT directly on the current page +- Only click if you need to navigate to an HN item page (e.g., for comments) + +Example for "Extract the title of the first post": +{ + "steps": [ + { + "id": 1, + "goal": "Extract the first post title from the page", + "action": "EXTRACT", + "target": "first post title", + "verify": [] + } + ] +} """ system = f"""You are the PLANNER. Output a JSON execution plan for the web automation task. @@ -1289,10 +1511,10 @@ def build_planner_prompt( }}, {{ "id": 3, - "goal": "Click on result", + "goal": "Click on product from results", "action": "CLICK", "intent": "Click on product title", - "verify": [{{"predicate": "url_contains", "args": ["/product/"]}}] + "verify": [] }} ] }} @@ -1303,15 +1525,44 @@ def build_planner_prompt( - {{"predicate": "not_exists", "args": ["text~'error'"]}} DO NOT use string format like "url_contains('text')" - use object format only. + +CRITICAL - url_contains RULES: +1. Use ONLY generic keywords, NEVER site-specific paths like "/product/", "/products/", "/collections/" +2. Different sites use different URL patterns - don't guess the path structure +3. For product pages: use "verify": [] (empty) or use the product keyword like ["snow-blower"] +4. For search: "search" or "query=" work across most sites +5. For checkout: "checkout" or "cart" work across most sites +6. NEVER use paths like "/product/", "/products/", "/p/", "/dp/" - these are site-specific + +Examples: +- GOOD: {{"predicate": "url_contains", "args": ["snow-blower"]}} - uses product keyword +- GOOD: {{"predicate": "url_contains", "args": ["search"]}} - generic search indicator +- BAD: {{"predicate": "url_contains", "args": ["/product/"]}} - site-specific path +- BAD: {{"predicate": "url_contains", "args": ["/products/vinyl-tablecloth"]}} - guessing path structure {domain_guidance} Return ONLY valid JSON. No prose, no code fences, no markdown.""" + # Build page context section if provided + page_context_section = "" + if page_context: + page_context_section = f""" + +Current Page Content: +The following is a markdown representation of the current page content. Use this to understand +the page structure, available elements (buttons, links, forms), and content to inform your plan. +Note: This may be truncated if the page is large. + +--- +{page_context} +--- +""" + user = f"""Task: {task} {schema_note} Starting URL: {start_url or "browser's current page"} Site type: {site_type} Auth state: {auth_state} - +{page_context_section} Output a JSON plan to accomplish this task. Each step should represent ONE distinct action.""" return system, user @@ -1354,23 +1605,28 @@ def build_stepwise_planner_prompt( system = """You are a browser automation planner. Decide the NEXT action. Actions: -- CLICK: Click an element. Set "intent" to element type/role (e.g., "invoice link", "submit button"). Optionally set "input" to the specific text to match. -- TYPE_AND_SUBMIT: Type and submit. Set "intent" to element type and "input" to text to type. +- CLICK: Click an element. Set "intent" to element type/role. Set "input" to EXACT text from elements list. +- TYPE_AND_SUBMIT: Type and submit. ONLY use if you see a "searchbox" or "textbox" with "search" in the text. - SCROLL: Scroll page. Set "direction" to "up" or "down". - DONE: Goal achieved. Return this when the goal is complete. +CRITICAL RULE FOR CLICK: +- The "input" field MUST contain text that ACTUALLY APPEARS in the elements list below +- Do NOT guess or invent text - copy EXACT text from an element +- If product title "vinyl tablecloth" is NOT in the elements list, click a category link instead (e.g., "Catalog", "Home Goods") +- Only click a specific product if you see its EXACT name in the elements + Output ONLY valid JSON (no markdown, no ```): -{"action":"CLICK","intent":"invoice link","input":"INV-2024-001","reasoning":"click first invoice"} -{"action":"CLICK","intent":"submit button","reasoning":"submit the form"} -{"action":"TYPE_AND_SUBMIT","intent":"search box","input":"wireless keyboard","reasoning":"search for product"} +{"action":"CLICK","intent":"category link","input":"Catalog","reasoning":"browse products via category"} +{"action":"CLICK","intent":"product link","input":"Vinyl Round Tablecloth","reasoning":"found exact product name"} {"action":"DONE","intent":"completed","reasoning":"goal achieved"} RULES: -1. Look at ACTUAL elements shown - pick one that matches your intent -2. For CLICK: "intent" = element type (link, button, etc.), "input" = specific text (optional) -3. CRITICAL: Do NOT repeat the same action twice. If history shows an action was already done (e.g., "CLICK Route To Review"), do NOT do it again. -4. CRITICAL: If the goal is to click a button (e.g., "click Route to Review") and history shows you already clicked it, return DONE immediately. -5. Return DONE when: (a) you clicked the target button, (b) you typed/submitted text, (c) page state shows goal is achieved +1. ONLY use text that appears EXACTLY in the elements list - do NOT invent names +2. For shopping: start with category links (Catalog, Shop Now, Home Goods) to find products +3. ONLY use TYPE_AND_SUBMIT if you see a textbox labeled "search" +4. Do NOT type into "email" or "newsletter" fields +5. Do NOT repeat the same action twice 6. Output ONLY JSON - no tags, no markdown, no prose""" user = f"""Goal: {goal} @@ -1385,11 +1641,53 @@ def build_stepwise_planner_prompt( return system, user +def _get_category_executor_hints(category: str | None) -> str: + """ + Get category-specific hints for the executor. + + These hints guide the executor to prioritize certain element types + based on the detected task category, improving accuracy without + adding tokens to the planner. + """ + if not category: + return "" + + category_lower = category.lower() if isinstance(category, str) else str(category).lower() + + hints = { + "shopping": ( + "Priority: 'Add to Cart', 'Buy Now', 'Add to Bag', product links, price elements." + ), + "checkout": ( + "Priority: 'Checkout', 'Proceed to Checkout', 'Place Order', payment fields." + ), + "form_filling": ( + "Priority: input fields, textboxes, submit/send buttons, form labels." + ), + "search": ( + "Priority: search box, search button, result links, filter controls." + ), + "auth": ( + "Priority: username/email field, password field, sign in/login button." + ), + "extraction": ( + "Priority: data elements, table cells, list items, content containers." + ), + "navigation": ( + "Priority: navigation links, menu items, breadcrumbs." + ), + } + + return hints.get(category_lower, "") + + def build_executor_prompt( goal: str, intent: str | None, compact_context: str, input_text: str | None = None, + category: str | None = None, + action_type: str | None = None, ) -> tuple[str, str]: """ Build system and user prompts for the Executor LLM. @@ -1398,23 +1696,43 @@ def build_executor_prompt( goal: Human-readable goal for this step intent: Intent hint for element selection (optional) compact_context: Compact representation of page elements - input_text: Text to type for TYPE_AND_SUBMIT actions (optional) + input_text: For TYPE_AND_SUBMIT: text to type. For CLICK: target text to match (optional) + category: Task category for category-specific hints (optional) + action_type: Action type (CLICK, TYPE_AND_SUBMIT, etc.) to determine prompt variant Returns: (system_prompt, user_prompt) """ intent_line = f"Intent: {intent}\n" if intent else "" - input_line = f"Text to type: \"{input_text}\"\n" if input_text else "" + + # For CLICK actions, input_text is target to match (not text to type) + is_type_action = action_type in ("TYPE_AND_SUBMIT", "TYPE") + if is_type_action and input_text: + input_line = f"Text to type: \"{input_text}\"\n" + elif input_text: + input_line = f"Target to find: \"{input_text}\"\n" + else: + input_line = "" + + # Get category-specific hints + category_hints = _get_category_executor_hints(category) + category_line = f"{category_hints}\n" if category_hints else "" # Tight prompt optimized for small local models (4B-7B) # Key: explicit format, no reasoning, clear failure consequence - if input_text: + if is_type_action and input_text: # TYPE action needed - find the INPUT element (textbox/combobox), not the submit button system = ( "You are an executor for browser automation.\n" "Task: Find the INPUT element (textbox, combobox, searchbox) to type into.\n" "Return ONLY ONE line: TYPE(, \"text\")\n" "IMPORTANT: Return the ID of the INPUT/TEXTBOX element, NOT the submit button.\n" + "CRITICAL - AVOID these fields (they are NOT search boxes):\n" + "- Fields with 'email', 'newsletter', 'subscribe', 'signup' in the text\n" + "- Fields labeled 'Your email address', 'Email', 'Enter your email'\n" + "- Fields in footer/newsletter sections\n" + "ONLY use fields explicitly labeled for SEARCH (placeholder='Search', aria='Search').\n" + "If NO search field exists, return NONE instead of guessing.\n" "If you output anything else, the action fails.\n" "Do NOT output or any reasoning.\n" "No prose, no markdown, no extra whitespace.\n" @@ -1422,25 +1740,112 @@ def build_executor_prompt( ) else: # CLICK action (most common) - system = ( - "You are an executor for browser automation.\n" - "Return ONLY a single-line CLICK(id) action.\n" - "If you output anything else, the action fails.\n" - "Do NOT output or any reasoning.\n" - "No prose, no markdown, no extra whitespace.\n" - "Output MUST match exactly: CLICK() with no spaces.\n" - "Example: CLICK(12)" + # Check if this is a search-related action (from intent or goal) + search_keywords = ["search", "magnify", "magnifier", "find"] + is_search_action = ( + (intent and any(kw in intent.lower() for kw in search_keywords)) + or any(kw in goal.lower() for kw in search_keywords) + ) + # Check if this is a product click action (from intent or goal) + product_keywords = ["product", "item", "result", "listing"] + is_product_action = ( + (intent and any(kw in intent.lower() for kw in product_keywords)) + or any(kw in goal.lower() for kw in product_keywords) ) + # Check if this is an Add to Cart action + add_to_cart_keywords = ["add to cart", "add to bag", "add to basket", "buy now"] + is_add_to_cart_action = ( + (intent and any(kw in intent.lower() for kw in add_to_cart_keywords)) + or any(kw in goal.lower() for kw in add_to_cart_keywords) + ) + # Check if intent asks to match text (e.g., "Click element with text matching [keyword]") + is_text_matching_action = intent and "matching" in intent.lower() + # Check if input_text specifies a target to match (for CLICK actions, input_text is target text) + has_target_text = bool(input_text) + + if is_search_action: + system = ( + "You are an executor for browser automation.\n" + "Return ONLY a single-line CLICK(id) action.\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + "SEARCH ICON HINTS: Look for links/buttons with 'search' in text/href, " + "or icon-only elements (text='0' or empty) with 'search' in href.\n" + "Output MUST match exactly: CLICK() with no spaces.\n" + "Example: CLICK(12)" + ) + elif is_text_matching_action or has_target_text: + # When planner specifies target text (input field), executor must match it + target_text = input_text or "" + system = ( + "You are an executor for browser automation.\n" + "Return ONLY a single-line CLICK(id) action.\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + f"CRITICAL: Find an element with text matching '{target_text}'.\n" + "- Look for: product titles, category names, link text, button labels\n" + "- Text must contain the target words (case-insensitive partial match is OK)\n" + "- If NO element contains the target text, return NONE instead of clicking something random\n" + "Output: CLICK() or NONE\n" + "Example: CLICK(42) or NONE" + ) + elif is_product_action: + # Product click action without specific target - guide executor to find product cards/links + system = ( + "You are an executor for browser automation.\n" + "Return ONLY a single-line CLICK(id) action.\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + "PRODUCT CLICK HINTS:\n" + "- Look for LINK elements (role=link) with product IDs in href (e.g., /7027762, /dp/B...)\n" + "- Prefer links with delivery info text like 'Delivery', 'Ships to Store', 'Get it...'\n" + "- These are inside product cards and will navigate to product detail pages\n" + "- AVOID buttons like 'Search', 'Shop', category buttons, or filter buttons\n" + "- AVOID image slider options (slider image 1, 2, etc.)\n" + "Output MUST match exactly: CLICK() with no spaces.\n" + "Example: CLICK(1268)" + ) + elif is_add_to_cart_action: + # Add to Cart action - may need to click product first if on search results page + system = ( + "You are an executor for browser automation.\n" + "Return ONLY a single-line CLICK(id) action.\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + "ADD TO CART HINTS:\n" + "- FIRST: Look for buttons with text: 'Add to Cart', 'Add to Bag', 'Add to Basket', 'Buy Now'\n" + "- If found, click that button directly\n" + "- FALLBACK: If NO 'Add to Cart' button is visible, you are likely on a SEARCH RESULTS page\n" + " - In this case, click a PRODUCT LINK to go to the product details page first\n" + " - Look for LINK elements with product IDs in href (e.g., /7027762, /dp/B...)\n" + " - Prefer links with product names, prices, or delivery info\n" + "- AVOID: 'Search' buttons, category buttons, filter buttons, pagination\n" + "Output MUST match exactly: CLICK() with no spaces.\n" + "Example: CLICK(42)" + ) + else: + system = ( + "You are an executor for browser automation.\n" + "Return ONLY a single-line CLICK(id) action.\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + "No prose, no markdown, no extra whitespace.\n" + "Output MUST match exactly: CLICK() with no spaces.\n" + "Example: CLICK(12)" + ) # Choose the appropriate closing instruction based on action type - if input_text: + if is_type_action and input_text: # For TYPE actions, explicitly ask for TYPE with the text action_instruction = f'Return TYPE(id, "{input_text}"):' + elif input_text: + # For CLICK with target text, remind to match target or return NONE + action_instruction = f'Return CLICK(id) for element matching "{input_text}", or NONE if not found:' else: action_instruction = "Return CLICK(id):" user = f"""Goal: {goal} -{intent_line}{input_line} +{intent_line}{category_line}{input_line} Elements: {compact_context} @@ -1575,6 +1980,9 @@ def __init__( # Current automation task (for run-level context) self._current_task: AutomationTask | None = None + # Cached pruning category (run-scoped, avoids re-classification per step) + self._cached_pruning_category: PruningTaskCategory | None = None + # Token usage tracking self._token_collector = _TokenUsageCollector() @@ -1606,6 +2014,93 @@ def _record_token_usage(self, role: str, resp: LLMResponse) -> None: except Exception: pass # Don't fail on token tracking errors + def _detect_pruning_category( + self, + snap: Snapshot, + goal: str, + ) -> PruningTaskCategory | None: + """Resolve the pruning category from task context, then goal-based rules. + + The category is cached for the duration of a run to ensure consistency + and avoid re-classification on every step. + """ + # Return cached category if available + if self._cached_pruning_category is not None: + return self._cached_pruning_category + + if self._current_task is not None: + try: + category = self._current_task.pruning_category_hint() + if category != PruningTaskCategory.GENERIC: + self._cached_pruning_category = category + if self.config.verbose: + print(f" [CATEGORY] Detected category from task hint: {category.value}", flush=True) + return category + except Exception: + pass + + result = classify_task_category( + task_text=self._current_task.task, + current_url=self._current_task.starting_url or getattr(snap, "url", "") or "", + domain_hints=self._current_task.domain_hints, + task_category=self._current_task.category, + ) + else: + result = classify_task_category( + task_text=goal, + current_url=getattr(snap, "url", "") or "", + ) + + if result.category == PruningTaskCategory.GENERIC: + return None + + # Cache the category for this run + self._cached_pruning_category = result.category + if self.config.verbose: + print(f" [CATEGORY] Detected category: {result.category.value} (confidence={result.confidence:.2f})", flush=True) + return result.category + + def _get_cached_category_str(self) -> str | None: + """Get the cached category as a string for executor hints.""" + if self._cached_pruning_category is not None: + return self._cached_pruning_category.value + return None + + def _build_pruned_context( + self, + snap: Snapshot, + goal: str, + ) -> PrunedSnapshotContext | None: + """Build a category-specific pruned context when task intent is known. + + Uses automatic over-pruning recovery via relaxation levels if the + initial pruning leaves too few elements. + """ + if self._context_formatter is not None: + return None + + category = self._detect_pruning_category(snap, goal) + if category is None: + return None + + try: + ctx = prune_with_recovery( + snap, + goal=goal, + category=category, + max_relaxation=3, + verbose=self.config.verbose, + ) + if self.config.verbose and ctx.relaxation_level == 0: + print( + f" [PRUNING] {ctx.raw_element_count} -> {ctx.pruned_element_count} elements " + f"(category={category.value})", + flush=True, + ) + return ctx + except Exception: + return None + def _format_context(self, snap: Snapshot, goal: str) -> str: """ Format snapshot for LLM context. @@ -1616,6 +2111,10 @@ def _format_context(self, snap: Snapshot, goal: str) -> str: if self._context_formatter is not None: return self._context_formatter(snap, goal) + pruned_context = self._build_pruned_context(snap, goal) + if pruned_context is not None and pruned_context.nodes: + return pruned_context.prompt_block + import re # Filter to interactive elements @@ -1911,6 +2410,10 @@ def _parse_action(self, text: str) -> tuple[str, list[Any]]: if "FINISH" in text: return "FINISH", [] + # NONE - executor couldn't find a suitable element (e.g., no search box found) + if text.upper() == "NONE" or "NONE" in text.upper(): + return "NONE", [] + return "UNKNOWN", [text] # ------------------------------------------------------------------------- @@ -2125,6 +2628,7 @@ async def _snapshot_with_escalation( max_limit = cfg.limit_max if cfg.enabled else cfg.limit_base # Disable escalation if not enabled last_snap: Snapshot | None = None last_compact: str = "" + last_pruned_context: PrunedSnapshotContext | None = None screenshot_b64: str | None = None requires_vision = False vision_reason: str | None = None @@ -2153,8 +2657,14 @@ async def _snapshot_with_escalation( # Format context FIRST - we always want the compact representation # even if vision fallback is required, so the planner can see available elements - compact = self._format_context(snap, goal) + pruned_context = self._build_pruned_context(snap, goal) + compact = ( + pruned_context.prompt_block + if pruned_context is not None and pruned_context.nodes + else self._format_context(snap, goal) + ) last_compact = compact + last_pruned_context = pruned_context # Check for vision fallback needs_vision, reason = detect_snapshot_failure(snap) @@ -2172,7 +2682,8 @@ async def _snapshot_with_escalation( # a specific target element was found. Intent heuristics are only used # for scroll-after-escalation AFTER limit escalation is exhausted. elements = getattr(snap, "elements", []) or [] - if len(elements) >= 10: + pruned_node_count = len(pruned_context.nodes) if pruned_context is not None else 0 + if len(elements) >= 10 and (pruned_context is None or pruned_node_count > 0): break # Escalate limit @@ -2258,7 +2769,12 @@ async def _snapshot_with_escalation( continue last_snap = snap - last_compact = self._format_context(snap, goal) + last_pruned_context = self._build_pruned_context(snap, goal) + last_compact = ( + last_pruned_context.prompt_block + if last_pruned_context is not None and last_pruned_context.nodes + else self._format_context(snap, goal) + ) # Extract screenshot if capture_screenshot: @@ -2307,6 +2823,16 @@ async def _snapshot_with_escalation( snapshot_success=not requires_vision, requires_vision=requires_vision, vision_reason=vision_reason, + pruning_category=( + last_pruned_context.category.value + if last_pruned_context is not None + else None + ), + pruned_node_count=( + len(last_pruned_context.nodes) + if last_pruned_context is not None + else 0 + ), ) # ------------------------------------------------------------------------- @@ -2319,6 +2845,7 @@ async def plan( *, start_url: str | None = None, max_attempts: int = 2, + page_context: str | None = None, ) -> Plan: """ Generate execution plan for the given task. @@ -2327,6 +2854,7 @@ async def plan( task: Task description start_url: Starting URL max_attempts: Maximum planning attempts + page_context: Optional markdown content of current page for better planning Returns: Plan object with steps @@ -2342,6 +2870,7 @@ async def plan( start_url=start_url, strict=(attempt > 1), schema_errors=last_errors or None, + page_context=page_context if attempt == 1 else None, # Only include on first attempt ) if self.config.verbose: @@ -2432,7 +2961,31 @@ async def replan( system = """You are the PLANNER. Output a JSON patch to edit an existing plan. Edit ONLY the failed step and optionally the next step. -Return ONLY a JSON object with mode="patch" and replace_steps array.""" +Return ONLY a JSON object with mode="patch" and replace_steps array. + +IMPORTANT - Alternative approaches when CLICK fails: +- If a product/category navigation failed, USE SITE SEARCH instead: + * Replace the failed CLICK with a TYPE_AND_SUBMIT to search for the product + * This is the MOST RELIABLE fallback - site search works on all websites +- If clicking a specific element failed: + * Try a different selector or button (e.g., "Quick Shop", "View Details") +- Don't just retry the same approach with minor changes""" + + # Extract product/item name from step for search suggestion + product_hint = "" + step_labels = " ".join([ + failed_step.goal or "", + failed_step.target or "", + failed_step.intent or "", + ]).lower() + # Common patterns to extract product name + for pattern in [r"snow\s*blower", r"product\s+(\w+)", r"click\s+(?:on\s+)?(.+?)(?:\s+product|\s+category)?$"]: + match = re.search(pattern, step_labels, re.IGNORECASE) + if match: + product_hint = match.group(0) if match.lastindex is None else match.group(1) + break + if not product_hint and failed_step.target: + product_hint = failed_step.target user = f"""Task: {task} @@ -2441,16 +2994,25 @@ async def replan( - Step goal: {failed_step.goal} - Reason: {failure_reason} -Return JSON patch: +IMPORTANT: The element could not be found or clicked. The current page likely doesn't have the target. +The BEST approach is to USE SITE SEARCH to find the product directly. + +RECOMMENDED: Replace the failed step with a site search: {{ "mode": "patch", "replace_steps": [ {{ "id": {failed_step.id}, - "step": {{ "id": {failed_step.id}, "goal": "...", "action": "...", "verify": [...] }} + "step": {{ "id": {failed_step.id}, "goal": "Search for {product_hint or 'the product'}", "action": "TYPE_AND_SUBMIT", "input": "{product_hint or 'product name'}", "intent": "Type in search box and submit", "verify": [{{"predicate": "url_contains", "args": ["search"]}}] }} }} ] -}}""" +}} + +Alternative approaches (if search doesn't apply): +1. Click "Catalog" or "Shop All" to browse products +2. Click "Quick Shop" or "View Details" buttons + +Return JSON patch:""" for attempt in range(1, max_attempts + 1): resp = self.planner.generate( @@ -2611,6 +3173,8 @@ async def _scroll_to_find_element( step.intent, ctx.compact_representation, input_text=step.input, + category=self._get_cached_category_str(), + action_type=step.action, ) resp = self.executor.generate( sys_prompt, @@ -2683,6 +3247,8 @@ async def _execute_optional_substeps( substep.intent, ctx.compact_representation, input_text=substep.input, + category=self._get_cached_category_str(), + action_type=substep.action, ) resp = self.executor.generate( sys_prompt, @@ -2800,6 +3366,9 @@ async def _attempt_modal_dismissal( if role not in ("button", "link"): continue + if self._is_global_nav_cart_link(el): + continue + text = (getattr(el, "text", "") or "").lower() aria_label = (getattr(el, "aria_label", "") or getattr(el, "ariaLabel", "") or "").lower() href = (getattr(el, "href", "") or "").lower() @@ -2908,6 +3477,599 @@ async def _attempt_modal_dismissal( print(" [MODAL] All dismissal attempts exhausted", flush=True) return False + def _is_global_nav_cart_link(self, el: Any) -> bool: + """ + Detect persistent header/nav cart links that should not be treated as + drawer-local checkout controls. + """ + href = (getattr(el, "href", "") or "").lower() + text = (getattr(el, "text", "") or "").lower().strip() + aria_label = (getattr(el, "aria_label", "") or getattr(el, "ariaLabel", "") or "").lower().strip() + label = text or aria_label + + layout = getattr(el, "layout", None) + region = (getattr(layout, "region", "") or "").lower() + doc_y = getattr(el, "doc_y", None) + + if "nav_cart" in href or "ref_=nav_cart" in href: + return True + + if region in {"header", "nav"} and ( + "cart" in href or label in {"cart", "0 items in cart"} or "items in cart" in label + ): + return True + + try: + if doc_y is not None and float(doc_y) <= 120 and ( + "cart" in href or label in {"cart", "0 items in cart"} or "items in cart" in label + ): + return True + except (TypeError, ValueError): + pass + + return False + + def _looks_like_search_submission(self, step: PlanStep, element: Any | None) -> bool: + """Detect TYPE_AND_SUBMIT actions that are likely site search submissions.""" + role = (getattr(element, "role", "") or "").lower() if element is not None else "" + if role in {"searchbox", "combobox"}: + return True + + labels = " ".join( + str(part or "") + for part in ( + step.goal, + step.intent, + step.input, + getattr(element, "text", None), + getattr(element, "name", None), + getattr(element, "aria_label", None), + getattr(element, "ariaLabel", None), + ) + ).lower() + return "search" in labels + + def _is_add_to_cart_step(self, step: PlanStep) -> bool: + """Detect if a step is an Add to Cart action.""" + add_to_cart_keywords = ["add to cart", "add to bag", "add to basket", "buy now"] + labels = " ".join( + str(part or "").lower() + for part in (step.goal, step.intent, step.input) + ) + return any(kw in labels for kw in add_to_cart_keywords) + + def _is_search_results_url(self, url: str) -> bool: + """Check if URL looks like a search results page.""" + url_lower = url.lower() + # Common patterns for search results pages + search_patterns = [ + "search", + "query=", + "q=", + "s=", + "/s?", + "keyword=", + "keywords=", + "results", + ] + return any(pattern in url_lower for pattern in search_patterns) + + def _is_category_navigation_step(self, step: PlanStep) -> bool: + """Check if this step is navigating to a category/section.""" + nav_keywords = [ + "navigate to", "go to", "click category", "category link", + "click on", "browse", "section", "department" + ] + labels = " ".join( + str(part or "").lower() + for part in (step.goal, step.intent) + ) + return any(kw in labels for kw in nav_keywords) + + def _url_change_matches_intent(self, step: PlanStep, pre_url: str, post_url: str) -> bool: + """ + Check if URL change actually matches the step's intent. + + For category navigation, the new URL should contain keywords from the target. + This prevents accepting unrelated URL changes as successful navigation. + """ + # Extract target keywords from step + target = step.target or "" + intent = step.intent or "" + goal = step.goal or "" + + post_url_lower = post_url.lower() + + # Special case: checkout/cart related steps + # These steps may go to /cart first before /checkout, which is valid + checkout_keywords = ["checkout", "proceed to checkout", "cart", "view cart"] + step_labels = f"{goal} {intent}".lower() + is_checkout_step = any(kw in step_labels for kw in checkout_keywords) + if is_checkout_step: + # Accept cart or checkout URLs as valid for checkout steps + checkout_url_patterns = ["cart", "checkout", "basket", "bag"] + if any(pattern in post_url_lower for pattern in checkout_url_patterns): + return True + + # Get keywords from target (e.g., "Outdoor Power Equipment" -> ["outdoor", "power", "equipment"]) + target_words = set( + word.lower() for word in re.split(r'[\s\-_]+', target) + if len(word) >= 3 # Skip short words like "to", "and" + ) + + # Also check predicates for expected URL patterns + expected_patterns = [] + for pred in (step.verify or []): + if pred.predicate == "url_contains" and pred.args: + expected_patterns.append(pred.args[0].lower()) + + # If predicates specify URL patterns, check those + if expected_patterns: + if any(pattern in post_url_lower for pattern in expected_patterns): + return True + # For non-checkout steps, reject URL changes that don't match predicates + # But only if we have a target to validate against + if target_words: + return False + # No target and no predicate match - be permissive + return True + + # Otherwise check if target keywords appear in URL + if target_words: + # At least one target word should appear in URL + if any(word in post_url_lower for word in target_words): + return True + # URL doesn't contain any target keywords - suspicious + return False + + # No target specified - can't validate, allow fallback + return True + + def _find_submit_button_for_type_and_submit( + self, + *, + elements: list[Any], + input_element_id: int | None, + step: PlanStep, + ) -> int | None: + """Find an explicit search/submit control for search-style TYPE_AND_SUBMIT steps.""" + selected_element = None + for el in elements: + if getattr(el, "id", None) == input_element_id: + selected_element = el + break + + if not self._looks_like_search_submission(step, selected_element): + return None + + candidates: list[tuple[int, int]] = [] + for el in elements: + el_id = getattr(el, "id", None) + if el_id is None or el_id == input_element_id: + continue + + role = (getattr(el, "role", "") or "").lower() + if role not in {"button", "link"}: + continue + + label = " ".join( + str(part or "") + for part in ( + getattr(el, "text", None), + getattr(el, "name", None), + getattr(el, "aria_label", None), + getattr(el, "ariaLabel", None), + ) + ).lower() + href = (getattr(el, "href", "") or "").lower() + + score = 0 + if "submit search" in label: + score += 120 + if "search" in label: + score += 80 + if "submit" in label: + score += 60 + if label.strip() in {"go", "search"}: + score += 50 + if "/search" in href or "search?" in href or "q=" in href: + score += 40 + + if score > 0: + score += int(getattr(el, "importance", 0) or 0) // 100 + candidates.append((int(el_id), score)) + + if not candidates: + return None + + candidates.sort(key=lambda item: item[1], reverse=True) + return candidates[0][0] + + def _type_and_submit_url_change_looks_valid( + self, + *, + pre_url: str, + post_url: str, + step: PlanStep, + element: Any | None, + typed_text: str, + ) -> bool: + """ + Allow URL-change fallback for TYPE_AND_SUBMIT only when the resulting URL + still matches the expected semantics of the typed action. + """ + if not self._looks_like_search_submission(step, element): + return True + + from urllib.parse import quote_plus, urlparse + + post_lower = post_url.lower() + if any(marker in post_lower for marker in ("/search", "?q=", "&q=", "query=", "search=", "keyword=")): + return True + + encoded_query = quote_plus((typed_text or "").strip().lower()) + if encoded_query and encoded_query in post_lower: + return True + + parsed = urlparse(post_url) + searchable = f"{parsed.path}?{parsed.query}".lower() + tokens = [tok for tok in re.split(r"[^a-z0-9]+", (typed_text or "").lower()) if len(tok) >= 3] + if tokens: + matched = sum(1 for tok in tokens[:4] if tok in searchable) + if matched >= min(2, len(tokens[:4])): + return True + + return False + + def _choose_type_and_submit_submit_method( + self, + *, + elements: list[Any], + input_element_id: int | None, + step: PlanStep, + prefer_alternate_of: Literal["click", "enter"] | None = None, + ) -> tuple[Literal["click", "enter"], int | None]: + """Choose the submit method for TYPE_AND_SUBMIT, optionally preferring the alternate path. + + NOTE: For search-like submissions, we prefer Enter key by default (matching WebBench behavior). + Many search boxes (e.g., lifeisgood.com) don't have a proper submit button, or clicking the + "submit" button navigates to a category page instead of performing a search. Pressing Enter + is more reliable for search inputs. + + When prefer_alternate_of is set, we try to return the opposite method for retry purposes: + - If prefer_alternate_of="enter", try to return "click" (if a submit button exists) + - If prefer_alternate_of="click", return "enter" + """ + submit_button_id = self._find_submit_button_for_type_and_submit( + elements=elements, + input_element_id=input_element_id, + step=step, + ) + + # For search-like submissions, prefer Enter key by default (more reliable) + # Only fall back to button click if Enter doesn't work (via prefer_alternate_of) + default_method: Literal["click", "enter"] = "enter" + + # Handle retry case: prefer the alternate method + if prefer_alternate_of == "enter" and submit_button_id is not None: + # First attempt used Enter, retry with click (if button available) + return "click", submit_button_id + if prefer_alternate_of == "click": + # First attempt used click, retry with Enter + return "enter", None + + return default_method, submit_button_id + + def _get_runtime_page(self, runtime: AgentRuntime) -> Any | None: + """Best-effort access to the live browser page for immediate URL observation.""" + backend = getattr(runtime, "backend", None) + candidates = [ + getattr(backend, "page", None), + getattr(backend, "_page", None), + getattr(runtime, "_legacy_page", None), + ] + for candidate in candidates: + if candidate is None: + continue + if type(candidate).__module__.startswith("unittest.mock"): + continue + if inspect.getattr_static(candidate, "url", None) is not None: + return candidate + return None + + async def _read_focused_input_value(self, runtime: AgentRuntime) -> str | None: + """Best-effort read of the currently focused input value.""" + page = self._get_runtime_page(runtime) + if page is None: + return None + try: + value = await page.evaluate( + """ + () => { + const el = document.activeElement; + if (!el) return null; + if ("value" in el) return el.value ?? ""; + return null; + } + """ + ) + except Exception: + return None + return value if isinstance(value, str) else None + + def _normalize_input_value(self, value: str | None) -> str: + """Normalize input text for equality checks across controlled inputs.""" + return " ".join((value or "").strip().lower().split()) + + async def _clear_and_type_search_input( + self, + *, + runtime: AgentRuntime, + input_element_id: int, + text: str, + ) -> bool: + """Clear and type into a search input using the live page when available.""" + page = self._get_runtime_page(runtime) + if page is None: + return False + + browser_like = getattr(runtime, "_legacy_browser", None) or SimpleNamespace(page=page) + + try: + await runtime.click(input_element_id) + except Exception: + return False + + try: + clear_result = await clear_async(browser_like, int(input_element_id), take_snapshot=False) + if getattr(clear_result, "success", False): + await runtime.record_action(f"CLEAR({input_element_id})") + except Exception: + pass + + try: + select_all_key = "Meta+A" if sys.platform == "darwin" else "Control+A" + await page.keyboard.press(select_all_key) + await page.keyboard.press("Backspace") + await runtime.record_action(f"PRESS({select_all_key})") + await runtime.record_action('PRESS("Backspace")') + except Exception: + pass + + try: + delay_ms = float(self.config.type_delay_ms or 0) + type_result = await type_text_async( + browser_like, + int(input_element_id), + str(text), + take_snapshot=False, + delay_ms=delay_ms, + ) + if not getattr(type_result, "success", False): + return False + await runtime.record_action( + f"TYPE({input_element_id}, '{text[:20]}...')" if len(text) > 20 else f"TYPE({input_element_id}, '{text}')" + ) + return True + except Exception: + return False + + async def _submit_if_already_typed( + self, + *, + runtime: AgentRuntime, + elements: list[Any], + input_element_id: int, + step: PlanStep, + text: str, + pre_url: str, + typed_element: Any | None, + telemetry: SearchSubmitTelemetry, + ) -> bool: + """Submit without retyping when the focused input already contains the desired text.""" + page = self._get_runtime_page(runtime) + if page is None: + return False + + try: + await runtime.click(input_element_id) + except Exception: + return False + + try: + await page.wait_for_timeout(80) + except Exception: + pass + + current_value = await self._read_focused_input_value(runtime) + if self._normalize_input_value(current_value) != self._normalize_input_value(text): + return False + + await self._submit_type_and_submit( + runtime=runtime, + elements=elements, + input_element_id=input_element_id, + step=step, + text=text, + pre_url=pre_url, + typed_element=typed_element, + telemetry=telemetry, + ) + await runtime.record_action(f"SUBMIT_ALREADY_TYPED({input_element_id})") + return True + + def _snapshot_looks_like_search_results(self, snapshot: Any, typed_text: str) -> bool: + """Best-effort heuristic for product/search results-like pages.""" + elements = getattr(snapshot, "elements", []) or [] + if not elements: + return False + + tokens = [tok for tok in re.split(r"[^a-z0-9]+", (typed_text or "").lower()) if len(tok) >= 3] + product_like_matches = 0 + token_matches = 0 + + for el in elements: + role = (getattr(el, "role", "") or "").lower() + if role != "link": + continue + href = (getattr(el, "href", "") or "").lower() + label = " ".join( + str(part or "") + for part in ( + getattr(el, "text", None), + getattr(el, "name", None), + getattr(el, "aria_label", None), + getattr(el, "ariaLabel", None), + ) + ).lower() + blob = f"{href} {label}" + + if any(p in href for p in ("/product/", "/products/", "/p/", "/dp/")): + product_like_matches += 1 + if tokens: + token_matches += sum(1 for tok in tokens[:4] if tok in blob) + + return product_like_matches > 0 or token_matches >= 2 + + async def _capture_search_results_snapshot_evidence( + self, + *, + runtime: AgentRuntime, + typed_text: str, + telemetry: SearchSubmitTelemetry, + ) -> None: + """Capture one post-submit snapshot to track results-like evidence.""" + try: + snap = await runtime.snapshot(emit_trace=False) + except Exception: + return + + telemetry.observed_search_results_dom = self._snapshot_looks_like_search_results(snap, typed_text) + + async def _submit_type_and_submit( + self, + *, + runtime: AgentRuntime, + elements: list[Any], + input_element_id: int, + step: PlanStep, + text: str, + pre_url: str, + typed_element: Any | None, + telemetry: SearchSubmitTelemetry, + prefer_alternate_of: Literal["click", "enter"] | None = None, + ) -> None: + """Submit a search-like TYPE_AND_SUBMIT using the chosen method and record telemetry.""" + submit_method, submit_target = self._choose_type_and_submit_submit_method( + elements=elements, + input_element_id=input_element_id, + step=step, + prefer_alternate_of=prefer_alternate_of, + ) + + if submit_method == "click" and submit_target is not None: + await runtime.click(submit_target) + if self.config.verbose: + print(f" [ACTION] TYPE_AND_SUBMIT submit via CLICK({submit_target})", flush=True) + # Wait briefly for page load after clicking submit button + page = self._get_runtime_page(runtime) + if page is not None: + try: + await page.wait_for_load_state("domcontentloaded", timeout=2000) + except Exception: + pass + else: + await runtime.press("Enter") + if self.config.verbose: + print(" [ACTION] TYPE_AND_SUBMIT submit via PRESS(Enter)", flush=True) + submit_method = "enter" + + # Wait briefly for URL change after Enter + page = self._get_runtime_page(runtime) + if page is not None: + try: + await page.wait_for_url(lambda url: url != pre_url, timeout=3000) + if self.config.verbose: + print(f" [SEARCH] URL changed to: {page.url}", flush=True) + except Exception: + if self.config.verbose: + print(" [SEARCH] URL unchanged after Enter", flush=True) + + if prefer_alternate_of is None: + telemetry.first_submit_method = submit_method + else: + telemetry.retry_submit_method = submit_method + + await runtime.stabilize() + await self._capture_search_results_snapshot_evidence( + runtime=runtime, + typed_text=text, + telemetry=telemetry, + ) + + async def _retry_search_widget_submission( + self, + *, + runtime: AgentRuntime, + elements: list[Any], + input_element_id: int, + step: PlanStep, + text: str, + pre_url: str, + typed_element: Any | None, + telemetry: SearchSubmitTelemetry, + ) -> bool: + """Retry a failed search submission once with a clean field and the alternate submit method.""" + if telemetry.first_submit_method not in {"click", "enter"}: + return False + + # Take a fresh snapshot in case DOM changed + fresh_elements = elements + try: + fresh_snap = await runtime.snapshot(emit_trace=False) + fresh_elements = getattr(fresh_snap, "elements", []) or elements + if self.config.verbose: + print(f" [SEARCH-RETRY] Fresh snapshot: {len(fresh_elements)} elements", flush=True) + except Exception: + pass + + retry_method, _ = self._choose_type_and_submit_submit_method( + elements=fresh_elements, + input_element_id=input_element_id, + step=step, + prefer_alternate_of=telemetry.first_submit_method, + ) + if retry_method == telemetry.first_submit_method: + if self.config.verbose: + print(" [SEARCH-RETRY] No alternate submit method available", flush=True) + return False + + select_all_key = "Meta+A" if sys.platform == "darwin" else "Control+A" + if self.config.verbose: + print(f" [SEARCH-RETRY] Retrying search via alternate submit method ({retry_method})", flush=True) + typed_ok = await self._clear_and_type_search_input( + runtime=runtime, + input_element_id=input_element_id, + text=text, + ) + if not typed_ok: + await runtime.click(input_element_id) + await runtime.press(select_all_key) + await runtime.press("Backspace") + await runtime.type(input_element_id, text, delay_ms=self.config.type_delay_ms) + await self._submit_type_and_submit( + runtime=runtime, + elements=fresh_elements, + input_element_id=input_element_id, + step=step, + text=text, + pre_url=pre_url, + typed_element=typed_element, + telemetry=telemetry, + prefer_alternate_of=telemetry.first_submit_method, + ) + return await self._verify_step(runtime, step) + def _looks_like_overlay_dismiss_intent(self, *, goal: str, intent: str) -> bool: """ Detect steps that are dismissing overlays, modals, cookie banners, or popups. @@ -3143,6 +4305,9 @@ async def _execute_step( used_heuristics = False error: str | None = None verification_passed = False + extraction_succeeded = False + extracted_data: Any | None = None + search_submit_telemetry = SearchSubmitTelemetry() try: # Pre-step verification check: skip if predicates already pass @@ -3250,7 +4415,92 @@ async def _execute_step( element_id: int | None = None executor_text: str | None = None # Text from executor response (for TYPE actions) - if action_type in ("CLICK", "TYPE_AND_SUBMIT"): + if action_type == "EXTRACT": + action_taken = "EXTRACT" + # Determine extraction query from step goal or task + extract_query = step.goal or ( + self._current_task.task if self._current_task is not None else "Extract relevant data from the current page" + ) + + # Check if this is a text extraction task that can use markdown-based extraction + use_markdown_extraction = _is_text_extraction_task(extract_query) + + if use_markdown_extraction: + # Step 1: Get page content as markdown (faster than snapshot-based extraction) + markdown_content = await runtime.read_markdown(max_chars=8000) + if markdown_content: + if self.config.verbose: + preview = markdown_content[:160].replace("\n", " ") + print(f" [ACTION] EXTRACT - got markdown: {preview}...", flush=True) + + # Step 2: Use LLM (executor) to extract specific data from markdown + extraction_prompt = f"""You are a text extraction assistant. Given the page content in markdown format, extract the specific information requested. + +PAGE CONTENT (MARKDOWN): +{markdown_content} + +EXTRACTION REQUEST: +{extract_query} + +INSTRUCTIONS: +1. Read the markdown content carefully +2. Find and extract ONLY the specific information requested +3. Return ONLY the extracted text, nothing else +4. If the information is not found, return "NOT_FOUND" + +EXTRACTED TEXT:""" + + resp = self.executor.generate( + "You extract specific text from markdown content. Return only the extracted text.", + extraction_prompt, + temperature=0.0, + max_new_tokens=500, + ) + self._record_token_usage("extract", resp) + + extracted_text = resp.content.strip() + if extracted_text and extracted_text != "NOT_FOUND": + extraction_succeeded = True + extracted_data = {"text": extracted_text, "query": extract_query} + if self.config.verbose: + print(f" [ACTION] EXTRACT ok: {extracted_text[:160]}", flush=True) + else: + error = f"Could not find requested data: {extract_query}" + else: + error = "Failed to extract markdown from page" + else: + # Use LLM-based extraction for complex extraction tasks + page = ( + getattr(getattr(runtime, "backend", None), "page", None) + or getattr(getattr(runtime, "backend", None), "_page", None) + or getattr(runtime, "_legacy_page", None) + ) + if page is None: + error = "No page available for EXTRACT" + else: + from types import SimpleNamespace + + from ..read import extract_async + + browser_like = SimpleNamespace(page=page) + result = await extract_async( + browser_like, + self.planner, + query=extract_query, + schema=None, + ) + llm_resp = getattr(result, "llm_response", None) + if llm_resp is not None: + self._record_token_usage("extract", llm_resp) + if result.ok: + extraction_succeeded = True + extracted_data = result.data + if self.config.verbose: + preview = str(result.raw or "")[:160] + print(f" [ACTION] EXTRACT ok: {preview}", flush=True) + else: + error = result.error or "Extraction failed" + elif action_type in ("CLICK", "TYPE_AND_SUBMIT"): # Try intent heuristics first (if available) elements = getattr(ctx.snapshot, "elements", []) or [] url = getattr(ctx.snapshot, "url", "") or "" @@ -3268,6 +4518,8 @@ async def _execute_step( step.intent, ctx.compact_representation, input_text=step.input, + category=self._get_cached_category_str(), + action_type=step.action, ) if self.config.verbose: @@ -3372,14 +4624,47 @@ async def _execute_step( elif action_type == "TYPE" and element_id is not None: # Use text from executor response first, then fall back to step.input text = executor_text or step.input or "" - await runtime.type(element_id, text) + typed_element = None + for el in elements: + if getattr(el, "id", None) == element_id: + typed_element = el + break # If original plan action was TYPE_AND_SUBMIT, press Enter to submit if original_action == "TYPE_AND_SUBMIT": - # Press Enter to submit (WebBench approach - simpler and more reliable) - await runtime.press("Enter") + submitted_without_retyping = False + if self._looks_like_search_submission(step, typed_element): + submitted_without_retyping = await self._submit_if_already_typed( + runtime=runtime, + elements=elements, + input_element_id=element_id, + step=step, + text=text, + pre_url=pre_url or (ctx.snapshot.url or ""), + typed_element=typed_element, + telemetry=search_submit_telemetry, + ) + if not submitted_without_retyping: + typed_ok = False + if self._looks_like_search_submission(step, typed_element): + typed_ok = await self._clear_and_type_search_input( + runtime=runtime, + input_element_id=element_id, + text=text, + ) + if not typed_ok: + await runtime.type(element_id, text, delay_ms=self.config.type_delay_ms) + await self._submit_type_and_submit( + runtime=runtime, + elements=elements, + input_element_id=element_id, + step=step, + text=text, + pre_url=pre_url or (ctx.snapshot.url or ""), + typed_element=typed_element, + telemetry=search_submit_telemetry, + ) if self.config.verbose: print(f" [ACTION] TYPE_AND_SUBMIT({element_id}, '{text}')", flush=True) - await runtime.stabilize() elif self.config.verbose: print(f" [ACTION] TYPE({element_id}, '{text[:30]}...')" if len(text) > 30 else f" [ACTION] TYPE({element_id}, '{text}')", flush=True) elif action_type == "TYPE_AND_SUBMIT" and element_id is not None: @@ -3393,16 +4678,46 @@ async def _execute_step( text = match.group(1).strip() # Type the text - await runtime.type(element_id, text) - - # Press Enter to submit (WebBench approach) - await runtime.press("Enter") + typed_element = None + for el in elements: + if getattr(el, "id", None) == element_id: + typed_element = el + break + submitted_without_retyping = False + if self._looks_like_search_submission(step, typed_element): + submitted_without_retyping = await self._submit_if_already_typed( + runtime=runtime, + elements=elements, + input_element_id=element_id, + step=step, + text=text, + pre_url=pre_url or (ctx.snapshot.url or ""), + typed_element=typed_element, + telemetry=search_submit_telemetry, + ) + if not submitted_without_retyping: + typed_ok = False + if self._looks_like_search_submission(step, typed_element): + typed_ok = await self._clear_and_type_search_input( + runtime=runtime, + input_element_id=element_id, + text=text, + ) + if not typed_ok: + await runtime.type(element_id, text, delay_ms=self.config.type_delay_ms) + await self._submit_type_and_submit( + runtime=runtime, + elements=elements, + input_element_id=element_id, + step=step, + text=text, + pre_url=pre_url or (ctx.snapshot.url or ""), + typed_element=typed_element, + telemetry=search_submit_telemetry, + ) if self.config.verbose: print(f" [ACTION] TYPE_AND_SUBMIT({element_id}, '{text}')", flush=True) - - # Wait for page to load after submit - await runtime.stabilize() elif action_type == "PRESS": key = "Enter" # Default await runtime.press(key) @@ -3423,8 +4738,16 @@ async def _execute_step( # No target URL - we're already at the page, just verify if self.config.verbose: print(f" [ACTION] NAVIGATE(skip - already at page)", flush=True) + elif action_type == "EXTRACT": + pass # Extraction already executed above elif action_type == "FINISH": pass # No action needed + elif action_type == "NONE": + # Executor couldn't find a suitable element (e.g., no search box) + # This triggers replanning to try an alternative approach + error = f"No suitable element found for step: {step.goal}" + if self.config.verbose: + print(f" [EXECUTOR] NONE - no suitable element found, will trigger replan", flush=True) elif action_type not in ("CLICK", "TYPE", "TYPE_AND_SUBMIT") or element_id is None: if action_type in ("CLICK", "TYPE", "TYPE_AND_SUBMIT"): error = f"No element ID for {action_type}" @@ -3436,7 +4759,14 @@ async def _execute_step( await runtime.record_action(action_taken) # Run verifications - if step.verify and error is None: + if action_type == "EXTRACT" and error is None: + verification_passed = extraction_succeeded + if self.config.verbose: + print( + f" [VERIFY] Using extraction result: {'PASS' if verification_passed else 'FAIL'}", + flush=True, + ) + elif step.verify and error is None: if self.config.verbose: print(f" [VERIFY] Running {len(step.verify)} verification predicates...", flush=True) verification_passed = await self._verify_step(runtime, step) @@ -3485,11 +4815,142 @@ async def _execute_step( if not verification_passed and original_action in ("TYPE_AND_SUBMIT", "CLICK"): current_url = await runtime.get_url() if hasattr(runtime, "get_url") else None if current_url and pre_url and current_url != pre_url: - # URL changed - the action likely achieved navigation - if self.config.verbose: - print(f" [VERIFY] Predicate failed but URL changed: {pre_url} -> {current_url}", flush=True) - print(f" [VERIFY] Accepting {original_action} as successful (URL change fallback)", flush=True) - verification_passed = True + # Check if this is a meaningful URL change (not just anchor change) + # Strip anchors (#...) before comparing + pre_url_base = pre_url.split("#")[0] + current_url_base = current_url.split("#")[0] + is_meaningful_change = pre_url_base != current_url_base + + fallback_ok = is_meaningful_change + if original_action == "TYPE_AND_SUBMIT": + typed_element = None + for el in (ctx.snapshot.elements or []): + if getattr(el, "id", None) == element_id: + typed_element = el + break + fallback_ok = self._type_and_submit_url_change_looks_valid( + pre_url=pre_url, + post_url=current_url, + step=step, + element=typed_element, + typed_text=executor_text or step.input or "", + ) + elif original_action == "CLICK" and is_meaningful_change: + # For CLICK actions, validate URL change matches step intent + # This prevents accepting wrong category navigations + url_matches_intent = self._url_change_matches_intent( + step=step, + pre_url=pre_url, + post_url=current_url, + ) + if not url_matches_intent: + fallback_ok = False + if self.config.verbose: + print(f" [VERIFY] URL changed but doesn't match step intent", flush=True) + print(f" [VERIFY] Step target: {step.target}, URL: {current_url}", flush=True) + + # Special handling for Add to Cart steps: if we were on search results + # and navigated to a product page, retry the step on the new page + # instead of accepting URL change as success + is_add_to_cart = self._is_add_to_cart_step(step) + was_on_search_results = self._is_search_results_url(pre_url) + now_on_product_page = not self._is_search_results_url(current_url) + + if is_add_to_cart and was_on_search_results and now_on_product_page and is_meaningful_change: + # We clicked a product link instead of Add to Cart + # Retry the step on the product page + if self.config.verbose: + print(f" [ADD-TO-CART] Navigated from search results to product page: {pre_url} -> {current_url}", flush=True) + print(f" [ADD-TO-CART] Retrying Add to Cart action on product page...", flush=True) + + # Get fresh snapshot on the product page + await asyncio.sleep(0.5) # Brief wait for page to load + try: + retry_ctx = await self._get_execution_context( + runtime, step, step_index + ) + # Build prompt for retry - looking for Add to Cart on product page + retry_prompt = self.build_executor_prompt( + goal=step.goal, + elements=retry_ctx.snapshot.elements or [], + intent=step.intent, + task_category=retry_ctx.task_category, + input_text=step.input, + ) + if self.config.verbose: + print(f" [ADD-TO-CART] Asking executor to find Add to Cart button...", flush=True) + + retry_resp = self.executor.generate( + retry_prompt["system"], + retry_prompt["user"], + max_tokens=self.config.executor_max_tokens, + ) + self._usage.record(role="executor", resp=retry_resp) + retry_action = retry_resp.content.strip() + if self.config.verbose: + print(f" [ADD-TO-CART] Retry executor output: {retry_action}", flush=True) + + # Parse and execute retry action + retry_match = re.match(r"CLICK\((\d+)\)", retry_action) + if retry_match: + retry_element_id = int(retry_match.group(1)) + await runtime.click(retry_element_id) + if self.config.verbose: + print(f" [ADD-TO-CART] Clicked element {retry_element_id}", flush=True) + + # Wait and verify + await asyncio.sleep(0.5) + verification_passed = await self._verify_step(runtime, step) + if verification_passed: + if self.config.verbose: + print(f" [ADD-TO-CART] Add to Cart successful after retry!", flush=True) + else: + # Check for DOM change (cart drawer/modal) + post_retry_snap = await runtime.snapshot(SnapshotOptions(limit=50)) + if post_retry_snap and hasattr(post_retry_snap, "elements"): + post_els = post_retry_snap.elements or [] + cart_indicators = ["cart", "bag", "basket", "checkout", "added", "item"] + has_cart_indicator = any( + any(ind in (getattr(el, "text", "") or "").lower() for ind in cart_indicators) + for el in post_els[:30] + ) + if has_cart_indicator: + if self.config.verbose: + print(f" [ADD-TO-CART] Cart indicator detected, accepting as success", flush=True) + verification_passed = True + except Exception as retry_err: + if self.config.verbose: + print(f" [ADD-TO-CART] Retry failed: {retry_err}", flush=True) + + # Skip the normal URL fallback since we handled Add to Cart specially + fallback_ok = False + + if fallback_ok: + if self.config.verbose: + print(f" [VERIFY] Predicate failed but URL changed: {pre_url} -> {current_url}", flush=True) + print(f" [VERIFY] Accepting {original_action} as successful (URL change fallback)", flush=True) + verification_passed = True + else: + if self.config.verbose: + print(f" [VERIFY] URL changed but does not match {original_action} intent: {pre_url} -> {current_url}", flush=True) + if original_action == "TYPE_AND_SUBMIT": + typed_element = None + for el in (ctx.snapshot.elements or []): + if getattr(el, "id", None) == element_id: + typed_element = el + break + if self._looks_like_search_submission(step, typed_element): + # Try retry submission with alternate method + verification_passed = await self._retry_search_widget_submission( + runtime=runtime, + elements=elements, + input_element_id=element_id, + step=step, + pre_url=pre_url, + text=executor_text or step.input or "", + typed_element=typed_element, + telemetry=search_submit_telemetry, + ) elif original_action == "CLICK" and error is None and element_id is not None: # For CLICK actions that don't change URL, check if DOM changed # (e.g., modal appeared, cart drawer opened) @@ -3582,6 +5043,7 @@ async def _execute_step( duration_ms=duration_ms, url_before=pre_url, url_after=post_url, + extracted_data=extracted_data, ) # Emit step_end trace @@ -3691,6 +5153,7 @@ async def run( start_url = automation_task.starting_url self._current_task = automation_task + self._cached_pruning_category = None # Reset category cache for new run self._run_id = run_id or automation_task.task_id self._replans_used = 0 self._vision_calls = 0 @@ -3719,9 +5182,24 @@ async def run( step_outcomes: list[StepOutcome] = [] error: str | None = None + # Optionally fetch page context (markdown) for better planning + page_context: str | None = None + if self.config.use_page_context: + try: + page_context = await runtime.read_markdown( + max_chars=self.config.page_context_max_chars + ) + if self.config.verbose and page_context: + print(f" [PAGE-CONTEXT] Extracted {len(page_context)} chars of markdown for planning", flush=True) + print("\n--- Page Context (Markdown) ---", flush=True) + print(page_context, flush=True) + print("--- End Page Context ---\n", flush=True) + except Exception: + pass # Fail silently - page context is optional + try: # Generate plan - plan = await self.plan(task_description, start_url=start_url) + plan = await self.plan(task_description, start_url=start_url, page_context=page_context) # Execute steps step_index = 0 @@ -3926,6 +5404,7 @@ async def run( continuation_task = self._build_checkout_continuation_task( task_description, page_type ) + # Note: page_context (markdown) is only extracted once during initial planning plan = await self.plan(continuation_task, start_url=None) step_index = 0 # Start from beginning of new plan self._replans_used += 1 @@ -4106,6 +5585,7 @@ async def run_stepwise( start_url = automation_task.starting_url self._current_task = automation_task + self._cached_pruning_category = None # Reset category cache for new run self._run_id = run_id or automation_task.task_id self._replans_used = 0 self._vision_calls = 0 diff --git a/predicate/models.py b/predicate/models.py index 2fbf8ea..9f1082a 100644 --- a/predicate/models.py +++ b/predicate/models.py @@ -1165,6 +1165,7 @@ class ExtractResult(BaseModel): data: Any | None = None raw: str | None = None error: str | None = None + llm_response: Any | None = None class TraceStats(BaseModel): diff --git a/predicate/overlay_dismissal.py b/predicate/overlay_dismissal.py new file mode 100644 index 0000000..2a209e8 --- /dev/null +++ b/predicate/overlay_dismissal.py @@ -0,0 +1,717 @@ +""" +Overlay Dismissal Utilities for proactive blocking overlay removal. + +This module provides proactive overlay/modal dismissal to clear blocking elements +(cookie banners, newsletter popups, promotional overlays) BEFORE the agent starts +executing its plan. + +The SDK's built-in ModalDismissalConfig (in planner_executor_agent.py) triggers +AFTER DOM changes from actions, but many sites show blocking overlays immediately +on page load. This module handles those initial blocking overlays. + +Key features: +- Multiple overlay detection strategies (modal_detected, ARIA roles, z-index, class names) +- ESC key press as first attempt +- Scoring system for close/accept buttons +- Iterative clicking with verification +- Wall-clock timeout to avoid stalling + +Usage: + from predicate.overlay_dismissal import dismiss_overlays_before_agent + + # After page load, before agent run: + result = await dismiss_overlays_before_agent(runtime, browser) + print(f"Dismissed {result.overlays_before - result.overlays_after} overlays") +""" + +from __future__ import annotations + +import logging +import re +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +if TYPE_CHECKING: + from .agent_runtime import AgentRuntime + from .browser import AsyncPredicateBrowser + +logger = logging.getLogger(__name__) + +# Common class name patterns for overlays/modals +_OVERLAY_CLASS_PATTERNS = frozenset({ + "modal", + "overlay", + "popup", + "dialog", + "lightbox", + "subscribe", + "newsletter", + "paywall", + "interstitial", + "splash", + "promo", + "announcement", + "banner", + "toast", + "drawer", + "sheet", + "cookie", + "consent", + "gdpr", +}) + +# Z-index threshold for overlay detection +_OVERLAY_Z_INDEX_THRESHOLD = 1000 + + +@dataclass(frozen=True) +class OverlayDismissResult: + """Result of overlay dismissal attempt.""" + + actions: tuple[str, ...] + overlays_before: int + overlays_after: int + status: str + + +def _norm(s: Any) -> str: + """Normalize string for comparison.""" + return str(s or "").strip().lower() + + +def _is_overlay_role(role: str) -> bool: + """Check if role indicates an overlay.""" + r = (role or "").strip().lower() + return r in {"dialog", "alertdialog"} + + +def _has_overlay_class(class_name: str | None) -> bool: + """Check if element has a class name suggesting it's an overlay.""" + if not class_name: + return False + cn = class_name.lower() + return any(p in cn for p in _OVERLAY_CLASS_PATTERNS) + + +def _count_overlays(snapshot: Any) -> int: + """ + Count overlays using multiple detection strategies. + + Detection methods: + 1. Gateway-provided modal_detected/modal_grids (z-index based) + 2. ARIA role-based detection (dialog, alertdialog) + 3. Class name-based detection (modal, overlay, popup, etc.) + 4. Z-index based detection (elements with z_index >= 1000) + 5. Dismiss button heuristic (presence of common dismiss buttons) + """ + try: + # Prefer gateway-provided modal detection if available + modal_detected = getattr(snapshot, "modal_detected", None) + modal_grids = getattr(snapshot, "modal_grids", None) + if modal_detected is True: + return max(1, len(modal_grids or [])) + if modal_detected is None and modal_grids is not None and len(modal_grids) > 0: + return len(modal_grids) + + # Heuristic: check for dismiss button patterns that indicate overlays + # Common overlay dismiss buttons + dismiss_indicators = ( + "close dialog", + "close modal", + "close popup", + "accept", + "decline", + "preferences", + "cookie", + "no thanks", + "not now", + "maybe later", + "dismiss", + ) + els = getattr(snapshot, "elements", None) or [] + dismiss_button_count = 0 + for el in els: + role = _norm(getattr(el, "role", "")) + if role != "button": + continue + text = _norm(getattr(el, "text", "") or "") + if any(p in text for p in dismiss_indicators): + dismiss_button_count += 1 + + # If we see 2+ dismiss-style buttons, there's likely an overlay + if dismiss_button_count >= 2: + return 1 + + # Fallback: scan elements for overlay patterns + els = getattr(snapshot, "elements", None) or [] + n = 0 + seen_high_z = False + + # Get viewport for size heuristics + vp_w = None + vp_h = None + try: + vp = getattr(snapshot, "viewport", None) + if vp is not None: + vp_w = getattr(vp, "width", None) + vp_h = getattr(vp, "height", None) + except Exception: + pass + + def _is_large_overlay(el: Any) -> bool: + """Check if element is large enough to be a blocking overlay.""" + try: + bbox = getattr(el, "bbox", None) + if bbox is None: + return False + w = float(getattr(bbox, "width", 0.0) or 0.0) + h = float(getattr(bbox, "height", 0.0) or 0.0) + if w <= 0.0 or h <= 0.0: + return False + area = w * h + # Minimum area threshold (~400x300) + if area < 120_000.0: + return False + # Check viewport coverage if available + if vp_w and vp_h: + try: + area_ratio = area / (float(vp_w) * float(vp_h)) + if area_ratio < 0.10: + return False + except Exception: + pass + return True + except Exception: + return False + + for el in els: + # Check ARIA role + if _is_overlay_role(getattr(el, "role", "")): + n += 1 + continue + # Check class name for overlay patterns + class_name = getattr(el, "class_name", None) or getattr(el, "className", None) + if _has_overlay_class(class_name): + n += 1 + continue + # Check z-index for high-z elements + z_index = getattr(el, "z_index", None) + if ( + z_index is not None + and z_index >= _OVERLAY_Z_INDEX_THRESHOLD + and not seen_high_z + and _is_large_overlay(el) + ): + role_l = _norm(getattr(el, "role", "")) + if role_l in {"button", "link"}: + # Only count if class strongly suggests overlay + if not _has_overlay_class(str(class_name or "")): + continue + seen_high_z = True + n += 1 + return n + except Exception: + return 0 + + +def _word_match(pattern: str, text: str) -> bool: + """Match pattern as word boundary (not substring of longer word).""" + if len(pattern) <= 2: + # For short patterns (icons), require exact match + return text == pattern or text.strip() == pattern + # For longer patterns, use word boundary + try: + return bool(re.search(r"\b" + re.escape(pattern) + r"\b", text)) + except Exception: + return pattern in text + + +# Button text patterns for scoring +_ACCEPT_PHRASES = ( + "accept all", + "accept", + "agree", + "i agree", + "allow all", + "allow", + "okay", + "got it", + "continue", + "i understand", +) +_ACCEPT_EXACT = ("ok", "yes") + +_CLOSE_PHRASES = ( + "close", + "dismiss", + "cancel", + "skip", + "no thanks", + "no, thanks", + "reject", + "decline", + "deny", + "not now", + "maybe later", + "not interested", + "no thank you", + "close dialog", + "close modal", + "close popup", + "close overlay", + "close banner", + "dismiss banner", + "dismiss dialog", +) +_CLOSE_ICONS = ("x", "\u00d7", "\u2715", "\u2716", "\u2717", "\u2573", "\u24e7") +_CLOSE_EXACT = ("later",) + +_AVOID_WORDS = ( + "learn more", + "more info", + "manage preferences", + "preferences", + "settings", + "customize", + "options", + "details", + "policy", + "sign up", + "sign in", + "login", + "log in", + "register", + "create account", + "subscribe", + "submit", + "join", + "get started", +) + + +def _is_clickable_control(el: Any) -> bool: + """Check if element is a clickable control.""" + try: + role = _norm(getattr(el, "role", "")) + if role in {"button", "link"}: + return True + vc = getattr(el, "visual_cues", None) + if vc is not None and bool(getattr(vc, "is_clickable", False)): + return True + except Exception: + pass + return False + + +def _label_variants(el: Any) -> list[str]: + """Get multiple label sources for an element.""" + out: list[str] = [] + try: + out.append(_norm(getattr(el, "text", None) or "")) + out.append(_norm(getattr(el, "name", None) or "")) + out.append( + _norm(getattr(el, "aria_label", None) or getattr(el, "ariaLabel", None) or "") + ) + out.append(_norm(getattr(el, "title", None) or "")) + except Exception: + pass + return [s for s in out if s] + + +def _collect_candidates( + elements: list[Any], + page_host: str, + overlay_bbox: tuple[float, float, float, float] | None, +) -> list[tuple[int, Any, str]]: + """Collect and score candidate dismiss buttons.""" + candidates: list[tuple[int, Any, str]] = [] + + for el in elements: + # Skip occluded elements + if bool(getattr(el, "is_occluded", False)): + continue + if not _is_clickable_control(el): + continue + + # If we have overlay bbox, filter to controls within it + if overlay_bbox is not None: + try: + bbox = getattr(el, "bbox", None) + if bbox is not None: + bx, by, bw, bh = overlay_bbox + ex = float(getattr(bbox, "x", 0.0) or 0.0) + ey = float(getattr(bbox, "y", 0.0) or 0.0) + ew = float(getattr(bbox, "width", 0.0) or 0.0) + eh = float(getattr(bbox, "height", 0.0) or 0.0) + cx, cy = ex + ew / 2, ey + eh / 2 + pad = 24.0 + if not ( + (bx - pad) <= cx <= (bx + bw + pad) + and (by - pad) <= cy <= (by + bh + pad) + ): + continue + except Exception: + pass + + labels = _label_variants(el) + if not labels: + continue + label = labels[0] + score = 0 + + # Penalize external links + try: + role = _norm(getattr(el, "role", "")) + href = str(getattr(el, "href", None) or "").strip() + if role == "link" and href: + href_host = "" + try: + href_host = (urlparse(href).hostname or "").lower() + except Exception: + pass + if href_host and page_host and href_host != page_host: + score -= 200 + except Exception: + pass + + # Score based on button text + for lbl in labels: + # Accept patterns + if any(_word_match(k, lbl) for k in _ACCEPT_PHRASES): + score += 100 + if lbl in _ACCEPT_EXACT: + score += 100 + # Close patterns + if lbl in _CLOSE_EXACT: + score += 80 + if any(_word_match(k, lbl) for k in _CLOSE_PHRASES): + score += 80 + if lbl in _CLOSE_ICONS: + score += 80 + # Cookie/consent bonus + if "cookie" in lbl or "consent" in lbl: + score += 20 + # Avoid patterns + if any(k in lbl for k in _AVOID_WORDS): + score -= 50 + # Newsletter penalty for accept + if "newsletter" in lbl or "subscribe" in lbl: + if any(_word_match(k, lbl) for k in _ACCEPT_PHRASES): + score -= 10 + + if score > 0: + candidates.append((score, el, label)) + + candidates.sort(key=lambda t: t[0], reverse=True) + return candidates + + +def _best_overlay_bbox(snap: Any) -> tuple[float, float, float, float] | None: + """Find the best overlay container bounding box.""" + try: + els = getattr(snap, "elements", None) or [] + except Exception: + return None + + best = None + best_area = 0.0 + for el in els: + try: + z = getattr(el, "z_index", None) + if z is None or float(z) < float(_OVERLAY_Z_INDEX_THRESHOLD): + continue + role = _norm(getattr(el, "role", "")) + if role in {"button", "link"}: + continue + bbox = getattr(el, "bbox", None) + if bbox is None: + continue + w = float(getattr(bbox, "width", 0.0) or 0.0) + h = float(getattr(bbox, "height", 0.0) or 0.0) + if w <= 0.0 or h <= 0.0: + continue + area = w * h + if area > best_area: + best_area = area + best = ( + float(getattr(bbox, "x", 0.0) or 0.0), + float(getattr(bbox, "y", 0.0) or 0.0), + w, + h, + ) + except Exception: + continue + return best + + +async def dismiss_overlays( + runtime: "AgentRuntime", + browser: "AsyncPredicateBrowser", + *, + max_rounds: int = 2, + snapshot_limit: int = 100, + max_clicks_per_round: int = 3, + use_api: bool | None = None, + max_seconds: float = 8.0, + verbose: bool = False, +) -> OverlayDismissResult: + """ + Best-effort cross-site overlay dismissal (cookie banners, modals, popups). + + This function attempts to dismiss blocking overlays before the agent runs. + It uses multiple detection strategies and clicks dismiss/accept buttons. + + Args: + runtime: AgentRuntime instance + browser: AsyncPredicateBrowser instance + max_rounds: Maximum dismissal rounds + snapshot_limit: Element limit for snapshots + max_clicks_per_round: Maximum clicks per round + use_api: Force API-based snapshots (None = auto) + max_seconds: Wall-clock timeout + verbose: Print debug info + + Returns: + OverlayDismissResult with actions taken and overlay counts + + Example: + from predicate import AgentRuntime, AsyncPredicateBrowser + from predicate.overlay_dismissal import dismiss_overlays + + async with AsyncPredicateBrowser() as browser: + await browser.goto("https://example.com") + runtime = AgentRuntime(backend=browser.backend) + + result = await dismiss_overlays(runtime, browser, verbose=True) + print(f"Status: {result.status}") + """ + actions: list[str] = [] + status = "unknown" + + # Snapshot options + snap_kwargs: dict[str, Any] = {"limit": snapshot_limit} + if use_api is not None: + snap_kwargs["use_api"] = use_api + + # Initial scan + snap0 = await runtime.snapshot(goal="overlay_scan", **snap_kwargs) + overlays_before = _count_overlays(snap0) + + if verbose: + modal_detected = getattr(snap0, "modal_detected", None) + modal_grids = getattr(snap0, "modal_grids", None) + logger.info( + f"[OVERLAY] Initial scan: overlays={overlays_before}, " + f"modal_detected={modal_detected}, modal_grids={len(modal_grids or [])}" + ) + + # If no overlays, return immediately + if overlays_before <= 0: + return OverlayDismissResult( + actions=tuple(actions), + overlays_before=0, + overlays_after=0, + status="none", + ) + + start_t = time.monotonic() + overlay_bbox = _best_overlay_bbox(snap0) + + # Get page host for external link detection + page_host = "" + try: + page_url = str(browser.page.url or "") + page_host = (urlparse(page_url).hostname or "").lower() + except Exception: + pass + + attempted_click_any = False + + for _round in range(max_rounds): + # Timeout check + if (time.monotonic() - start_t) > max_seconds and attempted_click_any: + status = "timeout" + if verbose: + logger.info("[OVERLAY] Timeout reached") + break + + if verbose: + logger.info(f"[OVERLAY] Round {_round + 1}/{max_rounds}") + + # Try ESC first + try: + await browser.page.keyboard.press("Escape") + actions.append('PRESS("Escape")') + if verbose: + logger.info("[OVERLAY] Pressed Escape") + except Exception: + pass + + # Re-scan + snap = await runtime.snapshot(goal="overlay_scan", **snap_kwargs) + + # Check if overlays are gone + if _count_overlays(snap) == 0: + status = "gone" + if verbose: + logger.info("[OVERLAY] Overlays dismissed by Escape") + break + + # Collect candidates + elements = getattr(snap, "elements", None) or [] + candidates = _collect_candidates(elements, page_host, overlay_bbox) + + if verbose: + logger.info(f"[OVERLAY] Found {len(candidates)} candidates") + for i, (sc, _el, lbl) in enumerate(candidates[:5]): + logger.info(f" [{i}] score={sc} label={lbl[:40]!r}") + + if not candidates: + status = "no_candidates" + if verbose: + logger.info("[OVERLAY] No dismiss candidates found") + break + + # Click candidates + clicks = 0 + clicked_labels: set[str] = set() + + while clicks < max_clicks_per_round: + if (time.monotonic() - start_t) > max_seconds and attempted_click_any: + status = "timeout" + break + + # Filter already-clicked + candidates = [c for c in candidates if c[2] not in clicked_labels] + if not candidates: + break + + _score, el, label = candidates[0] + bbox = getattr(el, "bbox", None) + if bbox is None: + clicked_labels.add(label) + candidates = candidates[1:] + continue + + try: + bbox_x = float(getattr(bbox, "x", 0.0)) + bbox_y = float(getattr(bbox, "y", 0.0)) + bbox_w = float(getattr(bbox, "width", 0.0)) + bbox_h = float(getattr(bbox, "height", 0.0)) + x = bbox_x + bbox_w / 2.0 + y = bbox_y + bbox_h / 2.0 + + if verbose: + logger.info(f"[OVERLAY] Clicking '{label[:30]}' at ({x:.0f}, {y:.0f})") + + # Click using backend + await runtime.backend.mouse_click(x, y) + actions.append(f'OVERLAY_CLICK("{label[:40]}")') + clicked_labels.add(label) + clicks += 1 + attempted_click_any = True + except Exception as e: + if verbose: + logger.warning(f"[OVERLAY] Click failed: {e}") + clicked_labels.add(label) + continue + + # Wait for UI to settle + try: + await browser.page.wait_for_timeout(350) + except Exception: + pass + + # Check if overlay is gone + snap_after = await runtime.snapshot(goal="overlay_verify", **snap_kwargs) + if _count_overlays(snap_after) == 0: + status = "gone" + if verbose: + logger.info("[OVERLAY] Overlay dismissed after click") + break + + # Update candidates from new snapshot + try: + overlay_bbox = _best_overlay_bbox(snap_after) + candidates = _collect_candidates( + getattr(snap_after, "elements", None) or [], + page_host, + overlay_bbox, + ) + except Exception: + pass + + if status == "gone": + break + + # Final count + final_snap = await runtime.snapshot(goal="overlay_final", **snap_kwargs) + overlays_after = _count_overlays(final_snap) + + if overlays_after == 0 and status not in ("gone", "none"): + status = "gone" + elif overlays_after > 0 and status not in ("timeout", "no_candidates"): + status = "partial" + + if verbose: + logger.info( + f"[OVERLAY] Done: status={status}, before={overlays_before}, after={overlays_after}" + ) + + return OverlayDismissResult( + actions=tuple(actions), + overlays_before=overlays_before, + overlays_after=overlays_after, + status=status, + ) + + +async def dismiss_overlays_before_agent( + runtime: "AgentRuntime", + browser: "AsyncPredicateBrowser", + *, + use_api: bool | None = None, + verbose: bool = False, +) -> OverlayDismissResult: + """ + Convenience wrapper to dismiss overlays before agent execution. + + This should be called after page load and before the agent's run() method. + It handles the common case of initial page overlays (cookie banners, popups). + + Args: + runtime: AgentRuntime instance + browser: AsyncPredicateBrowser instance + use_api: Force API-based snapshots during overlay handling + verbose: Print debug info + + Returns: + OverlayDismissResult + + Example: + from predicate import AgentRuntime + from predicate.browser import AsyncPredicateBrowser + from predicate.overlay_dismissal import dismiss_overlays_before_agent + + async with AsyncPredicateBrowser() as browser: + await browser.goto(url) + runtime = AgentRuntime(backend=browser.backend) + + # Dismiss initial overlays + result = await dismiss_overlays_before_agent(runtime, browser, verbose=True) + + # Now run the agent + await agent.run(runtime, task) + """ + return await dismiss_overlays( + runtime, + browser, + max_rounds=3, # More rounds to handle multiple overlays + snapshot_limit=100, + max_clicks_per_round=4, # More clicks per round + use_api=use_api, + max_seconds=12.0, # More time for complex sites + verbose=verbose, + ) diff --git a/predicate/pruning/__init__.py b/predicate/pruning/__init__.py new file mode 100644 index 0000000..d8b2a24 --- /dev/null +++ b/predicate/pruning/__init__.py @@ -0,0 +1,32 @@ +""" +Category-specific snapshot pruning helpers. + +Supports category-aware pruning with automatic over-pruning recovery: +- Rule-based category classification (no LLM needed) +- Deterministic allow/block policies per category +- Relaxation levels for recovery when pruning is too aggressive +""" + +from .classifier import classify_task_category +from .policies import get_pruning_policy, PruningPolicy +from .pruner import prune_snapshot_for_task, prune_with_recovery +from .serializer import serialize_pruned_snapshot +from .types import ( + CategoryDetectionResult, + PrunedSnapshotContext, + PruningTaskCategory, + SkeletonDomNode, +) + +__all__ = [ + "CategoryDetectionResult", + "PrunedSnapshotContext", + "PruningPolicy", + "PruningTaskCategory", + "SkeletonDomNode", + "classify_task_category", + "get_pruning_policy", + "prune_snapshot_for_task", + "prune_with_recovery", + "serialize_pruned_snapshot", +] diff --git a/predicate/pruning/classifier.py b/predicate/pruning/classifier.py new file mode 100644 index 0000000..f157251 --- /dev/null +++ b/predicate/pruning/classifier.py @@ -0,0 +1,71 @@ +""" +Rule-based task classifier for pruning categories. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from .types import CategoryDetectionResult, PruningTaskCategory + + +def _normalize_hints(domain_hints: Iterable[str] | None) -> set[str]: + return {hint.strip().lower() for hint in domain_hints or () if str(hint).strip()} + + +def classify_task_category( + *, + task_text: str, + current_url: str | None = None, + domain_hints: Iterable[str] | None = None, + task_category: Any | None = None, +) -> CategoryDetectionResult: + """ + Classify a browser task into a pruning-oriented category using rules only. + """ + + text = (task_text or "").lower() + url = (current_url or "").lower() + hints = _normalize_hints(domain_hints) + category_value = str(getattr(task_category, "value", task_category) or "").lower() + + if category_value == "form_fill": + return CategoryDetectionResult(PruningTaskCategory.FORM_FILLING, 0.90) + if category_value == "search": + return CategoryDetectionResult(PruningTaskCategory.SEARCH, 0.90) + if category_value == "extraction": + return CategoryDetectionResult(PruningTaskCategory.EXTRACTION, 0.90) + if category_value == "navigation": + return CategoryDetectionResult(PruningTaskCategory.NAVIGATION, 0.90) + if category_value == "verification": + return CategoryDetectionResult(PruningTaskCategory.VERIFICATION, 0.90) + + if any(keyword in text for keyword in ("add to cart", "add it to cart", "add to bag", "buy now", "purchase")): + return CategoryDetectionResult(PruningTaskCategory.SHOPPING, 0.95) + if "checkout" in text: + return CategoryDetectionResult(PruningTaskCategory.CHECKOUT, 0.95) + if any(keyword in text for keyword in ("fill out", "submit form", "contact form", "enter email", "type into")): + return CategoryDetectionResult(PruningTaskCategory.FORM_FILLING, 0.90) + if "sign in" in text or "login" in text or "password" in text: + return CategoryDetectionResult(PruningTaskCategory.AUTH, 0.90) + if any(keyword in text for keyword in ("extract", "list the", "count the", "scrape")): + return CategoryDetectionResult(PruningTaskCategory.EXTRACTION, 0.80) + if any(keyword in text for keyword in ("search for", "look up")) or "find" in text: + return CategoryDetectionResult(PruningTaskCategory.SEARCH, 0.85) + + if "ecommerce" in hints: + if category_value == "transaction": + return CategoryDetectionResult(PruningTaskCategory.SHOPPING, 0.85) + return CategoryDetectionResult(PruningTaskCategory.SHOPPING, 0.65) + + if hints & {"forms", "contact", "signup"}: + return CategoryDetectionResult(PruningTaskCategory.FORM_FILLING, 0.65) + + if hints & {"reference", "news", "recipes", "weather", "travel", "flights", "hotels", "social"}: + return CategoryDetectionResult(PruningTaskCategory.SEARCH, 0.65) + + if "checkout" in url or "cart" in url or "bag" in url: + return CategoryDetectionResult(PruningTaskCategory.CHECKOUT, 0.60) + + return CategoryDetectionResult(PruningTaskCategory.GENERIC, 0.0) diff --git a/predicate/pruning/policies.py b/predicate/pruning/policies.py new file mode 100644 index 0000000..d2eb42f --- /dev/null +++ b/predicate/pruning/policies.py @@ -0,0 +1,278 @@ +""" +Deterministic pruning policies for browser-agent snapshots. + +Supports relaxation levels for over-pruning recovery: +- Level 0: Strict category-specific pruning (default) +- Level 1: Relaxed - allow more interactive roles +- Level 2: Loose - include nearly all interactive elements +- Level 3+: Fallback to generic (minimal pruning) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from .types import PruningTaskCategory + +NodePredicate = Callable[[Any, str], bool] + + +@dataclass(frozen=True) +class PruningPolicy: + """Simple allow/block policy for an initial pruning category.""" + + category: PruningTaskCategory + max_nodes: int + allow: NodePredicate + block: NodePredicate + relaxation_level: int = 0 + + def with_relaxation(self, level: int) -> "PruningPolicy": + """Return a new policy with the specified relaxation level.""" + return PruningPolicy( + category=self.category, + max_nodes=self.max_nodes + (level * 15), # Increase budget per level + allow=self.allow, + block=self.block, + relaxation_level=level, + ) + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + +def _text(el: Any) -> str: + return str(getattr(el, "text", "") or "").lower() + + +def _href(el: Any) -> str: + return str(getattr(el, "href", "") or "").lower() + + +def _nearby_text(el: Any) -> str: + return str(getattr(el, "nearby_text", "") or "").lower() + + +def _role(el: Any) -> str: + return str(getattr(el, "role", "") or "").lower() + + +def _is_interactive(el: Any) -> bool: + """Check if element has an interactive role.""" + role = _role(el) + return role in { + "button", "link", "textbox", "searchbox", "combobox", + "checkbox", "radio", "slider", "tab", "menuitem", + "option", "switch", "cell", "input", "select", "textarea", + } + + +# --------------------------------------------------------------------------- +# Category-specific allow predicates +# --------------------------------------------------------------------------- + +def _allow_shopping(el: Any, goal: str) -> bool: + text = _text(el) + nearby = _nearby_text(el) + href = _href(el) + role = _role(el) + + if role in {"button", "link", "textbox", "searchbox", "combobox"} and ( + "add to cart" in text + or "add to bag" in text + or "buy now" in text + or "checkout" in text + or "cart" in text + ): + return True + if role == "link" and href and getattr(el, "in_dominant_group", False): + return True + if "$" in text or "price" in nearby: + return True + if role in {"textbox", "searchbox", "combobox"} and "search" in text: + return True + if getattr(el, "in_dominant_group", False) and len(text.strip()) >= 3: + return True + return False + + +def _allow_shopping_relaxed(el: Any, goal: str) -> bool: + """Level 1 relaxation: include more interactive elements.""" + if _allow_shopping(el, goal): + return True + role = _role(el) + # Add more buttons and links even outside dominant group + if role in {"button", "link"} and len(_text(el).strip()) >= 2: + return True + # Include quantity selectors, size selectors + if role in {"select", "combobox", "listbox"}: + return True + return False + + +def _allow_shopping_loose(el: Any, goal: str) -> bool: + """Level 2 relaxation: include nearly all interactive elements.""" + if _allow_shopping_relaxed(el, goal): + return True + return _is_interactive(el) + + +def _allow_form_filling(el: Any, goal: str) -> bool: + text = _text(el) + role = _role(el) + if role in {"textbox", "searchbox", "combobox", "checkbox", "radio", "textarea"}: + return True + if role == "button" and any(token in text for token in ("submit", "send", "continue", "sign up")): + return True + return False + + +def _allow_form_filling_relaxed(el: Any, goal: str) -> bool: + """Level 1 relaxation for form filling.""" + if _allow_form_filling(el, goal): + return True + role = _role(el) + # Include all buttons and selects + if role in {"button", "select", "listbox", "option"}: + return True + return False + + +def _allow_search(el: Any, goal: str) -> bool: + text = _text(el) + role = _role(el) + href = _href(el) + if role in {"searchbox", "textbox", "combobox"}: + return True + if role == "button" and "search" in text: + return True + if role == "link" and href: + return True + return False + + +def _allow_search_relaxed(el: Any, goal: str) -> bool: + """Level 1 relaxation for search.""" + if _allow_search(el, goal): + return True + role = _role(el) + if role in {"button", "tab", "menuitem"}: + return True + return False + + +def _allow_generic(el: Any, goal: str) -> bool: + role = _role(el) + return role in {"button", "link", "textbox", "searchbox", "combobox", "checkbox", "radio"} + + +def _allow_generic_relaxed(el: Any, goal: str) -> bool: + """Relaxed generic - include all interactive elements.""" + return _is_interactive(el) + + +def _block_common(el: Any, goal: str) -> bool: + text = _text(el) + href = _href(el) + return any(token in text for token in ("privacy policy", "terms", "cookie policy")) or any( + token in href for token in ("/privacy", "/terms", "/cookies") + ) + + +def _block_nothing(el: Any, goal: str) -> bool: + """At high relaxation levels, don't block anything.""" + return False + + +# --------------------------------------------------------------------------- +# Policy factory with relaxation support +# --------------------------------------------------------------------------- + +def get_pruning_policy( + category: PruningTaskCategory, + relaxation_level: int = 0, +) -> PruningPolicy: + """ + Return the deterministic policy for a category with optional relaxation. + + Args: + category: The task category + relaxation_level: 0=strict, 1=relaxed, 2=loose, 3+=fallback + + Returns: + PruningPolicy configured for the category and relaxation level + """ + # At level 3+, fall back to generic with no blocking + if relaxation_level >= 3: + return PruningPolicy( + category=category, + max_nodes=80, + allow=_allow_generic_relaxed, + block=_block_nothing, + relaxation_level=relaxation_level, + ) + + if category in {PruningTaskCategory.SHOPPING, PruningTaskCategory.CHECKOUT}: + if relaxation_level == 0: + allow_fn = _allow_shopping + max_nodes = 25 + elif relaxation_level == 1: + allow_fn = _allow_shopping_relaxed + max_nodes = 40 + else: # level 2 + allow_fn = _allow_shopping_loose + max_nodes = 60 + return PruningPolicy( + category=category, + max_nodes=max_nodes, + allow=allow_fn, + block=_block_common if relaxation_level < 2 else _block_nothing, + relaxation_level=relaxation_level, + ) + + if category == PruningTaskCategory.FORM_FILLING: + if relaxation_level == 0: + allow_fn = _allow_form_filling + max_nodes = 20 + else: + allow_fn = _allow_form_filling_relaxed + max_nodes = 35 + (relaxation_level * 10) + return PruningPolicy( + category=category, + max_nodes=max_nodes, + allow=allow_fn, + block=_block_common if relaxation_level == 0 else _block_nothing, + relaxation_level=relaxation_level, + ) + + if category == PruningTaskCategory.SEARCH: + if relaxation_level == 0: + allow_fn = _allow_search + max_nodes = 20 + else: + allow_fn = _allow_search_relaxed + max_nodes = 35 + (relaxation_level * 10) + return PruningPolicy( + category=category, + max_nodes=max_nodes, + allow=allow_fn, + block=_block_common if relaxation_level == 0 else _block_nothing, + relaxation_level=relaxation_level, + ) + + # Generic or other categories + if relaxation_level == 0: + allow_fn = _allow_generic + max_nodes = 20 + else: + allow_fn = _allow_generic_relaxed + max_nodes = 40 + (relaxation_level * 15) + return PruningPolicy( + category=category, + max_nodes=max_nodes, + allow=allow_fn, + block=_block_common if relaxation_level == 0 else _block_nothing, + relaxation_level=relaxation_level, + ) diff --git a/predicate/pruning/pruner.py b/predicate/pruning/pruner.py new file mode 100644 index 0000000..c216b2d --- /dev/null +++ b/predicate/pruning/pruner.py @@ -0,0 +1,180 @@ +""" +Deterministic snapshot pruning entry points. + +Supports over-pruning recovery via relaxation levels: +- If pruning leaves too few elements, increase relaxation level and re-prune +- Relaxation progressively loosens allow predicates and increases node budgets +""" + +from __future__ import annotations + +from typing import Any + +from .policies import get_pruning_policy +from .serializer import serialize_pruned_snapshot +from .types import PrunedSnapshotContext, PruningTaskCategory, SkeletonDomNode + +# Minimum elements threshold - if pruning leaves fewer, consider relaxation +MIN_PRUNED_ELEMENTS = 5 + + +def _node_score(el: Any, goal: str) -> float: + """Score an element for ranking within the pruned set.""" + text = str(getattr(el, "text", "") or "").lower() + goal_lower = (goal or "").lower() + + score = float(getattr(el, "importance", 0) or 0) + if bool(getattr(el, "in_viewport", True)): + score += 25.0 + if bool(getattr(el, "in_dominant_group", False)): + score += 20.0 + visual_cues = getattr(el, "visual_cues", None) + if visual_cues is not None and bool(getattr(visual_cues, "is_clickable", False)): + score += 15.0 + if text and goal_lower and any(token in text for token in goal_lower.split()): + score += 10.0 + if "$" in text: + score += 8.0 + return score + + +def _semantic_tags(el: Any) -> tuple[str, ...]: + """Derive semantic tags from element properties.""" + text = str(getattr(el, "text", "") or "").lower() + role = str(getattr(el, "role", "") or "").lower() + tags: list[str] = [] + if "$" in text: + tags.append("price") + if "add to cart" in text or "add to bag" in text: + tags.append("add_to_cart") + if "checkout" in text or "cart" in text: + tags.append("checkout") + if role in {"searchbox", "textbox", "combobox"} and "search" in text: + tags.append("search_input") + if role == "link" and len(text.strip()) >= 3: + tags.append("product_title") + return tuple(tags) + + +def prune_snapshot_for_task( + snapshot: Any, + *, + goal: str, + category: PruningTaskCategory, + relaxation_level: int = 0, +) -> PrunedSnapshotContext: + """ + Prune a snapshot deterministically for the given category. + + Args: + snapshot: The snapshot to prune + goal: The task goal for context-aware scoring + category: The detected task category + relaxation_level: 0=strict, 1=relaxed, 2=loose, 3+=fallback + + Returns: + PrunedSnapshotContext with the pruned nodes and metadata + """ + all_elements = getattr(snapshot, "elements", []) or [] + raw_count = len(all_elements) + + policy = get_pruning_policy(category, relaxation_level) + kept: list[Any] = [] + + for el in all_elements: + if policy.block(el, goal): + continue + if policy.allow(el, goal): + kept.append(el) + + kept.sort(key=lambda el: _node_score(el, goal), reverse=True) + selected = kept[: policy.max_nodes] + + nodes = tuple( + SkeletonDomNode( + id=int(getattr(el, "id")), + role=str(getattr(el, "role", "") or ""), + text=getattr(el, "text", None), + href=getattr(el, "href", None), + region=getattr(getattr(el, "layout", None), "region", None) or "unknown", + semantic_tags=_semantic_tags(el), + ) + for el in selected + ) + + ctx = PrunedSnapshotContext( + category=category, + url=str(getattr(snapshot, "url", "") or ""), + nodes=nodes, + prompt_block="", + relaxation_level=relaxation_level, + raw_element_count=raw_count, + pruned_element_count=len(nodes), + ) + + return PrunedSnapshotContext( + category=ctx.category, + url=ctx.url, + nodes=ctx.nodes, + prompt_block=serialize_pruned_snapshot(ctx), + relaxation_level=ctx.relaxation_level, + raw_element_count=ctx.raw_element_count, + pruned_element_count=ctx.pruned_element_count, + ) + + +def prune_with_recovery( + snapshot: Any, + *, + goal: str, + category: PruningTaskCategory, + max_relaxation: int = 3, + verbose: bool = False, +) -> PrunedSnapshotContext: + """ + Prune with automatic recovery via relaxation if over-pruning is detected. + + This function progressively relaxes the pruning policy if the initial + pruning leaves too few elements. + + Args: + snapshot: The snapshot to prune + goal: The task goal + category: The detected task category + max_relaxation: Maximum relaxation level to try (default 3) + verbose: Print relaxation info + + Returns: + PrunedSnapshotContext with the best pruning result + """ + for level in range(max_relaxation + 1): + ctx = prune_snapshot_for_task( + snapshot, + goal=goal, + category=category, + relaxation_level=level, + ) + + if verbose and level > 0: + print( + f" [PRUNING] Relaxation level {level}: " + f"{ctx.raw_element_count} -> {ctx.pruned_element_count} elements", + flush=True, + ) + + # If we have enough elements, stop relaxing + if not ctx.is_sparse: + return ctx + + # If we're at max relaxation, return whatever we have + if level == max_relaxation: + if verbose: + print( + f" [PRUNING] Max relaxation reached, " + f"returning {ctx.pruned_element_count} elements", + flush=True, + ) + return ctx + + # Shouldn't reach here, but return last result + return ctx diff --git a/predicate/pruning/serializer.py b/predicate/pruning/serializer.py new file mode 100644 index 0000000..a739ffc --- /dev/null +++ b/predicate/pruning/serializer.py @@ -0,0 +1,27 @@ +""" +Serializer for pruned snapshot contexts. +""" + +from __future__ import annotations + +from .types import PrunedSnapshotContext + + +def serialize_pruned_snapshot(ctx: PrunedSnapshotContext) -> str: + """Serialize a pruned snapshot context into a compact prompt block.""" + + lines = [ + f"Category: {ctx.category.value}", + f"URL: {ctx.url}", + "Nodes:", + ] + for node in ctx.nodes: + line = f'[{node.id}] {node.role} text="{node.text or ""}"' + if node.semantic_tags: + line += f" tags={','.join(node.semantic_tags)}" + if node.region: + line += f" region={node.region}" + if node.href: + line += f" href={node.href}" + lines.append(line) + return "\n".join(lines) diff --git a/predicate/pruning/types.py b/predicate/pruning/types.py new file mode 100644 index 0000000..01b4229 --- /dev/null +++ b/predicate/pruning/types.py @@ -0,0 +1,63 @@ +""" +Core types for category-specific snapshot pruning. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Literal + + +class PruningTaskCategory(str, Enum): + """Task categories used by the pruning pipeline.""" + + SHOPPING = "shopping" + FORM_FILLING = "form_filling" + SEARCH = "search" + EXTRACTION = "extraction" + NAVIGATION = "navigation" + AUTH = "auth" + CHECKOUT = "checkout" + VERIFICATION = "verification" + GENERIC = "generic" + + +@dataclass(frozen=True) +class CategoryDetectionResult: + """Result of pruning category detection.""" + + category: PruningTaskCategory + confidence: float + source: Literal["rule", "llm"] = "rule" + + +@dataclass(frozen=True) +class SkeletonDomNode: + """Compact node retained after deterministic pruning.""" + + id: int + role: str + text: str | None = None + href: str | None = None + region: str = "unknown" + semantic_tags: tuple[str, ...] = () + ordinal: int | None = None + + +@dataclass(frozen=True) +class PrunedSnapshotContext: + """Pruned snapshot plus compact prompt block.""" + + category: PruningTaskCategory + url: str + nodes: tuple[SkeletonDomNode, ...] + prompt_block: str + relaxation_level: int = 0 + raw_element_count: int = 0 + pruned_element_count: int = 0 + + @property + def is_sparse(self) -> bool: + """Check if pruning left too few elements (potential over-pruning).""" + return len(self.nodes) < 5 diff --git a/predicate/read.py b/predicate/read.py index 376a8d1..4e676d3 100644 --- a/predicate/read.py +++ b/predicate/read.py @@ -493,14 +493,14 @@ def extract( raw = response.content.strip() if schema is None: - return ExtractResult(ok=True, data={"text": raw}, raw=raw) + return ExtractResult(ok=True, data={"text": raw}, raw=raw, llm_response=response) try: payload = _extract_json_payload(raw) validated = schema.model_validate(payload) - return ExtractResult(ok=True, data=validated, raw=raw) + return ExtractResult(ok=True, data=validated, raw=raw, llm_response=response) except (json.JSONDecodeError, ValidationError) as exc: - return ExtractResult(ok=False, error=str(exc), raw=raw) + return ExtractResult(ok=False, error=str(exc), raw=raw, llm_response=response) async def extract_async( @@ -527,11 +527,11 @@ async def extract_async( raw = response.content.strip() if schema is None: - return ExtractResult(ok=True, data={"text": raw}, raw=raw) + return ExtractResult(ok=True, data={"text": raw}, raw=raw, llm_response=response) try: payload = _extract_json_payload(raw) validated = schema.model_validate(payload) - return ExtractResult(ok=True, data=validated, raw=raw) + return ExtractResult(ok=True, data=validated, raw=raw, llm_response=response) except (json.JSONDecodeError, ValidationError) as exc: - return ExtractResult(ok=False, error=str(exc), raw=raw) + return ExtractResult(ok=False, error=str(exc), raw=raw, llm_response=response) diff --git a/tests/test_automation_task.py b/tests/test_automation_task.py index 4502361..49a5ed4 100644 --- a/tests/test_automation_task.py +++ b/tests/test_automation_task.py @@ -17,6 +17,7 @@ COMMON_HINTS, get_common_hint, ) +from predicate.pruning import PruningTaskCategory class TestAutomationTask: @@ -145,6 +146,30 @@ class MockWebBenchTask: assert automation_task.category == TaskCategory.EXTRACTION assert automation_task.extraction_spec is not None + def test_pruning_category_hint_maps_transaction_task_to_shopping(self): + """Transaction tasks with shopping intent should map to shopping pruning.""" + task = AutomationTask( + task_id="test-007", + starting_url="https://shop.com", + task="Find a jacket and add it to cart", + category=TaskCategory.TRANSACTION, + domain_hints=("ecommerce",), + ) + + assert task.pruning_category_hint() == PruningTaskCategory.SHOPPING + + def test_pruning_category_hint_maps_form_fill_task(self): + """Form-fill tasks should map to form-filling pruning.""" + task = AutomationTask( + task_id="test-008", + starting_url="https://example.com/contact", + task="Fill out the contact form", + category=TaskCategory.FORM_FILL, + domain_hints=("forms",), + ) + + assert task.pruning_category_hint() == PruningTaskCategory.FORM_FILLING + class TestHeuristicHint: """Tests for HeuristicHint model.""" diff --git a/tests/unit/test_category_classifier.py b/tests/unit/test_category_classifier.py new file mode 100644 index 0000000..d5f80ee --- /dev/null +++ b/tests/unit/test_category_classifier.py @@ -0,0 +1,62 @@ +""" +Unit tests for pruning category classification. +""" + +from predicate.pruning import ( + CategoryDetectionResult, + PruningTaskCategory, + classify_task_category, +) + + +class TestCategoryClassifier: + """Tests for rule-based pruning category detection.""" + + def test_classify_task_category_returns_shopping_for_add_to_cart(self) -> None: + """Shopping tasks should be classified without an LLM call.""" + result = classify_task_category( + task_text="Search for a hat and add it to cart", + current_url="https://www.amazon.com", + domain_hints=("ecommerce",), + ) + + assert isinstance(result, CategoryDetectionResult) + assert result.category == PruningTaskCategory.SHOPPING + assert result.source == "rule" + assert result.confidence >= 0.80 + + def test_classify_task_category_returns_form_filling_for_submit_form(self) -> None: + """Form tasks should be identified from task text.""" + result = classify_task_category( + task_text="Fill out the contact form and submit it", + current_url="https://example.com/contact", + domain_hints=("forms",), + ) + + assert result.category == PruningTaskCategory.FORM_FILLING + assert result.source == "rule" + assert result.confidence >= 0.80 + + def test_classify_task_category_returns_search_for_search_task(self) -> None: + """Search tasks should map to the search pruning category.""" + result = classify_task_category( + task_text="Search for flights from Seattle to New York", + current_url="https://www.google.com/travel/flights", + domain_hints=("travel", "flights"), + ) + + assert result.category == PruningTaskCategory.SEARCH + assert result.source == "rule" + assert result.confidence >= 0.80 + + def test_classify_task_category_falls_back_to_generic_when_no_rule_matches(self) -> None: + """Unknown tasks should fall back to GENERIC without raising.""" + result = classify_task_category( + task_text="Observe the page and think about what to do next", + current_url="https://example.com", + domain_hints=(), + ) + + assert result.category == PruningTaskCategory.GENERIC + assert result.source == "rule" + assert result.confidence == 0.0 diff --git a/tests/unit/test_planner_executor_agent.py b/tests/unit/test_planner_executor_agent.py index 33d31d1..7aee77f 100644 --- a/tests/unit/test_planner_executor_agent.py +++ b/tests/unit/test_planner_executor_agent.py @@ -28,11 +28,65 @@ RecoveryNavigationConfig, SnapshotEscalationConfig, build_executor_prompt, + build_planner_prompt, normalize_plan, validate_plan_smoothness, ) +# --------------------------------------------------------------------------- +# Test build_planner_prompt with page_context +# --------------------------------------------------------------------------- + + +class TestBuildPlannerPromptPageContext: + """Tests for build_planner_prompt with page_context parameter.""" + + def test_page_context_not_included_when_none(self) -> None: + sys_prompt, user_prompt = build_planner_prompt( + task="Buy a laptop", + start_url="https://example.com", + page_context=None, + ) + assert "Current Page Content" not in user_prompt + assert "markdown" not in user_prompt.lower() + + def test_page_context_included_when_provided(self) -> None: + markdown_content = "# Welcome to Example Store\n\n- Laptops\n- Phones\n- Tablets" + sys_prompt, user_prompt = build_planner_prompt( + task="Buy a laptop", + start_url="https://example.com", + page_context=markdown_content, + ) + assert "Current Page Content:" in user_prompt + assert "markdown representation" in user_prompt + assert "may be truncated" in user_prompt + assert "# Welcome to Example Store" in user_prompt + assert "Laptops" in user_prompt + + def test_page_context_helps_with_task_understanding(self) -> None: + # Page context should help planner understand what's on the page + markdown_content = """ +# Search Results for "gaming laptop" + +## Products +- ASUS ROG Gaming Laptop - $1299 +- MSI Raider - $1499 +- Alienware M15 - $1799 + +## Filters +- Price Range +- Brand +""" + sys_prompt, user_prompt = build_planner_prompt( + task="Add the ASUS gaming laptop to cart", + start_url="https://store.example.com/search?q=gaming+laptop", + page_context=markdown_content, + ) + assert "ASUS ROG Gaming Laptop" in user_prompt + assert "Search Results" in user_prompt + + # --------------------------------------------------------------------------- # Test build_executor_prompt # --------------------------------------------------------------------------- @@ -75,6 +129,7 @@ def test_includes_input_text_when_provided(self) -> None: intent=None, compact_context="167|searchbox|Search|100|1|0|-|0|", input_text="Logitech mouse", + action_type="TYPE_AND_SUBMIT", ) assert 'Text to type: "Logitech mouse"' in user_prompt @@ -95,10 +150,63 @@ def test_includes_both_intent_and_input(self) -> None: intent="search_box", compact_context="100|searchbox|Search|100|1|0|-|0|", input_text="laptop", + action_type="TYPE_AND_SUBMIT", ) assert "Intent: search_box" in user_prompt assert 'Text to type: "laptop"' in user_prompt + def test_text_matching_prompt_for_matching_intent(self) -> None: + """Should include text matching hints when intent mentions matching.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Click category for tablecloth", + intent="Click element with text matching tablecloth", + compact_context="100|link|Tablecloths|100|1|0|-|0|\n101|link|Kitchen|100|1|0|-|0|", + action_type="CLICK", + ) + # Should have matching instructions and allow NONE response + assert "matching" in sys_prompt.lower() + assert "NONE" in sys_prompt + assert "Intent: Click element with text matching tablecloth" in user_prompt + + def test_search_icon_prompt_for_search_intent(self) -> None: + """Should include search icon hints for search-related intents.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Open search", + intent="Click search icon", + compact_context="100|link|Search|100|1|0|-|0|\n101|button||100|1|0|-|0|", + ) + assert "SEARCH ICON HINTS" in sys_prompt + + def test_type_prompt_warns_about_email_fields(self) -> None: + """Should warn about email/newsletter fields when typing.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Search for vinyl tablecloth", + intent="Type in search box", + compact_context="100|textbox|Search|100|1|0|-|0|\n101|textbox|Your email address|100|1|0|-|0|", + input_text="vinyl tablecloth", + action_type="TYPE_AND_SUBMIT", + ) + assert "email" in sys_prompt.lower() + assert "newsletter" in sys_prompt.lower() + assert "NONE" in sys_prompt # Should mention NONE as fallback + + def test_click_with_target_text(self) -> None: + """CLICK action with target text should use text matching prompt.""" + sys_prompt, user_prompt = build_executor_prompt( + goal="Click vinyl tablecloth product", + intent="product link", + compact_context="100|link|Kitchen Goods|100|1|0|-|0|\n101|link|Vinyl Tablecloth|100|1|0|-|0|", + input_text="Vinyl Tablecloth", + action_type="CLICK", + ) + # Should have target matching instructions, not "text to type" + assert "Target to find:" in user_prompt + assert 'Vinyl Tablecloth' in user_prompt + assert "Text to type:" not in user_prompt + # Should allow NONE response + assert "NONE" in sys_prompt + assert "matching" in sys_prompt.lower() + # --------------------------------------------------------------------------- # Test normalize_plan @@ -126,6 +234,7 @@ def test_normalizes_action_aliases(self) -> None: ("INPUT", "TYPE_AND_SUBMIT"), ("TYPE_TEXT", "TYPE_AND_SUBMIT"), ("ENTER_TEXT", "TYPE_AND_SUBMIT"), + ("EXTRACT_TEXT", "EXTRACT"), ("GOTO", "NAVIGATE"), ("GO_TO", "NAVIGATE"), ("OPEN", "NAVIGATE"), @@ -667,6 +776,19 @@ def test_custom_recovery_config(self) -> None: ) assert config.recovery.max_recovery_attempts == 3 + def test_use_page_context_default_disabled(self) -> None: + config = PlannerExecutorConfig() + assert config.use_page_context is False + assert config.page_context_max_chars == 8000 + + def test_use_page_context_can_be_enabled(self) -> None: + config = PlannerExecutorConfig(use_page_context=True) + assert config.use_page_context is True + + def test_page_context_max_chars_customizable(self) -> None: + config = PlannerExecutorConfig(use_page_context=True, page_context_max_chars=4000) + assert config.page_context_max_chars == 4000 + # --------------------------------------------------------------------------- # Test PlanStep with optional_substeps @@ -1245,6 +1367,754 @@ async def test_no_scroll_without_intent_heuristics(self) -> None: assert runtime.scroll_count == 0 +class TestExtractActionSupport: + """Tests for EXTRACT step execution in PlannerExecutorAgent.""" + + @pytest.mark.asyncio + async def test_execute_step_supports_extract_action(self) -> None: + """EXTRACT steps should use the SDK extraction path instead of failing.""" + from datetime import datetime + from types import SimpleNamespace + from unittest.mock import AsyncMock, MagicMock, patch + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.backend = SimpleNamespace(page=object()) + runtime.get_url = AsyncMock(return_value="https://news.ycombinator.com/") + runtime.record_action = AsyncMock() + runtime.goto = AsyncMock() + runtime.stabilize = AsyncMock() + # Mock read_markdown for text extraction tasks (used when keywords like "title" are detected) + runtime.read_markdown = AsyncMock(return_value="# Story 1\n## Story 2\n## Story 3") + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(1, "link", "Story 1")], url="https://news.ycombinator.com/"), + compact_representation="Category: extraction\nURL: https://news.ycombinator.com/\nNodes:", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + # Use a goal with "title" keyword to trigger markdown extraction path + step = PlanStep(id=1, goal="Identify the top 5 story titles", action="EXTRACT", verify=[]) + + with patch.object(agent, "_snapshot_with_escalation", AsyncMock(return_value=ctx)): + outcome = await agent._execute_step(step, runtime, step_index=0) + + assert outcome.status == StepStatus.SUCCESS + assert outcome.verification_passed is True + assert outcome.action_taken == "EXTRACT" + # Verify read_markdown was called for text extraction task + runtime.read_markdown.assert_awaited_once() + + @pytest.mark.asyncio + async def test_extract_step_succeeds_even_if_planner_verify_is_brittle(self) -> None: + """EXTRACT success should be the primary success signal for extraction steps.""" + from datetime import datetime + from types import SimpleNamespace + from unittest.mock import AsyncMock, MagicMock, patch + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.backend = SimpleNamespace(page=object()) + runtime.get_url = AsyncMock(return_value="https://news.ycombinator.com/") + runtime.record_action = AsyncMock() + runtime.stabilize = AsyncMock() + # Mock read_markdown for text extraction tasks + runtime.read_markdown = AsyncMock(return_value="# Story 1\n## Story 2\n## Story 3") + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(1, "link", "Story 1")], url="https://news.ycombinator.com/"), + compact_representation="Category: extraction\nURL: https://news.ycombinator.com/\nNodes:", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + # Goal contains "title" keyword, which triggers markdown extraction + step = PlanStep( + id=1, + goal="Identify the top 5 story titles", + action="EXTRACT", + verify=[PredicateSpec(predicate="exists", args=[".storylink"])], + ) + + with patch.object(agent, "_snapshot_with_escalation", AsyncMock(return_value=ctx)): + with patch.object(agent, "_verify_step", AsyncMock(return_value=False)): + outcome = await agent._execute_step(step, runtime, step_index=0) + + assert outcome.status == StepStatus.SUCCESS + assert outcome.verification_passed is True + + +class TestMarkdownExtraction: + """Tests for markdown-based text extraction optimization.""" + + @pytest.mark.asyncio + async def test_text_extraction_uses_read_markdown_then_llm(self) -> None: + """EXTRACT with text extraction keywords should use read_markdown() + executor LLM.""" + from datetime import datetime + from types import SimpleNamespace + from unittest.mock import AsyncMock, MagicMock, patch + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + from predicate.llm_provider import LLMResponse + + # Create executor that returns extracted text + executor = MagicMock() + executor.generate = MagicMock( + return_value=LLMResponse( + content="Story Title Two", + prompt_tokens=100, + completion_tokens=10, + total_tokens=110, + model_name="test-model", + ) + ) + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=executor, + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.backend = SimpleNamespace(page=object()) + runtime.get_url = AsyncMock(return_value="https://news.ycombinator.com/") + runtime.record_action = AsyncMock() + runtime.stabilize = AsyncMock() + # Mock read_markdown to return page content + markdown_content = "# Hacker News\n\n1. Story Title One\n2. Story Title Two\n3. Story Title Three" + runtime.read_markdown = AsyncMock(return_value=markdown_content) + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(1, "link", "Story 1")], url="https://news.ycombinator.com/"), + compact_representation="Category: extraction\nURL: https://news.ycombinator.com/\nNodes:", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + # Goal contains "extract" and "title" keywords - should trigger markdown extraction + step = PlanStep(id=1, goal="Extract the title of the second post", action="EXTRACT", verify=[]) + + with patch.object(agent, "_snapshot_with_escalation", AsyncMock(return_value=ctx)): + outcome = await agent._execute_step(step, runtime, step_index=0) + + assert outcome.status == StepStatus.SUCCESS + assert outcome.verification_passed is True + assert outcome.action_taken == "EXTRACT" + # Verify read_markdown was called first + runtime.read_markdown.assert_awaited_once_with(max_chars=8000) + # Verify executor LLM was called to extract specific text from markdown + executor.generate.assert_called_once() + # Check that the markdown content was passed to the executor + call_args = executor.generate.call_args + assert markdown_content in call_args[0][1] # markdown in user prompt + assert "Extract the title of the second post" in call_args[0][1] # query in prompt + + @pytest.mark.asyncio + async def test_complex_extraction_uses_llm(self) -> None: + """EXTRACT without text extraction keywords should use LLM-based extraction.""" + from datetime import datetime + from types import SimpleNamespace + from unittest.mock import AsyncMock, MagicMock, patch + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.backend = SimpleNamespace(page=object()) + runtime.get_url = AsyncMock(return_value="https://news.ycombinator.com/") + runtime.record_action = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.read_markdown = AsyncMock(return_value="# Page Content") + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(1, "link", "Story 1")], url="https://news.ycombinator.com/"), + compact_representation="Category: extraction\nURL: https://news.ycombinator.com/\nNodes:", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + # Goal without extraction keywords - should use LLM extraction + step = PlanStep(id=1, goal="Analyze sentiment patterns from visible HTML", action="EXTRACT", verify=[]) + + extract_result = SimpleNamespace(ok=True, data={"sentiment": "positive"}, raw='{"sentiment": "positive"}') + + with patch.object(agent, "_snapshot_with_escalation", AsyncMock(return_value=ctx)): + with patch("predicate.read.extract_async", new=AsyncMock(return_value=extract_result)) as mock_extract: + outcome = await agent._execute_step(step, runtime, step_index=0) + + assert outcome.status == StepStatus.SUCCESS + # Verify LLM-based extraction was used, not read_markdown + mock_extract.assert_awaited_once() + runtime.read_markdown.assert_not_awaited() + + +class TestTextExtractionKeywords: + """Tests for the _is_text_extraction_task helper function.""" + + def test_extraction_keywords_match(self) -> None: + """Keywords like extract, read, parse should trigger markdown extraction.""" + from predicate.agents.planner_executor_agent import _is_text_extraction_task + + assert _is_text_extraction_task("Extract the title from the page") is True + assert _is_text_extraction_task("Read the article content") is True + assert _is_text_extraction_task("Parse the product names") is True + assert _is_text_extraction_task("Get the price of this item") is True + assert _is_text_extraction_task("What is the headline?") is True + + def test_question_keywords_match(self) -> None: + """Question words like 'what is' should trigger markdown extraction.""" + from predicate.agents.planner_executor_agent import _is_text_extraction_task + + assert _is_text_extraction_task("What is the title of the page?") is True + assert _is_text_extraction_task("What are the prices listed?") is True + assert _is_text_extraction_task("Show me the description") is True + assert _is_text_extraction_task("Tell me the author name") is True + + def test_plural_keywords_match(self) -> None: + """Plural forms of keywords should also match.""" + from predicate.agents.planner_executor_agent import _is_text_extraction_task + + assert _is_text_extraction_task("Get all the titles") is True + assert _is_text_extraction_task("List the prices") is True + assert _is_text_extraction_task("Find the headlines") is True + assert _is_text_extraction_task("Extract the items") is True + + def test_non_extraction_tasks_dont_match(self) -> None: + """Tasks without extraction keywords should not match.""" + from predicate.agents.planner_executor_agent import _is_text_extraction_task + + assert _is_text_extraction_task("Analyze sentiment patterns") is False + assert _is_text_extraction_task("Identify HTML structure") is False + assert _is_text_extraction_task("Determine page complexity") is False + + def test_word_boundary_prevents_false_positives(self) -> None: + """Keywords embedded in other words should not match.""" + from predicate.agents.planner_executor_agent import _is_text_extraction_task + + # "time" is a keyword but should not match inside "sentiment" + assert _is_text_extraction_task("Analyze sentiment") is False + # "get" is a keyword but should not match inside "together" + assert _is_text_extraction_task("Put together a report") is False + + +class TestExtractTokenAccounting: + """Tests for token accounting on EXTRACT steps.""" + + @pytest.mark.asyncio + async def test_extract_step_records_token_usage(self) -> None: + """EXTRACT steps should contribute to agent token stats (for LLM-based extraction).""" + from datetime import datetime + from types import SimpleNamespace + from unittest.mock import AsyncMock, MagicMock, patch + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext + from predicate.llm_provider import LLMResponse + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MockLLMProvider(), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.backend = SimpleNamespace(page=object()) + runtime.get_url = AsyncMock(return_value="https://news.ycombinator.com/") + runtime.record_action = AsyncMock() + runtime.stabilize = AsyncMock() + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(1, "link", "Story 1")], url="https://news.ycombinator.com/"), + compact_representation="Category: extraction\nURL: https://news.ycombinator.com/\nNodes:", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + # Use a goal without common extraction keywords to trigger LLM-based extraction + # (Avoids keywords like extract, read, parse, title, content, text, etc.) + step = PlanStep(id=1, goal="Identify sentiment patterns from visible HTML", action="EXTRACT", verify=[]) + + extract_result = SimpleNamespace( + ok=True, + data={"text": '["Story 1"]'}, + raw='["Story 1"]', + llm_response=LLMResponse( + content='["Story 1"]', + prompt_tokens=111, + completion_tokens=22, + total_tokens=133, + model_name="gpt-4o-mini", + ), + ) + + with patch.object(agent, "_snapshot_with_escalation", AsyncMock(return_value=ctx)): + with patch("predicate.read.extract_async", new=AsyncMock(return_value=extract_result)): + await agent._execute_step(step, runtime, step_index=0) + + stats = agent.get_token_stats() + assert stats["total"]["calls"] == 1 + assert stats["total"]["total_tokens"] == 133 + assert stats["by_role"]["extract"]["calls"] == 1 + assert stats["by_role"]["extract"]["prompt_tokens"] == 111 + assert stats["by_role"]["extract"]["completion_tokens"] == 22 + assert stats["by_model"]["gpt-4o-mini"]["total_tokens"] == 133 + + +# --------------------------------------------------------------------------- +# Test TYPE_AND_SUBMIT typing delay +# --------------------------------------------------------------------------- + + +class TestTypeAndSubmitTypingDelay: + """Tests for humanized typing in TYPE_AND_SUBMIT execution.""" + + @pytest.mark.asyncio + async def test_execute_step_passes_typing_delay_to_runtime(self) -> None: + """TYPE_AND_SUBMIT should forward configured keystroke delay to runtime.type.""" + from datetime import datetime + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext + from predicate.llm_provider import LLMResponse + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock( + generate=MagicMock( + return_value=LLMResponse(content='TYPE(1, "Thinkpad laptop")', model_name="stub") + ) + ), + config=PlannerExecutorConfig(type_delay_ms=17.0), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.press = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.get_url = AsyncMock(return_value="https://www.amazon.com/") + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(1, "searchbox", "Search Amazon")], url="https://www.amazon.com/"), + compact_representation="1|searchbox|Search Amazon|100|1|0|-|0|", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + step = PlanStep(id=1, goal="Search for Thinkpad laptop", action="TYPE_AND_SUBMIT", input="Thinkpad laptop", verify=[]) + + agent._snapshot_with_escalation = AsyncMock(return_value=ctx) # type: ignore[method-assign] + + await agent._execute_step(step, runtime, step_index=0) + + runtime.type.assert_awaited_once_with(1, "Thinkpad laptop", delay_ms=17.0) + runtime.press.assert_awaited_once_with("Enter") + + @pytest.mark.asyncio + async def test_execute_search_type_and_submit_uses_enter_by_default(self) -> None: + """Search submissions should prefer Enter key over clicking submit button (more reliable). + + Many search boxes (e.g., lifeisgood.com) don't work correctly when clicking + the submit button - they may navigate to a category page instead of searching. + Pressing Enter is more reliable for search inputs, matching WebBench behavior. + """ + from datetime import datetime + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext + from predicate.llm_provider import LLMResponse + + submit_button = MockElement(110, "button", "Submit Search") + submit_button.aria_label = "Submit Search" + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock( + generate=MagicMock( + return_value=LLMResponse(content='TYPE(105, "Rainbow trout trucker hat")', model_name="stub") + ) + ), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.click = AsyncMock() + runtime.press = AsyncMock() + runtime.goto = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.record_action = AsyncMock() + runtime.get_url = AsyncMock(return_value="https://lifeisgood.com/search?q=Rainbow+trout+trucker+hat") + + ctx = SnapshotContext( + snapshot=MockSnapshot( + [ + MockElement(105, "searchbox", "Search for cat tees"), + submit_button, + ], + url="https://lifeisgood.com/", + ), + compact_representation="\n".join( + [ + "105|searchbox|Search for cat tees|1177|0||0|||0|", + "110|button|Submit Search|308|0||0|||0|", + ] + ), + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + step = PlanStep(id=1, goal="Search for the product", action="TYPE_AND_SUBMIT", input="Rainbow trout trucker hat", verify=[]) + + agent._snapshot_with_escalation = AsyncMock(return_value=ctx) # type: ignore[method-assign] + + await agent._execute_step(step, runtime, step_index=0) + + runtime.type.assert_awaited_once() + # Now we expect Enter to be pressed (not click), matching WebBench behavior + runtime.press.assert_awaited_once_with("Enter") + runtime.click.assert_not_awaited() + + @pytest.mark.asyncio + async def test_search_type_and_submit_does_not_accept_unrelated_url_change(self) -> None: + """Searchbox submissions should not treat arbitrary collection redirects as success.""" + from datetime import datetime + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + from predicate.llm_provider import LLMResponse + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock( + generate=MagicMock( + return_value=LLMResponse(content='TYPE(105, "Rainbow trout trucker hat")', model_name="stub") + ) + ), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.click = AsyncMock() + runtime.press = AsyncMock() + runtime.goto = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.record_action = AsyncMock() + url_reads = [ + "https://lifeisgood.com/", + "https://lifeisgood.com/collections/hats", + "https://lifeisgood.com/collections/hats", + ] + + async def _next_url() -> str: + if url_reads: + return url_reads.pop(0) + return "https://lifeisgood.com/collections/hats" + + runtime.get_url = AsyncMock(side_effect=_next_url) + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(105, "searchbox", "Search for cat tees")], url="https://lifeisgood.com/"), + compact_representation="105|searchbox|Search for cat tees|1177|0||0|||0|", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + step = PlanStep( + id=1, + goal="Search for the product", + action="TYPE_AND_SUBMIT", + input="Rainbow trout trucker hat", + verify=[PredicateSpec(predicate="url_contains", args=["search"])], + ) + + agent._snapshot_with_escalation = AsyncMock(return_value=ctx) # type: ignore[method-assign] + agent._verify_step = AsyncMock(return_value=False) # type: ignore[method-assign] + + outcome = await agent._execute_step(step, runtime, step_index=0) + + assert outcome.status == StepStatus.FAILED + assert outcome.verification_passed is False + assert runtime.goto.await_count == 0 + + @pytest.mark.asyncio + async def test_search_type_and_submit_retries_with_alternate_submit_method(self) -> None: + """Searchbox submissions should retry once with the alternate submit method before failing. + + Since Enter is now the default, when Enter fails, retry should try clicking the submit button. + """ + from datetime import datetime + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + from predicate.llm_provider import LLMResponse + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock( + generate=MagicMock( + return_value=LLMResponse(content='TYPE(105, "Rainbow trout trucker hat")', model_name="stub") + ) + ), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.click = AsyncMock() + runtime.press = AsyncMock() + runtime.goto = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.record_action = AsyncMock() + submit_button = MockElement(110, "button", "Submit Search") + submit_button.aria_label = "Submit Search" + url_reads = ["https://lifeisgood.com/"] + + async def _next_url() -> str: + if url_reads: + return url_reads.pop(0) + if runtime.press.await_count > 0 or runtime.click.await_count > 0: + return "https://lifeisgood.com/collections/hats" + return "https://lifeisgood.com/collections/hats" + + runtime.get_url = AsyncMock(side_effect=_next_url) + + ctx = SnapshotContext( + snapshot=MockSnapshot( + [MockElement(105, "searchbox", "Search for cat tees"), submit_button], + url="https://lifeisgood.com/", + ), + compact_representation="\n".join( + [ + "105|searchbox|Search for cat tees|1177|0||0|||0|", + "110|button|Submit Search|308|0||0|||0|", + ] + ), + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + step = PlanStep( + id=1, + goal="Search for the product", + action="TYPE_AND_SUBMIT", + input="Rainbow trout trucker hat", + verify=[PredicateSpec(predicate="url_contains", args=["search"])], + ) + + agent._snapshot_with_escalation = AsyncMock(return_value=ctx) # type: ignore[method-assign] + agent._verify_step = AsyncMock(side_effect=[False, False]) # type: ignore[method-assign] + + outcome = await agent._execute_step(step, runtime, step_index=0) + + # Now: first attempt uses Enter, retry uses click + runtime.press.assert_any_await("Enter") + assert runtime.click.await_count >= 1 # Retry should try click + runtime.goto.assert_not_awaited() + assert outcome.status == StepStatus.FAILED + assert outcome.verification_passed is False + + @pytest.mark.asyncio + async def test_search_type_and_submit_retry_can_switch_from_enter_to_click(self) -> None: + """If Enter fails and a submit button becomes available, retry should switch to click. + + With the new behavior (Enter first), this test validates that when Enter fails, + the retry logic can find and use a submit button as a fallback. + """ + from datetime import datetime + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + from predicate.llm_provider import LLMResponse + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock( + generate=MagicMock( + return_value=LLMResponse(content='TYPE(105, "Rainbow trout trucker hat")', model_name="stub") + ) + ), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.click = AsyncMock() + runtime.press = AsyncMock() + runtime.goto = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.record_action = AsyncMock() + url_reads = [ + "https://lifeisgood.com/", + "https://lifeisgood.com/collections/hats", + "https://lifeisgood.com/collections/hats", + ] + + async def _next_url() -> str: + if url_reads: + return url_reads.pop(0) + return "https://lifeisgood.com/collections/hats" + + runtime.get_url = AsyncMock(side_effect=_next_url) + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(105, "searchbox", "Search for cat tees")], url="https://lifeisgood.com/"), + compact_representation="105|searchbox|Search for cat tees|1177|0||0|||0|", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + step = PlanStep( + id=1, + goal="Search for the product", + action="TYPE_AND_SUBMIT", + input="Rainbow trout trucker hat", + verify=[PredicateSpec(predicate="url_contains", args=["search"])], + ) + + agent._snapshot_with_escalation = AsyncMock(return_value=ctx) # type: ignore[method-assign] + agent._verify_step = AsyncMock(side_effect=[False, False]) # type: ignore[method-assign] + # Submit button not found initially, then found on retry + agent._find_submit_button_for_type_and_submit = MagicMock(side_effect=[None, 110, 110]) # type: ignore[method-assign] + + outcome = await agent._execute_step(step, runtime, step_index=0) + + # First attempt: Enter (default) + runtime.press.assert_any_await("Enter") + # Retry: click the submit button that became available + runtime.click.assert_any_await(110) + runtime.goto.assert_not_awaited() + assert outcome.status == StepStatus.FAILED + assert outcome.verification_passed is False + + @pytest.mark.asyncio + async def test_search_type_and_submit_submits_without_retyping_if_value_already_matches(self) -> None: + """Search submissions should skip retyping when the focused input already contains the desired query.""" + from datetime import datetime + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SnapshotContext, StepStatus + from predicate.llm_provider import LLMResponse + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock( + generate=MagicMock( + return_value=LLMResponse(content='TYPE(105, "Rainbow trout trucker hat")', model_name="stub") + ) + ), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.click = AsyncMock() + runtime.press = AsyncMock() + runtime.goto = AsyncMock() + runtime.stabilize = AsyncMock() + runtime.record_action = AsyncMock() + runtime.get_url = AsyncMock(return_value="https://lifeisgood.com/search?q=Rainbow+trout+trucker+hat") + runtime.backend = MagicMock(page=MagicMock()) + + ctx = SnapshotContext( + snapshot=MockSnapshot([MockElement(105, "searchbox", "Search for cat tees")], url="https://lifeisgood.com/"), + compact_representation="105|searchbox|Search for cat tees|1177|0||0|||0|", + screenshot_base64=None, + captured_at=datetime.now(), + limit_used=60, + ) + step = PlanStep( + id=1, + goal="Search for the product", + action="TYPE_AND_SUBMIT", + input="Rainbow trout trucker hat", + verify=[PredicateSpec(predicate="url_contains", args=["search"])], + ) + + agent._snapshot_with_escalation = AsyncMock(return_value=ctx) # type: ignore[method-assign] + agent._verify_step = AsyncMock(return_value=True) # type: ignore[method-assign] + agent._submit_if_already_typed = AsyncMock(return_value=True) # type: ignore[attr-defined,method-assign] + + outcome = await agent._execute_step(step, runtime, step_index=0) + + agent._submit_if_already_typed.assert_awaited_once() # type: ignore[attr-defined] + runtime.type.assert_not_awaited() + assert outcome.status == StepStatus.SUCCESS + assert outcome.verification_passed is True + + @pytest.mark.asyncio + async def test_search_retry_uses_browser_clear_and_type_helper_when_available(self) -> None: + """Search retry should prefer the browser-backed clear/type helper over generic select-all fallback. + + With the new behavior where Enter is the default, first_submit_method="enter" reflects + that the initial submission used Enter key. + """ + from unittest.mock import AsyncMock, MagicMock + + from predicate.agents.planner_executor_agent import PlannerExecutorAgent, SearchSubmitTelemetry + + agent = PlannerExecutorAgent( + planner=MockLLMProvider(), + executor=MagicMock(), + config=PlannerExecutorConfig(), + ) + + runtime = MagicMock() + runtime.type = AsyncMock() + runtime.click = AsyncMock() + runtime.press = AsyncMock() + runtime.backend = MagicMock(page=MagicMock()) + + # Now Enter is the default first method + telemetry = SearchSubmitTelemetry(first_submit_method="enter") + step = PlanStep( + id=1, + goal="Search for the product", + action="TYPE_AND_SUBMIT", + input="Rainbow trout trucker hat", + verify=[PredicateSpec(predicate="url_contains", args=["search"])], + ) + typed_element = MockElement(105, "searchbox", "Search for cat tees") + + agent._clear_and_type_search_input = AsyncMock(return_value=True) # type: ignore[attr-defined,method-assign] + agent._submit_type_and_submit = AsyncMock() # type: ignore[method-assign] + agent._verify_step = AsyncMock(return_value=False) # type: ignore[method-assign] + + await agent._retry_search_widget_submission( + runtime=runtime, + elements=[typed_element, MockElement(110, "button", "Submit Search")], + input_element_id=105, + step=step, + text="Rainbow trout trucker hat", + pre_url="https://lifeisgood.com/", + typed_element=typed_element, + telemetry=telemetry, + ) + + agent._clear_and_type_search_input.assert_awaited_once() # type: ignore[attr-defined] + runtime.type.assert_not_awaited() + # Submit is handled by _submit_type_and_submit, which is mocked + # No direct press calls in this test path + # --------------------------------------------------------------------------- # Test AuthBoundaryConfig # --------------------------------------------------------------------------- diff --git a/tests/unit/test_planner_executor_checkout_continuation.py b/tests/unit/test_planner_executor_checkout_continuation.py index 34adfbf..b8d7aa4 100644 --- a/tests/unit/test_planner_executor_checkout_continuation.py +++ b/tests/unit/test_planner_executor_checkout_continuation.py @@ -261,6 +261,25 @@ def test_does_not_skip_for_non_checkout_elements(self) -> None: should_skip = any(pattern in text.lower() for pattern in checkout_patterns) assert should_skip is False, f"Incorrectly matched: {text}" + def test_does_not_treat_global_nav_cart_link_as_drawer_checkout(self) -> None: + """Top-nav Amazon cart links should not suppress drawer dismissal.""" + from types import SimpleNamespace + + mock_planner = MagicMock() + mock_executor = MagicMock() + agent = PlannerExecutorAgent(planner=mock_planner, executor=mock_executor) + + nav_cart = SimpleNamespace( + role="link", + text="Cart", + aria_label="0 items in cart", + href="https://www.amazon.com/gp/cart/view.html?ref_=nav_cart", + doc_y=24.0, + layout=SimpleNamespace(region="header"), + ) + + assert agent._is_global_nav_cart_link(nav_cart) is True + # --------------------------------------------------------------------------- # Test Build Executor Prompt Improvements @@ -279,6 +298,7 @@ def test_type_action_specifies_input_element(self) -> None: intent=None, compact_context="100|textbox|Search|500|1|0|-|0|\n200|button|Submit|300|1|0|-|0|", input_text="laptop", + action_type="TYPE_AND_SUBMIT", # Must specify action_type to get TYPE prompt ) # System prompt should mention INPUT element explicitly @@ -357,6 +377,36 @@ def test_handles_multiple_think_tags(self) -> None: assert action_type == "CLICK" assert args == [100] + def test_parses_none_response(self) -> None: + """Should parse NONE response when executor can't find suitable element.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + from predicate.llm_provider import LLMProvider + + mock_planner = MagicMock(spec=LLMProvider) + mock_executor = MagicMock(spec=LLMProvider) + agent = PlannerExecutorAgent(planner=mock_planner, executor=mock_executor) + + # Test NONE response + text = "NONE" + action_type, args = agent._parse_action(text) + assert action_type == "NONE" + assert args == [] + + def test_parses_none_with_explanation(self) -> None: + """Should parse NONE even with additional text.""" + from predicate.agents.planner_executor_agent import PlannerExecutorAgent + from predicate.llm_provider import LLMProvider + + mock_planner = MagicMock(spec=LLMProvider) + mock_executor = MagicMock(spec=LLMProvider) + agent = PlannerExecutorAgent(planner=mock_planner, executor=mock_executor) + + # Test NONE with trailing text + text = "NONE - no search box found" + action_type, args = agent._parse_action(text) + assert action_type == "NONE" + assert args == [] + # --------------------------------------------------------------------------- # Test Overlay Dismiss Intent Detection diff --git a/tests/unit/test_planner_executor_pruning.py b/tests/unit/test_planner_executor_pruning.py new file mode 100644 index 0000000..281c0e8 --- /dev/null +++ b/tests/unit/test_planner_executor_pruning.py @@ -0,0 +1,218 @@ +""" +Unit tests for PlannerExecutorAgent pruning integration. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from predicate.agents.automation_task import AutomationTask, TaskCategory +from predicate.agents.planner_executor_agent import ( + PlannerExecutorAgent, + PlannerExecutorConfig, + SnapshotEscalationConfig, +) +from predicate.models import BBox, Element, Snapshot, VisualCues +from predicate.pruning import ( + PruningTaskCategory, + prune_snapshot_for_task, + prune_with_recovery, +) + + +def make_element( + *, + id: int, + role: str, + text: str, + importance: int, + href: str | None = None, + in_dominant_group: bool | None = None, +) -> Element: + return Element( + id=id, + role=role, + text=text, + importance=importance, + href=href, + in_dominant_group=in_dominant_group, + bbox=BBox(x=0, y=0, width=100, height=20), + visual_cues=VisualCues(is_primary=False, is_clickable=role in {"button", "link"}), + ) + + +def make_snapshot(elements: list[Element]) -> Snapshot: + return Snapshot(status="success", url="https://shop.example.com", elements=elements) + + +class TestPlannerExecutorPruningIntegration: + """Tests for pruning-aware context formatting and escalation.""" + + def test_format_context_uses_pruned_snapshot_when_task_category_is_known(self) -> None: + """Pruned context should include category and prioritize relevant elements.""" + agent = PlannerExecutorAgent( + planner=MagicMock(), + executor=MagicMock(), + ) + agent._current_task = AutomationTask( + task_id="shopping-1", + starting_url="https://shop.example.com", + task="add the product to cart", + category=TaskCategory.TRANSACTION, + domain_hints=("ecommerce",), + ) + snap = make_snapshot( + [ + make_element(id=1, role="button", text="Add to Cart", importance=950, in_dominant_group=True), + make_element(id=2, role="link", text="Privacy Policy", importance=50, href="/privacy"), + ] + ) + + result = agent._format_context(snap, "add the product to cart") + + # Should include category and the Add to Cart button + assert "Category: shopping" in result + assert "[1] button" in result + assert "Add to Cart" in result + + def test_format_context_excludes_blocked_elements_at_strict_level(self) -> None: + """At relaxation level 0, common footer links should be blocked.""" + snap = make_snapshot( + [ + make_element(id=1, role="button", text="Add to Cart", importance=950, in_dominant_group=True), + make_element(id=2, role="link", text="Privacy Policy", importance=50, href="/privacy"), + make_element(id=3, role="link", text="Terms of Service", importance=50, href="/terms"), + ] + ) + + # Test strict pruning (level 0) + ctx = prune_snapshot_for_task( + snap, + goal="add the product to cart", + category=PruningTaskCategory.SHOPPING, + relaxation_level=0, + ) + + # Privacy Policy and Terms should be blocked at level 0 + node_texts = [n.text for n in ctx.nodes] + assert "Add to Cart" in node_texts + assert "Privacy Policy" not in node_texts + assert "Terms of Service" not in node_texts + + +class TestPruningRecovery: + """Tests for over-pruning recovery via relaxation levels.""" + + def test_relaxation_increases_node_count(self) -> None: + """Higher relaxation levels should allow more elements.""" + snap = make_snapshot( + [ + make_element(id=1, role="button", text="Add to Cart", importance=950, in_dominant_group=True), + make_element(id=2, role="link", text="Privacy Policy", importance=50, href="/privacy"), + make_element(id=3, role="link", text="Terms of Service", importance=50, href="/terms"), + make_element(id=4, role="button", text="Close", importance=100), + ] + ) + + ctx_strict = prune_snapshot_for_task( + snap, + goal="add the product to cart", + category=PruningTaskCategory.SHOPPING, + relaxation_level=0, + ) + + ctx_relaxed = prune_snapshot_for_task( + snap, + goal="add the product to cart", + category=PruningTaskCategory.SHOPPING, + relaxation_level=2, + ) + + # Relaxed should have more elements + assert len(ctx_relaxed.nodes) >= len(ctx_strict.nodes) + + def test_prune_with_recovery_auto_relaxes(self) -> None: + """prune_with_recovery should auto-relax if initial pruning is sparse.""" + # Create a snapshot with very few matching elements at level 0 + snap = make_snapshot( + [ + make_element(id=1, role="heading", text="Welcome", importance=100), + make_element(id=2, role="paragraph", text="Some text", importance=50), + make_element(id=3, role="button", text="OK", importance=200), + ] + ) + + ctx = prune_with_recovery( + snap, + goal="find something", + category=PruningTaskCategory.SHOPPING, + max_relaxation=3, + verbose=False, + ) + + # Should have relaxed to include the button + assert ctx.relaxation_level > 0 or len(ctx.nodes) >= 1 + + def test_pruned_context_includes_metadata(self) -> None: + """PrunedSnapshotContext should include element count metadata.""" + snap = make_snapshot( + [ + make_element(id=1, role="button", text="Add to Cart", importance=950, in_dominant_group=True), + make_element(id=2, role="link", text="Product A", importance=800, href="/a", in_dominant_group=True), + make_element(id=3, role="link", text="Product B", importance=700, href="/b", in_dominant_group=True), + ] + ) + + ctx = prune_snapshot_for_task( + snap, + goal="add to cart", + category=PruningTaskCategory.SHOPPING, + ) + + assert ctx.raw_element_count == 3 + assert ctx.pruned_element_count == len(ctx.nodes) + assert ctx.relaxation_level == 0 + + +class TestCategorySpecificExecutorHints: + """Tests for category-specific hints in executor prompts.""" + + def test_shopping_category_hints(self) -> None: + """Shopping category should provide relevant hints.""" + from predicate.agents.planner_executor_agent import _get_category_executor_hints + + hints = _get_category_executor_hints("shopping") + assert "Add to Cart" in hints + assert "Buy Now" in hints + + def test_form_filling_category_hints(self) -> None: + """Form filling category should provide relevant hints.""" + from predicate.agents.planner_executor_agent import _get_category_executor_hints + + hints = _get_category_executor_hints("form_filling") + assert "input" in hints.lower() + assert "submit" in hints.lower() + + def test_search_category_hints(self) -> None: + """Search category should provide relevant hints.""" + from predicate.agents.planner_executor_agent import _get_category_executor_hints + + hints = _get_category_executor_hints("search") + assert "search" in hints.lower() + assert "result" in hints.lower() + + def test_unknown_category_returns_empty(self) -> None: + """Unknown categories should return empty hints.""" + from predicate.agents.planner_executor_agent import _get_category_executor_hints + + hints = _get_category_executor_hints("unknown_category") + assert hints == "" + + def test_none_category_returns_empty(self) -> None: + """None category should return empty hints.""" + from predicate.agents.planner_executor_agent import _get_category_executor_hints + + hints = _get_category_executor_hints(None) + assert hints == "" diff --git a/tests/unit/test_pruning_policies.py b/tests/unit/test_pruning_policies.py new file mode 100644 index 0000000..6f3b097 --- /dev/null +++ b/tests/unit/test_pruning_policies.py @@ -0,0 +1,164 @@ +""" +Unit tests for deterministic pruning policies and serializer output. +""" + +from predicate.models import BBox, Element, Snapshot, VisualCues +from predicate.pruning import PruningTaskCategory +from predicate.pruning.pruner import prune_snapshot_for_task +from predicate.pruning.serializer import serialize_pruned_snapshot + + +def make_element( + *, + id: int, + role: str, + text: str, + importance: int, + doc_y: float = 0.0, + in_dominant_group: bool | None = None, + href: str | None = None, + nearby_text: str | None = None, +) -> Element: + return Element( + id=id, + role=role, + text=text, + importance=importance, + bbox=BBox(x=0, y=doc_y, width=100, height=24), + visual_cues=VisualCues(is_primary=False, is_clickable=role in {"button", "link", "textbox", "searchbox"}), + doc_y=doc_y, + in_dominant_group=in_dominant_group, + href=href, + nearby_text=nearby_text, + ) + + +def make_snapshot(elements: list[Element]) -> Snapshot: + return Snapshot(status="success", url="https://example.com", elements=elements) + + +class TestPruningPolicies: + """Tests for category-specific pruning behavior.""" + + def test_shopping_policy_keeps_price_and_add_to_cart(self) -> None: + snap = make_snapshot( + [ + make_element( + id=1, + role="link", + text="Rainbow Trout Trucker", + importance=900, + doc_y=100, + in_dominant_group=True, + href="/product/hat", + ), + make_element( + id=2, + role="text", + text="$32.50", + importance=850, + doc_y=120, + in_dominant_group=True, + nearby_text="Price", + ), + make_element( + id=3, + role="button", + text="Add to Cart", + importance=950, + doc_y=140, + in_dominant_group=True, + ), + make_element( + id=4, + role="link", + text="Privacy Policy", + importance=100, + doc_y=900, + href="/privacy", + ), + ] + ) + + ctx = prune_snapshot_for_task( + snap, + goal="add the product to cart", + category=PruningTaskCategory.SHOPPING, + ) + + kept_ids = {node.id for node in ctx.nodes} + assert {1, 2, 3}.issubset(kept_ids) + assert 4 not in kept_ids + + def test_form_filling_policy_keeps_inputs_and_submit(self) -> None: + snap = make_snapshot( + [ + make_element(id=10, role="textbox", text="Email", importance=850, doc_y=100), + make_element(id=11, role="textbox", text="Message", importance=800, doc_y=140), + make_element(id=12, role="button", text="Submit", importance=900, doc_y=180), + make_element(id=13, role="link", text="Company Blog", importance=200, doc_y=800, href="/blog"), + ] + ) + + ctx = prune_snapshot_for_task( + snap, + goal="fill out the contact form and submit it", + category=PruningTaskCategory.FORM_FILLING, + ) + + kept_ids = {node.id for node in ctx.nodes} + assert {10, 11, 12}.issubset(kept_ids) + assert 13 not in kept_ids + + def test_search_policy_keeps_search_box_and_results(self) -> None: + snap = make_snapshot( + [ + make_element(id=20, role="searchbox", text="Search", importance=950, doc_y=50), + make_element( + id=21, + role="link", + text="Best Trail Shoes", + importance=900, + doc_y=180, + in_dominant_group=True, + href="/trail-shoes", + ), + make_element(id=22, role="text", text="Footer links", importance=50, doc_y=1200), + ] + ) + + ctx = prune_snapshot_for_task( + snap, + goal="search for trail shoes and open the best result", + category=PruningTaskCategory.SEARCH, + ) + + kept_ids = {node.id for node in ctx.nodes} + assert 20 in kept_ids + assert 21 in kept_ids + assert 22 not in kept_ids + + def test_serializer_outputs_compact_skeleton_dom(self) -> None: + snap = make_snapshot( + [ + make_element( + id=30, + role="button", + text="Checkout", + importance=900, + doc_y=200, + in_dominant_group=True, + ) + ] + ) + + ctx = prune_snapshot_for_task( + snap, + goal="go to checkout", + category=PruningTaskCategory.CHECKOUT, + ) + result = serialize_pruned_snapshot(ctx) + + assert "Category: checkout" in result + assert "URL: https://example.com" in result + assert "[30] button text=\"Checkout\"" in result