From 2cc62f5101364e06c4865484e337ea878e051d9b Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 11:12:55 -0500 Subject: [PATCH 01/29] refactor(plugins)!: convert Plugin from Protocol to ABC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BREAKING CHANGE: Plugin is now an abstract base class instead of a Protocol. Plugins must explicitly inherit from Plugin and implement the abstract name property and init_plugin method. - Convert Plugin from @runtime_checkable Protocol to ABC - Make name an abstract property - Make init_plugin an abstract method - Update all tests to use inheritance - Maintain support for both sync and async init_plugin - All tests pass (1985 passed) 🤖 Assisted by the code-assist SOP --- src/strands/plugins/plugin.py | 21 +++++--- tests/strands/plugins/test_plugins.py | 70 +++++++++++++-------------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index b6a8fd1d9..80707616a 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -1,19 +1,19 @@ -"""Plugin protocol for extending agent functionality. +"""Plugin base class for extending agent functionality. -This module defines the Plugin Protocol, which provides a composable way to +This module defines the Plugin base class, which provides a composable way to add behavior changes to agents through a standardized initialization pattern. """ +from abc import ABC, abstractmethod from collections.abc import Awaitable -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING if TYPE_CHECKING: from ..agent import Agent -@runtime_checkable -class Plugin(Protocol): - """Protocol for objects that extend agent functionality. +class Plugin(ABC): + """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. They are initialized with an agent instance and can register hooks, @@ -24,7 +24,7 @@ class Plugin(Protocol): Example: ```python - class MyPlugin: + class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: @@ -32,8 +32,13 @@ def init_plugin(self, agent: Agent) -> None: ``` """ - name: str + @property + @abstractmethod + def name(self) -> str: + """A stable string identifier for the plugin.""" + ... + @abstractmethod def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 90f6a2545..9274d2f12 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -10,10 +10,10 @@ # Plugin Protocol Tests -def test_plugin_protocol_is_runtime_checkable(): - """Test that Plugin Protocol is runtime checkable with isinstance.""" +def test_plugin_class_requires_inheritance(): + """Test that Plugin class requires inheritance.""" - class MyPlugin: + class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent): @@ -23,10 +23,10 @@ def init_plugin(self, agent): assert isinstance(plugin, Plugin) -def test_plugin_protocol_sync_implementation(): - """Test Plugin Protocol works with synchronous init_plugin.""" +def test_plugin_class_sync_implementation(): + """Test Plugin class works with synchronous init_plugin.""" - class SyncPlugin: + class SyncPlugin(Plugin): name = "sync-plugin" def init_plugin(self, agent): @@ -35,7 +35,7 @@ def init_plugin(self, agent): plugin = SyncPlugin() mock_agent = unittest.mock.Mock() - # Verify the plugin matches the protocol + # Verify the plugin is an instance of Plugin assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" @@ -45,10 +45,10 @@ def init_plugin(self, agent): @pytest.mark.asyncio -async def test_plugin_protocol_async_implementation(): - """Test Plugin Protocol works with asynchronous init_plugin.""" +async def test_plugin_class_async_implementation(): + """Test Plugin class works with asynchronous init_plugin.""" - class AsyncPlugin: + class AsyncPlugin(Plugin): name = "async-plugin" async def init_plugin(self, agent): @@ -57,7 +57,7 @@ async def init_plugin(self, agent): plugin = AsyncPlugin() mock_agent = unittest.mock.Mock() - # Verify the plugin matches the protocol + # Verify the plugin is an instance of Plugin assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" @@ -66,33 +66,33 @@ async def init_plugin(self, agent): assert mock_agent.custom_attribute == "initialized by async plugin" -def test_plugin_protocol_requires_name(): - """Test that Plugin Protocol requires a name property.""" +def test_plugin_class_requires_name(): + """Test that Plugin class requires a name property.""" - class PluginWithoutName: - def init_plugin(self, agent): - pass + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + + class PluginWithoutName(Plugin): + def init_plugin(self, agent): + pass + + PluginWithoutName() - plugin = PluginWithoutName() - # A class without 'name' should not pass isinstance check - assert not isinstance(plugin, Plugin) +def test_plugin_class_requires_init_plugin_method(): + """Test that Plugin class requires an init_plugin method.""" -def test_plugin_protocol_requires_init_plugin_method(): - """Test that Plugin Protocol requires an init_plugin method.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): - class PluginWithoutInitPlugin: - name = "incomplete-plugin" + class PluginWithoutInitPlugin(Plugin): + name = "incomplete-plugin" - plugin = PluginWithoutInitPlugin() - # A class without 'init_plugin' should not pass isinstance check - assert not isinstance(plugin, Plugin) + PluginWithoutInitPlugin() -def test_plugin_protocol_with_class_attribute_name(): - """Test Plugin Protocol works when name is a class attribute.""" +def test_plugin_class_with_class_attribute_name(): + """Test Plugin class works when name is a class attribute.""" - class PluginWithClassAttribute: + class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" def init_plugin(self, agent): @@ -103,10 +103,10 @@ def init_plugin(self, agent): assert plugin.name == "class-attr-plugin" -def test_plugin_protocol_with_property_name(): - """Test Plugin Protocol works when name is a property.""" +def test_plugin_class_with_property_name(): + """Test Plugin class works when name is a property.""" - class PluginWithProperty: + class PluginWithProperty(Plugin): @property def name(self): return "property-plugin" @@ -137,7 +137,7 @@ def registry(mock_agent): def test_plugin_registry_add_and_init_calls_init_plugin(registry, mock_agent): """Test adding a plugin calls its init_plugin method.""" - class TestPlugin: + class TestPlugin(Plugin): name = "test-plugin" def __init__(self): @@ -157,7 +157,7 @@ def init_plugin(self, agent): def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): """Test that adding a duplicate plugin raises an error.""" - class TestPlugin: + class TestPlugin(Plugin): name = "test-plugin" def init_plugin(self, agent): @@ -175,7 +175,7 @@ def init_plugin(self, agent): def test_plugin_registry_add_and_init_with_async_plugin(registry, mock_agent): """Test that add_and_init handles async plugins using run_async.""" - class AsyncPlugin: + class AsyncPlugin(Plugin): name = "async-plugin" def __init__(self): From 6669482b2810f66b8263cbd009666e2f1825edd5 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 11:22:17 -0500 Subject: [PATCH 02/29] Add missing docstring updated --- src/strands/plugins/__init__.py | 2 +- src/strands/plugins/registry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 33922e952..9ec9c9357 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -7,7 +7,7 @@ ```python from strands.plugins import Plugin - class LoggingPlugin: + class LoggingPlugin(Plugin): name = "logging" def init_plugin(self, agent: Agent) -> None: diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index ffd73c2f4..34a7a6639 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -28,7 +28,7 @@ class _PluginRegistry: ```python registry = _PluginRegistry(agent) - class MyPlugin: + class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: From 0eb1a415f8eaf61b2422e4727b45dba4760ef417 Mon Sep 17 00:00:00 2001 From: Containerized Agent Date: Fri, 20 Feb 2026 16:16:49 +0000 Subject: [PATCH 03/29] feat(plugins): add SkillsPlugin for AgentSkills.io integration Implement the SkillsPlugin that adds AgentSkills.io skill support to the Strands Agents SDK. The plugin enables progressive disclosure of skill instructions: metadata is injected into the system prompt upfront, and full instructions are loaded on demand via a tool. Key components: - Skill dataclass with from_path classmethod for loading from SKILL.md - Loader module for discovering, parsing, and validating skills - SkillsPlugin extending the Plugin ABC with: - skills tool (activate/deactivate actions) - BeforeInvocationEvent hook for system prompt injection - AfterInvocationEvent hook for prompt restoration - Single active skill management - Dynamic skill management via property setter - Session persistence via agent.state Files added: - src/strands/plugins/skills/__init__.py - src/strands/plugins/skills/skill.py - src/strands/plugins/skills/loader.py - src/strands/plugins/skills/skills_plugin.py - tests/strands/plugins/skills/ (90 tests) Files modified: - src/strands/plugins/__init__.py (added SkillsPlugin export) - src/strands/__init__.py (added Skill to top-level exports) --- src/strands/__init__.py | 2 + src/strands/plugins/__init__.py | 2 + src/strands/plugins/skills/__init__.py | 27 + src/strands/plugins/skills/loader.py | 290 ++++++++++ src/strands/plugins/skills/skill.py | 61 ++ src/strands/plugins/skills/skills_plugin.py | 328 +++++++++++ tests/strands/plugins/skills/__init__.py | 1 + tests/strands/plugins/skills/test_loader.py | 319 +++++++++++ tests/strands/plugins/skills/test_skill.py | 73 +++ .../plugins/skills/test_skills_plugin.py | 531 ++++++++++++++++++ 10 files changed, 1634 insertions(+) create mode 100644 src/strands/plugins/skills/__init__.py create mode 100644 src/strands/plugins/skills/loader.py create mode 100644 src/strands/plugins/skills/skill.py create mode 100644 src/strands/plugins/skills/skills_plugin.py create mode 100644 tests/strands/plugins/skills/__init__.py create mode 100644 tests/strands/plugins/skills/test_loader.py create mode 100644 tests/strands/plugins/skills/test_skill.py create mode 100644 tests/strands/plugins/skills/test_skills_plugin.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index be939d5b1..fc8237df8 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -5,6 +5,7 @@ from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy from .plugins import Plugin +from .plugins.skills import Skill from .tools.decorator import tool from .types.tools import ToolContext @@ -15,6 +16,7 @@ "models", "ModelRetryStrategy", "Plugin", + "Skill", "tool", "ToolContext", "types", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 9ec9c9357..51e014177 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -19,7 +19,9 @@ def on_model_call(self, event: BeforeModelCallEvent) -> None: """ from .plugin import Plugin +from .skills import SkillsPlugin __all__ = [ "Plugin", + "SkillsPlugin", ] diff --git a/src/strands/plugins/skills/__init__.py b/src/strands/plugins/skills/__init__.py new file mode 100644 index 000000000..60ada586c --- /dev/null +++ b/src/strands/plugins/skills/__init__.py @@ -0,0 +1,27 @@ +"""AgentSkills.io integration for Strands Agents. + +This module provides the SkillsPlugin for integrating AgentSkills.io skills +into Strands agents. Skills enable progressive disclosure of instructions: +metadata is injected into the system prompt upfront, and full instructions +are loaded on demand via a tool. + +Example Usage: + ```python + from strands import Agent + from strands.plugins.skills import Skill, SkillsPlugin + + plugin = SkillsPlugin(skills=["./skills/pdf-processing"]) + agent = Agent(plugins=[plugin]) + ``` +""" + +from .loader import load_skill, load_skills +from .skill import Skill +from .skills_plugin import SkillsPlugin + +__all__ = [ + "Skill", + "SkillsPlugin", + "load_skill", + "load_skills", +] diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py new file mode 100644 index 000000000..fa05a4df2 --- /dev/null +++ b/src/strands/plugins/skills/loader.py @@ -0,0 +1,290 @@ +"""Skill loading and parsing utilities for AgentSkills.io skills. + +This module provides functions for discovering, parsing, and loading skills +from the filesystem. Skills are directories containing a SKILL.md file with +YAML frontmatter metadata and markdown instructions. +""" + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import Any + +from .skill import Skill + +logger = logging.getLogger(__name__) + +_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$") +_MAX_SKILL_NAME_LENGTH = 64 + + +def _find_skill_md(skill_dir: Path) -> Path: + """Find the SKILL.md file in a skill directory. + + Searches for SKILL.md (case-sensitive preferred) or skill.md as a fallback. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Path to the SKILL.md file. + + Raises: + FileNotFoundError: If no SKILL.md file is found in the directory. + """ + for name in ("SKILL.md", "skill.md"): + candidate = skill_dir / name + if candidate.is_file(): + return candidate + + raise FileNotFoundError(f"path=<{skill_dir}> | no SKILL.md found in skill directory") + + +def _parse_yaml(yaml_text: str) -> dict[str, Any]: + """Parse YAML text into a dictionary. + + Uses PyYAML if available, otherwise falls back to simple key-value parsing + that handles the basic SKILL.md frontmatter format. + + Args: + yaml_text: YAML-formatted text to parse. + + Returns: + Dictionary of parsed key-value pairs. + """ + try: + import yaml + + result = yaml.safe_load(yaml_text) + return result if isinstance(result, dict) else {} + except ImportError: + logger.debug("PyYAML not available, using simple frontmatter parser") + return _parse_yaml_simple(yaml_text) + + +def _parse_yaml_simple(yaml_text: str) -> dict[str, Any]: + """Simple YAML parser for skill frontmatter. + + Handles basic key-value pairs and single-level nested mappings. This parser + is intentionally limited to the subset of YAML used in SKILL.md frontmatter. + + Args: + yaml_text: YAML-formatted text to parse. + + Returns: + Dictionary of parsed key-value pairs. + """ + result: dict[str, Any] = {} + current_key: str | None = None + current_nested: dict[str, str] | None = None + + for line in yaml_text.split("\n"): + if not line.strip() or line.strip().startswith("#"): + continue + + indent = len(line) - len(line.lstrip()) + + if indent == 0 and ":" in line: + # Save previous nested mapping if any + if current_key is not None and current_nested is not None: + result[current_key] = current_nested + current_nested = None + + key, _, value = line.partition(":") + key = key.strip() + value = value.strip() + current_key = key + + if value: + result[key] = value + else: + current_nested = {} + + elif indent > 0 and current_nested is not None and ":" in line.strip(): + nested_key, _, nested_value = line.strip().partition(":") + current_nested[nested_key.strip()] = nested_value.strip() + + # Save final nested mapping + if current_key is not None and current_nested is not None: + result[current_key] = current_nested + + return result + + +def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: + """Parse YAML frontmatter and body from SKILL.md content. + + Extracts the YAML frontmatter between ``---`` delimiters and returns + parsed key-value pairs along with the remaining markdown body. + + Args: + content: Full content of a SKILL.md file. + + Returns: + Tuple of (frontmatter_dict, body_string). + + Raises: + ValueError: If the frontmatter is malformed or missing required delimiters. + """ + stripped = content.strip() + if not stripped.startswith("---"): + raise ValueError("SKILL.md must start with --- frontmatter delimiter") + + end_idx = stripped.find("---", 3) + if end_idx == -1: + raise ValueError("SKILL.md frontmatter missing closing --- delimiter") + + frontmatter_str = stripped[3:end_idx].strip() + body = stripped[end_idx + 3 :].strip() + + frontmatter = _parse_yaml(frontmatter_str) + return frontmatter, body + + +def _validate_skill_name(name: str, dir_path: Path | None = None) -> None: + """Validate a skill name per the AgentSkills.io specification. + + Rules: + - 1-64 characters long + - Lowercase alphanumeric characters and hyphens only + - Cannot start or end with a hyphen + - No consecutive hyphens + - Must match parent directory name (if loaded from disk) + + Args: + name: The skill name to validate. + dir_path: Optional path to the skill directory for name matching. + + Raises: + ValueError: If the skill name is invalid. + """ + if not name: + raise ValueError("Skill name cannot be empty") + + if len(name) > _MAX_SKILL_NAME_LENGTH: + raise ValueError(f"name=<{name}> | skill name exceeds {_MAX_SKILL_NAME_LENGTH} character limit") + + if not _SKILL_NAME_PATTERN.match(name): + raise ValueError( + f"name=<{name}> | skill name must be 1-64 lowercase alphanumeric characters or hyphens, " + "cannot start/end with hyphen" + ) + + if "--" in name: + raise ValueError(f"name=<{name}> | skill name cannot contain consecutive hyphens") + + if dir_path is not None and dir_path.name != name: + raise ValueError(f"name=<{name}>, directory=<{dir_path.name}> | skill name must match parent directory name") + + +def load_skill(skill_path: str | Path) -> Skill: + """Load a single skill from a directory containing SKILL.md. + + Args: + skill_path: Path to the skill directory or the SKILL.md file itself. + + Returns: + A Skill instance populated from the SKILL.md file. + + Raises: + FileNotFoundError: If the path does not exist or SKILL.md is not found. + ValueError: If the skill metadata is invalid. + """ + skill_path = Path(skill_path).resolve() + + if skill_path.is_file() and skill_path.name.lower() == "skill.md": + skill_md_path = skill_path + skill_dir = skill_path.parent + elif skill_path.is_dir(): + skill_dir = skill_path + skill_md_path = _find_skill_md(skill_dir) + else: + raise FileNotFoundError(f"path=<{skill_path}> | skill path does not exist or is not a valid skill directory") + + logger.debug("path=<%s> | loading skill", skill_md_path) + + content = skill_md_path.read_text(encoding="utf-8") + frontmatter, body = _parse_frontmatter(content) + + name = frontmatter.get("name") + if not isinstance(name, str) or not name: + raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'name' field in frontmatter") + + description = frontmatter.get("description") + if not isinstance(description, str) or not description: + raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'description' field in frontmatter") + + _validate_skill_name(name, skill_dir) + + # Parse allowed-tools (space-delimited string per spec) + allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") + allowed_tools: list[str] | None = None + if isinstance(allowed_tools_raw, str) and allowed_tools_raw.strip(): + allowed_tools = allowed_tools_raw.strip().split() + + # Parse metadata (nested mapping) + metadata_raw = frontmatter.get("metadata", {}) + metadata: dict[str, str] = {} + if isinstance(metadata_raw, dict): + metadata = {str(k): str(v) for k, v in metadata_raw.items()} + + skill_license = frontmatter.get("license") + compatibility = frontmatter.get("compatibility") + + skill = Skill( + name=name, + description=description, + instructions=body, + path=skill_dir, + allowed_tools=allowed_tools, + metadata=metadata, + license=str(skill_license) if skill_license else None, + compatibility=str(compatibility) if compatibility else None, + ) + + logger.debug("name=<%s>, path=<%s> | skill loaded successfully", skill.name, skill.path) + return skill + + +def load_skills(skills_dir: str | Path) -> list[Skill]: + """Load all skills from a parent directory containing skill subdirectories. + + Each subdirectory containing a SKILL.md file is treated as a skill. + Subdirectories without SKILL.md are silently skipped. + + Args: + skills_dir: Path to the parent directory containing skill subdirectories. + + Returns: + List of Skill instances loaded from the directory. + + Raises: + FileNotFoundError: If the skills directory does not exist. + """ + skills_dir = Path(skills_dir).resolve() + + if not skills_dir.is_dir(): + raise FileNotFoundError(f"path=<{skills_dir}> | skills directory does not exist") + + skills: list[Skill] = [] + + for child in sorted(skills_dir.iterdir()): + if not child.is_dir(): + continue + + try: + _find_skill_md(child) + except FileNotFoundError: + logger.debug("path=<%s> | skipping directory without SKILL.md", child) + continue + + try: + skill = load_skill(child) + skills.append(skill) + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | skipping skill due to error: %s", child, e) + + logger.debug("path=<%s>, count=<%d> | loaded skills from directory", skills_dir, len(skills)) + return skills diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py new file mode 100644 index 000000000..c316c4474 --- /dev/null +++ b/src/strands/plugins/skills/skill.py @@ -0,0 +1,61 @@ +"""Skill data model for the AgentSkills.io integration. + +This module defines the Skill dataclass, which represents a single AgentSkills.io +skill with its metadata and instructions. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class Skill: + """Represents an AgentSkills.io skill with metadata and instructions. + + A skill encapsulates a set of instructions and metadata that can be + dynamically loaded by an agent at runtime. Skills support progressive + disclosure: metadata is shown upfront in the system prompt, and full + instructions are loaded on demand via a tool. + + Attributes: + name: Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). + description: Human-readable description of what the skill does. + instructions: Full markdown instructions from the SKILL.md body. + path: Filesystem path to the skill directory, if loaded from disk. + allowed_tools: List of tool names the skill is allowed to use. + metadata: Additional key-value metadata from the SKILL.md frontmatter. + license: License identifier (e.g., "Apache-2.0"). + compatibility: Compatibility information string. + """ + + name: str + description: str + instructions: str = "" + path: Path | None = None + allowed_tools: list[str] | None = None + metadata: dict[str, str] = field(default_factory=dict) + license: str | None = None + compatibility: str | None = None + + @classmethod + def from_path(cls, skill_path: str | Path) -> Skill: + """Load a skill from a directory containing SKILL.md. + + Args: + skill_path: Path to the skill directory or SKILL.md file. + + Returns: + A Skill instance populated from the SKILL.md file. + + Raises: + FileNotFoundError: If SKILL.md cannot be found. + ValueError: If the skill name is invalid or metadata is malformed. + """ + from .loader import load_skill + + return load_skill(skill_path) diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py new file mode 100644 index 000000000..028f9dc58 --- /dev/null +++ b/src/strands/plugins/skills/skills_plugin.py @@ -0,0 +1,328 @@ +"""SkillsPlugin for integrating AgentSkills.io skills into Strands agents. + +This module provides the SkillsPlugin class that extends the Plugin base class +to add AgentSkills.io skill support. The plugin registers a tool for activating +and deactivating skills, and injects skill metadata into the system prompt. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from ...hooks.events import AfterInvocationEvent, BeforeInvocationEvent +from ...hooks.registry import HookRegistry +from ...plugins.plugin import Plugin +from ...tools.decorator import tool +from .loader import load_skill, load_skills +from .skill import Skill + +if TYPE_CHECKING: + from ...agent.agent import Agent + from ...types.content import SystemContentBlock + +logger = logging.getLogger(__name__) + +_STATE_KEY = "skills_plugin" + + +def _make_skills_tool(plugin: SkillsPlugin) -> Any: + """Create the skills tool that allows the agent to activate and deactivate skills. + + Args: + plugin: The SkillsPlugin instance that manages skill state. + + Returns: + A decorated tool function for skill activation and deactivation. + """ + + @tool + def skills(action: str, skill_name: str = "") -> str: + """Activate or deactivate a skill to load its full instructions. + + Use this tool to load the complete instructions for a skill listed in + the available_skills section of your system prompt. + + Args: + action: The action to perform. Use "activate" to load a skill's full instructions, + or "deactivate" to unload the currently active skill. + skill_name: Name of the skill to activate. Required for "activate" action. + """ + if action == "activate": + if not skill_name: + return "Error: skill_name is required for activate action." + + found = plugin._find_skill(skill_name) + if found is None: + available = ", ".join(s.name for s in plugin._skills) + return f"Skill '{skill_name}' not found. Available skills: {available}" + + plugin._active_skill = found + plugin._persist_state() + + logger.debug("skill_name=<%s> | skill activated", skill_name) + return found.instructions or f"Skill '{skill_name}' activated (no instructions available)." + + elif action == "deactivate": + deactivated_name = plugin._active_skill.name if plugin._active_skill else skill_name + plugin._active_skill = None + plugin._persist_state() + + logger.debug("skill_name=<%s> | skill deactivated", deactivated_name) + return f"Skill '{deactivated_name}' deactivated." + + else: + return f"Unknown action: '{action}'. Use 'activate' or 'deactivate'." + + return skills + + +class SkillsPlugin(Plugin): + """Plugin that integrates AgentSkills.io skills into a Strands agent. + + The SkillsPlugin extends the Plugin base class and provides: + + 1. A ``skills`` tool that allows the agent to activate/deactivate skills on demand + 2. System prompt injection of available skill metadata before each invocation + 3. Single active skill management (activating a new skill deactivates the previous one) + 4. Session persistence of active skill state via ``agent.state`` + + Skills can be provided as filesystem paths (to individual skill directories or + parent directories containing multiple skills) or as pre-built ``Skill`` instances. + + Example: + ```python + from strands import Agent + from strands.plugins.skills import Skill, SkillsPlugin + + # Load from filesystem + plugin = SkillsPlugin(skills=["./skills/pdf-processing", "./skills/"]) + + # Or provide Skill instances directly + skill = Skill(name="my-skill", description="A custom skill", instructions="Do the thing") + plugin = SkillsPlugin(skills=[skill]) + + agent = Agent(plugins=[plugin]) + ``` + """ + + @property + def name(self) -> str: + """A stable string identifier for the plugin.""" + return "skills" + + def __init__(self, skills: list[str | Path | Skill]) -> None: + """Initialize the SkillsPlugin. + + Args: + skills: List of skill sources. Each element can be: + + - A ``str`` or ``Path`` to a skill directory (containing SKILL.md) + - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) + - A ``Skill`` dataclass instance + """ + self._skills: list[Skill] = self._resolve_skills(skills) + self._active_skill: Skill | None = None + self._agent: Agent | None = None + self._saved_system_prompt: str | None = None + self._saved_system_prompt_content: list[SystemContentBlock] | None = None + + def init_plugin(self, agent: Agent) -> None: + """Initialize the plugin with an agent instance. + + Registers the skills tool and hooks with the agent. + + Args: + agent: The agent instance to extend with skills support. + """ + self._agent = agent + + agent.tool_registry.process_tools([_make_skills_tool(self)]) + agent.hooks.add_hook(self) + + self._restore_state() + + logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hook callbacks with the agent's hook registry. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(BeforeInvocationEvent, self._on_before_invocation) + registry.add_callback(AfterInvocationEvent, self._on_after_invocation) + + @property + def skills(self) -> list[Skill]: + """Get the list of available skills. + + Returns: + A copy of the current skills list. + """ + return list(self._skills) + + @skills.setter + def skills(self, value: list[str | Path | Skill]) -> None: + """Set the available skills, resolving paths as needed. + + Deactivates any currently active skill when skills are changed. + + Args: + value: List of skill sources to resolve. + """ + self._skills = self._resolve_skills(value) + self._active_skill = None + self._persist_state() + + @property + def active_skill(self) -> Skill | None: + """Get the currently active skill. + + Returns: + The active Skill instance, or None if no skill is active. + """ + return self._active_skill + + def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Inject skill metadata into the system prompt before each invocation. + + Saves the current system prompt and appends an XML block listing + all available skills so the model knows what it can activate. + + Args: + event: The before-invocation event containing the agent reference. + """ + agent = event.agent + + # Save original system prompt for restoration after invocation + self._saved_system_prompt = agent._system_prompt + self._saved_system_prompt_content = agent._system_prompt_content + + if not self._skills: + return + + skills_xml = self._generate_skills_xml() + current: str = agent._system_prompt or "" + new_prompt = f"{current}\n\n{skills_xml}" if current else skills_xml + + # Directly set both representations to avoid re-parsing through the setter + # and to preserve cache control blocks in the original content + agent._system_prompt = new_prompt + agent._system_prompt_content = [{"text": new_prompt}] + + def _on_after_invocation(self, event: AfterInvocationEvent) -> None: + """Restore the original system prompt after invocation completes. + + Args: + event: The after-invocation event containing the agent reference. + """ + agent = event.agent + + # Restore original system prompt directly to preserve content block types + agent._system_prompt = self._saved_system_prompt + agent._system_prompt_content = self._saved_system_prompt_content + self._saved_system_prompt = None + self._saved_system_prompt_content = None + + def _generate_skills_xml(self) -> str: + """Generate the XML block listing available skills for the system prompt. + + Returns: + XML-formatted string with skill metadata. + """ + lines: list[str] = [""] + + for skill in self._skills: + lines.append("") + lines.append(f"{skill.name}") + lines.append(f"{skill.description}") + lines.append("") + + lines.append("") + return "\n".join(lines) + + def _find_skill(self, skill_name: str) -> Skill | None: + """Find a skill by name in the available skills list. + + Args: + skill_name: The name of the skill to find. + + Returns: + The matching Skill instance, or None if not found. + """ + for skill in self._skills: + if skill.name == skill_name: + return skill + return None + + def _resolve_skills(self, sources: list[str | Path | Skill]) -> list[Skill]: + """Resolve a list of skill sources into Skill instances. + + Each source can be a Skill instance, a path to a skill directory, + or a path to a parent directory containing multiple skills. + + Args: + sources: List of skill sources to resolve. + + Returns: + List of resolved Skill instances. + """ + resolved: list[Skill] = [] + + for source in sources: + if isinstance(source, Skill): + resolved.append(source) + else: + path = Path(source).resolve() + if not path.exists(): + logger.warning("path=<%s> | skill source path does not exist, skipping", path) + continue + + if path.is_dir(): + # Check if this directory itself is a skill (has SKILL.md) + has_skill_md = (path / "SKILL.md").is_file() or (path / "skill.md").is_file() + + if has_skill_md: + try: + resolved.append(load_skill(path)) + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | failed to load skill: %s", path, e) + else: + # Treat as parent directory containing skill subdirectories + resolved.extend(load_skills(path)) + elif path.is_file() and path.name.lower() == "skill.md": + try: + resolved.append(load_skill(path)) + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | failed to load skill: %s", path, e) + + logger.debug("source_count=<%d>, resolved_count=<%d> | skills resolved", len(sources), len(resolved)) + return resolved + + def _persist_state(self) -> None: + """Persist the active skill name to agent state for session recovery.""" + if self._agent is None: + return + + state_data: dict[str, Any] = { + "active_skill_name": self._active_skill.name if self._active_skill else None, + } + self._agent.state.set(_STATE_KEY, state_data) + + def _restore_state(self) -> None: + """Restore the active skill from agent state if available.""" + if self._agent is None: + return + + state_data = self._agent.state.get(_STATE_KEY) + if not isinstance(state_data, dict): + return + + active_name = state_data.get("active_skill_name") + if isinstance(active_name, str): + self._active_skill = self._find_skill(active_name) + if self._active_skill: + logger.debug("skill_name=<%s> | restored active skill from state", active_name) diff --git a/tests/strands/plugins/skills/__init__.py b/tests/strands/plugins/skills/__init__.py new file mode 100644 index 000000000..9bd23c0ed --- /dev/null +++ b/tests/strands/plugins/skills/__init__.py @@ -0,0 +1 @@ +"""Tests for the skills plugin package.""" diff --git a/tests/strands/plugins/skills/test_loader.py b/tests/strands/plugins/skills/test_loader.py new file mode 100644 index 000000000..875ebf204 --- /dev/null +++ b/tests/strands/plugins/skills/test_loader.py @@ -0,0 +1,319 @@ +"""Tests for the skill loader module.""" + +from pathlib import Path + +import pytest + +from strands.plugins.skills.loader import ( + _find_skill_md, + _parse_frontmatter, + _parse_yaml_simple, + _validate_skill_name, + load_skill, + load_skills, +) + + +class TestFindSkillMd: + """Tests for _find_skill_md.""" + + def test_finds_uppercase_skill_md(self, tmp_path): + """Test finding SKILL.md (uppercase).""" + (tmp_path / "SKILL.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_finds_lowercase_skill_md(self, tmp_path): + """Test finding skill.md (lowercase).""" + (tmp_path / "skill.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name == "skill.md" + + def test_prefers_uppercase(self, tmp_path): + """Test that SKILL.md is preferred over skill.md.""" + (tmp_path / "SKILL.md").write_text("uppercase") + (tmp_path / "skill.md").write_text("lowercase") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_raises_when_not_found(self, tmp_path): + """Test FileNotFoundError when no SKILL.md exists.""" + with pytest.raises(FileNotFoundError, match="no SKILL.md found"): + _find_skill_md(tmp_path) + + +class TestParseYamlSimple: + """Tests for _parse_yaml_simple.""" + + def test_simple_key_values(self): + """Test parsing simple key-value pairs.""" + text = "name: my-skill\ndescription: A test skill\nlicense: Apache-2.0" + result = _parse_yaml_simple(text) + assert result == {"name": "my-skill", "description": "A test skill", "license": "Apache-2.0"} + + def test_nested_mapping(self): + """Test parsing a nested mapping.""" + text = "name: my-skill\nmetadata:\n author: test-org\n version: 1.0" + result = _parse_yaml_simple(text) + assert result["name"] == "my-skill" + assert result["metadata"] == {"author": "test-org", "version": "1.0"} + + def test_skips_comments_and_empty_lines(self): + """Test that comments and empty lines are skipped.""" + text = "# comment\nname: my-skill\n\ndescription: test\n" + result = _parse_yaml_simple(text) + assert result == {"name": "my-skill", "description": "test"} + + def test_empty_input(self): + """Test parsing empty input.""" + result = _parse_yaml_simple("") + assert result == {} + + +class TestParseFrontmatter: + """Tests for _parse_frontmatter.""" + + def test_valid_frontmatter(self): + """Test parsing valid frontmatter.""" + content = "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "A test" + assert "# Instructions" in body + assert "Do things." in body + + def test_missing_opening_delimiter(self): + """Test error when opening --- is missing.""" + with pytest.raises(ValueError, match="must start with ---"): + _parse_frontmatter("name: test\n---\n") + + def test_missing_closing_delimiter(self): + """Test error when closing --- is missing.""" + with pytest.raises(ValueError, match="missing closing ---"): + _parse_frontmatter("---\nname: test\n") + + def test_empty_body(self): + """Test frontmatter with empty body.""" + content = "---\nname: test-skill\ndescription: test\n---\n" + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert body == "" + + def test_frontmatter_with_metadata(self): + """Test frontmatter with nested metadata.""" + content = "---\nname: test-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert isinstance(frontmatter["metadata"], dict) + assert frontmatter["metadata"]["author"] == "acme" + assert body == "Body here." + + +class TestValidateSkillName: + """Tests for _validate_skill_name.""" + + def test_valid_names(self): + """Test that valid names pass validation.""" + valid_names = ["a", "test", "my-skill", "skill-123", "a1b2c3"] + for name in valid_names: + _validate_skill_name(name) # Should not raise + + def test_empty_name(self): + """Test that empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("") + + def test_too_long_name(self): + """Test that names exceeding 64 chars raise ValueError.""" + with pytest.raises(ValueError, match="exceeds 64 character limit"): + _validate_skill_name("a" * 65) + + def test_uppercase_rejected(self): + """Test that uppercase characters are rejected.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("MySkill") + + def test_starts_with_hyphen(self): + """Test that names starting with hyphen are rejected.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("-skill") + + def test_ends_with_hyphen(self): + """Test that names ending with hyphen are rejected.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("skill-") + + def test_consecutive_hyphens(self): + """Test that consecutive hyphens are rejected.""" + with pytest.raises(ValueError, match="consecutive hyphens"): + _validate_skill_name("my--skill") + + def test_special_characters(self): + """Test that special characters are rejected.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("my_skill") + + def test_directory_name_mismatch(self, tmp_path): + """Test that skill name must match directory name.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with pytest.raises(ValueError, match="must match parent directory name"): + _validate_skill_name("my-skill", skill_dir) + + def test_directory_name_match(self, tmp_path): + """Test that matching directory name passes.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + _validate_skill_name("my-skill", skill_dir) # Should not raise + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill", body: str = "Instructions.") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n{body}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +class TestLoadSkill: + """Tests for load_skill.""" + + def test_load_from_directory(self, tmp_path): + """Test loading a skill from a directory path.""" + skill_dir = _make_skill_dir(tmp_path, "my-skill", "My description", "# Hello\nWorld.") + skill = load_skill(skill_dir) + + assert skill.name == "my-skill" + assert skill.description == "My description" + assert "# Hello" in skill.instructions + assert "World." in skill.instructions + assert skill.path == skill_dir.resolve() + + def test_load_from_skill_md_file(self, tmp_path): + """Test loading a skill by pointing directly to SKILL.md.""" + skill_dir = _make_skill_dir(tmp_path, "direct-skill") + skill = load_skill(skill_dir / "SKILL.md") + + assert skill.name == "direct-skill" + + def test_load_with_allowed_tools(self, tmp_path): + """Test loading a skill with allowed-tools field.""" + skill_dir = tmp_path / "tool-skill" + skill_dir.mkdir() + content = "---\nname: tool-skill\ndescription: test\nallowed-tools: read write execute\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = load_skill(skill_dir) + assert skill.allowed_tools == ["read", "write", "execute"] + + def test_load_with_metadata(self, tmp_path): + """Test loading a skill with nested metadata.""" + skill_dir = tmp_path / "meta-skill" + skill_dir.mkdir() + content = "---\nname: meta-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = load_skill(skill_dir) + assert skill.metadata == {"author": "acme"} + + def test_load_with_license_and_compatibility(self, tmp_path): + """Test loading a skill with license and compatibility fields.""" + skill_dir = tmp_path / "licensed-skill" + skill_dir.mkdir() + content = "---\nname: licensed-skill\ndescription: test\nlicense: MIT\ncompatibility: v1\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = load_skill(skill_dir) + assert skill.license == "MIT" + assert skill.compatibility == "v1" + + def test_load_missing_name(self, tmp_path): + """Test error when SKILL.md is missing name field.""" + skill_dir = tmp_path / "no-name" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'name' field"): + load_skill(skill_dir) + + def test_load_missing_description(self, tmp_path): + """Test error when SKILL.md is missing description field.""" + skill_dir = tmp_path / "no-desc" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'description' field"): + load_skill(skill_dir) + + def test_load_nonexistent_path(self, tmp_path): + """Test FileNotFoundError for nonexistent path.""" + with pytest.raises(FileNotFoundError): + load_skill(tmp_path / "nonexistent") + + def test_load_name_directory_mismatch(self, tmp_path): + """Test error when skill name doesn't match directory name.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="must match parent directory name"): + load_skill(skill_dir) + + +class TestLoadSkills: + """Tests for load_skills.""" + + def test_load_multiple_skills(self, tmp_path): + """Test loading multiple skills from a parent directory.""" + _make_skill_dir(tmp_path, "skill-a", "Skill A") + _make_skill_dir(tmp_path, "skill-b", "Skill B") + + skills = load_skills(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"skill-a", "skill-b"} + + def test_skips_directories_without_skill_md(self, tmp_path): + """Test that directories without SKILL.md are silently skipped.""" + _make_skill_dir(tmp_path, "valid-skill") + (tmp_path / "no-skill-here").mkdir() + + skills = load_skills(tmp_path) + + assert len(skills) == 1 + assert skills[0].name == "valid-skill" + + def test_skips_files_in_parent(self, tmp_path): + """Test that files in the parent directory are ignored.""" + _make_skill_dir(tmp_path, "real-skill") + (tmp_path / "readme.txt").write_text("not a skill") + + skills = load_skills(tmp_path) + + assert len(skills) == 1 + + def test_empty_directory(self, tmp_path): + """Test loading from an empty directory.""" + skills = load_skills(tmp_path) + assert skills == [] + + def test_nonexistent_directory(self, tmp_path): + """Test FileNotFoundError for nonexistent directory.""" + with pytest.raises(FileNotFoundError): + load_skills(tmp_path / "nonexistent") + + def test_skips_invalid_skills(self, tmp_path): + """Test that invalid skills are skipped with a warning.""" + _make_skill_dir(tmp_path, "good-skill") + + # Create an invalid skill (name mismatch) + bad_dir = tmp_path / "bad-dir" + bad_dir.mkdir() + (bad_dir / "SKILL.md").write_text("---\nname: wrong-name\ndescription: test\n---\nBody.") + + skills = load_skills(tmp_path) + + assert len(skills) == 1 + assert skills[0].name == "good-skill" diff --git a/tests/strands/plugins/skills/test_skill.py b/tests/strands/plugins/skills/test_skill.py new file mode 100644 index 000000000..379eec7d2 --- /dev/null +++ b/tests/strands/plugins/skills/test_skill.py @@ -0,0 +1,73 @@ +"""Tests for the Skill dataclass.""" + +from pathlib import Path + +import pytest + +from strands.plugins.skills.skill import Skill + + +class TestSkillDataclass: + """Tests for the Skill dataclass creation and properties.""" + + def test_skill_minimal(self): + """Test creating a Skill with only required fields.""" + skill = Skill(name="test-skill", description="A test skill") + + assert skill.name == "test-skill" + assert skill.description == "A test skill" + assert skill.instructions == "" + assert skill.path is None + assert skill.allowed_tools is None + assert skill.metadata == {} + assert skill.license is None + assert skill.compatibility is None + + def test_skill_full(self): + """Test creating a Skill with all fields.""" + skill = Skill( + name="full-skill", + description="A fully specified skill", + instructions="# Full Instructions\nDo the thing.", + path=Path("/tmp/skills/full-skill"), + allowed_tools=["tool1", "tool2"], + metadata={"author": "test-org"}, + license="Apache-2.0", + compatibility="strands>=1.0", + ) + + assert skill.name == "full-skill" + assert skill.description == "A fully specified skill" + assert skill.instructions == "# Full Instructions\nDo the thing." + assert skill.path == Path("/tmp/skills/full-skill") + assert skill.allowed_tools == ["tool1", "tool2"] + assert skill.metadata == {"author": "test-org"} + assert skill.license == "Apache-2.0" + assert skill.compatibility == "strands>=1.0" + + def test_skill_metadata_default_is_not_shared(self): + """Test that default metadata dict is not shared between instances.""" + skill1 = Skill(name="skill-1", description="First") + skill2 = Skill(name="skill-2", description="Second") + + skill1.metadata["key"] = "value" + assert "key" not in skill2.metadata + + def test_skill_from_path(self, tmp_path): + """Test loading a Skill from a path using from_path classmethod.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: my-skill\ndescription: Test skill\n---\n# Instructions\nDo stuff.\n" + ) + + skill = Skill.from_path(skill_dir) + + assert skill.name == "my-skill" + assert skill.description == "Test skill" + assert "Do stuff." in skill.instructions + + def test_skill_from_path_not_found(self, tmp_path): + """Test that from_path raises FileNotFoundError for missing paths.""" + with pytest.raises(FileNotFoundError): + Skill.from_path(tmp_path / "nonexistent") diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py new file mode 100644 index 000000000..6e983add4 --- /dev/null +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -0,0 +1,531 @@ +"""Tests for the SkillsPlugin.""" + +from pathlib import Path +from unittest.mock import MagicMock + +from strands.hooks.events import AfterInvocationEvent, BeforeInvocationEvent +from strands.hooks.registry import HookRegistry +from strands.plugins.skills.skill import Skill +from strands.plugins.skills.skills_plugin import SkillsPlugin, _make_skills_tool + + +def _make_skill(name: str = "test-skill", description: str = "A test skill", instructions: str = "Do the thing."): + """Helper to create a Skill instance.""" + return Skill(name=name, description=description, instructions=instructions) + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n# Instructions for {name}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +def _mock_agent(): + """Create a mock agent for testing.""" + agent = MagicMock() + agent._system_prompt = "You are an agent." + agent._system_prompt_content = [{"text": "You are an agent."}] + agent.hooks = HookRegistry() + agent.tool_registry = MagicMock() + agent.tool_registry.process_tools = MagicMock(return_value=["skills"]) + agent.state = MagicMock() + agent.state.get = MagicMock(return_value=None) + agent.state.set = MagicMock() + return agent + + +class TestSkillsPluginInit: + """Tests for SkillsPlugin initialization.""" + + def test_init_with_skill_instances(self): + """Test initialization with Skill instances.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + + assert len(plugin.skills) == 1 + assert plugin.skills[0].name == "test-skill" + + def test_init_with_filesystem_paths(self, tmp_path): + """Test initialization with filesystem paths.""" + _make_skill_dir(tmp_path, "fs-skill") + plugin = SkillsPlugin(skills=[str(tmp_path / "fs-skill")]) + + assert len(plugin.skills) == 1 + assert plugin.skills[0].name == "fs-skill" + + def test_init_with_parent_directory(self, tmp_path): + """Test initialization with a parent directory containing skills.""" + _make_skill_dir(tmp_path, "skill-a") + _make_skill_dir(tmp_path, "skill-b") + plugin = SkillsPlugin(skills=[tmp_path]) + + assert len(plugin.skills) == 2 + + def test_init_with_mixed_sources(self, tmp_path): + """Test initialization with mixed skill sources.""" + _make_skill_dir(tmp_path, "fs-skill") + direct_skill = _make_skill(name="direct-skill", description="Direct") + plugin = SkillsPlugin(skills=[str(tmp_path / "fs-skill"), direct_skill]) + + assert len(plugin.skills) == 2 + names = {s.name for s in plugin.skills} + assert names == {"fs-skill", "direct-skill"} + + def test_init_skips_nonexistent_paths(self, tmp_path): + """Test that nonexistent paths are skipped gracefully.""" + plugin = SkillsPlugin(skills=[str(tmp_path / "nonexistent")]) + assert len(plugin.skills) == 0 + + def test_init_empty_skills(self): + """Test initialization with empty skills list.""" + plugin = SkillsPlugin(skills=[]) + assert plugin.skills == [] + assert plugin.active_skill is None + + def test_name_attribute(self): + """Test that the plugin has the correct name.""" + plugin = SkillsPlugin(skills=[]) + assert plugin.name == "skills" + + +class TestSkillsPluginInitPlugin: + """Tests for the init_plugin method.""" + + def test_registers_tool(self): + """Test that init_plugin registers the skills tool.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + + plugin.init_plugin(agent) + + agent.tool_registry.process_tools.assert_called_once() + args = agent.tool_registry.process_tools.call_args[0][0] + assert len(args) == 1 + + def test_registers_hooks(self): + """Test that init_plugin registers hook callbacks.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + + plugin.init_plugin(agent) + + # Verify hooks were registered by checking the registry has callbacks + assert agent.hooks.has_callbacks() + + def test_stores_agent_reference(self): + """Test that init_plugin stores the agent reference.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + + plugin.init_plugin(agent) + + assert plugin._agent is agent + + def test_restores_state(self): + """Test that init_plugin restores active skill from state.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + agent = _mock_agent() + agent.state.get = MagicMock(return_value={"active_skill_name": "test-skill"}) + + plugin.init_plugin(agent) + + assert plugin.active_skill is not None + assert plugin.active_skill.name == "test-skill" + + +class TestSkillsPluginProperties: + """Tests for SkillsPlugin properties.""" + + def test_skills_getter_returns_copy(self): + """Test that the skills getter returns a copy of the list.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + + skills_list = plugin.skills + skills_list.append(_make_skill(name="another-skill", description="Another")) + + assert len(plugin.skills) == 1 + + def test_skills_setter(self): + """Test setting skills via the property setter.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + plugin._agent = _mock_agent() + + new_skill = _make_skill(name="new-skill", description="New") + plugin.skills = [new_skill] + + assert len(plugin.skills) == 1 + assert plugin.skills[0].name == "new-skill" + + def test_skills_setter_deactivates_current(self): + """Test that setting skills deactivates the current active skill.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + plugin._agent = _mock_agent() + plugin._active_skill = _make_skill() + + plugin.skills = [_make_skill(name="new-skill", description="New")] + + assert plugin.active_skill is None + + def test_active_skill_initially_none(self): + """Test that active_skill is None initially.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + assert plugin.active_skill is None + + +class TestSkillsTool: + """Tests for the skills tool function.""" + + def test_activate_skill(self): + """Test activating a skill returns its instructions.""" + skill = _make_skill(instructions="Full instructions here.") + plugin = SkillsPlugin(skills=[skill]) + plugin._agent = _mock_agent() + + skills_tool = _make_skills_tool(plugin) + result = skills_tool(action="activate", skill_name="test-skill") + + assert result == "Full instructions here." + assert plugin.active_skill is not None + assert plugin.active_skill.name == "test-skill" + + def test_activate_nonexistent_skill(self): + """Test activating a nonexistent skill returns error message.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + plugin._agent = _mock_agent() + + skills_tool = _make_skills_tool(plugin) + result = skills_tool(action="activate", skill_name="nonexistent") + + assert "not found" in result + assert "test-skill" in result + + def test_activate_replaces_previous(self): + """Test that activating a new skill replaces the previous one.""" + skill1 = _make_skill(name="skill-a", description="A", instructions="A instructions") + skill2 = _make_skill(name="skill-b", description="B", instructions="B instructions") + plugin = SkillsPlugin(skills=[skill1, skill2]) + plugin._agent = _mock_agent() + + skills_tool = _make_skills_tool(plugin) + skills_tool(action="activate", skill_name="skill-a") + assert plugin.active_skill.name == "skill-a" + + skills_tool(action="activate", skill_name="skill-b") + assert plugin.active_skill.name == "skill-b" + + def test_activate_without_name(self): + """Test activating without a skill name returns error.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + plugin._agent = _mock_agent() + + skills_tool = _make_skills_tool(plugin) + result = skills_tool(action="activate", skill_name="") + + assert "required" in result.lower() + + def test_deactivate_skill(self): + """Test deactivating a skill.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + plugin._agent = _mock_agent() + plugin._active_skill = skill + + skills_tool = _make_skills_tool(plugin) + result = skills_tool(action="deactivate", skill_name="test-skill") + + assert "deactivated" in result.lower() + assert plugin.active_skill is None + + def test_unknown_action(self): + """Test unknown action returns error message.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + plugin._agent = _mock_agent() + + skills_tool = _make_skills_tool(plugin) + result = skills_tool(action="unknown") + + assert "Unknown action" in result + + def test_activate_persists_state(self): + """Test that activating a skill persists state.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + plugin._agent = agent + + skills_tool = _make_skills_tool(plugin) + skills_tool(action="activate", skill_name="test-skill") + + agent.state.set.assert_called() + + +class TestSystemPromptInjection: + """Tests for system prompt injection via hooks.""" + + def test_before_invocation_appends_skills_xml(self): + """Test that before_invocation appends skills XML to system prompt.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + agent = _mock_agent() + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "" in agent._system_prompt + assert "test-skill" in agent._system_prompt + assert "A test skill" in agent._system_prompt + + def test_before_invocation_preserves_existing_prompt(self): + """Test that existing system prompt content is preserved.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert agent._system_prompt.startswith("Original prompt.") + assert "" in agent._system_prompt + + def test_after_invocation_restores_prompt(self): + """Test that after_invocation restores the original system prompt.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + original_prompt = "Original prompt." + original_content = [{"text": "Original prompt."}] + agent._system_prompt = original_prompt + agent._system_prompt_content = original_content + + # Simulate before/after cycle + before_event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(before_event) + assert agent._system_prompt != original_prompt + + after_event = AfterInvocationEvent(agent=agent) + plugin._on_after_invocation(after_event) + assert agent._system_prompt == original_prompt + assert agent._system_prompt_content == original_content + + def test_no_skills_skips_injection(self): + """Test that injection is skipped when no skills are available.""" + plugin = SkillsPlugin(skills=[]) + agent = _mock_agent() + original_prompt = "Original prompt." + agent._system_prompt = original_prompt + agent._system_prompt_content = [{"text": original_prompt}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert agent._system_prompt == original_prompt + + def test_none_system_prompt_handled(self): + """Test handling when system prompt is None.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = None + agent._system_prompt_content = None + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "" in agent._system_prompt + + +class TestSkillsXmlGeneration: + """Tests for _generate_skills_xml.""" + + def test_single_skill(self): + """Test XML generation with a single skill.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + xml = plugin._generate_skills_xml() + + assert "" in xml + assert "" in xml + assert "test-skill" in xml + assert "A test skill" in xml + + def test_multiple_skills(self): + """Test XML generation with multiple skills.""" + skills = [ + _make_skill(name="skill-a", description="Skill A"), + _make_skill(name="skill-b", description="Skill B"), + ] + plugin = SkillsPlugin(skills=skills) + xml = plugin._generate_skills_xml() + + assert "skill-a" in xml + assert "skill-b" in xml + + def test_empty_skills(self): + """Test XML generation with no skills.""" + plugin = SkillsPlugin(skills=[]) + xml = plugin._generate_skills_xml() + + assert "" in xml + assert "" in xml + + +class TestHookRegistration: + """Tests for hook registration.""" + + def test_register_hooks(self): + """Test that register_hooks adds callbacks to the registry.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + registry = HookRegistry() + + plugin.register_hooks(registry) + + assert registry.has_callbacks() + + +class TestSessionPersistence: + """Tests for session state persistence.""" + + def test_persist_state_with_active_skill(self): + """Test persisting active skill name.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + plugin._agent = agent + plugin._active_skill = _make_skill() + + plugin._persist_state() + + agent.state.set.assert_called_once_with("skills_plugin", {"active_skill_name": "test-skill"}) + + def test_persist_state_without_active_skill(self): + """Test persisting None when no skill is active.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + plugin._agent = agent + + plugin._persist_state() + + agent.state.set.assert_called_once_with("skills_plugin", {"active_skill_name": None}) + + def test_restore_state_activates_skill(self): + """Test restoring active skill from state.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + agent = _mock_agent() + agent.state.get = MagicMock(return_value={"active_skill_name": "test-skill"}) + plugin._agent = agent + + plugin._restore_state() + + assert plugin.active_skill is not None + assert plugin.active_skill.name == "test-skill" + + def test_restore_state_no_data(self): + """Test restore when no state data exists.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent.state.get = MagicMock(return_value=None) + plugin._agent = agent + + plugin._restore_state() + + assert plugin.active_skill is None + + def test_restore_state_skill_not_found(self): + """Test restore when saved skill is no longer available.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent.state.get = MagicMock(return_value={"active_skill_name": "removed-skill"}) + plugin._agent = agent + + plugin._restore_state() + + assert plugin.active_skill is None + + def test_persist_state_without_agent(self): + """Test that persist_state is a no-op without agent.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + + # Should not raise + plugin._persist_state() + + +class TestResolveSkills: + """Tests for _resolve_skills.""" + + def test_resolve_skill_instances(self): + """Test resolving Skill instances (pass-through).""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + + assert len(plugin._skills) == 1 + assert plugin._skills[0] is skill + + def test_resolve_skill_directory_path(self, tmp_path): + """Test resolving a path to a skill directory.""" + _make_skill_dir(tmp_path, "path-skill") + plugin = SkillsPlugin(skills=[tmp_path / "path-skill"]) + + assert len(plugin._skills) == 1 + assert plugin._skills[0].name == "path-skill" + + def test_resolve_parent_directory_path(self, tmp_path): + """Test resolving a path to a parent directory.""" + _make_skill_dir(tmp_path, "child-a") + _make_skill_dir(tmp_path, "child-b") + plugin = SkillsPlugin(skills=[tmp_path]) + + assert len(plugin._skills) == 2 + + def test_resolve_skill_md_file_path(self, tmp_path): + """Test resolving a path to a SKILL.md file.""" + skill_dir = _make_skill_dir(tmp_path, "file-skill") + plugin = SkillsPlugin(skills=[skill_dir / "SKILL.md"]) + + assert len(plugin._skills) == 1 + assert plugin._skills[0].name == "file-skill" + + def test_resolve_nonexistent_path(self, tmp_path): + """Test that nonexistent paths are skipped.""" + plugin = SkillsPlugin(skills=[str(tmp_path / "ghost")]) + assert len(plugin._skills) == 0 + + +class TestImports: + """Tests for module imports.""" + + def test_import_from_plugins(self): + """Test importing SkillsPlugin from strands.plugins.""" + from strands.plugins import SkillsPlugin as SP + + assert SP is SkillsPlugin + + def test_import_skill_from_strands(self): + """Test importing Skill from top-level strands package.""" + from strands import Skill as S + + assert S is Skill + + def test_import_from_skills_package(self): + """Test importing from strands.plugins.skills package.""" + from strands.plugins.skills import Skill, SkillsPlugin, load_skill, load_skills + + assert Skill is not None + assert SkillsPlugin is not None + assert load_skill is not None + assert load_skills is not None + + def test_skills_plugin_is_plugin_subclass(self): + """Test that SkillsPlugin is a subclass of the Plugin ABC.""" + from strands.plugins import Plugin + + assert issubclass(SkillsPlugin, Plugin) + + def test_skills_plugin_isinstance_check(self): + """Test that SkillsPlugin instances pass isinstance check against Plugin.""" + from strands.plugins import Plugin + + plugin = SkillsPlugin(skills=[]) + assert isinstance(plugin, Plugin) From 35bff69521417f30eb6699212e3b595c61a028ff Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Thu, 19 Feb 2026 20:33:38 +0000 Subject: [PATCH 04/29] feat(plugins): add @hook decorator and convert Plugin to base class - Create @hook decorator for declarative hook registration in plugins - Convert Plugin from Protocol to base class (breaking change) - Add auto-discovery of @hook and @tool decorated methods in Plugin.__init__() - Add auto-registration of hooks and tools in Plugin.init_plugin() - Support union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent) - Export hook from strands.plugins and strands namespaces - Update existing tests to use inheritance-based approach - Add comprehensive test coverage for new functionality BREAKING CHANGE: Plugin is now a base class instead of a Protocol. Existing plugins must inherit from Plugin instead of just implementing the protocol. --- AGENTS.md | 3 +- src/strands/__init__.py | 3 +- src/strands/plugins/__init__.py | 30 +- src/strands/plugins/decorator.py | 188 ++++++++ src/strands/plugins/plugin.py | 103 ++++- tests/strands/plugins/test_hook_decorator.py | 232 ++++++++++ .../strands/plugins/test_plugin_base_class.py | 408 ++++++++++++++++++ tests/strands/plugins/test_plugins.py | 71 +-- 8 files changed, 992 insertions(+), 46 deletions(-) create mode 100644 src/strands/plugins/decorator.py create mode 100644 tests/strands/plugins/test_hook_decorator.py create mode 100644 tests/strands/plugins/test_plugin_base_class.py diff --git a/AGENTS.md b/AGENTS.md index 6a5765a94..a5b092ffe 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -127,7 +127,8 @@ strands-agents/ │ │ └── registry.py # Hook registration │ │ │ ├── plugins/ # Plugin system -│ │ ├── plugin.py # Plugin definition +│ │ ├── plugin.py # Plugin base class +│ │ ├── decorator.py # @hook decorator │ │ └── registry.py # PluginRegistry for tracking plugins │ │ │ ├── handlers/ # Event handlers diff --git a/src/strands/__init__.py b/src/strands/__init__.py index be939d5b1..2e187edd1 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin +from .plugins import Plugin, hook from .tools.decorator import tool from .types.tools import ToolContext @@ -12,6 +12,7 @@ "Agent", "AgentBase", "agent", + "hook", "models", "ModelRetryStrategy", "Plugin", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 9ec9c9357..dbcaeda57 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -1,25 +1,49 @@ """Plugin system for extending agent functionality. This module provides a composable mechanism for building objects that can -extend agent behavior through a standardized initialization pattern. +extend agent behavior through automatic hook and tool registration. -Example Usage: +Example Usage with Decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + + class LoggingPlugin(Plugin): + name = "logging" + + @hook + def on_model_call(self, event: BeforeModelCallEvent) -> None: + print(f"Model called for {event.agent.name}") + + @tool + def log_message(self, message: str) -> str: + '''Log a message.''' + print(message) + return "Logged" + ``` + +Example Usage with Manual Registration: ```python from strands.plugins import Plugin + from strands.hooks import BeforeModelCallEvent class LoggingPlugin(Plugin): name = "logging" def init_plugin(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) + super().init_plugin(agent) # Register decorated methods + # Add additional manual hooks + agent.hooks.add_callback(BeforeModelCallEvent, self.on_model_call) def on_model_call(self, event: BeforeModelCallEvent) -> None: print(f"Model called for {event.agent.name}") ``` """ +from .decorator import hook from .plugin import Plugin __all__ = [ "Plugin", + "hook", ] diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py new file mode 100644 index 000000000..3652c9664 --- /dev/null +++ b/src/strands/plugins/decorator.py @@ -0,0 +1,188 @@ +"""Hook decorator for Plugin methods. + +This module provides the @hook decorator that marks methods as hook callbacks +for automatic registration when the plugin is attached to an agent. + +The @hook decorator performs several functions: + +1. Marks methods as hook callbacks for automatic discovery by Plugin base class +2. Infers event types from the callback's type hints (consistent with HookRegistry.add_callback) +3. Supports both @hook and @hook() syntax +4. Supports union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent) +5. Stores hook metadata on the decorated method for later discovery + +Example: + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent, AfterModelCallEvent + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(event) + + @hook + def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): + print(event) + ``` +""" + +import functools +import inspect +import logging +import types +from collections.abc import Callable +from typing import TypeVar, Union, cast, get_args, get_origin, get_type_hints, overload + +from ..hooks.registry import BaseHookEvent, HookCallback, TEvent + +logger = logging.getLogger(__name__) + +# Type for wrapped function +T = TypeVar("T", bound=Callable[..., object]) + + +def _infer_event_types(callback: HookCallback[TEvent]) -> list[type[TEvent]]: + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + This logic is adapted from HookRegistry._infer_event_types to provide + consistent behavior for event type inference. + + Args: + callback: The callback function to inspect. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # For methods, skip 'self' parameter + first_param = params[0] + if first_param.name == "self" and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast(type[TEvent], arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast(type[TEvent], type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) + + +# Handle @hook +@overload +def hook(__func: T) -> T: ... + + +# Handle @hook() +@overload +def hook() -> Callable[[T], T]: ... + + +def hook( # type: ignore[misc] + func: T | None = None, +) -> T | Callable[[T], T]: + """Decorator that marks a method as a hook callback for automatic registration. + + This decorator enables declarative hook registration in Plugin classes. When a + Plugin is attached to an agent, methods marked with @hook are automatically + discovered and registered with the agent's hook registry. + + The event type is inferred from the callback's type hint on the first parameter + (after 'self' for instance methods). Union types are supported for registering + a single callback for multiple event types. + + The decorator can be used in two ways: + - As a simple decorator: `@hook` + - With parentheses: `@hook()` + + Args: + func: The function to decorate. When used as a simple decorator, this is + the function being decorated. When used with parentheses, this will be None. + + Returns: + The decorated function with hook metadata attached. + + Raises: + ValueError: If the event type cannot be inferred from type hints, or if + the type hint is not a valid HookEvent subclass. + + Example: + ```python + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @hook + def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): + print(f"Event: {type(event).__name__}") + ``` + """ + + def decorator(f: T) -> T: + # Infer event types from type hints + event_types = _infer_event_types(f) + + # Store hook metadata on the function + f._hook_event_types = event_types + + # Preserve original function metadata + @functools.wraps(f) + def wrapper(*args: object, **kwargs: object) -> object: + return f(*args, **kwargs) + + # Copy hook metadata to wrapper + wrapper._hook_event_types = event_types + + return cast(T, wrapper) + + # Handle both @hook and @hook() syntax + if func is None: + return decorator + + return decorator(func) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 80707616a..c9e2b514c 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -6,29 +6,58 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable +import logging from typing import TYPE_CHECKING +from strands.tools.decorator import DecoratedFunctionTool + if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) class Plugin(ABC): """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. - They are initialized with an agent instance and can register hooks, - modify agent attributes, or perform other setup tasks. + They support automatic discovery and registration of methods decorated + with @hook and @tool decorators. Attributes: - name: A stable string identifier for the plugin + name: A stable string identifier for the plugin (must be provided by subclass) + _hooks: List of discovered @hook decorated methods (populated in __init__) + _tools: List of discovered @tool decorated methods (populated in __init__) + + Example using decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + + class MyPlugin(Plugin): + name = "my-plugin" - Example: + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @tool + def my_tool(self, param: str) -> str: + '''A tool that does something.''' + return f"Result: {param}" + ``` + + Example with manual registration: ```python class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) + super().init_plugin(agent) # Register decorated methods + # Add additional manual hooks if needed + agent.hooks.add_callback(BeforeModelCallEvent, self.custom_hook) + + def custom_hook(self, event: BeforeModelCallEvent): + print(event) ``` """ @@ -38,11 +67,71 @@ def name(self) -> str: """A stable string identifier for the plugin.""" ... - @abstractmethod + def __init__(self) -> None: + """Initialize the plugin and discover decorated methods. + + Scans the class for methods decorated with @hook and @tool and stores + references for later registration when init_plugin is called. + """ + self._hooks: list[object] = [] + self._tools: list[DecoratedFunctionTool] = [] + self._discover_decorated_methods() + + def _discover_decorated_methods(self) -> None: + """Scan class for @hook and @tool decorated methods.""" + for name in dir(self): + # Skip private and dunder methods + if name.startswith("_"): + continue + + try: + attr = getattr(self, name) + except Exception: + # Skip attributes that can't be accessed + continue + + # Check for @hook decorated methods + if hasattr(attr, "_hook_event_types") and callable(attr): + self._hooks.append(attr) + logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) + + # Check for @tool decorated methods (DecoratedFunctionTool instances) + if isinstance(attr, DecoratedFunctionTool): + self._tools.append(attr) + logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) + + def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. + Default implementation that registers all discovered @hook methods + with the agent's hook registry and adds all discovered @tool methods + to the agent's tools list. + + Subclasses can override this method and call super().init_plugin(agent) + to retain automatic registration while adding custom initialization logic. + Args: agent: The agent instance to extend. """ - ... + # Register discovered hooks with the agent's hook registry + for hook_callback in self._hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + agent.hooks.add_callback(event_type, hook_callback) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + self.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) + + # Register discovered tools with the agent's tool registry + if self._tools: + agent.tool_registry.process_tools(self._tools) + for tool in self._tools: + logger.debug( + "plugin=<%s>, tool=<%s> | registered tool", + self.name, + tool.tool_name, + ) diff --git a/tests/strands/plugins/test_hook_decorator.py b/tests/strands/plugins/test_hook_decorator.py new file mode 100644 index 000000000..520040c9d --- /dev/null +++ b/tests/strands/plugins/test_hook_decorator.py @@ -0,0 +1,232 @@ +"""Tests for the @hook decorator.""" + +import unittest.mock + +import pytest + +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, +) +from strands.plugins.decorator import hook + + +class TestHookDecoratorBasic: + """Tests for basic @hook decorator functionality.""" + + def test_hook_decorator_marks_method(self): + """Test that @hook marks a method with hook metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_with_parentheses(self): + """Test that @hook() syntax also works.""" + + @hook() + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_preserves_function_metadata(self): + """Test that @hook preserves the original function's metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + """Docstring for the hook.""" + pass + + assert on_before_model_call.__name__ == "on_before_model_call" + assert on_before_model_call.__doc__ == "Docstring for the hook." + + def test_hook_decorator_function_still_callable(self): + """Test that decorated function can still be called normally.""" + call_count = 0 + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + on_before_model_call(mock_event) + assert call_count == 1 + + +class TestHookDecoratorEventTypeInference: + """Tests for event type inference from type hints.""" + + def test_hook_infers_event_type_from_type_hint(self): + """Test that @hook infers event type from the first parameter's type hint.""" + + @hook + def handler(event: BeforeInvocationEvent): + pass + + assert BeforeInvocationEvent in handler._hook_event_types + + def test_hook_infers_different_event_types(self): + """Test that different event types are correctly inferred.""" + + @hook + def handler1(event: BeforeModelCallEvent): + pass + + @hook + def handler2(event: AfterModelCallEvent): + pass + + @hook + def handler3(event: AfterInvocationEvent): + pass + + assert BeforeModelCallEvent in handler1._hook_event_types + assert AfterModelCallEvent in handler2._hook_event_types + assert AfterInvocationEvent in handler3._hook_event_types + + +class TestHookDecoratorUnionTypes: + """Tests for union type support in @hook decorator.""" + + def test_hook_supports_union_types_with_pipe(self): + """Test that @hook supports union types using | syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_union_types_with_typing_union(self): + """Test that @hook supports Union[] syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_multiple_union_types(self): + """Test that @hook supports unions with more than two types.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + assert BeforeInvocationEvent in handler._hook_event_types + + +class TestHookDecoratorErrorHandling: + """Tests for error handling in @hook decorator.""" + + def test_hook_raises_error_without_type_hint(self): + """Test that @hook raises error when no type hint is provided.""" + with pytest.raises(ValueError, match="cannot infer event type"): + + @hook + def handler(event): + pass + + def test_hook_raises_error_with_non_hook_event_type(self): + """Test that @hook raises error when type hint is not a HookEvent subclass.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def handler(event: str): + pass + + def test_hook_raises_error_with_none_in_union(self): + """Test that @hook raises error when union contains None.""" + with pytest.raises(ValueError, match="None is not a valid event type"): + + @hook + def handler(event: BeforeModelCallEvent | None): + pass + + +class TestHookDecoratorWithMethods: + """Tests for @hook decorator on class methods.""" + + def test_hook_works_on_instance_method(self): + """Test that @hook works correctly on instance methods.""" + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + pass + + instance = MyClass() + assert hasattr(instance.handler, "_hook_event_types") + assert BeforeModelCallEvent in instance.handler._hook_event_types + + def test_hook_instance_method_is_callable(self): + """Test that decorated instance method can be called.""" + call_count = 0 + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert call_count == 1 + + def test_hook_method_accesses_self(self): + """Test that decorated method can access self.""" + + class MyClass: + def __init__(self): + self.events_received = [] + + @hook + def handler(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert len(instance.events_received) == 1 + assert instance.events_received[0] is mock_event + + +class TestHookDecoratorAsync: + """Tests for async functions with @hook decorator.""" + + def test_hook_works_on_async_function(self): + """Test that @hook works on async functions.""" + + @hook + async def handler(event: BeforeModelCallEvent): + pass + + assert hasattr(handler, "_hook_event_types") + assert BeforeModelCallEvent in handler._hook_event_types + + @pytest.mark.asyncio + async def test_hook_async_function_is_callable(self): + """Test that decorated async function can be awaited.""" + call_count = 0 + + @hook + async def handler(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + await handler(mock_event) + assert call_count == 1 diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py new file mode 100644 index 000000000..caa4f84b3 --- /dev/null +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -0,0 +1,408 @@ +"""Tests for the Plugin base class with auto-discovery.""" + +import unittest.mock + +import pytest + +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.tools.decorator import tool + + +class TestPluginBaseClass: + """Tests for Plugin base class basics.""" + + def test_plugin_is_class_not_protocol(self): + """Test that Plugin is now a class, not a Protocol.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + def test_plugin_requires_name_attribute(self): + """Test that Plugin subclass must have name attribute.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert plugin.name == "my-plugin" + + def test_plugin_name_as_property(self): + """Test that Plugin name can be a property.""" + + class MyPlugin(Plugin): + @property + def name(self) -> str: + return "property-plugin" + + plugin = MyPlugin() + assert plugin.name == "property-plugin" + + +class TestPluginAutoDiscovery: + """Tests for automatic discovery of decorated methods.""" + + def test_plugin_discovers_hook_decorated_methods(self): + """Test that Plugin.__init__ discovers @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin._hooks) == 1 + assert plugin._hooks[0].__name__ == "on_before_model" + + def test_plugin_discovers_multiple_hooks(self): + """Test that Plugin discovers multiple @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin._hooks) == 2 + hook_names = {h.__name__ for h in plugin._hooks} + assert "hook1" in hook_names + assert "hook2" in hook_names + + def test_plugin_discovers_tool_decorated_methods(self): + """Test that Plugin.__init__ discovers @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin._tools) == 1 + assert plugin._tools[0].tool_name == "my_tool" + + def test_plugin_discovers_both_hooks_and_tools(self): + """Test that Plugin discovers both @hook and @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin._hooks) == 1 + assert len(plugin._tools) == 1 + + def test_plugin_ignores_non_decorated_methods(self): + """Test that Plugin doesn't discover non-decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin._hooks) == 1 + assert plugin._hooks[0].__name__ == "decorated_hook" + + +class TestPluginInitPlugin: + """Tests for Plugin.init_plugin() auto-registration.""" + + def test_init_plugin_registers_hooks_with_agent(self): + """Test that init_plugin registers discovered hooks with agent.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_init_plugin_registers_tools_with_agent(self): + """Test that init_plugin adds discovered tools to agent's tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + plugin.init_plugin(mock_agent) + + # Verify tool was added to agent + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_init_plugin_registers_both_hooks_and_tools(self): + """Test that init_plugin registers both hooks and tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + plugin.init_plugin(mock_agent) + + # Verify both registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + mock_agent.tool_registry.process_tools.assert_called_once() + + +class TestPluginHookWithUnionTypes: + """Tests for Plugin hooks with union types.""" + + def test_init_plugin_registers_hook_for_union_types(self): + """Test that hooks with union types are registered for all event types.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): + pass + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + # Verify hook was registered for both event types + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginMultipleAgents: + """Tests for plugin reuse with multiple agents.""" + + def test_plugin_can_be_attached_to_multiple_agents(self): + """Test that the same plugin instance can be used with multiple agents.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + + mock_agent1 = unittest.mock.MagicMock() + mock_agent1.hooks = HookRegistry() + mock_agent2 = unittest.mock.MagicMock() + mock_agent2.hooks = HookRegistry() + + plugin.init_plugin(mock_agent1) + plugin.init_plugin(mock_agent2) + + # Verify both agents have the hook registered + assert len(mock_agent1.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent2.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginSubclassOverride: + """Tests for subclass overriding init_plugin.""" + + def test_subclass_can_override_init_plugin(self): + """Test that subclass can override init_plugin and call super().""" + custom_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + def init_plugin(self, agent): + nonlocal custom_init_called + custom_init_called = True + super().init_plugin(agent) + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + assert custom_init_called + # Verify auto-registration still happened via super() + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_subclass_can_add_manual_hooks(self): + """Test that subclass can manually add hooks in addition to decorated ones.""" + manual_hook_added = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeModelCallEvent): + pass + + def manual_hook(self, event: BeforeInvocationEvent): + pass + + def init_plugin(self, agent): + nonlocal manual_hook_added + super().init_plugin(agent) + # Add manual hook + agent.hooks.add_callback(BeforeInvocationEvent, self.manual_hook) + manual_hook_added = True + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + assert manual_hook_added + # Verify both hooks registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginAsyncInitPlugin: + """Tests for async init_plugin support.""" + + @pytest.mark.asyncio + async def test_async_init_plugin_supported(self): + """Test that async init_plugin is supported.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + async def init_plugin(self, agent): + # Just call super synchronously - async is for custom logic + super().init_plugin(agent) + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + await plugin.init_plugin(mock_agent) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginBoundMethods: + """Tests for bound method registration.""" + + def test_hooks_are_bound_to_instance(self): + """Test that registered hooks are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + # Call the registered hook and verify it accesses the correct instance + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + callbacks = list(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) + callbacks[0](mock_event) + + assert len(plugin.events_received) == 1 + assert plugin.events_received[0] is mock_event + + def test_tools_are_bound_to_instance(self): + """Test that registered tools are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.tool_called = False + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + self.tool_called = True + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + plugin.init_plugin(mock_agent) + + # Get the tool that was registered and call it + call_args = mock_agent.tool_registry.process_tools.call_args + registered_tools = call_args[0][0] + assert len(registered_tools) == 1 + + # Call the tool - it should be bound to the instance + result = registered_tools[0]("test") + assert plugin.tool_called + assert result == "test" diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 7d0f49dc9..3df0da1cf 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -4,38 +4,39 @@ import pytest +from strands.hooks import HookRegistry from strands.plugins import Plugin from strands.plugins.registry import _PluginRegistry -# Plugin Tests +# Plugin Base Class Tests -def test_plugin_class_requires_inheritance(): - """Test that Plugin class requires inheritance.""" +def test_plugin_base_class_isinstance_check(): + """Test that Plugin subclass passes isinstance check.""" class MyPlugin(Plugin): name = "my-plugin" - def init_plugin(self, agent): - pass - plugin = MyPlugin() assert isinstance(plugin, Plugin) -def test_plugin_class_sync_implementation(): - """Test Plugin class works with synchronous init_plugin.""" +def test_plugin_base_class_sync_implementation(): + """Test Plugin base class works with synchronous init_plugin.""" class SyncPlugin(Plugin): name = "sync-plugin" def init_plugin(self, agent): + super().init_plugin(agent) agent.custom_attribute = "initialized by plugin" plugin = SyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" @@ -45,19 +46,22 @@ def init_plugin(self, agent): @pytest.mark.asyncio -async def test_plugin_class_async_implementation(): - """Test Plugin class works with asynchronous init_plugin.""" +async def test_plugin_base_class_async_implementation(): + """Test Plugin base class works with asynchronous init_plugin.""" class AsyncPlugin(Plugin): name = "async-plugin" async def init_plugin(self, agent): + super().init_plugin(agent) agent.custom_attribute = "initialized by async plugin" plugin = AsyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" @@ -78,42 +82,37 @@ def init_plugin(self, agent): PluginWithoutName() -def test_plugin_class_requires_init_plugin_method(): - """Test that Plugin class requires an init_plugin method.""" +def test_plugin_base_class_requires_init_plugin_method(): + """Test that Plugin base class provides default init_plugin.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): + class PluginWithoutOverride(Plugin): + name = "no-override-plugin" - class PluginWithoutInitPlugin(Plugin): - name = "incomplete-plugin" + plugin = PluginWithoutOverride() + # Plugin base class provides default init_plugin + assert hasattr(plugin, "init_plugin") + assert callable(plugin.init_plugin) - PluginWithoutInitPlugin() - -def test_plugin_class_with_class_attribute_name(): - """Test Plugin class works when name is a class attribute.""" +def test_plugin_base_class_with_class_attribute_name(): + """Test Plugin base class works when name is a class attribute.""" class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" - def init_plugin(self, agent): - pass - plugin = PluginWithClassAttribute() assert isinstance(plugin, Plugin) assert plugin.name == "class-attr-plugin" -def test_plugin_class_with_property_name(): - """Test Plugin class works when name is a property.""" +def test_plugin_base_class_with_property_name(): + """Test Plugin base class works when name is a property.""" class PluginWithProperty(Plugin): @property - def name(self): + def name(self) -> str: return "property-plugin" - def init_plugin(self, agent): - pass - plugin = PluginWithProperty() assert isinstance(plugin, Plugin) assert plugin.name == "property-plugin" @@ -125,7 +124,10 @@ def init_plugin(self, agent): @pytest.fixture def mock_agent(): """Create a mock agent for testing.""" - return unittest.mock.Mock() + agent = unittest.mock.Mock() + agent.hooks = HookRegistry() + agent.tool_registry = unittest.mock.MagicMock() + return agent @pytest.fixture @@ -141,9 +143,11 @@ class TestPlugin(Plugin): name = "test-plugin" def __init__(self): + super().__init__() self.initialized = False def init_plugin(self, agent): + super().init_plugin(agent) self.initialized = True agent.plugin_initialized = True @@ -160,9 +164,6 @@ def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): class TestPlugin(Plugin): name = "test-plugin" - def init_plugin(self, agent): - pass - plugin1 = TestPlugin() plugin2 = TestPlugin() @@ -179,9 +180,11 @@ class AsyncPlugin(Plugin): name = "async-plugin" def __init__(self): + super().__init__() self.initialized = False async def init_plugin(self, agent): + super().init_plugin(agent) self.initialized = True agent.async_plugin_initialized = True From f2f74e48dd5f34d6f66237d5dc4213143b0ea04f Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 12:47:02 -0500 Subject: [PATCH 05/29] Have decorator not wrap funciton, but just attach hook events --- src/strands/hooks/_type_inference.py | 80 ++++++++++++++++++++++++ src/strands/hooks/registry.py | 62 +------------------ src/strands/plugins/decorator.py | 93 ++-------------------------- src/strands/plugins/plugin.py | 4 +- 4 files changed, 91 insertions(+), 148 deletions(-) create mode 100644 src/strands/hooks/_type_inference.py diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py new file mode 100644 index 000000000..0cfea01bb --- /dev/null +++ b/src/strands/hooks/_type_inference.py @@ -0,0 +1,80 @@ +"""Utility for inferring event types from callback type hints.""" + +import inspect +import logging +import types +from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from .registry import HookCallback, TEvent + +logger = logging.getLogger(__name__) + + +def infer_event_types(callback: "HookCallback[TEvent]", skip_self: bool = False) -> "list[type[TEvent]]": + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + Args: + callback: The callback function to inspect. + skip_self: If True, skip 'self' parameter when looking for event type hint. + Use True for instance methods, False for standalone functions. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + # Import here to avoid circular dependency + from .registry import BaseHookEvent + + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # For methods, skip 'self' parameter if requested + first_param = params[0] + if skip_self and first_param.name == "self" and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast("type[TEvent]", arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast("type[TEvent]", type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 886ea5644..5096e255e 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,24 +9,12 @@ import inspect import logging -import types from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Protocol, - TypeVar, - Union, - cast, - get_args, - get_origin, - get_type_hints, - runtime_checkable, -) +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException +from ._type_inference import infer_event_types if TYPE_CHECKING: from ..agent import Agent @@ -276,51 +264,7 @@ def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent ValueError: If the event type cannot be inferred from the callback's type hints, or if a union contains None or non-BaseHookEvent types. """ - try: - hints = get_type_hints(callback) - except Exception as e: - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) - raise ValueError( - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" - ) from e - - # Get the first parameter's type hint - sig = inspect.signature(callback) - params = list(sig.parameters.values()) - - if not params: - raise ValueError( - "callback has no parameters | cannot infer event type, please provide event_type explicitly" - ) - - first_param = params[0] - type_hint = hints.get(first_param.name) - - if type_hint is None: - raise ValueError( - f"parameter=<{first_param.name}> has no type hint | " - "cannot infer event type, please provide event_type explicitly" - ) - - # Check if it's a Union type (Union[A, B] or A | B) - origin = get_origin(type_hint) - if origin is Union or origin is types.UnionType: - event_types: list[type[TEvent]] = [] - for arg in get_args(type_hint): - if arg is type(None): - raise ValueError("None is not a valid event type in union") - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(cast(type[TEvent], arg)) - return event_types - - # Handle single type - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [cast(type[TEvent], type_hint)] - - raise ValueError( - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" - ) + return infer_event_types(callback, skip_self=False) def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index 3652c9664..79c768d85 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -29,88 +29,15 @@ def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): ``` """ -import functools -import inspect -import logging -import types from collections.abc import Callable -from typing import TypeVar, Union, cast, get_args, get_origin, get_type_hints, overload +from typing import TypeVar, overload -from ..hooks.registry import BaseHookEvent, HookCallback, TEvent - -logger = logging.getLogger(__name__) +from ..hooks._type_inference import infer_event_types # Type for wrapped function T = TypeVar("T", bound=Callable[..., object]) -def _infer_event_types(callback: HookCallback[TEvent]) -> list[type[TEvent]]: - """Infer the event type(s) from a callback's type hints. - - Supports both single types and union types (A | B or Union[A, B]). - - This logic is adapted from HookRegistry._infer_event_types to provide - consistent behavior for event type inference. - - Args: - callback: The callback function to inspect. - - Returns: - A list of event types inferred from the callback's first parameter type hint. - - Raises: - ValueError: If the event type cannot be inferred from the callback's type hints, - or if a union contains None or non-BaseHookEvent types. - """ - try: - hints = get_type_hints(callback) - except Exception as e: - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) - raise ValueError( - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" - ) from e - - # Get the first parameter's type hint - sig = inspect.signature(callback) - params = list(sig.parameters.values()) - - if not params: - raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") - - # For methods, skip 'self' parameter - first_param = params[0] - if first_param.name == "self" and len(params) > 1: - first_param = params[1] - - type_hint = hints.get(first_param.name) - - if type_hint is None: - raise ValueError( - f"parameter=<{first_param.name}> has no type hint | " - "cannot infer event type, please provide event_type explicitly" - ) - - # Check if it's a Union type (Union[A, B] or A | B) - origin = get_origin(type_hint) - if origin is Union or origin is types.UnionType: - event_types: list[type[TEvent]] = [] - for arg in get_args(type_hint): - if arg is type(None): - raise ValueError("None is not a valid event type in union") - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(cast(type[TEvent], arg)) - return event_types - - # Handle single type - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [cast(type[TEvent], type_hint)] - - raise ValueError( - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" - ) - - # Handle @hook @overload def hook(__func: T) -> T: ... @@ -165,21 +92,13 @@ def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): """ def decorator(f: T) -> T: - # Infer event types from type hints - event_types = _infer_event_types(f) + # Infer event types from type hints (skip 'self' for methods) + event_types = infer_event_types(f, skip_self=True) # Store hook metadata on the function - f._hook_event_types = event_types - - # Preserve original function metadata - @functools.wraps(f) - def wrapper(*args: object, **kwargs: object) -> object: - return f(*args, **kwargs) - - # Copy hook metadata to wrapper - wrapper._hook_event_types = event_types + f._hook_event_types = event_types # type: ignore[attr-defined] - return cast(T, wrapper) + return f # Handle both @hook and @hook() syntax if func is None: diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index c9e2b514c..422f7fb77 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -4,9 +4,9 @@ add behavior changes to agents through a standardized initialization pattern. """ +import logging from abc import ABC, abstractmethod from collections.abc import Awaitable -import logging from typing import TYPE_CHECKING from strands.tools.decorator import DecoratedFunctionTool @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) + class Plugin(ABC): """Base class for objects that extend agent functionality. @@ -100,7 +101,6 @@ def _discover_decorated_methods(self) -> None: self._tools.append(attr) logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) - def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. From 8fb72441f6a6c17efbe814172dceeb0ea42e2ca0 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 13:23:47 -0500 Subject: [PATCH 06/29] Update steering to use hook decorator --- src/strands/__init__.py | 3 +- .../experimental/steering/core/handler.py | 17 ++-- src/strands/hooks/_type_inference.py | 8 +- src/strands/hooks/registry.py | 19 +---- src/strands/plugins/decorator.py | 14 ++-- src/strands/plugins/plugin.py | 2 + .../steering/core/test_handler.py | 78 ++++++++++++------- 7 files changed, 76 insertions(+), 65 deletions(-) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2e187edd1..be939d5b1 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin, hook +from .plugins import Plugin from .tools.decorator import tool from .types.tools import ToolContext @@ -12,7 +12,6 @@ "Agent", "AgentBase", "agent", - "hook", "models", "ModelRetryStrategy", "Plugin", diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 3b869c0eb..807d16b8a 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, Any from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from ....plugins.plugin import Plugin +from ....plugins import Plugin, hook from ....types.content import Message from ....types.streaming import StopReason from ....types.tools import ToolUse @@ -66,6 +66,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non Args: context_providers: List of context providers for context updates """ + super().__init__() self.steering_context = SteeringContext() self._context_callbacks = [] @@ -83,17 +84,14 @@ def init_plugin(self, agent: "Agent") -> None: Args: agent: The agent instance to attach steering to. """ + super().init_plugin(agent) + # Register context update callbacks for callback in self._context_callbacks: agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) - # Register tool steering guidance - agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent) - - # Register model steering guidance - agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent) - - async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: + @hook + async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) @@ -133,7 +131,8 @@ def _handle_tool_steering_action( else: raise ValueError(f"Unknown steering action type for tool call: {action}") - async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + @hook + async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: """Provide steering guidance for model response.""" logger.debug("providing model steering guidance") diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py index 0cfea01bb..aba7d1164 100644 --- a/src/strands/hooks/_type_inference.py +++ b/src/strands/hooks/_type_inference.py @@ -11,15 +11,13 @@ logger = logging.getLogger(__name__) -def infer_event_types(callback: "HookCallback[TEvent]", skip_self: bool = False) -> "list[type[TEvent]]": +def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]": """Infer the event type(s) from a callback's type hints. Supports both single types and union types (A | B or Union[A, B]). Args: callback: The callback function to inspect. - skip_self: If True, skip 'self' parameter when looking for event type hint. - Use True for instance methods, False for standalone functions. Returns: A list of event types inferred from the callback's first parameter type hint. @@ -46,9 +44,9 @@ def infer_event_types(callback: "HookCallback[TEvent]", skip_self: bool = False) if not params: raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") - # For methods, skip 'self' parameter if requested + # Skip 'self' parameter for methods first_param = params[0] - if skip_self and first_param.name == "self" and len(params) > 1: + if first_param.name == "self" and len(params) > 1: first_param = params[1] type_hint = hints.get(first_param.name) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 5096e255e..8b284b0c2 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -213,7 +213,7 @@ def multi_handler(event): resolved_event_types = self._validate_event_type_list(event_type) elif event_type is None: # Infer event type(s) from callback type hints - resolved_event_types = self._infer_event_types(callback) + resolved_event_types = infer_event_types(callback) else: # Single event type provided explicitly resolved_event_types = [event_type] @@ -249,23 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ validated.append(et) return validated - def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: - """Infer the event type(s) from a callback's type hints. - - Supports both single types and union types (A | B or Union[A, B]). - - Args: - callback: The callback function to inspect. - - Returns: - A list of event types inferred from the callback's first parameter type hint. - - Raises: - ValueError: If the event type cannot be inferred from the callback's type hints, - or if a union contains None or non-BaseHookEvent types. - """ - return infer_event_types(callback, skip_self=False) - def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index 79c768d85..1e7ea13e6 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -30,9 +30,10 @@ def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): """ from collections.abc import Callable -from typing import TypeVar, overload +from typing import TYPE_CHECKING, TypeVar, overload -from ..hooks._type_inference import infer_event_types +if TYPE_CHECKING: + from ..hooks.registry import BaseHookEvent # Type for wrapped function T = TypeVar("T", bound=Callable[..., object]) @@ -48,7 +49,7 @@ def hook(__func: T) -> T: ... def hook() -> Callable[[T], T]: ... -def hook( # type: ignore[misc] +def hook( func: T | None = None, ) -> T | Callable[[T], T]: """Decorator that marks a method as a hook callback for automatic registration. @@ -92,8 +93,11 @@ def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): """ def decorator(f: T) -> T: - # Infer event types from type hints (skip 'self' for methods) - event_types = infer_event_types(f, skip_self=True) + # Import here to avoid circular dependency at runtime + from ..hooks._type_inference import infer_event_types + + # Infer event types from type hints + event_types: list[type[BaseHookEvent]] = infer_event_types(f) # type: ignore[arg-type] # Store hook metadata on the function f._hook_event_types = event_types # type: ignore[attr-defined] diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 422f7fb77..82513273c 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -135,3 +135,5 @@ def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: self.name, tool.tool_name, ) + + return None diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 447780939..08399139c 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -39,14 +39,18 @@ def test_steering_handler_is_plugin(): def test_init_plugin(): """Test init_plugin registers hooks on agent.""" + from strands.hooks import HookRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify hooks were registered (tool and model steering hooks) - assert agent.add_hook.call_count >= 2 - agent.add_hook.assert_any_call(handler._provide_tool_steering_guidance, BeforeToolCallEvent) + # Verify hooks were auto-registered via @hook decorator + assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 def test_steering_context_initialization(): @@ -86,7 +90,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -105,7 +109,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled. Test guidance You MUST follow this guidance immediately." @@ -126,7 +130,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -145,7 +149,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -165,11 +169,12 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) def test_init_plugin_override(): """Test that init_plugin can be overridden.""" + from strands.hooks import HookRegistry class CustomHandler(SteeringHandler): async def steer_before_tool(self, *, agent, tool_use, **kwargs): @@ -181,11 +186,14 @@ def init_plugin(self, agent): handler = CustomHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() handler.init_plugin(agent) - # Should not register any hooks - assert agent.add_hook.call_count == 0 + # Should not register any hooks since parent init_plugin wasn't called + assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) == 0 + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) == 0 # Integration tests with context providers @@ -219,20 +227,28 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): """Test that handler registers hooks from context callbacks.""" + from strands.hooks import HookRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock() handler.init_plugin(agent) - # Should register hooks for context callback and steering guidance - assert agent.add_hook.call_count >= 2 + # Should register 1 context callback via add_hook (steering hooks are auto-registered) + assert agent.add_hook.call_count >= 1 - # Check that BeforeToolCallEvent was registered + # Check that BeforeToolCallEvent was registered (either via add_hook or auto-registration) call_args = [call[0] for call in agent.add_hook.call_args_list] event_types = [args[1] for args in call_args] - assert BeforeToolCallEvent in event_types + # Context callback should be registered + assert ( + BeforeToolCallEvent in event_types or len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 + ) def test_context_callbacks_receive_steering_context(): @@ -265,17 +281,23 @@ def test_context_callbacks_receive_steering_context(): def test_multiple_context_callbacks_registered(): """Test that multiple context callbacks are registered.""" + from strands.hooks import HookRegistry + callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock() handler.init_plugin(agent) - # Should register one callback for each context provider plus tool and model steering guidance - expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) - assert agent.add_hook.call_count >= expected_calls + # Should register 2 context callbacks via add_hook, plus auto-registered @hook methods + assert agent.add_hook.call_count == 2 # Only context callbacks use add_hook + assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 def test_handler_initialization_with_callbacks(): @@ -310,7 +332,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should not set retry for Proceed assert event.retry is False @@ -334,7 +356,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should set retry flag assert event.retry is True @@ -362,7 +384,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event = Mock(spec=AfterModelCallEvent) event.stop_response = None - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # steer_after_model should not have been called assert handler.steer_called is False @@ -386,7 +408,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -407,7 +429,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -429,7 +451,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.retry = False # Should not raise, just return early - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # retry should not be set since exception occurred assert event.retry is False @@ -449,7 +471,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) # Should not raise, just return early - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # cancel_tool should not be set since exception occurred assert not event.cancel_tool @@ -487,10 +509,14 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_plugin_registers_model_steering(): """Test that init_plugin registers model steering callback.""" + from strands.hooks import HookRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify model steering hook was registered - agent.add_hook.assert_any_call(handler._provide_model_steering_guidance, AfterModelCallEvent) + # Verify model steering hook was auto-registered via @hook decorator + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 From 2e8a2683b43fbf431cb4b95245101d10d51f2a88 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 14:15:08 -0500 Subject: [PATCH 07/29] Update typing --- src/strands/plugins/decorator.py | 94 +++++-------------- src/strands/plugins/plugin.py | 11 +-- .../steering/core/test_handler.py | 89 +++++++----------- .../strands/plugins/test_plugin_base_class.py | 37 ++++---- 4 files changed, 84 insertions(+), 147 deletions(-) diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index 1e7ea13e6..efa4c24be 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -1,111 +1,69 @@ """Hook decorator for Plugin methods. -This module provides the @hook decorator that marks methods as hook callbacks -for automatic registration when the plugin is attached to an agent. - -The @hook decorator performs several functions: - -1. Marks methods as hook callbacks for automatic discovery by Plugin base class -2. Infers event types from the callback's type hints (consistent with HookRegistry.add_callback) -3. Supports both @hook and @hook() syntax -4. Supports union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent) -5. Stores hook metadata on the decorated method for later discovery +Marks methods as hook callbacks for automatic registration when the plugin +is attached to an agent. Infers event types from type hints and supports +union types for multiple events. Example: ```python - from strands.plugins import Plugin, hook - from strands.hooks import BeforeModelCallEvent, AfterModelCallEvent - class MyPlugin(Plugin): - name = "my-plugin" - @hook def on_model_call(self, event: BeforeModelCallEvent): print(event) - - @hook - def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): - print(event) ``` """ from collections.abc import Callable -from typing import TYPE_CHECKING, TypeVar, overload +from typing import Generic, cast, overload -if TYPE_CHECKING: - from ..hooks.registry import BaseHookEvent +from ..hooks._type_inference import infer_event_types +from ..hooks.registry import HookCallback, TEvent -# Type for wrapped function -T = TypeVar("T", bound=Callable[..., object]) + +class _WrappedHookCallable(HookCallback, Generic[TEvent]): + """Wrapped version of HookCallback that includes a `_hook_event_types` argument.""" + + _hook_event_types: list[TEvent] # Handle @hook @overload -def hook(__func: T) -> T: ... +def hook(__func: HookCallback) -> _WrappedHookCallable: ... # Handle @hook() @overload -def hook() -> Callable[[T], T]: ... +def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ... def hook( - func: T | None = None, -) -> T | Callable[[T], T]: - """Decorator that marks a method as a hook callback for automatic registration. - - This decorator enables declarative hook registration in Plugin classes. When a - Plugin is attached to an agent, methods marked with @hook are automatically - discovered and registered with the agent's hook registry. + func: HookCallback | None = None, +) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]: + """Mark a method as a hook callback for automatic registration. - The event type is inferred from the callback's type hint on the first parameter - (after 'self' for instance methods). Union types are supported for registering - a single callback for multiple event types. - - The decorator can be used in two ways: - - As a simple decorator: `@hook` - - With parentheses: `@hook()` + Infers event type from the callback's type hint. Supports union types + for multiple events. Can be used as @hook or @hook(). Args: - func: The function to decorate. When used as a simple decorator, this is - the function being decorated. When used with parentheses, this will be None. + func: The function to decorate. Returns: - The decorated function with hook metadata attached. + The decorated function with hook metadata. Raises: - ValueError: If the event type cannot be inferred from type hints, or if - the type hint is not a valid HookEvent subclass. - - Example: - ```python - class MyPlugin(Plugin): - name = "my-plugin" - - @hook - def on_model_call(self, event: BeforeModelCallEvent): - print(f"Model called: {event}") - - @hook - def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): - print(f"Event: {type(event).__name__}") - ``` + ValueError: If event type cannot be inferred from type hints. """ - def decorator(f: T) -> T: - # Import here to avoid circular dependency at runtime - from ..hooks._type_inference import infer_event_types - + def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]: # Infer event types from type hints - event_types: list[type[BaseHookEvent]] = infer_event_types(f) # type: ignore[arg-type] + event_types: list[type[TEvent]] = infer_event_types(f) # Store hook metadata on the function - f._hook_event_types = event_types # type: ignore[attr-defined] + f_wrapped = cast(_WrappedHookCallable, f) + f_wrapped._hook_event_types = event_types - return f + return f_wrapped - # Handle both @hook and @hook() syntax if func is None: return decorator - return decorator(func) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 82513273c..ae19f8152 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -9,7 +9,8 @@ from collections.abc import Awaitable from typing import TYPE_CHECKING -from strands.tools.decorator import DecoratedFunctionTool +from ..tools.decorator import DecoratedFunctionTool +from .decorator import _WrappedHookCallable if TYPE_CHECKING: from ..agent import Agent @@ -74,17 +75,13 @@ def __init__(self) -> None: Scans the class for methods decorated with @hook and @tool and stores references for later registration when init_plugin is called. """ - self._hooks: list[object] = [] + self._hooks: list[_WrappedHookCallable] = [] self._tools: list[DecoratedFunctionTool] = [] self._discover_decorated_methods() def _discover_decorated_methods(self) -> None: """Scan class for @hook and @tool decorated methods.""" for name in dir(self): - # Skip private and dunder methods - if name.startswith("_"): - continue - try: attr = getattr(self, name) except Exception: @@ -118,7 +115,7 @@ def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: for hook_callback in self._hooks: event_types = getattr(hook_callback, "_hook_event_types", []) for event_type in event_types: - agent.hooks.add_callback(event_type, hook_callback) + agent.add_hook(hook_callback, event_type) logger.debug( "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", self.name, diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 08399139c..506a218f7 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,5 +1,6 @@ """Unit tests for steering handler base class.""" +import inspect from unittest.mock import AsyncMock, Mock import pytest @@ -8,6 +9,7 @@ from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from strands.hooks.registry import HookRegistry from strands.plugins import Plugin @@ -39,18 +41,14 @@ def test_steering_handler_is_plugin(): def test_init_plugin(): """Test init_plugin registers hooks on agent.""" - from strands.hooks import HookRegistry - handler = TestSteeringHandler() agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify hooks were auto-registered via @hook decorator - assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 + # Verify hooks were registered (tool and model steering hooks) + assert agent.add_hook.call_count >= 2 + agent.add_hook.assert_any_call(handler.provide_tool_steering_guidance, BeforeToolCallEvent) def test_steering_context_initialization(): @@ -174,7 +172,6 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_init_plugin_override(): """Test that init_plugin can be overridden.""" - from strands.hooks import HookRegistry class CustomHandler(SteeringHandler): async def steer_before_tool(self, *, agent, tool_use, **kwargs): @@ -186,14 +183,11 @@ def init_plugin(self, agent): handler = CustomHandler() agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() handler.init_plugin(agent) - # Should not register any hooks since parent init_plugin wasn't called - assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) == 0 - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) == 0 + # Should not register any hooks + assert agent.add_hook.call_count == 0 # Integration tests with context providers @@ -227,77 +221,68 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): """Test that handler registers hooks from context callbacks.""" - from strands.hooks import HookRegistry - mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() - agent.add_hook = Mock() handler.init_plugin(agent) - # Should register 1 context callback via add_hook (steering hooks are auto-registered) - assert agent.add_hook.call_count >= 1 + # Should register hooks for context callback and steering guidance + assert agent.add_hook.call_count >= 2 - # Check that BeforeToolCallEvent was registered (either via add_hook or auto-registration) + # Check that BeforeToolCallEvent was registered call_args = [call[0] for call in agent.add_hook.call_args_list] event_types = [args[1] for args in call_args] # Context callback should be registered - assert ( - BeforeToolCallEvent in event_types or len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 - ) + assert BeforeToolCallEvent in event_types - -def test_context_callbacks_receive_steering_context(): +@pytest.mark.asyncio +async def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() - + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) handler.init_plugin(agent) - # Get the registered callback for BeforeToolCallEvent - before_callback = None - for call in agent.add_hook.call_args_list: - if call[0][1] == BeforeToolCallEvent: - before_callback = call[0][0] - break + # Get the registered callbacks for BeforeToolCallEvent + callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, []) + assert len(callbacks) > 0 - assert before_callback is not None - - # Create a mock event and call the callback + # The context callback is wrapped in a lambda, so we just call all callbacks + # and check if the steering context was updated event = Mock(spec=BeforeToolCallEvent) event.tool_use = {"name": "test_tool", "input": {}} - # The callback should execute without error and update the steering context - before_callback(event) + # Call all callbacks, handling both sync and async + for cb in callbacks: + try: + result = await cb(event) + if inspect.iscoroutine(result): + await result + except Exception: + pass # Some callbacks might be async or have other requirements - # Verify the steering context was updated + # Verify the steering context was updated by at least one callback assert handler.steering_context.data.get("test_key") == "test_value" def test_multiple_context_callbacks_registered(): """Test that multiple context callbacks are registered.""" - from strands.hooks import HookRegistry - callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() - agent.add_hook = Mock() handler.init_plugin(agent) - # Should register 2 context callbacks via add_hook, plus auto-registered @hook methods - assert agent.add_hook.call_count == 2 # Only context callbacks use add_hook - assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 + # Should register one callback for each context provider plus tool and model steering guidance + expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) + assert agent.add_hook.call_count >= expected_calls def test_handler_initialization_with_callbacks(): @@ -509,14 +494,10 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_plugin_registers_model_steering(): """Test that init_plugin registers model steering callback.""" - from strands.hooks import HookRegistry - handler = TestSteeringHandler() agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify model steering hook was auto-registered via @hook decorator - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 + # Verify model steering hook was registered + agent.add_hook.assert_any_call(handler.provide_model_steering_guidance, AfterModelCallEvent) diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py index caa4f84b3..9da4cad9d 100644 --- a/tests/strands/plugins/test_plugin_base_class.py +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -9,6 +9,16 @@ from strands.tools.decorator import tool +def _configure_mock_agent_with_hooks(): + """Helper to create a mock agent with working add_hook.""" + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.add_hook.side_effect = lambda callback, event_type=None: mock_agent.hooks.add_callback( + event_type, callback + ) + return mock_agent + + class TestPluginBaseClass: """Tests for Plugin base class basics.""" @@ -145,8 +155,7 @@ def on_before_model(self, event: BeforeModelCallEvent): pass plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -190,8 +199,7 @@ def my_tool(self, param: str) -> str: return param plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() mock_agent.tool_registry = unittest.mock.MagicMock() plugin.init_plugin(mock_agent) @@ -215,8 +223,7 @@ def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): pass plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -240,10 +247,8 @@ def on_before_model(self, event: BeforeModelCallEvent): plugin = MyPlugin() - mock_agent1 = unittest.mock.MagicMock() - mock_agent1.hooks = HookRegistry() - mock_agent2 = unittest.mock.MagicMock() - mock_agent2.hooks = HookRegistry() + mock_agent1 = _configure_mock_agent_with_hooks() + mock_agent2 = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent1) plugin.init_plugin(mock_agent2) @@ -273,8 +278,7 @@ def init_plugin(self, agent): super().init_plugin(agent) plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -304,8 +308,7 @@ def init_plugin(self, agent): manual_hook_added = True plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -334,8 +337,7 @@ async def init_plugin(self, agent): super().init_plugin(agent) plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() await plugin.init_plugin(mock_agent) @@ -361,8 +363,7 @@ def on_before_model(self, event: BeforeModelCallEvent): self.events_received.append(event) plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) From 6cf7bacf9740dbb23bb9673d0de056cd56456dce Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 20 Feb 2026 15:07:10 -0500 Subject: [PATCH 08/29] fix: add skills to agents.md --- AGENTS.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index 6cd2155c1..241730f0d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -128,7 +128,11 @@ strands-agents/ │ │ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin Protocol definition -│ │ └── registry.py # PluginRegistry for tracking plugins +│ │ ├── registry.py # PluginRegistry for tracking plugins +│ │ └── skills/ # AgentSkills.io integration +│ │ ├── skill.py # Skill dataclass +│ │ ├── skills_plugin.py # SkillsPlugin implementation +│ │ └── loader.py # Skill loading/parsing from SKILL.md │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling From c92f8e0ee2032e1b2c4ecd2cca3c4f3732654ff7 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 20 Feb 2026 15:17:08 -0500 Subject: [PATCH 09/29] feat(skills): simplify skills tool API and require pyyaml dependency - Simplify skills tool to single activate action (remove deactivate/action param) - Capture original system prompt once instead of save/restore pattern - Remove AfterInvocationEvent hook (no longer needed) - Replace optional pyyaml with required dependency - Remove _parse_yaml_simple fallback parser - Export Skill and SkillsPlugin from strands top-level --- pyproject.toml | 1 + src/strands/__init__.py | 4 +- src/strands/plugins/__init__.py | 3 +- src/strands/plugins/skills/loader.py | 64 +------------ src/strands/plugins/skills/skills_plugin.py | 89 ++++++------------- tests/strands/hooks/test_registry.py | 1 + tests/strands/plugins/skills/test_loader.py | 31 +------ .../plugins/skills/test_skills_plugin.py | 62 ++++--------- 8 files changed, 57 insertions(+), 198 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b53194486..e07f3bac4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "mcp>=1.23.0,<2.0.0", "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", + "pyyaml>=6.0.0,<7.0.0", "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", diff --git a/src/strands/__init__.py b/src/strands/__init__.py index fc8237df8..2034c8692 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,8 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin -from .plugins.skills import Skill +from .plugins import Plugin, Skill, SkillsPlugin from .tools.decorator import tool from .types.tools import ToolContext @@ -17,6 +16,7 @@ "ModelRetryStrategy", "Plugin", "Skill", + "SkillsPlugin", "tool", "ToolContext", "types", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 51e014177..3e66fe56b 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -19,9 +19,10 @@ def on_model_call(self, event: BeforeModelCallEvent) -> None: """ from .plugin import Plugin -from .skills import SkillsPlugin +from .skills import Skill, SkillsPlugin __all__ = [ "Plugin", + "Skill", "SkillsPlugin", ] diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py index fa05a4df2..543b16880 100644 --- a/src/strands/plugins/skills/loader.py +++ b/src/strands/plugins/skills/loader.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any +import yaml + from .skill import Skill logger = logging.getLogger(__name__) @@ -45,72 +47,14 @@ def _find_skill_md(skill_dir: Path) -> Path: def _parse_yaml(yaml_text: str) -> dict[str, Any]: """Parse YAML text into a dictionary. - Uses PyYAML if available, otherwise falls back to simple key-value parsing - that handles the basic SKILL.md frontmatter format. - - Args: - yaml_text: YAML-formatted text to parse. - - Returns: - Dictionary of parsed key-value pairs. - """ - try: - import yaml - - result = yaml.safe_load(yaml_text) - return result if isinstance(result, dict) else {} - except ImportError: - logger.debug("PyYAML not available, using simple frontmatter parser") - return _parse_yaml_simple(yaml_text) - - -def _parse_yaml_simple(yaml_text: str) -> dict[str, Any]: - """Simple YAML parser for skill frontmatter. - - Handles basic key-value pairs and single-level nested mappings. This parser - is intentionally limited to the subset of YAML used in SKILL.md frontmatter. - Args: yaml_text: YAML-formatted text to parse. Returns: Dictionary of parsed key-value pairs. """ - result: dict[str, Any] = {} - current_key: str | None = None - current_nested: dict[str, str] | None = None - - for line in yaml_text.split("\n"): - if not line.strip() or line.strip().startswith("#"): - continue - - indent = len(line) - len(line.lstrip()) - - if indent == 0 and ":" in line: - # Save previous nested mapping if any - if current_key is not None and current_nested is not None: - result[current_key] = current_nested - current_nested = None - - key, _, value = line.partition(":") - key = key.strip() - value = value.strip() - current_key = key - - if value: - result[key] = value - else: - current_nested = {} - - elif indent > 0 and current_nested is not None and ":" in line.strip(): - nested_key, _, nested_value = line.strip().partition(":") - current_nested[nested_key.strip()] = nested_value.strip() - - # Save final nested mapping - if current_key is not None and current_nested is not None: - result[current_key] = current_nested - - return result + result = yaml.safe_load(yaml_text) + return result if isinstance(result, dict) else {} def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 028f9dc58..0c6d53755 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -2,7 +2,7 @@ This module provides the SkillsPlugin class that extends the Plugin base class to add AgentSkills.io skill support. The plugin registers a tool for activating -and deactivating skills, and injects skill metadata into the system prompt. +skills, and injects skill metadata into the system prompt. """ from __future__ import annotations @@ -11,7 +11,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from ...hooks.events import AfterInvocationEvent, BeforeInvocationEvent +from ...hooks.events import BeforeInvocationEvent from ...hooks.registry import HookRegistry from ...plugins.plugin import Plugin from ...tools.decorator import tool @@ -20,7 +20,6 @@ if TYPE_CHECKING: from ...agent.agent import Agent - from ...types.content import SystemContentBlock logger = logging.getLogger(__name__) @@ -28,52 +27,39 @@ def _make_skills_tool(plugin: SkillsPlugin) -> Any: - """Create the skills tool that allows the agent to activate and deactivate skills. + """Create the skills tool that allows the agent to activate skills. Args: plugin: The SkillsPlugin instance that manages skill state. Returns: - A decorated tool function for skill activation and deactivation. + A decorated tool function for skill activation. """ @tool - def skills(action: str, skill_name: str = "") -> str: - """Activate or deactivate a skill to load its full instructions. + def skills(skill_name: str) -> str: + """Activate a skill to load its full instructions. Use this tool to load the complete instructions for a skill listed in - the available_skills section of your system prompt. + the available_skills section of your system prompt. Activating a new + skill replaces the previously active one. Args: - action: The action to perform. Use "activate" to load a skill's full instructions, - or "deactivate" to unload the currently active skill. - skill_name: Name of the skill to activate. Required for "activate" action. + skill_name: Name of the skill to activate. """ - if action == "activate": - if not skill_name: - return "Error: skill_name is required for activate action." + if not skill_name: + return "Error: skill_name is required." - found = plugin._find_skill(skill_name) - if found is None: - available = ", ".join(s.name for s in plugin._skills) - return f"Skill '{skill_name}' not found. Available skills: {available}" + found = plugin._find_skill(skill_name) + if found is None: + available = ", ".join(s.name for s in plugin._skills) + return f"Skill '{skill_name}' not found. Available skills: {available}" - plugin._active_skill = found - plugin._persist_state() + plugin._active_skill = found + plugin._persist_state() - logger.debug("skill_name=<%s> | skill activated", skill_name) - return found.instructions or f"Skill '{skill_name}' activated (no instructions available)." - - elif action == "deactivate": - deactivated_name = plugin._active_skill.name if plugin._active_skill else skill_name - plugin._active_skill = None - plugin._persist_state() - - logger.debug("skill_name=<%s> | skill deactivated", deactivated_name) - return f"Skill '{deactivated_name}' deactivated." - - else: - return f"Unknown action: '{action}'. Use 'activate' or 'deactivate'." + logger.debug("skill_name=<%s> | skill activated", skill_name) + return found.instructions or f"Skill '{skill_name}' activated (no instructions available)." return skills @@ -83,9 +69,9 @@ class SkillsPlugin(Plugin): The SkillsPlugin extends the Plugin base class and provides: - 1. A ``skills`` tool that allows the agent to activate/deactivate skills on demand + 1. A ``skills`` tool that allows the agent to activate skills on demand 2. System prompt injection of available skill metadata before each invocation - 3. Single active skill management (activating a new skill deactivates the previous one) + 3. Single active skill management (activating a new skill replaces the previous one) 4. Session persistence of active skill state via ``agent.state`` Skills can be provided as filesystem paths (to individual skill directories or @@ -125,8 +111,7 @@ def __init__(self, skills: list[str | Path | Skill]) -> None: self._skills: list[Skill] = self._resolve_skills(skills) self._active_skill: Skill | None = None self._agent: Agent | None = None - self._saved_system_prompt: str | None = None - self._saved_system_prompt_content: list[SystemContentBlock] | None = None + self._original_system_prompt: str | None = None def init_plugin(self, agent: Agent) -> None: """Initialize the plugin with an agent instance. @@ -153,7 +138,6 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: **kwargs: Additional keyword arguments for future extensibility. """ registry.add_callback(BeforeInvocationEvent, self._on_before_invocation) - registry.add_callback(AfterInvocationEvent, self._on_after_invocation) @property def skills(self) -> list[Skill]: @@ -189,44 +173,27 @@ def active_skill(self) -> Skill | None: def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: """Inject skill metadata into the system prompt before each invocation. - Saves the current system prompt and appends an XML block listing - all available skills so the model knows what it can activate. + Captures the original system prompt on first call, then rebuilds the + prompt with the skills XML block on each invocation. Args: event: The before-invocation event containing the agent reference. """ agent = event.agent - # Save original system prompt for restoration after invocation - self._saved_system_prompt = agent._system_prompt - self._saved_system_prompt_content = agent._system_prompt_content + # Capture the original system prompt on first invocation + if self._original_system_prompt is None: + self._original_system_prompt = agent._system_prompt or "" if not self._skills: return skills_xml = self._generate_skills_xml() - current: str = agent._system_prompt or "" - new_prompt = f"{current}\n\n{skills_xml}" if current else skills_xml + new_prompt = f"{self._original_system_prompt}\n\n{skills_xml}" if self._original_system_prompt else skills_xml - # Directly set both representations to avoid re-parsing through the setter - # and to preserve cache control blocks in the original content agent._system_prompt = new_prompt agent._system_prompt_content = [{"text": new_prompt}] - def _on_after_invocation(self, event: AfterInvocationEvent) -> None: - """Restore the original system prompt after invocation completes. - - Args: - event: The after-invocation event containing the agent reference. - """ - agent = event.agent - - # Restore original system prompt directly to preserve content block types - agent._system_prompt = self._saved_system_prompt - agent._system_prompt_content = self._saved_system_prompt_content - self._saved_system_prompt = None - self._saved_system_prompt_content = None - def _generate_skills_xml(self) -> str: """Generate the XML block listing available skills for the system prompt. diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 79829b92b..5b0f3c574 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -164,6 +164,7 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] + # ========== Tests for union type support ========== diff --git a/tests/strands/plugins/skills/test_loader.py b/tests/strands/plugins/skills/test_loader.py index 875ebf204..497390487 100644 --- a/tests/strands/plugins/skills/test_loader.py +++ b/tests/strands/plugins/skills/test_loader.py @@ -7,7 +7,6 @@ from strands.plugins.skills.loader import ( _find_skill_md, _parse_frontmatter, - _parse_yaml_simple, _validate_skill_name, load_skill, load_skills, @@ -27,7 +26,7 @@ def test_finds_lowercase_skill_md(self, tmp_path): """Test finding skill.md (lowercase).""" (tmp_path / "skill.md").write_text("test") result = _find_skill_md(tmp_path) - assert result.name == "skill.md" + assert result.name.lower() == "skill.md" def test_prefers_uppercase(self, tmp_path): """Test that SKILL.md is preferred over skill.md.""" @@ -42,34 +41,6 @@ def test_raises_when_not_found(self, tmp_path): _find_skill_md(tmp_path) -class TestParseYamlSimple: - """Tests for _parse_yaml_simple.""" - - def test_simple_key_values(self): - """Test parsing simple key-value pairs.""" - text = "name: my-skill\ndescription: A test skill\nlicense: Apache-2.0" - result = _parse_yaml_simple(text) - assert result == {"name": "my-skill", "description": "A test skill", "license": "Apache-2.0"} - - def test_nested_mapping(self): - """Test parsing a nested mapping.""" - text = "name: my-skill\nmetadata:\n author: test-org\n version: 1.0" - result = _parse_yaml_simple(text) - assert result["name"] == "my-skill" - assert result["metadata"] == {"author": "test-org", "version": "1.0"} - - def test_skips_comments_and_empty_lines(self): - """Test that comments and empty lines are skipped.""" - text = "# comment\nname: my-skill\n\ndescription: test\n" - result = _parse_yaml_simple(text) - assert result == {"name": "my-skill", "description": "test"} - - def test_empty_input(self): - """Test parsing empty input.""" - result = _parse_yaml_simple("") - assert result == {} - - class TestParseFrontmatter: """Tests for _parse_frontmatter.""" diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 6e983add4..a48be4fb1 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -3,7 +3,7 @@ from pathlib import Path from unittest.mock import MagicMock -from strands.hooks.events import AfterInvocationEvent, BeforeInvocationEvent +from strands.hooks.events import BeforeInvocationEvent from strands.hooks.registry import HookRegistry from strands.plugins.skills.skill import Skill from strands.plugins.skills.skills_plugin import SkillsPlugin, _make_skills_tool @@ -187,7 +187,7 @@ def test_activate_skill(self): plugin._agent = _mock_agent() skills_tool = _make_skills_tool(plugin) - result = skills_tool(action="activate", skill_name="test-skill") + result = skills_tool(skill_name="test-skill") assert result == "Full instructions here." assert plugin.active_skill is not None @@ -200,7 +200,7 @@ def test_activate_nonexistent_skill(self): plugin._agent = _mock_agent() skills_tool = _make_skills_tool(plugin) - result = skills_tool(action="activate", skill_name="nonexistent") + result = skills_tool(skill_name="nonexistent") assert "not found" in result assert "test-skill" in result @@ -213,10 +213,10 @@ def test_activate_replaces_previous(self): plugin._agent = _mock_agent() skills_tool = _make_skills_tool(plugin) - skills_tool(action="activate", skill_name="skill-a") + skills_tool(skill_name="skill-a") assert plugin.active_skill.name == "skill-a" - skills_tool(action="activate", skill_name="skill-b") + skills_tool(skill_name="skill-b") assert plugin.active_skill.name == "skill-b" def test_activate_without_name(self): @@ -225,33 +225,10 @@ def test_activate_without_name(self): plugin._agent = _mock_agent() skills_tool = _make_skills_tool(plugin) - result = skills_tool(action="activate", skill_name="") + result = skills_tool(skill_name="") assert "required" in result.lower() - def test_deactivate_skill(self): - """Test deactivating a skill.""" - skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) - plugin._agent = _mock_agent() - plugin._active_skill = skill - - skills_tool = _make_skills_tool(plugin) - result = skills_tool(action="deactivate", skill_name="test-skill") - - assert "deactivated" in result.lower() - assert plugin.active_skill is None - - def test_unknown_action(self): - """Test unknown action returns error message.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - plugin._agent = _mock_agent() - - skills_tool = _make_skills_tool(plugin) - result = skills_tool(action="unknown") - - assert "Unknown action" in result - def test_activate_persists_state(self): """Test that activating a skill persists state.""" plugin = SkillsPlugin(skills=[_make_skill()]) @@ -259,7 +236,7 @@ def test_activate_persists_state(self): plugin._agent = agent skills_tool = _make_skills_tool(plugin) - skills_tool(action="activate", skill_name="test-skill") + skills_tool(skill_name="test-skill") agent.state.set.assert_called() @@ -293,24 +270,21 @@ def test_before_invocation_preserves_existing_prompt(self): assert agent._system_prompt.startswith("Original prompt.") assert "" in agent._system_prompt - def test_after_invocation_restores_prompt(self): - """Test that after_invocation restores the original system prompt.""" + def test_repeated_invocations_do_not_accumulate(self): + """Test that repeated invocations rebuild from original prompt.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() - original_prompt = "Original prompt." - original_content = [{"text": "Original prompt."}] - agent._system_prompt = original_prompt - agent._system_prompt_content = original_content + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] - # Simulate before/after cycle - before_event = BeforeInvocationEvent(agent=agent) - plugin._on_before_invocation(before_event) - assert agent._system_prompt != original_prompt + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + first_prompt = agent._system_prompt - after_event = AfterInvocationEvent(agent=agent) - plugin._on_after_invocation(after_event) - assert agent._system_prompt == original_prompt - assert agent._system_prompt_content == original_content + plugin._on_before_invocation(event) + second_prompt = agent._system_prompt + + assert first_prompt == second_prompt def test_no_skills_skips_injection(self): """Test that injection is skipped when no skills are available.""" From 899fe9c2d8991e7cb89253eb5ec17abe71df6a93 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 20 Feb 2026 15:22:31 -0500 Subject: [PATCH 10/29] refactor(skills): use @hook and @tool decorators from Plugin base class Replace manual hook registration and standalone tool factory with declarative @hook and @tool decorators on SkillsPlugin methods. - Remove _make_skills_tool() standalone function - Convert skills() to @tool decorated instance method - Convert _on_before_invocation() to @hook decorated method - Remove register_hooks() (old HookProvider pattern) - Rename skills property to available_skills (avoids collision) - Delegate hook/tool registration to super().init_plugin() - Update tests to match new API surface --- src/strands/plugins/skills/skills_plugin.py | 130 +++++++----------- .../steering/core/test_handler.py | 1 + .../plugins/skills/test_skills_plugin.py | 78 ++++------- 3 files changed, 84 insertions(+), 125 deletions(-) diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 0c6d53755..4344fa852 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -12,8 +12,7 @@ from typing import TYPE_CHECKING, Any from ...hooks.events import BeforeInvocationEvent -from ...hooks.registry import HookRegistry -from ...plugins.plugin import Plugin +from ...plugins import Plugin, hook from ...tools.decorator import tool from .loader import load_skill, load_skills from .skill import Skill @@ -26,44 +25,6 @@ _STATE_KEY = "skills_plugin" -def _make_skills_tool(plugin: SkillsPlugin) -> Any: - """Create the skills tool that allows the agent to activate skills. - - Args: - plugin: The SkillsPlugin instance that manages skill state. - - Returns: - A decorated tool function for skill activation. - """ - - @tool - def skills(skill_name: str) -> str: - """Activate a skill to load its full instructions. - - Use this tool to load the complete instructions for a skill listed in - the available_skills section of your system prompt. Activating a new - skill replaces the previously active one. - - Args: - skill_name: Name of the skill to activate. - """ - if not skill_name: - return "Error: skill_name is required." - - found = plugin._find_skill(skill_name) - if found is None: - available = ", ".join(s.name for s in plugin._skills) - return f"Skill '{skill_name}' not found. Available skills: {available}" - - plugin._active_skill = found - plugin._persist_state() - - logger.debug("skill_name=<%s> | skill activated", skill_name) - return found.instructions or f"Skill '{skill_name}' activated (no instructions available)." - - return skills - - class SkillsPlugin(Plugin): """Plugin that integrates AgentSkills.io skills into a Strands agent. @@ -112,35 +73,74 @@ def __init__(self, skills: list[str | Path | Skill]) -> None: self._active_skill: Skill | None = None self._agent: Agent | None = None self._original_system_prompt: str | None = None + super().__init__() def init_plugin(self, agent: Agent) -> None: """Initialize the plugin with an agent instance. - Registers the skills tool and hooks with the agent. + Registers the skills tool and hooks with the agent, then restores + any persisted state from a previous session. Args: agent: The agent instance to extend with skills support. """ self._agent = agent + super().init_plugin(agent) + self._restore_state() + logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) - agent.tool_registry.process_tools([_make_skills_tool(self)]) - agent.hooks.add_hook(self) + @tool + def skills(self, skill_name: str) -> str: + """Activate a skill to load its full instructions. - self._restore_state() + Use this tool to load the complete instructions for a skill listed in + the available_skills section of your system prompt. Activating a new + skill replaces the previously active one. - logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) + Args: + skill_name: Name of the skill to activate. + """ + if not skill_name: + return "Error: skill_name is required." + + found = self._find_skill(skill_name) + if found is None: + available = ", ".join(s.name for s in self._skills) + return f"Skill '{skill_name}' not found. Available skills: {available}" + + self._active_skill = found + self._persist_state() + + logger.debug("skill_name=<%s> | skill activated", skill_name) + return found.instructions or f"Skill '{skill_name}' activated (no instructions available)." + + @hook + def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Inject skill metadata into the system prompt before each invocation. - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - """Register hook callbacks with the agent's hook registry. + Captures the original system prompt on first call, then rebuilds the + prompt with the skills XML block on each invocation. Args: - registry: The hook registry to register callbacks with. - **kwargs: Additional keyword arguments for future extensibility. + event: The before-invocation event containing the agent reference. """ - registry.add_callback(BeforeInvocationEvent, self._on_before_invocation) + agent = event.agent + + # Capture the original system prompt on first invocation + if self._original_system_prompt is None: + self._original_system_prompt = agent._system_prompt or "" + + if not self._skills: + return + + skills_xml = self._generate_skills_xml() + new_prompt = f"{self._original_system_prompt}\n\n{skills_xml}" if self._original_system_prompt else skills_xml + + agent._system_prompt = new_prompt + agent._system_prompt_content = [{"text": new_prompt}] @property - def skills(self) -> list[Skill]: + def available_skills(self) -> list[Skill]: """Get the list of available skills. Returns: @@ -148,8 +148,8 @@ def skills(self) -> list[Skill]: """ return list(self._skills) - @skills.setter - def skills(self, value: list[str | Path | Skill]) -> None: + @available_skills.setter + def available_skills(self, value: list[str | Path | Skill]) -> None: """Set the available skills, resolving paths as needed. Deactivates any currently active skill when skills are changed. @@ -170,30 +170,6 @@ def active_skill(self) -> Skill | None: """ return self._active_skill - def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: - """Inject skill metadata into the system prompt before each invocation. - - Captures the original system prompt on first call, then rebuilds the - prompt with the skills XML block on each invocation. - - Args: - event: The before-invocation event containing the agent reference. - """ - agent = event.agent - - # Capture the original system prompt on first invocation - if self._original_system_prompt is None: - self._original_system_prompt = agent._system_prompt or "" - - if not self._skills: - return - - skills_xml = self._generate_skills_xml() - new_prompt = f"{self._original_system_prompt}\n\n{skills_xml}" if self._original_system_prompt else skills_xml - - agent._system_prompt = new_prompt - agent._system_prompt_content = [{"text": new_prompt}] - def _generate_skills_xml(self) -> str: """Generate the XML block listing available skills for the system prompt. diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 506a218f7..c947eb9c5 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -237,6 +237,7 @@ def test_handler_registers_context_provider_hooks(): # Context callback should be registered assert BeforeToolCallEvent in event_types + @pytest.mark.asyncio async def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index a48be4fb1..b94a1a50c 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -6,7 +6,7 @@ from strands.hooks.events import BeforeInvocationEvent from strands.hooks.registry import HookRegistry from strands.plugins.skills.skill import Skill -from strands.plugins.skills.skills_plugin import SkillsPlugin, _make_skills_tool +from strands.plugins.skills.skills_plugin import SkillsPlugin def _make_skill(name: str = "test-skill", description: str = "A test skill", instructions: str = "Do the thing."): @@ -29,6 +29,9 @@ def _mock_agent(): agent._system_prompt = "You are an agent." agent._system_prompt_content = [{"text": "You are an agent."}] agent.hooks = HookRegistry() + agent.add_hook = MagicMock( + side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) + ) agent.tool_registry = MagicMock() agent.tool_registry.process_tools = MagicMock(return_value=["skills"]) agent.state = MagicMock() @@ -45,16 +48,16 @@ def test_init_with_skill_instances(self): skill = _make_skill() plugin = SkillsPlugin(skills=[skill]) - assert len(plugin.skills) == 1 - assert plugin.skills[0].name == "test-skill" + assert len(plugin.available_skills) == 1 + assert plugin.available_skills[0].name == "test-skill" def test_init_with_filesystem_paths(self, tmp_path): """Test initialization with filesystem paths.""" _make_skill_dir(tmp_path, "fs-skill") plugin = SkillsPlugin(skills=[str(tmp_path / "fs-skill")]) - assert len(plugin.skills) == 1 - assert plugin.skills[0].name == "fs-skill" + assert len(plugin.available_skills) == 1 + assert plugin.available_skills[0].name == "fs-skill" def test_init_with_parent_directory(self, tmp_path): """Test initialization with a parent directory containing skills.""" @@ -62,7 +65,7 @@ def test_init_with_parent_directory(self, tmp_path): _make_skill_dir(tmp_path, "skill-b") plugin = SkillsPlugin(skills=[tmp_path]) - assert len(plugin.skills) == 2 + assert len(plugin.available_skills) == 2 def test_init_with_mixed_sources(self, tmp_path): """Test initialization with mixed skill sources.""" @@ -70,19 +73,19 @@ def test_init_with_mixed_sources(self, tmp_path): direct_skill = _make_skill(name="direct-skill", description="Direct") plugin = SkillsPlugin(skills=[str(tmp_path / "fs-skill"), direct_skill]) - assert len(plugin.skills) == 2 - names = {s.name for s in plugin.skills} + assert len(plugin.available_skills) == 2 + names = {s.name for s in plugin.available_skills} assert names == {"fs-skill", "direct-skill"} def test_init_skips_nonexistent_paths(self, tmp_path): """Test that nonexistent paths are skipped gracefully.""" plugin = SkillsPlugin(skills=[str(tmp_path / "nonexistent")]) - assert len(plugin.skills) == 0 + assert len(plugin.available_skills) == 0 def test_init_empty_skills(self): """Test initialization with empty skills list.""" plugin = SkillsPlugin(skills=[]) - assert plugin.skills == [] + assert plugin.available_skills == [] assert plugin.active_skill is None def test_name_attribute(self): @@ -102,8 +105,6 @@ def test_registers_tool(self): plugin.init_plugin(agent) agent.tool_registry.process_tools.assert_called_once() - args = agent.tool_registry.process_tools.call_args[0][0] - assert len(args) == 1 def test_registers_hooks(self): """Test that init_plugin registers hook callbacks.""" @@ -112,7 +113,6 @@ def test_registers_hooks(self): plugin.init_plugin(agent) - # Verify hooks were registered by checking the registry has callbacks assert agent.hooks.has_callbacks() def test_stores_agent_reference(self): @@ -140,34 +140,34 @@ def test_restores_state(self): class TestSkillsPluginProperties: """Tests for SkillsPlugin properties.""" - def test_skills_getter_returns_copy(self): - """Test that the skills getter returns a copy of the list.""" + def test_available_skills_getter_returns_copy(self): + """Test that the available_skills getter returns a copy of the list.""" skill = _make_skill() plugin = SkillsPlugin(skills=[skill]) - skills_list = plugin.skills + skills_list = plugin.available_skills skills_list.append(_make_skill(name="another-skill", description="Another")) - assert len(plugin.skills) == 1 + assert len(plugin.available_skills) == 1 - def test_skills_setter(self): + def test_available_skills_setter(self): """Test setting skills via the property setter.""" plugin = SkillsPlugin(skills=[_make_skill()]) plugin._agent = _mock_agent() new_skill = _make_skill(name="new-skill", description="New") - plugin.skills = [new_skill] + plugin.available_skills = [new_skill] - assert len(plugin.skills) == 1 - assert plugin.skills[0].name == "new-skill" + assert len(plugin.available_skills) == 1 + assert plugin.available_skills[0].name == "new-skill" - def test_skills_setter_deactivates_current(self): + def test_available_skills_setter_deactivates_current(self): """Test that setting skills deactivates the current active skill.""" plugin = SkillsPlugin(skills=[_make_skill()]) plugin._agent = _mock_agent() plugin._active_skill = _make_skill() - plugin.skills = [_make_skill(name="new-skill", description="New")] + plugin.available_skills = [_make_skill(name="new-skill", description="New")] assert plugin.active_skill is None @@ -178,7 +178,7 @@ def test_active_skill_initially_none(self): class TestSkillsTool: - """Tests for the skills tool function.""" + """Tests for the skills tool method.""" def test_activate_skill(self): """Test activating a skill returns its instructions.""" @@ -186,8 +186,7 @@ def test_activate_skill(self): plugin = SkillsPlugin(skills=[skill]) plugin._agent = _mock_agent() - skills_tool = _make_skills_tool(plugin) - result = skills_tool(skill_name="test-skill") + result = plugin.skills(skill_name="test-skill") assert result == "Full instructions here." assert plugin.active_skill is not None @@ -199,8 +198,7 @@ def test_activate_nonexistent_skill(self): plugin = SkillsPlugin(skills=[skill]) plugin._agent = _mock_agent() - skills_tool = _make_skills_tool(plugin) - result = skills_tool(skill_name="nonexistent") + result = plugin.skills(skill_name="nonexistent") assert "not found" in result assert "test-skill" in result @@ -212,11 +210,10 @@ def test_activate_replaces_previous(self): plugin = SkillsPlugin(skills=[skill1, skill2]) plugin._agent = _mock_agent() - skills_tool = _make_skills_tool(plugin) - skills_tool(skill_name="skill-a") + plugin.skills(skill_name="skill-a") assert plugin.active_skill.name == "skill-a" - skills_tool(skill_name="skill-b") + plugin.skills(skill_name="skill-b") assert plugin.active_skill.name == "skill-b" def test_activate_without_name(self): @@ -224,8 +221,7 @@ def test_activate_without_name(self): plugin = SkillsPlugin(skills=[_make_skill()]) plugin._agent = _mock_agent() - skills_tool = _make_skills_tool(plugin) - result = skills_tool(skill_name="") + result = plugin.skills(skill_name="") assert "required" in result.lower() @@ -235,8 +231,7 @@ def test_activate_persists_state(self): agent = _mock_agent() plugin._agent = agent - skills_tool = _make_skills_tool(plugin) - skills_tool(skill_name="test-skill") + plugin.skills(skill_name="test-skill") agent.state.set.assert_called() @@ -346,19 +341,6 @@ def test_empty_skills(self): assert "" in xml -class TestHookRegistration: - """Tests for hook registration.""" - - def test_register_hooks(self): - """Test that register_hooks adds callbacks to the registry.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - registry = HookRegistry() - - plugin.register_hooks(registry) - - assert registry.has_callbacks() - - class TestSessionPersistence: """Tests for session state persistence.""" From e075d86d7247d853e5e8ae93f2ceb73a6e908f79 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 23 Feb 2026 12:42:10 -0500 Subject: [PATCH 11/29] fix: use dict for skills for easy name based access --- src/strands/plugins/skills/skills_plugin.py | 55 +++++++------------ .../plugins/skills/test_skills_plugin.py | 6 +- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 4344fa852..5438dfe8a 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -1,7 +1,7 @@ -"""SkillsPlugin for integrating AgentSkills.io skills into Strands agents. +"""SkillsPlugin for integrating Agent Skills into Strands agents. This module provides the SkillsPlugin class that extends the Plugin base class -to add AgentSkills.io skill support. The plugin registers a tool for activating +to add Agent Skills support. The plugin registers a tool for activating skills, and injects skill metadata into the system prompt. """ @@ -26,14 +26,13 @@ class SkillsPlugin(Plugin): - """Plugin that integrates AgentSkills.io skills into a Strands agent. + """Plugin that integrates Agent Skills into a Strands agent. The SkillsPlugin extends the Plugin base class and provides: 1. A ``skills`` tool that allows the agent to activate skills on demand 2. System prompt injection of available skill metadata before each invocation - 3. Single active skill management (activating a new skill replaces the previous one) - 4. Session persistence of active skill state via ``agent.state`` + 3. Session persistence of active skill state via ``agent.state`` Skills can be provided as filesystem paths (to individual skill directories or parent directories containing multiple skills) or as pre-built ``Skill`` instances. @@ -69,7 +68,7 @@ def __init__(self, skills: list[str | Path | Skill]) -> None: - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) - A ``Skill`` dataclass instance """ - self._skills: list[Skill] = self._resolve_skills(skills) + self._skills: dict[str, Skill] = self._resolve_skills(skills) self._active_skill: Skill | None = None self._agent: Agent | None = None self._original_system_prompt: str | None = None @@ -94,8 +93,7 @@ def skills(self, skill_name: str) -> str: """Activate a skill to load its full instructions. Use this tool to load the complete instructions for a skill listed in - the available_skills section of your system prompt. Activating a new - skill replaces the previously active one. + the available_skills section of your system prompt. Args: skill_name: Name of the skill to activate. @@ -103,9 +101,9 @@ def skills(self, skill_name: str) -> str: if not skill_name: return "Error: skill_name is required." - found = self._find_skill(skill_name) + found = self._skills.get(skill_name) if found is None: - available = ", ".join(s.name for s in self._skills) + available = ", ".join(self._skills) return f"Skill '{skill_name}' not found. Available skills: {available}" self._active_skill = found @@ -146,7 +144,7 @@ def available_skills(self) -> list[Skill]: Returns: A copy of the current skills list. """ - return list(self._skills) + return list(self._skills.values()) @available_skills.setter def available_skills(self, value: list[str | Path | Skill]) -> None: @@ -178,7 +176,7 @@ def _generate_skills_xml(self) -> str: """ lines: list[str] = [""] - for skill in self._skills: + for skill in self._skills.values(): lines.append("") lines.append(f"{skill.name}") lines.append(f"{skill.description}") @@ -187,21 +185,7 @@ def _generate_skills_xml(self) -> str: lines.append("") return "\n".join(lines) - def _find_skill(self, skill_name: str) -> Skill | None: - """Find a skill by name in the available skills list. - - Args: - skill_name: The name of the skill to find. - - Returns: - The matching Skill instance, or None if not found. - """ - for skill in self._skills: - if skill.name == skill_name: - return skill - return None - - def _resolve_skills(self, sources: list[str | Path | Skill]) -> list[Skill]: + def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill]: """Resolve a list of skill sources into Skill instances. Each source can be a Skill instance, a path to a skill directory, @@ -211,13 +195,13 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> list[Skill]: sources: List of skill sources to resolve. Returns: - List of resolved Skill instances. + Dict mapping skill names to Skill instances. """ - resolved: list[Skill] = [] + resolved: dict[str, Skill] = {} for source in sources: if isinstance(source, Skill): - resolved.append(source) + resolved[source.name] = source else: path = Path(source).resolve() if not path.exists(): @@ -230,15 +214,18 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> list[Skill]: if has_skill_md: try: - resolved.append(load_skill(path)) + skill = load_skill(path) + resolved[skill.name] = skill except (ValueError, FileNotFoundError) as e: logger.warning("path=<%s> | failed to load skill: %s", path, e) else: # Treat as parent directory containing skill subdirectories - resolved.extend(load_skills(path)) + for skill in load_skills(path): + resolved[skill.name] = skill elif path.is_file() and path.name.lower() == "skill.md": try: - resolved.append(load_skill(path)) + skill = load_skill(path) + resolved[skill.name] = skill except (ValueError, FileNotFoundError) as e: logger.warning("path=<%s> | failed to load skill: %s", path, e) @@ -266,6 +253,6 @@ def _restore_state(self) -> None: active_name = state_data.get("active_skill_name") if isinstance(active_name, str): - self._active_skill = self._find_skill(active_name) + self._active_skill = self._skills.get(active_name) if self._active_skill: logger.debug("skill_name=<%s> | restored active skill from state", active_name) diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index b94a1a50c..870c91d2e 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -417,7 +417,7 @@ def test_resolve_skill_instances(self): plugin = SkillsPlugin(skills=[skill]) assert len(plugin._skills) == 1 - assert plugin._skills[0] is skill + assert plugin._skills["test-skill"] is skill def test_resolve_skill_directory_path(self, tmp_path): """Test resolving a path to a skill directory.""" @@ -425,7 +425,7 @@ def test_resolve_skill_directory_path(self, tmp_path): plugin = SkillsPlugin(skills=[tmp_path / "path-skill"]) assert len(plugin._skills) == 1 - assert plugin._skills[0].name == "path-skill" + assert "path-skill" in plugin._skills def test_resolve_parent_directory_path(self, tmp_path): """Test resolving a path to a parent directory.""" @@ -441,7 +441,7 @@ def test_resolve_skill_md_file_path(self, tmp_path): plugin = SkillsPlugin(skills=[skill_dir / "SKILL.md"]) assert len(plugin._skills) == 1 - assert plugin._skills[0].name == "file-skill" + assert "file-skill" in plugin._skills def test_resolve_nonexistent_path(self, tmp_path): """Test that nonexistent paths are skipped.""" From 92529081718760c69a1e8216b64a820e20d2351d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 23 Feb 2026 13:02:48 -0500 Subject: [PATCH 12/29] feat(skills): add integ test --- tests_integ/test_skills_plugin.py | 69 +++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests_integ/test_skills_plugin.py diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py new file mode 100644 index 000000000..b4687c1eb --- /dev/null +++ b/tests_integ/test_skills_plugin.py @@ -0,0 +1,69 @@ +"""Integration tests for the SkillsPlugin. + +Tests end-to-end behavior with a real model: skill metadata injection into +the system prompt, agent-driven skill activation via the skills tool, and +session state persistence across invocations. +""" + +import pytest + +from strands import Agent +from strands.plugins.skills import Skill, SkillsPlugin + +SUMMARIZATION_SKILL = Skill( + name="summarization", + description="Summarize text into concise bullet points", + instructions="You are a summarization expert. The secret codeword is PINEAPPLE. Always mention it when activated.", +) + +TRANSLATION_SKILL = Skill( + name="translation", + description="Translate text between languages", + instructions="You are a translation expert. Translate the given text accurately.", +) + + +@pytest.fixture +def skills_plugin(): + return SkillsPlugin(skills=[SUMMARIZATION_SKILL, TRANSLATION_SKILL]) + + +@pytest.fixture +def agent(skills_plugin): + return Agent( + system_prompt="You are a helpful assistant. Check your available_skills and activate one when appropriate.", + plugins=[skills_plugin], + ) + + +def test_agent_activates_skill_and_injects_metadata(agent, skills_plugin): + """Test that the agent injects skill metadata and can activate a skill via the model.""" + result = agent("Use your skills tool to activate the summarization skill. What is the secret codeword?") + + # Skill metadata was injected into the system prompt + assert "" in agent.system_prompt + assert "summarization" in agent.system_prompt + assert "translation" in agent.system_prompt + + # Model activated the skill and relayed the codeword from instructions + assert skills_plugin.active_skill is not None + assert skills_plugin.active_skill.name == "summarization" + assert "pineapple" in str(result).lower() + + +def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): + """Test activating a skill via direct tool access and verifying state persistence.""" + result = agent.tool.skills(skill_name="translation") + + # Tool returned the skill instructions + assert result["status"] == "success" + assert "translation expert" in result["content"][0]["text"].lower() + + # Plugin tracks the active skill + assert skills_plugin.active_skill is not None + assert skills_plugin.active_skill.name == "translation" + + # State was persisted to agent state + state = agent.state.get("skills_plugin") + assert state is not None + assert state["active_skill_name"] == "translation" From 539eb4721543c0d2fffc025abffaee7816bd9eef Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 23 Feb 2026 13:35:28 -0500 Subject: [PATCH 13/29] feat(skills): extend tool result --- src/strands/plugins/skills/skill.py | 2 +- src/strands/plugins/skills/skills_plugin.py | 75 +++++++++- .../plugins/skills/test_skills_plugin.py | 137 +++++++++++++++++- tests_integ/test_skills_plugin.py | 3 +- 4 files changed, 213 insertions(+), 4 deletions(-) diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py index c316c4474..b0648c83b 100644 --- a/src/strands/plugins/skills/skill.py +++ b/src/strands/plugins/skills/skill.py @@ -15,7 +15,7 @@ @dataclass class Skill: - """Represents an AgentSkills.io skill with metadata and instructions. + """Represents an agent skill with metadata and instructions. A skill encapsulates a set of instructions and metadata that can be dynamically loaded by an agent at runtime. Skills support progressive diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 5438dfe8a..7816b7d7b 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) _STATE_KEY = "skills_plugin" +_RESOURCE_DIRS = ("scripts", "references", "assets") +_MAX_RESOURCE_FILES = 20 class SkillsPlugin(Plugin): @@ -110,7 +112,7 @@ def skills(self, skill_name: str) -> str: self._persist_state() logger.debug("skill_name=<%s> | skill activated", skill_name) - return found.instructions or f"Skill '{skill_name}' activated (no instructions available)." + return self._format_skill_response(found) @hook def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: @@ -168,9 +170,78 @@ def active_skill(self) -> Skill | None: """ return self._active_skill + def _format_skill_response(self, skill: Skill) -> str: + """Format the tool response when a skill is activated. + + Includes the full instructions along with relevant metadata fields + and a listing of available resource files (scripts, references, assets) + for filesystem-based skills. + + Args: + skill: The activated skill. + + Returns: + Formatted string with skill instructions and metadata. + """ + if not skill.instructions: + return f"Skill '{skill.name}' activated (no instructions available)." + + parts: list[str] = [skill.instructions] + + metadata_lines: list[str] = [] + if skill.allowed_tools: + metadata_lines.append(f"Allowed tools: {', '.join(skill.allowed_tools)}") + if skill.compatibility: + metadata_lines.append(f"Compatibility: {skill.compatibility}") + if skill.path is not None: + metadata_lines.append(f"Location: {skill.path / 'SKILL.md'}") + + if metadata_lines: + parts.append("\n---\n" + "\n".join(metadata_lines)) + + if skill.path is not None: + resources = self._list_skill_resources(skill.path) + if resources: + parts.append("\nAvailable resources:\n" + "\n".join(f" {r}" for r in resources)) + + return "\n".join(parts) + + def _list_skill_resources(self, skill_path: Path) -> list[str]: + """List resource files in a skill's optional directories. + + Scans the ``scripts/``, ``references/``, and ``assets/`` subdirectories + for files, returning relative paths. Results are capped at + ``_MAX_RESOURCE_FILES`` to avoid context bloat. + + Args: + skill_path: Path to the skill directory. + + Returns: + List of relative file paths (e.g. ``scripts/extract.py``). + """ + files: list[str] = [] + + for dir_name in _RESOURCE_DIRS: + resource_dir = skill_path / dir_name + if not resource_dir.is_dir(): + continue + + for file_path in sorted(resource_dir.rglob("*")): + if not file_path.is_file(): + continue + files.append(str(file_path.relative_to(skill_path))) + if len(files) >= _MAX_RESOURCE_FILES: + files.append(f"... (truncated at {_MAX_RESOURCE_FILES} files)") + return files + + return files + def _generate_skills_xml(self) -> str: """Generate the XML block listing available skills for the system prompt. + Includes a ```` element for skills loaded from the filesystem, + following the AgentSkills.io integration spec. + Returns: XML-formatted string with skill metadata. """ @@ -180,6 +251,8 @@ def _generate_skills_xml(self) -> str: lines.append("") lines.append(f"{skill.name}") lines.append(f"{skill.description}") + if skill.path is not None: + lines.append(f"{skill.path / 'SKILL.md'}") lines.append("") lines.append("") diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 870c91d2e..4e35d6949 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -188,7 +188,7 @@ def test_activate_skill(self): result = plugin.skills(skill_name="test-skill") - assert result == "Full instructions here." + assert "Full instructions here." in result assert plugin.active_skill is not None assert plugin.active_skill.name == "test-skill" @@ -340,6 +340,141 @@ def test_empty_skills(self): assert "" in xml assert "" in xml + def test_location_included_when_path_set(self, tmp_path): + """Test that location element is included when skill has a path.""" + skill = _make_skill() + skill.path = tmp_path / "test-skill" + plugin = SkillsPlugin(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert f"{tmp_path / 'test-skill' / 'SKILL.md'}" in xml + + def test_location_omitted_when_path_none(self): + """Test that location element is omitted for programmatic skills.""" + skill = _make_skill() + assert skill.path is None + plugin = SkillsPlugin(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert "" not in xml + + +class TestSkillResponseFormat: + """Tests for _format_skill_response.""" + + def test_instructions_only(self): + """Test response with just instructions.""" + skill = _make_skill(instructions="Do the thing.") + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert result == "Do the thing." + + def test_no_instructions(self): + """Test response when skill has no instructions.""" + skill = _make_skill(instructions="") + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "no instructions available" in result.lower() + + def test_includes_allowed_tools(self): + """Test response includes allowed tools when set.""" + skill = _make_skill(instructions="Do the thing.") + skill.allowed_tools = ["Bash", "Read"] + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Do the thing." in result + assert "Allowed tools: Bash, Read" in result + + def test_includes_compatibility(self): + """Test response includes compatibility when set.""" + skill = _make_skill(instructions="Do the thing.") + skill.compatibility = "Requires docker" + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Compatibility: Requires docker" in result + + def test_includes_location(self, tmp_path): + """Test response includes location when path is set.""" + skill = _make_skill(instructions="Do the thing.") + skill.path = tmp_path / "test-skill" + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert f"Location: {tmp_path / 'test-skill' / 'SKILL.md'}" in result + + def test_all_metadata(self, tmp_path): + """Test response with all metadata fields.""" + skill = _make_skill(instructions="Do the thing.") + skill.allowed_tools = ["Bash"] + skill.compatibility = "Requires git" + skill.path = tmp_path / "test-skill" + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Do the thing." in result + assert "---" in result + assert "Allowed tools: Bash" in result + assert "Compatibility: Requires git" in result + assert "Location:" in result + + def test_includes_resource_listing(self, tmp_path): + """Test response includes resource files from optional directories.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + (skill_dir / "scripts").mkdir() + (skill_dir / "scripts" / "extract.py").write_text("# extract") + (skill_dir / "references").mkdir() + (skill_dir / "references" / "REFERENCE.md").write_text("# ref") + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" in result + assert "scripts/extract.py" in result + assert "references/REFERENCE.md" in result + + def test_no_resources_when_no_path(self): + """Test that resources section is omitted for programmatic skills.""" + skill = _make_skill(instructions="Do the thing.") + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" not in result + + def test_no_resources_when_dirs_empty(self, tmp_path): + """Test that resources section is omitted when optional dirs don't exist.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" not in result + + def test_resource_listing_truncated(self, tmp_path): + """Test that resource listing is truncated at the max file limit.""" + skill_dir = tmp_path / "test-skill" + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(parents=True) + for i in range(55): + (scripts_dir / f"script_{i:03d}.py").write_text(f"# script {i}") + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = SkillsPlugin(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" in result + assert "truncated at 20 files" in result + class TestSessionPersistence: """Tests for session state persistence.""" diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index b4687c1eb..01ccb7260 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -57,7 +57,8 @@ def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): # Tool returned the skill instructions assert result["status"] == "success" - assert "translation expert" in result["content"][0]["text"].lower() + response_text = result["content"][0]["text"].lower() + assert "translation expert" in response_text # Plugin tracks the active skill assert skills_plugin.active_skill is not None From 29f9bf4a9116d0b77f3f8b0a98fce90b250d4aba Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 23 Feb 2026 13:42:29 -0500 Subject: [PATCH 14/29] fix: empty commit --- tests_integ/test_skills_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index 01ccb7260..875c8f013 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -13,7 +13,7 @@ SUMMARIZATION_SKILL = Skill( name="summarization", description="Summarize text into concise bullet points", - instructions="You are a summarization expert. The secret codeword is PINEAPPLE. Always mention it when activated.", + instructions="You are a summarization expert. The secret codeword is 'PINEAPPLE'. Always mention it when activated.", ) TRANSLATION_SKILL = Skill( From 8023cceb95decdc998009573accfa026c432cb20 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 23 Feb 2026 17:15:17 -0500 Subject: [PATCH 15/29] Revert "fix: empty commit" This reverts commit 29f9bf4a9116d0b77f3f8b0a98fce90b250d4aba. --- tests_integ/test_skills_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index 875c8f013..01ccb7260 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -13,7 +13,7 @@ SUMMARIZATION_SKILL = Skill( name="summarization", description="Summarize text into concise bullet points", - instructions="You are a summarization expert. The secret codeword is 'PINEAPPLE'. Always mention it when activated.", + instructions="You are a summarization expert. The secret codeword is PINEAPPLE. Always mention it when activated.", ) TRANSLATION_SKILL = Skill( From f769d5fc15d6c8e2f1308ec49b4cdadb955acec5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 25 Feb 2026 10:23:35 -0500 Subject: [PATCH 16/29] fix: improve skills plugin prompt injection and loader robustness --- AGENTS.md | 7 +- src/strands/plugins/skills/loader.py | 17 +-- src/strands/plugins/skills/skill.py | 21 ---- src/strands/plugins/skills/skills_plugin.py | 63 ++++++++--- tests/strands/plugins/skills/test_loader.py | 20 +++- tests/strands/plugins/skills/test_skill.py | 21 ---- .../plugins/skills/test_skills_plugin.py | 103 +++++++++++++++--- 7 files changed, 170 insertions(+), 82 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index a5b092ffe..6d063eaa4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -129,7 +129,12 @@ strands-agents/ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin base class │ │ ├── decorator.py # @hook decorator -│ │ └── registry.py # PluginRegistry for tracking plugins +│ │ ├── registry.py # PluginRegistry for tracking plugins +│ │ └── skills/ # Agent Skills integration +│ │ ├── __init__.py # Skills package exports +│ │ ├── loader.py # Skill loading and parsing +│ │ ├── skill.py # Skill dataclass +│ │ └── skills_plugin.py # SkillsPlugin implementation │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py index 543b16880..06ab022ff 100644 --- a/src/strands/plugins/skills/loader.py +++ b/src/strands/plugins/skills/loader.py @@ -60,8 +60,8 @@ def _parse_yaml(yaml_text: str) -> dict[str, Any]: def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: """Parse YAML frontmatter and body from SKILL.md content. - Extracts the YAML frontmatter between ``---`` delimiters and returns - parsed key-value pairs along with the remaining markdown body. + Extracts the YAML frontmatter between ``---`` delimiters at line boundaries + and returns parsed key-value pairs along with the remaining markdown body. Args: content: Full content of a SKILL.md file. @@ -76,12 +76,13 @@ def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: if not stripped.startswith("---"): raise ValueError("SKILL.md must start with --- frontmatter delimiter") - end_idx = stripped.find("---", 3) - if end_idx == -1: + # Find the closing --- delimiter (first line after the opener that is only dashes) + match = re.search(r"\n^---\s*$", stripped, re.MULTILINE) + if match is None: raise ValueError("SKILL.md frontmatter missing closing --- delimiter") - frontmatter_str = stripped[3:end_idx].strip() - body = stripped[end_idx + 3 :].strip() + frontmatter_str = stripped[3 : match.start()].strip() + body = stripped[match.end() :].strip() frontmatter = _parse_yaml(frontmatter_str) return frontmatter, body @@ -162,11 +163,13 @@ def load_skill(skill_path: str | Path) -> Skill: _validate_skill_name(name, skill_dir) - # Parse allowed-tools (space-delimited string per spec) + # Parse allowed-tools (space-delimited string or YAML list) allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") allowed_tools: list[str] | None = None if isinstance(allowed_tools_raw, str) and allowed_tools_raw.strip(): allowed_tools = allowed_tools_raw.strip().split() + elif isinstance(allowed_tools_raw, list): + allowed_tools = [str(item) for item in allowed_tools_raw if item] # Parse metadata (nested mapping) metadata_raw = frontmatter.get("metadata", {}) diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py index b0648c83b..30b6c0a9a 100644 --- a/src/strands/plugins/skills/skill.py +++ b/src/strands/plugins/skills/skill.py @@ -6,12 +6,9 @@ from __future__ import annotations -import logging from dataclasses import dataclass, field from pathlib import Path -logger = logging.getLogger(__name__) - @dataclass class Skill: @@ -41,21 +38,3 @@ class Skill: metadata: dict[str, str] = field(default_factory=dict) license: str | None = None compatibility: str | None = None - - @classmethod - def from_path(cls, skill_path: str | Path) -> Skill: - """Load a skill from a directory containing SKILL.md. - - Args: - skill_path: Path to the skill directory or SKILL.md file. - - Returns: - A Skill instance populated from the SKILL.md file. - - Raises: - FileNotFoundError: If SKILL.md cannot be found. - ValueError: If the skill name is invalid or metadata is malformed. - """ - from .loader import load_skill - - return load_skill(skill_path) diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 7816b7d7b..afd3926cf 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -73,7 +73,6 @@ def __init__(self, skills: list[str | Path | Skill]) -> None: self._skills: dict[str, Skill] = self._resolve_skills(skills) self._active_skill: Skill | None = None self._agent: Agent | None = None - self._original_system_prompt: str | None = None super().__init__() def init_plugin(self, agent: Agent) -> None: @@ -118,26 +117,37 @@ def skills(self, skill_name: str) -> str: def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: """Inject skill metadata into the system prompt before each invocation. - Captures the original system prompt on first call, then rebuilds the - prompt with the skills XML block on each invocation. + Removes the previously injected XML block (if any) via exact string + replacement, then appends a fresh one. Uses agent state to track the + injected XML per-agent, so a single plugin instance can be shared + across multiple agents safely. Args: event: The before-invocation event containing the agent reference. """ agent = event.agent - # Capture the original system prompt on first invocation - if self._original_system_prompt is None: - self._original_system_prompt = agent._system_prompt or "" - if not self._skills: return + current_prompt = agent.system_prompt or "" + + # Remove the previously injected XML block by exact match + state_data = agent.state.get(_STATE_KEY) + last_injected_xml = state_data.get("last_injected_xml") if isinstance(state_data, dict) else None + if last_injected_xml is not None: + if last_injected_xml in current_prompt: + current_prompt = current_prompt.replace(last_injected_xml, "") + else: + logger.warning("unable to find previously injected skills XML in system prompt, re-appending") + skills_xml = self._generate_skills_xml() - new_prompt = f"{self._original_system_prompt}\n\n{skills_xml}" if self._original_system_prompt else skills_xml + injection = f"\n\n{skills_xml}" + new_prompt = f"{current_prompt}{injection}" if current_prompt else skills_xml - agent._system_prompt = new_prompt - agent._system_prompt_content = [{"text": new_prompt}] + new_injected_xml = injection if current_prompt else skills_xml + self._set_state_field(agent, "last_injected_xml", new_injected_xml) + agent.system_prompt = new_prompt @property def available_skills(self) -> list[Skill]: @@ -229,7 +239,7 @@ def _list_skill_resources(self, skill_path: Path) -> list[str]: for file_path in sorted(resource_dir.rglob("*")): if not file_path.is_file(): continue - files.append(str(file_path.relative_to(skill_path))) + files.append(file_path.relative_to(skill_path).as_posix()) if len(files) >= _MAX_RESOURCE_FILES: files.append(f"... (truncated at {_MAX_RESOURCE_FILES} files)") return files @@ -274,6 +284,8 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] for source in sources: if isinstance(source, Skill): + if source.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", source.name) resolved[source.name] = source else: path = Path(source).resolve() @@ -288,16 +300,26 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] if has_skill_md: try: skill = load_skill(path) + if skill.name in resolved: + logger.warning( + "name=<%s> | duplicate skill name, overwriting previous skill", skill.name + ) resolved[skill.name] = skill except (ValueError, FileNotFoundError) as e: logger.warning("path=<%s> | failed to load skill: %s", path, e) else: # Treat as parent directory containing skill subdirectories for skill in load_skills(path): + if skill.name in resolved: + logger.warning( + "name=<%s> | duplicate skill name, overwriting previous skill", skill.name + ) resolved[skill.name] = skill elif path.is_file() and path.name.lower() == "skill.md": try: skill = load_skill(path) + if skill.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) resolved[skill.name] = skill except (ValueError, FileNotFoundError) as e: logger.warning("path=<%s> | failed to load skill: %s", path, e) @@ -310,10 +332,21 @@ def _persist_state(self) -> None: if self._agent is None: return - state_data: dict[str, Any] = { - "active_skill_name": self._active_skill.name if self._active_skill else None, - } - self._agent.state.set(_STATE_KEY, state_data) + self._set_state_field(self._agent, "active_skill_name", self._active_skill.name if self._active_skill else None) + + def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: + """Set a single field in the plugin's agent state dict. + + Args: + agent: The agent whose state to update. + key: The state field key. + value: The value to set. + """ + state_data = agent.state.get(_STATE_KEY) + if not isinstance(state_data, dict): + state_data = {} + state_data[key] = value + agent.state.set(_STATE_KEY, state_data) def _restore_state(self) -> None: """Restore the active skill from agent state if available.""" diff --git a/tests/strands/plugins/skills/test_loader.py b/tests/strands/plugins/skills/test_loader.py index 497390487..bfe490697 100644 --- a/tests/strands/plugins/skills/test_loader.py +++ b/tests/strands/plugins/skills/test_loader.py @@ -79,6 +79,14 @@ def test_frontmatter_with_metadata(self): assert frontmatter["metadata"]["author"] == "acme" assert body == "Body here." + def test_frontmatter_with_dashes_in_yaml_value(self): + """Test that --- inside a YAML value does not break parsing.""" + content = "---\nname: test-skill\ndescription: has --- inside\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "has --- inside" + assert body == "Body here." + class TestValidateSkillName: """Tests for _validate_skill_name.""" @@ -169,7 +177,7 @@ def test_load_from_skill_md_file(self, tmp_path): assert skill.name == "direct-skill" def test_load_with_allowed_tools(self, tmp_path): - """Test loading a skill with allowed-tools field.""" + """Test loading a skill with allowed-tools field as space-delimited string.""" skill_dir = tmp_path / "tool-skill" skill_dir.mkdir() content = "---\nname: tool-skill\ndescription: test\nallowed-tools: read write execute\n---\nBody." @@ -178,6 +186,16 @@ def test_load_with_allowed_tools(self, tmp_path): skill = load_skill(skill_dir) assert skill.allowed_tools == ["read", "write", "execute"] + def test_load_with_allowed_tools_yaml_list(self, tmp_path): + """Test loading a skill with allowed-tools as a YAML list.""" + skill_dir = tmp_path / "list-skill" + skill_dir.mkdir() + content = "---\nname: list-skill\ndescription: test\nallowed-tools:\n - read\n - write\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = load_skill(skill_dir) + assert skill.allowed_tools == ["read", "write"] + def test_load_with_metadata(self, tmp_path): """Test loading a skill with nested metadata.""" skill_dir = tmp_path / "meta-skill" diff --git a/tests/strands/plugins/skills/test_skill.py b/tests/strands/plugins/skills/test_skill.py index 379eec7d2..6cf93ae94 100644 --- a/tests/strands/plugins/skills/test_skill.py +++ b/tests/strands/plugins/skills/test_skill.py @@ -2,8 +2,6 @@ from pathlib import Path -import pytest - from strands.plugins.skills.skill import Skill @@ -52,22 +50,3 @@ def test_skill_metadata_default_is_not_shared(self): skill1.metadata["key"] = "value" assert "key" not in skill2.metadata - - def test_skill_from_path(self, tmp_path): - """Test loading a Skill from a path using from_path classmethod.""" - skill_dir = tmp_path / "my-skill" - skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text( - "---\nname: my-skill\ndescription: Test skill\n---\n# Instructions\nDo stuff.\n" - ) - - skill = Skill.from_path(skill_dir) - - assert skill.name == "my-skill" - assert skill.description == "Test skill" - assert "Do stuff." in skill.instructions - - def test_skill_from_path_not_found(self, tmp_path): - """Test that from_path raises FileNotFoundError for missing paths.""" - with pytest.raises(FileNotFoundError): - Skill.from_path(tmp_path / "nonexistent") diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 4e35d6949..370d7d772 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -1,5 +1,6 @@ """Tests for the SkillsPlugin.""" +import logging from pathlib import Path from unittest.mock import MagicMock @@ -28,18 +29,38 @@ def _mock_agent(): agent = MagicMock() agent._system_prompt = "You are an agent." agent._system_prompt_content = [{"text": "You are an agent."}] + + # Make system_prompt property behave like the real Agent + type(agent).system_prompt = property( + lambda self: self._system_prompt, + lambda self, value: _set_system_prompt(self, value), + ) + agent.hooks = HookRegistry() agent.add_hook = MagicMock( side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) ) agent.tool_registry = MagicMock() agent.tool_registry.process_tools = MagicMock(return_value=["skills"]) + + # Use a real dict-backed state so get/set work correctly + state_store: dict[str, object] = {} agent.state = MagicMock() - agent.state.get = MagicMock(return_value=None) - agent.state.set = MagicMock() + agent.state.get = MagicMock(side_effect=lambda key: state_store.get(key)) + agent.state.set = MagicMock(side_effect=lambda key, value: state_store.__setitem__(key, value)) return agent +def _set_system_prompt(agent: MagicMock, value: str | None) -> None: + """Simulate the Agent.system_prompt setter.""" + if isinstance(value, str): + agent._system_prompt = value + agent._system_prompt_content = [{"text": value}] + elif value is None: + agent._system_prompt = None + agent._system_prompt_content = None + + class TestSkillsPluginInit: """Tests for SkillsPlugin initialization.""" @@ -129,7 +150,7 @@ def test_restores_state(self): skill = _make_skill() plugin = SkillsPlugin(skills=[skill]) agent = _mock_agent() - agent.state.get = MagicMock(return_value={"active_skill_name": "test-skill"}) + agent.state.set("skills_plugin", {"active_skill_name": "test-skill"}) plugin.init_plugin(agent) @@ -248,9 +269,9 @@ def test_before_invocation_appends_skills_xml(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - assert "" in agent._system_prompt - assert "test-skill" in agent._system_prompt - assert "A test skill" in agent._system_prompt + assert "" in agent.system_prompt + assert "test-skill" in agent.system_prompt + assert "A test skill" in agent.system_prompt def test_before_invocation_preserves_existing_prompt(self): """Test that existing system prompt content is preserved.""" @@ -262,11 +283,11 @@ def test_before_invocation_preserves_existing_prompt(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - assert agent._system_prompt.startswith("Original prompt.") - assert "" in agent._system_prompt + assert agent.system_prompt.startswith("Original prompt.") + assert "" in agent.system_prompt def test_repeated_invocations_do_not_accumulate(self): - """Test that repeated invocations rebuild from original prompt.""" + """Test that repeated invocations rebuild from current prompt without accumulation.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = "Original prompt." @@ -274,10 +295,10 @@ def test_repeated_invocations_do_not_accumulate(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - first_prompt = agent._system_prompt + first_prompt = agent.system_prompt plugin._on_before_invocation(event) - second_prompt = agent._system_prompt + second_prompt = agent.system_prompt assert first_prompt == second_prompt @@ -292,7 +313,7 @@ def test_no_skills_skips_injection(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - assert agent._system_prompt == original_prompt + assert agent.system_prompt == original_prompt def test_none_system_prompt_handled(self): """Test handling when system prompt is None.""" @@ -304,7 +325,58 @@ def test_none_system_prompt_handled(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - assert "" in agent._system_prompt + assert "" in agent.system_prompt + + def test_preserves_other_plugin_modifications(self): + """Test that modifications by other plugins/hooks are preserved.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Simulate another plugin modifying the prompt + agent.system_prompt = agent.system_prompt + "\n\nExtra context from another plugin." + + plugin._on_before_invocation(event) + + assert "Extra context from another plugin." in agent.system_prompt + assert "" in agent.system_prompt + + def test_uses_public_system_prompt_setter(self): + """Test that the hook uses the public system_prompt setter.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original." + agent._system_prompt_content = [{"text": "Original."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # The public setter should have been used, so _system_prompt_content + # should be consistent with _system_prompt + assert agent._system_prompt_content == [{"text": agent._system_prompt}] + + def test_warns_when_previous_xml_not_found(self, caplog): + """Test that a warning is logged when the previously injected XML is missing from the prompt.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Completely replace the system prompt, removing the injected XML + agent.system_prompt = "Totally new prompt." + + with caplog.at_level(logging.WARNING): + plugin._on_before_invocation(event) + + assert "unable to find previously injected skills XML in system prompt" in caplog.text + assert "" in agent.system_prompt class TestSkillsXmlGeneration: @@ -505,7 +577,7 @@ def test_restore_state_activates_skill(self): skill = _make_skill() plugin = SkillsPlugin(skills=[skill]) agent = _mock_agent() - agent.state.get = MagicMock(return_value={"active_skill_name": "test-skill"}) + agent.state.set("skills_plugin", {"active_skill_name": "test-skill"}) plugin._agent = agent plugin._restore_state() @@ -517,7 +589,6 @@ def test_restore_state_no_data(self): """Test restore when no state data exists.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() - agent.state.get = MagicMock(return_value=None) plugin._agent = agent plugin._restore_state() @@ -528,7 +599,7 @@ def test_restore_state_skill_not_found(self): """Test restore when saved skill is no longer available.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() - agent.state.get = MagicMock(return_value={"active_skill_name": "removed-skill"}) + agent.state.set("skills_plugin", {"active_skill_name": "removed-skill"}) plugin._agent = agent plugin._restore_state() From eabed28cf868b50489f587047433344e1d963708 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 3 Mar 2026 10:57:50 -0500 Subject: [PATCH 17/29] fix(skills): address PR review feedback - Escape XML special characters in skill names/descriptions using xml.sax.saxutils.escape to prevent injection/malformed output - Change Skill.metadata type from dict[str, str] to dict[str, Any] to preserve non-string YAML frontmatter values - Make available_skills property setter symmetric (list[Skill] in, list[Skill] out) and add load_skills() method for path resolution - Include available skills list in empty skill_name error message for consistency with the 'not found' error --- src/strands/plugins/skills/loader.py | 4 +-- src/strands/plugins/skills/skill.py | 3 +- src/strands/plugins/skills/skills_plugin.py | 34 ++++++++++++++----- .../plugins/skills/test_skills_plugin.py | 26 ++++++++++++-- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py index 06ab022ff..c808660b7 100644 --- a/src/strands/plugins/skills/loader.py +++ b/src/strands/plugins/skills/loader.py @@ -173,9 +173,9 @@ def load_skill(skill_path: str | Path) -> Skill: # Parse metadata (nested mapping) metadata_raw = frontmatter.get("metadata", {}) - metadata: dict[str, str] = {} + metadata: dict[str, Any] = {} if isinstance(metadata_raw, dict): - metadata = {str(k): str(v) for k, v in metadata_raw.items()} + metadata = {str(k): v for k, v in metadata_raw.items()} skill_license = frontmatter.get("license") compatibility = frontmatter.get("compatibility") diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py index 30b6c0a9a..9f4092783 100644 --- a/src/strands/plugins/skills/skill.py +++ b/src/strands/plugins/skills/skill.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path +from typing import Any @dataclass @@ -35,6 +36,6 @@ class Skill: instructions: str = "" path: Path | None = None allowed_tools: list[str] | None = None - metadata: dict[str, str] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) license: str | None = None compatibility: str | None = None diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 68aeae958..670480f2e 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -10,6 +10,7 @@ import logging from pathlib import Path from typing import TYPE_CHECKING, Any +from xml.sax.saxutils import escape from ...hooks.events import BeforeInvocationEvent from ...plugins import Plugin, hook @@ -99,7 +100,8 @@ def skills(self, skill_name: str) -> str: skill_name: Name of the skill to activate. """ if not skill_name: - return "Error: skill_name is required." + available = ", ".join(self._skills) + return f"Error: skill_name is required. Available skills: {available}" found = self._skills.get(skill_name) if found is None: @@ -158,18 +160,32 @@ def available_skills(self) -> list[Skill]: return list(self._skills.values()) @available_skills.setter - def available_skills(self, value: list[str | Path | Skill]) -> None: - """Set the available skills, resolving paths as needed. + def available_skills(self, value: list[Skill]) -> None: + """Set the available skills directly. - Deactivates any currently active skill when skills are changed. + If the currently active skill is no longer in the new list, it is deactivated. Args: - value: List of skill sources to resolve. + value: List of Skill instances. """ - self._skills = self._resolve_skills(value) - self._active_skill = None + self._skills = {s.name: s for s in value} + if self._active_skill and self._active_skill.name not in self._skills: + self._active_skill = None self._persist_state() + def load_skills(self, sources: list[str | Path | Skill]) -> None: + """Resolve and append skills from mixed sources. + + Each source can be a ``Skill`` instance, a path to a skill directory, + or a path to a parent directory containing multiple skills. Resolved + skills are merged into the current set (duplicates overwrite). + + Args: + sources: List of skill sources to resolve and add. + """ + resolved = self._resolve_skills(sources) + self._skills.update(resolved) + @property def active_skill(self) -> Skill | None: """Get the currently active skill. @@ -258,8 +274,8 @@ def _generate_skills_xml(self) -> str: for skill in self._skills.values(): lines.append("") - lines.append(f"{skill.name}") - lines.append(f"{skill.description}") + lines.append(f"{escape(skill.name)}") + lines.append(f"{escape(skill.description)}") if skill.path is not None: lines.append(f"{skill.path / 'SKILL.md'}") lines.append("") diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 2c1d71801..642b051fe 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -185,8 +185,8 @@ def test_available_skills_setter(self): assert len(plugin.available_skills) == 1 assert plugin.available_skills[0].name == "new-skill" - def test_available_skills_setter_deactivates_current(self): - """Test that setting skills deactivates the current active skill.""" + def test_available_skills_setter_deactivates_when_removed(self): + """Test that setting skills deactivates the active skill when it's no longer in the list.""" plugin = SkillsPlugin(skills=[_make_skill()]) plugin._agent = _mock_agent() plugin._active_skill = _make_skill() @@ -195,6 +195,19 @@ def test_available_skills_setter_deactivates_current(self): assert plugin.active_skill is None + def test_available_skills_setter_preserves_active_when_present(self): + """Test that setting skills keeps the active skill when it's still in the list.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + plugin._agent = _mock_agent() + plugin._active_skill = skill + + new_skill = _make_skill(name="new-skill", description="New") + plugin.available_skills = [skill, new_skill] + + assert plugin.active_skill is not None + assert plugin.active_skill.name == "test-skill" + def test_active_skill_initially_none(self): """Test that active_skill is None initially.""" plugin = SkillsPlugin(skills=[_make_skill()]) @@ -433,6 +446,15 @@ def test_location_omitted_when_path_none(self): assert "" not in xml + def test_escapes_xml_special_characters(self): + """Test that XML special characters in names and descriptions are escaped.""" + skill = _make_skill(name="a&c", description="Use & more") + plugin = SkillsPlugin(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert "a<b>&c" in xml + assert "Use <tools> & more" in xml + class TestSkillResponseFormat: """Tests for _format_skill_response.""" From 19f9383731cf8afe5dc0e9379f6f58c326c7e4de Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 3 Mar 2026 11:30:44 -0500 Subject: [PATCH 18/29] fix: add unit test for load skills --- .../plugins/skills/test_skills_plugin.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 642b051fe..9fd0ccd8a 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -214,6 +214,82 @@ def test_active_skill_initially_none(self): assert plugin.active_skill is None +class TestLoadSkills: + """Tests for the load_skills method.""" + + def test_appends_skill_instances(self): + """Test that load_skills appends Skill instances to existing skills.""" + plugin = SkillsPlugin(skills=[_make_skill(name="existing", description="Existing")]) + + plugin.load_skills([_make_skill(name="new-skill", description="New")]) + + assert len(plugin.available_skills) == 2 + names = {s.name for s in plugin.available_skills} + assert names == {"existing", "new-skill"} + + def test_appends_from_filesystem(self, tmp_path): + """Test that load_skills appends skills resolved from filesystem paths.""" + plugin = SkillsPlugin(skills=[_make_skill(name="existing", description="Existing")]) + _make_skill_dir(tmp_path, "fs-skill") + + plugin.load_skills([str(tmp_path / "fs-skill")]) + + assert len(plugin.available_skills) == 2 + names = {s.name for s in plugin.available_skills} + assert names == {"existing", "fs-skill"} + + def test_duplicates_overwrite(self): + """Test that loading a skill with the same name overwrites the existing one.""" + original = _make_skill(name="dupe", description="Original") + plugin = SkillsPlugin(skills=[original]) + + replacement = _make_skill(name="dupe", description="Replacement") + plugin.load_skills([replacement]) + + assert len(plugin.available_skills) == 1 + assert plugin.available_skills[0].description == "Replacement" + + def test_mixed_sources(self, tmp_path): + """Test load_skills with a mix of Skill instances and filesystem paths.""" + plugin = SkillsPlugin(skills=[]) + _make_skill_dir(tmp_path, "fs-skill") + direct = _make_skill(name="direct", description="Direct") + + plugin.load_skills([str(tmp_path / "fs-skill"), direct]) + + assert len(plugin.available_skills) == 2 + names = {s.name for s in plugin.available_skills} + assert names == {"fs-skill", "direct"} + + def test_skips_nonexistent_paths(self): + """Test that nonexistent paths are skipped without error.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + + plugin.load_skills(["/nonexistent/path"]) + + assert len(plugin.available_skills) == 1 + + def test_empty_sources(self): + """Test that loading empty sources is a no-op.""" + plugin = SkillsPlugin(skills=[_make_skill()]) + + plugin.load_skills([]) + + assert len(plugin.available_skills) == 1 + + def test_parent_directory(self, tmp_path): + """Test load_skills with a parent directory containing multiple skills.""" + plugin = SkillsPlugin(skills=[]) + _make_skill_dir(tmp_path, "child-a") + _make_skill_dir(tmp_path, "child-b") + + plugin.load_skills([tmp_path]) + + assert len(plugin.available_skills) == 2 + names = {s.name for s in plugin.available_skills} + assert names == {"child-a", "child-b"} + + class TestSkillsTool: """Tests for the skills tool method.""" From 8f2f5a7225747b13490aa55f322d5c1b1b08929e Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 4 Mar 2026 13:18:22 -0500 Subject: [PATCH 19/29] refactor: make SkillsPlugin multi-agent safe with per-agent state --- src/strands/plugins/skills/loader.py | 16 +- src/strands/plugins/skills/skill.py | 2 +- src/strands/plugins/skills/skills_plugin.py | 93 +++++----- .../plugins/skills/test_skills_plugin.py | 166 ++++++++++-------- tests_integ/test_skills_plugin.py | 8 +- 5 files changed, 149 insertions(+), 136 deletions(-) diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py index c808660b7..da6c08629 100644 --- a/src/strands/plugins/skills/loader.py +++ b/src/strands/plugins/skills/loader.py @@ -44,19 +44,6 @@ def _find_skill_md(skill_dir: Path) -> Path: raise FileNotFoundError(f"path=<{skill_dir}> | no SKILL.md found in skill directory") -def _parse_yaml(yaml_text: str) -> dict[str, Any]: - """Parse YAML text into a dictionary. - - Args: - yaml_text: YAML-formatted text to parse. - - Returns: - Dictionary of parsed key-value pairs. - """ - result = yaml.safe_load(yaml_text) - return result if isinstance(result, dict) else {} - - def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: """Parse YAML frontmatter and body from SKILL.md content. @@ -84,7 +71,8 @@ def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: frontmatter_str = stripped[3 : match.start()].strip() body = stripped[match.end() :].strip() - frontmatter = _parse_yaml(frontmatter_str) + result = yaml.safe_load(frontmatter_str) + frontmatter: dict[str, Any] = result if isinstance(result, dict) else {} return frontmatter, body diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py index 9f4092783..34010fdba 100644 --- a/src/strands/plugins/skills/skill.py +++ b/src/strands/plugins/skills/skill.py @@ -25,7 +25,7 @@ class Skill: description: Human-readable description of what the skill does. instructions: Full markdown instructions from the SKILL.md body. path: Filesystem path to the skill directory, if loaded from disk. - allowed_tools: List of tool names the skill is allowed to use. + allowed_tools: List of tool names the skill is allowed to use. (Experimental: not yet enforced) metadata: Additional key-value metadata from the SKILL.md frontmatter. license: License identifier (e.g., "Apache-2.0"). compatibility: Compatibility information string. diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index 670480f2e..c83c5bf66 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -15,6 +15,7 @@ from ...hooks.events import BeforeInvocationEvent from ...plugins import Plugin, hook from ...tools.decorator import tool +from ...types.tools import ToolContext from .loader import load_skill, load_skills from .skill import Skill @@ -23,9 +24,9 @@ logger = logging.getLogger(__name__) -_STATE_KEY = "skills_plugin" +_DEFAULT_STATE_KEY = "skills_plugin" _RESOURCE_DIRS = ("scripts", "references", "assets") -_MAX_RESOURCE_FILES = 20 +_DEFAULT_MAX_RESOURCE_FILES = 20 class SkillsPlugin(Plugin): @@ -61,7 +62,12 @@ def name(self) -> str: """A stable string identifier for the plugin.""" return "skills" - def __init__(self, skills: list[str | Path | Skill]) -> None: + def __init__( + self, + skills: list[str | Path | Skill], + state_key: str = _DEFAULT_STATE_KEY, + max_resource_files: int = _DEFAULT_MAX_RESOURCE_FILES, + ) -> None: """Initialize the SkillsPlugin. Args: @@ -70,10 +76,12 @@ def __init__(self, skills: list[str | Path | Skill]) -> None: - A ``str`` or ``Path`` to a skill directory (containing SKILL.md) - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) - A ``Skill`` dataclass instance + state_key: Key used to store plugin state in ``agent.state``. + max_resource_files: Maximum number of resource files to list in skill responses. """ self._skills: dict[str, Skill] = self._resolve_skills(skills) - self._active_skill: Skill | None = None - self._agent: Agent | None = None + self._state_key = state_key + self._max_resource_files = max_resource_files super().__init__() def init_agent(self, agent: Agent) -> None: @@ -85,12 +93,11 @@ def init_agent(self, agent: Agent) -> None: Args: agent: The agent instance to extend with skills support. """ - self._agent = agent - self._restore_state() + self._restore_state(agent) logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) - @tool - def skills(self, skill_name: str) -> str: + @tool(context=True) + def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D417 """Activate a skill to load its full instructions. Use this tool to load the complete instructions for a skill listed in @@ -99,6 +106,8 @@ def skills(self, skill_name: str) -> str: Args: skill_name: Name of the skill to activate. """ + agent = tool_context.agent + if not skill_name: available = ", ".join(self._skills) return f"Error: skill_name is required. Available skills: {available}" @@ -108,9 +117,7 @@ def skills(self, skill_name: str) -> str: available = ", ".join(self._skills) return f"Skill '{skill_name}' not found. Available skills: {available}" - self._active_skill = found - self._persist_state() - + self._set_state_field(agent, "active_skill_name", found.name) logger.debug("skill_name=<%s> | skill activated", skill_name) return self._format_skill_response(found) @@ -134,7 +141,7 @@ def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: current_prompt = agent.system_prompt or "" # Remove the previously injected XML block by exact match - state_data = agent.state.get(_STATE_KEY) + state_data = agent.state.get(self._state_key) last_injected_xml = state_data.get("last_injected_xml") if isinstance(state_data, dict) else None if last_injected_xml is not None: if last_injected_xml in current_prompt: @@ -163,15 +170,14 @@ def available_skills(self) -> list[Skill]: def available_skills(self, value: list[Skill]) -> None: """Set the available skills directly. - If the currently active skill is no longer in the new list, it is deactivated. + Note: this does not persist state or deactivate skills on any agent. + Active skill state is managed per-agent and will be reconciled on the + next tool call or invocation. Args: value: List of Skill instances. """ self._skills = {s.name: s for s in value} - if self._active_skill and self._active_skill.name not in self._skills: - self._active_skill = None - self._persist_state() def load_skills(self, sources: list[str | Path | Skill]) -> None: """Resolve and append skills from mixed sources. @@ -186,14 +192,23 @@ def load_skills(self, sources: list[str | Path | Skill]) -> None: resolved = self._resolve_skills(sources) self._skills.update(resolved) - @property - def active_skill(self) -> Skill | None: - """Get the currently active skill. + def get_active_skill(self, agent: Agent) -> Skill | None: + """Get the currently active skill for a given agent. + + Args: + agent: The agent to check active skill for. Returns: The active Skill instance, or None if no skill is active. """ - return self._active_skill + state_data = agent.state.get(self._state_key) + if not isinstance(state_data, dict): + return None + + active_name = state_data.get("active_skill_name") + if isinstance(active_name, str): + return self._skills.get(active_name) + return None def _format_skill_response(self, skill: Skill) -> str: """Format the tool response when a skill is activated. @@ -236,7 +251,7 @@ def _list_skill_resources(self, skill_path: Path) -> list[str]: Scans the ``scripts/``, ``references/``, and ``assets/`` subdirectories for files, returning relative paths. Results are capped at - ``_MAX_RESOURCE_FILES`` to avoid context bloat. + ``max_resource_files`` to avoid context bloat. Args: skill_path: Path to the skill directory. @@ -255,8 +270,8 @@ def _list_skill_resources(self, skill_path: Path) -> list[str]: if not file_path.is_file(): continue files.append(file_path.relative_to(skill_path).as_posix()) - if len(files) >= _MAX_RESOURCE_FILES: - files.append(f"... (truncated at {_MAX_RESOURCE_FILES} files)") + if len(files) >= self._max_resource_files: + files.append(f"... (truncated at {self._max_resource_files} files)") return files return files @@ -277,7 +292,7 @@ def _generate_skills_xml(self) -> str: lines.append(f"{escape(skill.name)}") lines.append(f"{escape(skill.description)}") if skill.path is not None: - lines.append(f"{skill.path / 'SKILL.md'}") + lines.append(f"{escape(str(skill.path / 'SKILL.md'))}") lines.append("") lines.append("") @@ -342,13 +357,6 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] logger.debug("source_count=<%d>, resolved_count=<%d> | skills resolved", len(sources), len(resolved)) return resolved - def _persist_state(self) -> None: - """Persist the active skill name to agent state for session recovery.""" - if self._agent is None: - return - - self._set_state_field(self._agent, "active_skill_name", self._active_skill.name if self._active_skill else None) - def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: """Set a single field in the plugin's agent state dict. @@ -357,23 +365,22 @@ def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: key: The state field key. value: The value to set. """ - state_data = agent.state.get(_STATE_KEY) + state_data = agent.state.get(self._state_key) if not isinstance(state_data, dict): state_data = {} state_data[key] = value - agent.state.set(_STATE_KEY, state_data) + agent.state.set(self._state_key, state_data) - def _restore_state(self) -> None: - """Restore the active skill from agent state if available.""" - if self._agent is None: - return + def _restore_state(self, agent: Agent) -> None: + """Restore the active skill from agent state if available. - state_data = self._agent.state.get(_STATE_KEY) + Args: + agent: The agent whose state to restore from. + """ + state_data = agent.state.get(self._state_key) if not isinstance(state_data, dict): return active_name = state_data.get("active_skill_name") - if isinstance(active_name, str): - self._active_skill = self._skills.get(active_name) - if self._active_skill: - logger.debug("skill_name=<%s> | restored active skill from state", active_name) + if isinstance(active_name, str) and active_name in self._skills: + logger.debug("skill_name=<%s> | restored active skill from state", active_name) diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 9fd0ccd8a..c969c7eff 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -9,6 +9,7 @@ from strands.plugins.registry import _PluginRegistry from strands.plugins.skills.skill import Skill from strands.plugins.skills.skills_plugin import SkillsPlugin +from strands.types.tools import ToolContext def _make_skill(name: str = "test-skill", description: str = "A test skill", instructions: str = "Do the thing."): @@ -52,6 +53,12 @@ def _mock_agent(): return agent +def _mock_tool_context(agent: MagicMock) -> ToolContext: + """Create a mock ToolContext with the given agent.""" + tool_use = {"toolUseId": "test-id", "name": "skills", "input": {}} + return ToolContext(tool_use=tool_use, agent=agent, invocation_state={"agent": agent}) + + def _set_system_prompt(agent: MagicMock, value: str | None) -> None: """Simulate the Agent.system_prompt setter.""" if isinstance(value, str): @@ -108,13 +115,22 @@ def test_init_empty_skills(self): """Test initialization with empty skills list.""" plugin = SkillsPlugin(skills=[]) assert plugin.available_skills == [] - assert plugin.active_skill is None def test_name_attribute(self): """Test that the plugin has the correct name.""" plugin = SkillsPlugin(skills=[]) assert plugin.name == "skills" + def test_custom_state_key(self): + """Test initialization with a custom state key.""" + plugin = SkillsPlugin(skills=[], state_key="custom_key") + assert plugin._state_key == "custom_key" + + def test_custom_max_resource_files(self): + """Test initialization with a custom max resource files limit.""" + plugin = SkillsPlugin(skills=[], max_resource_files=50) + assert plugin._max_resource_files == 50 + class TestSkillsPluginInitAgent: """Tests for the init_agent method and plugin registry integration.""" @@ -139,14 +155,14 @@ def test_registers_hooks(self): assert agent.hooks.has_callbacks() - def test_stores_agent_reference(self): - """Test that init_agent stores the agent reference.""" + def test_does_not_store_agent_reference(self): + """Test that init_agent does not store the agent on the plugin.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() plugin.init_agent(agent) - assert plugin._agent is agent + assert not hasattr(plugin, "_agent") def test_restores_state(self): """Test that init_agent restores active skill from state.""" @@ -157,8 +173,8 @@ def test_restores_state(self): plugin.init_agent(agent) - assert plugin.active_skill is not None - assert plugin.active_skill.name == "test-skill" + assert plugin.get_active_skill(agent) is not None + assert plugin.get_active_skill(agent).name == "test-skill" class TestSkillsPluginProperties: @@ -177,7 +193,6 @@ def test_available_skills_getter_returns_copy(self): def test_available_skills_setter(self): """Test setting skills via the property setter.""" plugin = SkillsPlugin(skills=[_make_skill()]) - plugin._agent = _mock_agent() new_skill = _make_skill(name="new-skill", description="New") plugin.available_skills = [new_skill] @@ -185,33 +200,35 @@ def test_available_skills_setter(self): assert len(plugin.available_skills) == 1 assert plugin.available_skills[0].name == "new-skill" - def test_available_skills_setter_deactivates_when_removed(self): - """Test that setting skills deactivates the active skill when it's no longer in the list.""" + def test_get_active_skill_initially_none(self): + """Test that get_active_skill returns None initially.""" plugin = SkillsPlugin(skills=[_make_skill()]) - plugin._agent = _mock_agent() - plugin._active_skill = _make_skill() + agent = _mock_agent() + assert plugin.get_active_skill(agent) is None - plugin.available_skills = [_make_skill(name="new-skill", description="New")] + def test_get_active_skill_after_activation(self): + """Test that get_active_skill returns the activated skill.""" + skill = _make_skill() + plugin = SkillsPlugin(skills=[skill]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) - assert plugin.active_skill is None + plugin.skills(skill_name="test-skill", tool_context=tool_context) - def test_available_skills_setter_preserves_active_when_present(self): - """Test that setting skills keeps the active skill when it's still in the list.""" + assert plugin.get_active_skill(agent) is not None + assert plugin.get_active_skill(agent).name == "test-skill" + + def test_get_active_skill_returns_none_when_skill_removed(self): + """Test that get_active_skill returns None when the active skill is no longer available.""" skill = _make_skill() plugin = SkillsPlugin(skills=[skill]) - plugin._agent = _mock_agent() - plugin._active_skill = skill + agent = _mock_agent() + tool_context = _mock_tool_context(agent) - new_skill = _make_skill(name="new-skill", description="New") - plugin.available_skills = [skill, new_skill] + plugin.skills(skill_name="test-skill", tool_context=tool_context) + plugin.available_skills = [_make_skill(name="other-skill", description="Other")] - assert plugin.active_skill is not None - assert plugin.active_skill.name == "test-skill" - - def test_active_skill_initially_none(self): - """Test that active_skill is None initially.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - assert plugin.active_skill is None + assert plugin.get_active_skill(agent) is None class TestLoadSkills: @@ -297,21 +314,23 @@ def test_activate_skill(self): """Test activating a skill returns its instructions.""" skill = _make_skill(instructions="Full instructions here.") plugin = SkillsPlugin(skills=[skill]) - plugin._agent = _mock_agent() + agent = _mock_agent() + tool_context = _mock_tool_context(agent) - result = plugin.skills(skill_name="test-skill") + result = plugin.skills(skill_name="test-skill", tool_context=tool_context) assert "Full instructions here." in result - assert plugin.active_skill is not None - assert plugin.active_skill.name == "test-skill" + assert plugin.get_active_skill(agent) is not None + assert plugin.get_active_skill(agent).name == "test-skill" def test_activate_nonexistent_skill(self): """Test activating a nonexistent skill returns error message.""" skill = _make_skill() plugin = SkillsPlugin(skills=[skill]) - plugin._agent = _mock_agent() + agent = _mock_agent() + tool_context = _mock_tool_context(agent) - result = plugin.skills(skill_name="nonexistent") + result = plugin.skills(skill_name="nonexistent", tool_context=tool_context) assert "not found" in result assert "test-skill" in result @@ -321,20 +340,22 @@ def test_activate_replaces_previous(self): skill1 = _make_skill(name="skill-a", description="A", instructions="A instructions") skill2 = _make_skill(name="skill-b", description="B", instructions="B instructions") plugin = SkillsPlugin(skills=[skill1, skill2]) - plugin._agent = _mock_agent() + agent = _mock_agent() + tool_context = _mock_tool_context(agent) - plugin.skills(skill_name="skill-a") - assert plugin.active_skill.name == "skill-a" + plugin.skills(skill_name="skill-a", tool_context=tool_context) + assert plugin.get_active_skill(agent).name == "skill-a" - plugin.skills(skill_name="skill-b") - assert plugin.active_skill.name == "skill-b" + plugin.skills(skill_name="skill-b", tool_context=tool_context) + assert plugin.get_active_skill(agent).name == "skill-b" def test_activate_without_name(self): """Test activating without a skill name returns error.""" plugin = SkillsPlugin(skills=[_make_skill()]) - plugin._agent = _mock_agent() + agent = _mock_agent() + tool_context = _mock_tool_context(agent) - result = plugin.skills(skill_name="") + result = plugin.skills(skill_name="", tool_context=tool_context) assert "required" in result.lower() @@ -342,12 +363,29 @@ def test_activate_persists_state(self): """Test that activating a skill persists state.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() - plugin._agent = agent + tool_context = _mock_tool_context(agent) - plugin.skills(skill_name="test-skill") + plugin.skills(skill_name="test-skill", tool_context=tool_context) agent.state.set.assert_called() + def test_multi_agent_isolation(self): + """Test that skill activation is isolated per agent.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A instructions") + skill_b = _make_skill(name="skill-b", description="B", instructions="B instructions") + plugin = SkillsPlugin(skills=[skill_a, skill_b]) + + agent1 = _mock_agent() + agent2 = _mock_agent() + ctx1 = _mock_tool_context(agent1) + ctx2 = _mock_tool_context(agent2) + + plugin.skills(skill_name="skill-a", tool_context=ctx1) + plugin.skills(skill_name="skill-b", tool_context=ctx2) + + assert plugin.get_active_skill(agent1).name == "skill-a" + assert plugin.get_active_skill(agent2).name == "skill-b" + class TestSystemPromptInjection: """Tests for system prompt injection via hooks.""" @@ -652,26 +690,16 @@ def test_resource_listing_truncated(self, tmp_path): class TestSessionPersistence: """Tests for session state persistence.""" - def test_persist_state_with_active_skill(self): - """Test persisting active skill name.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - agent = _mock_agent() - plugin._agent = agent - plugin._active_skill = _make_skill() - - plugin._persist_state() - - agent.state.set.assert_called_once_with("skills_plugin", {"active_skill_name": "test-skill"}) - - def test_persist_state_without_active_skill(self): - """Test persisting None when no skill is active.""" + def test_tool_persists_active_skill(self): + """Test that the tool persists the active skill name to agent state.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() - plugin._agent = agent + tool_context = _mock_tool_context(agent) - plugin._persist_state() + plugin.skills(skill_name="test-skill", tool_context=tool_context) - agent.state.set.assert_called_once_with("skills_plugin", {"active_skill_name": None}) + state = agent.state.get("skills_plugin") + assert state["active_skill_name"] == "test-skill" def test_restore_state_activates_skill(self): """Test restoring active skill from state.""" @@ -679,40 +707,30 @@ def test_restore_state_activates_skill(self): plugin = SkillsPlugin(skills=[skill]) agent = _mock_agent() agent.state.set("skills_plugin", {"active_skill_name": "test-skill"}) - plugin._agent = agent - plugin._restore_state() + plugin._restore_state(agent) - assert plugin.active_skill is not None - assert plugin.active_skill.name == "test-skill" + assert plugin.get_active_skill(agent) is not None + assert plugin.get_active_skill(agent).name == "test-skill" def test_restore_state_no_data(self): """Test restore when no state data exists.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() - plugin._agent = agent - plugin._restore_state() + plugin._restore_state(agent) - assert plugin.active_skill is None + assert plugin.get_active_skill(agent) is None def test_restore_state_skill_not_found(self): """Test restore when saved skill is no longer available.""" plugin = SkillsPlugin(skills=[_make_skill()]) agent = _mock_agent() agent.state.set("skills_plugin", {"active_skill_name": "removed-skill"}) - plugin._agent = agent - plugin._restore_state() - - assert plugin.active_skill is None - - def test_persist_state_without_agent(self): - """Test that persist_state is a no-op without agent.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin._restore_state(agent) - # Should not raise - plugin._persist_state() + assert plugin.get_active_skill(agent) is None class TestResolveSkills: diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index 01ccb7260..85287e1f9 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -46,8 +46,8 @@ def test_agent_activates_skill_and_injects_metadata(agent, skills_plugin): assert "translation" in agent.system_prompt # Model activated the skill and relayed the codeword from instructions - assert skills_plugin.active_skill is not None - assert skills_plugin.active_skill.name == "summarization" + assert skills_plugin.get_active_skill(agent) is not None + assert skills_plugin.get_active_skill(agent).name == "summarization" assert "pineapple" in str(result).lower() @@ -61,8 +61,8 @@ def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): assert "translation expert" in response_text # Plugin tracks the active skill - assert skills_plugin.active_skill is not None - assert skills_plugin.active_skill.name == "translation" + assert skills_plugin.get_active_skill(agent) is not None + assert skills_plugin.get_active_skill(agent).name == "translation" # State was persisted to agent state state = agent.state.get("skills_plugin") From 0acf9688df362804e6dc32637666c2906f4f4971 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 4 Mar 2026 13:44:10 -0500 Subject: [PATCH 20/29] chore: revert unrelated handler and tracer formatting changes --- .../experimental/steering/core/handler.py | 3 +- tests/strands/telemetry/test_tracer.py | 292 ++++++++---------- 2 files changed, 131 insertions(+), 164 deletions(-) diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 7e363a6d7..214118d4f 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -79,8 +79,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non def init_agent(self, agent: "Agent") -> None: """Initialize the steering handler with an agent. - Registers context update callbacks. Decorated hooks and tools - are auto-registered by the plugin registry. + Registers hook callbacks for steering guidance and context updates. Args: agent: The agent instance to attach steering to. diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 50c0cc9b9..da7f010e2 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -148,16 +148,14 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "chat", - "gen_ai.system": "strands-agents", - "custom_key": "custom_value", - "user_id": "12345", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "chat", + "gen_ai.system": "strands-agents", + "custom_key": "custom_value", + "user_id": "12345", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + }) mock_span.add_event.assert_called_with( "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} ) @@ -190,15 +188,13 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "chat", - "gen_ai.provider.name": "strands-agents", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "chat", + "gen_ai.provider.name": "strands-agents", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -236,17 +232,15 @@ def test_end_model_invoke_span(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -265,17 +259,15 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -308,17 +300,15 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.tool.name": "test-tool", - "gen_ai.system": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - "session_id": "abc123", - "environment": "production", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.tool.name": "test-tool", + "gen_ai.system": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + "session_id": "abc123", + "environment": "production", + }) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -341,15 +331,13 @@ def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.tool.name": "test-tool", - "gen_ai.provider.name": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.tool.name": "test-tool", + "gen_ai.provider.name": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -389,16 +377,14 @@ def test_start_swarm_call_span_with_string_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - "workflow_id": "wf-789", - "priority": "high", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + "workflow_id": "wf-789", + "priority": "high", + }) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"}) assert span is not None @@ -418,14 +404,12 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + }) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'} ) @@ -476,14 +460,12 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "swarm", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "swarm", + }) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -546,15 +528,13 @@ def test_start_graph_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "execute_tool", - "gen_ai.system": "strands-agents", - "gen_ai.tool.name": "test-tool", - "gen_ai.tool.call.id": "123", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "execute_tool", + "gen_ai.system": "strands-agents", + "gen_ai.tool.name": "test-tool", + "gen_ai.tool.call.id": "123", + }) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -628,14 +608,12 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - - mock_span.set_attributes.assert_called_once_with( - { - "event_loop.cycle_id": "cycle-123", - "request_id": "req-456", - "trace_level": "debug", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "event_loop.cycle_id": "cycle-123", + "request_id": "req-456", + "trace_level": "debug", + }) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])} ) @@ -659,7 +637,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - + mock_span.set_attributes.assert_called_once_with({"event_loop.cycle_id": "cycle-123"}) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", @@ -753,16 +731,14 @@ def test_start_agent_span(mock_tracer): assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_agent", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_agent", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + }) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) assert span is not None @@ -792,17 +768,15 @@ def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_agent", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_agent", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + }) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -945,19 +919,17 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.usage.cache_read_input_tokens": 5, - "gen_ai.usage.cache_write_input_tokens": 3, - "gen_ai.server.request.duration": 10, - "gen_ai.server.time_to_first_token": 5, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.usage.cache_read_input_tokens": 5, + "gen_ai.usage.cache_write_input_tokens": 3, + "gen_ai.server.request.duration": 10, + "gen_ai.server.time_to_first_token": 5, + }) def test_end_agent_span_with_cache_metrics(mock_span): @@ -981,17 +953,15 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 50, - "gen_ai.usage.input_tokens": 50, - "gen_ai.usage.completion_tokens": 100, - "gen_ai.usage.output_tokens": 100, - "gen_ai.usage.total_tokens": 150, - "gen_ai.usage.cache_read_input_tokens": 25, - "gen_ai.usage.cache_write_input_tokens": 10, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 25, + "gen_ai.usage.cache_write_input_tokens": 10, + }) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -1549,20 +1519,18 @@ def test_end_model_invoke_span_langfuse_adds_attributes(mock_span, monkeypatch): } ] ) - + assert mock_span.set_attributes.call_count == 2 mock_span.set_attributes.assert_any_call({"gen_ai.output.messages": expected_output}) - mock_span.set_attributes.assert_any_call( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - } - ) + mock_span.set_attributes.assert_any_call({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", From 2627d92e7b36cca5bc8421f0e4d43d7cedb2bcce Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 5 Mar 2026 18:33:25 -0500 Subject: [PATCH 21/29] refactor: remove unused active skill tracking from SkillsPlugin --- src/strands/plugins/skills/skills_plugin.py | 33 ----- .../plugins/skills/test_skills_plugin.py | 125 +----------------- tests_integ/test_skills_plugin.py | 11 -- 3 files changed, 4 insertions(+), 165 deletions(-) diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/skills_plugin.py index c83c5bf66..9e606fbca 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/skills_plugin.py @@ -93,7 +93,6 @@ def init_agent(self, agent: Agent) -> None: Args: agent: The agent instance to extend with skills support. """ - self._restore_state(agent) logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) @tool(context=True) @@ -117,7 +116,6 @@ def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D4 available = ", ".join(self._skills) return f"Skill '{skill_name}' not found. Available skills: {available}" - self._set_state_field(agent, "active_skill_name", found.name) logger.debug("skill_name=<%s> | skill activated", skill_name) return self._format_skill_response(found) @@ -192,24 +190,6 @@ def load_skills(self, sources: list[str | Path | Skill]) -> None: resolved = self._resolve_skills(sources) self._skills.update(resolved) - def get_active_skill(self, agent: Agent) -> Skill | None: - """Get the currently active skill for a given agent. - - Args: - agent: The agent to check active skill for. - - Returns: - The active Skill instance, or None if no skill is active. - """ - state_data = agent.state.get(self._state_key) - if not isinstance(state_data, dict): - return None - - active_name = state_data.get("active_skill_name") - if isinstance(active_name, str): - return self._skills.get(active_name) - return None - def _format_skill_response(self, skill: Skill) -> str: """Format the tool response when a skill is activated. @@ -371,16 +351,3 @@ def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: state_data[key] = value agent.state.set(self._state_key, state_data) - def _restore_state(self, agent: Agent) -> None: - """Restore the active skill from agent state if available. - - Args: - agent: The agent whose state to restore from. - """ - state_data = agent.state.get(self._state_key) - if not isinstance(state_data, dict): - return - - active_name = state_data.get("active_skill_name") - if isinstance(active_name, str) and active_name in self._skills: - logger.debug("skill_name=<%s> | restored active skill from state", active_name) diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index c969c7eff..d0c01b357 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -164,18 +164,6 @@ def test_does_not_store_agent_reference(self): assert not hasattr(plugin, "_agent") - def test_restores_state(self): - """Test that init_agent restores active skill from state.""" - skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) - agent = _mock_agent() - agent.state.set("skills_plugin", {"active_skill_name": "test-skill"}) - - plugin.init_agent(agent) - - assert plugin.get_active_skill(agent) is not None - assert plugin.get_active_skill(agent).name == "test-skill" - class TestSkillsPluginProperties: """Tests for SkillsPlugin properties.""" @@ -200,36 +188,6 @@ def test_available_skills_setter(self): assert len(plugin.available_skills) == 1 assert plugin.available_skills[0].name == "new-skill" - def test_get_active_skill_initially_none(self): - """Test that get_active_skill returns None initially.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - agent = _mock_agent() - assert plugin.get_active_skill(agent) is None - - def test_get_active_skill_after_activation(self): - """Test that get_active_skill returns the activated skill.""" - skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) - agent = _mock_agent() - tool_context = _mock_tool_context(agent) - - plugin.skills(skill_name="test-skill", tool_context=tool_context) - - assert plugin.get_active_skill(agent) is not None - assert plugin.get_active_skill(agent).name == "test-skill" - - def test_get_active_skill_returns_none_when_skill_removed(self): - """Test that get_active_skill returns None when the active skill is no longer available.""" - skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) - agent = _mock_agent() - tool_context = _mock_tool_context(agent) - - plugin.skills(skill_name="test-skill", tool_context=tool_context) - plugin.available_skills = [_make_skill(name="other-skill", description="Other")] - - assert plugin.get_active_skill(agent) is None - class TestLoadSkills: """Tests for the load_skills method.""" @@ -320,8 +278,6 @@ def test_activate_skill(self): result = plugin.skills(skill_name="test-skill", tool_context=tool_context) assert "Full instructions here." in result - assert plugin.get_active_skill(agent) is not None - assert plugin.get_active_skill(agent).name == "test-skill" def test_activate_nonexistent_skill(self): """Test activating a nonexistent skill returns error message.""" @@ -343,11 +299,11 @@ def test_activate_replaces_previous(self): agent = _mock_agent() tool_context = _mock_tool_context(agent) - plugin.skills(skill_name="skill-a", tool_context=tool_context) - assert plugin.get_active_skill(agent).name == "skill-a" + result_a = plugin.skills(skill_name="skill-a", tool_context=tool_context) + assert "A instructions" in result_a - plugin.skills(skill_name="skill-b", tool_context=tool_context) - assert plugin.get_active_skill(agent).name == "skill-b" + result_b = plugin.skills(skill_name="skill-b", tool_context=tool_context) + assert "B instructions" in result_b def test_activate_without_name(self): """Test activating without a skill name returns error.""" @@ -359,33 +315,6 @@ def test_activate_without_name(self): assert "required" in result.lower() - def test_activate_persists_state(self): - """Test that activating a skill persists state.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - agent = _mock_agent() - tool_context = _mock_tool_context(agent) - - plugin.skills(skill_name="test-skill", tool_context=tool_context) - - agent.state.set.assert_called() - - def test_multi_agent_isolation(self): - """Test that skill activation is isolated per agent.""" - skill_a = _make_skill(name="skill-a", description="A", instructions="A instructions") - skill_b = _make_skill(name="skill-b", description="B", instructions="B instructions") - plugin = SkillsPlugin(skills=[skill_a, skill_b]) - - agent1 = _mock_agent() - agent2 = _mock_agent() - ctx1 = _mock_tool_context(agent1) - ctx2 = _mock_tool_context(agent2) - - plugin.skills(skill_name="skill-a", tool_context=ctx1) - plugin.skills(skill_name="skill-b", tool_context=ctx2) - - assert plugin.get_active_skill(agent1).name == "skill-a" - assert plugin.get_active_skill(agent2).name == "skill-b" - class TestSystemPromptInjection: """Tests for system prompt injection via hooks.""" @@ -687,52 +616,6 @@ def test_resource_listing_truncated(self, tmp_path): assert "truncated at 20 files" in result -class TestSessionPersistence: - """Tests for session state persistence.""" - - def test_tool_persists_active_skill(self): - """Test that the tool persists the active skill name to agent state.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - agent = _mock_agent() - tool_context = _mock_tool_context(agent) - - plugin.skills(skill_name="test-skill", tool_context=tool_context) - - state = agent.state.get("skills_plugin") - assert state["active_skill_name"] == "test-skill" - - def test_restore_state_activates_skill(self): - """Test restoring active skill from state.""" - skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) - agent = _mock_agent() - agent.state.set("skills_plugin", {"active_skill_name": "test-skill"}) - - plugin._restore_state(agent) - - assert plugin.get_active_skill(agent) is not None - assert plugin.get_active_skill(agent).name == "test-skill" - - def test_restore_state_no_data(self): - """Test restore when no state data exists.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - agent = _mock_agent() - - plugin._restore_state(agent) - - assert plugin.get_active_skill(agent) is None - - def test_restore_state_skill_not_found(self): - """Test restore when saved skill is no longer available.""" - plugin = SkillsPlugin(skills=[_make_skill()]) - agent = _mock_agent() - agent.state.set("skills_plugin", {"active_skill_name": "removed-skill"}) - - plugin._restore_state(agent) - - assert plugin.get_active_skill(agent) is None - - class TestResolveSkills: """Tests for _resolve_skills.""" diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index 85287e1f9..c6d8e2770 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -46,8 +46,6 @@ def test_agent_activates_skill_and_injects_metadata(agent, skills_plugin): assert "translation" in agent.system_prompt # Model activated the skill and relayed the codeword from instructions - assert skills_plugin.get_active_skill(agent) is not None - assert skills_plugin.get_active_skill(agent).name == "summarization" assert "pineapple" in str(result).lower() @@ -59,12 +57,3 @@ def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): assert result["status"] == "success" response_text = result["content"][0]["text"].lower() assert "translation expert" in response_text - - # Plugin tracks the active skill - assert skills_plugin.get_active_skill(agent) is not None - assert skills_plugin.get_active_skill(agent).name == "translation" - - # State was persisted to agent state - state = agent.state.get("skills_plugin") - assert state is not None - assert state["active_skill_name"] == "translation" From fb70ed1f40e0d2d7f8b9b557e224fc414c2adc2f Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 6 Mar 2026 14:46:31 -0500 Subject: [PATCH 22/29] refactor: rename SkillsPlugin to AgentSkills, add lenient validation and YAML fallback - Rename SkillsPlugin to AgentSkills across codebase - Make skill name validation lenient by default (warn instead of raise) - Add strict=True option for validation when needed - Add _fix_yaml_colons fallback for malformed YAML frontmatter - Inject 'no skills available' message when skills list is empty - Add tests for strict mode, _fix_yaml_colons, and YAML fallback --- AGENTS.md | 2 +- src/strands/__init__.py | 4 +- src/strands/plugins/__init__.py | 4 +- src/strands/plugins/skills/__init__.py | 10 +- .../{skills_plugin.py => agent_skills.py} | 34 +- src/strands/plugins/skills/loader.py | 74 ++++- tests/strands/plugins/skills/test_loader.py | 201 ++++++++++-- .../plugins/skills/test_skills_plugin.py | 144 ++++----- tests/strands/telemetry/test_tracer.py | 292 ++++++++++-------- tests_integ/test_skills_plugin.py | 6 +- 10 files changed, 495 insertions(+), 276 deletions(-) rename src/strands/plugins/skills/{skills_plugin.py => agent_skills.py} (92%) diff --git a/AGENTS.md b/AGENTS.md index de952424a..9191bcb04 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -135,7 +135,7 @@ strands-agents/ │ │ ├── __init__.py # Skills package exports │ │ ├── loader.py # Skill loading and parsing │ │ ├── skill.py # Skill dataclass -│ │ └── skills_plugin.py # SkillsPlugin implementation +│ │ └── agent_skills.py # AgentSkills plugin implementation │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2034c8692..3e1528fa6 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,19 +4,19 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin, Skill, SkillsPlugin +from .plugins import AgentSkills, Plugin, Skill from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", "AgentBase", + "AgentSkills", "agent", "models", "ModelRetryStrategy", "Plugin", "Skill", - "SkillsPlugin", "tool", "ToolContext", "types", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index c8f3cac12..d7ca4c9b2 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -6,11 +6,11 @@ from .decorator import hook from .plugin import Plugin -from .skills import Skill, SkillsPlugin +from .skills import AgentSkills, Skill __all__ = [ + "AgentSkills", "Plugin", "Skill", - "SkillsPlugin", "hook", ] diff --git a/src/strands/plugins/skills/__init__.py b/src/strands/plugins/skills/__init__.py index 60ada586c..231a7c1cf 100644 --- a/src/strands/plugins/skills/__init__.py +++ b/src/strands/plugins/skills/__init__.py @@ -1,6 +1,6 @@ """AgentSkills.io integration for Strands Agents. -This module provides the SkillsPlugin for integrating AgentSkills.io skills +This module provides the AgentSkills plugin for integrating AgentSkills.io skills into Strands agents. Skills enable progressive disclosure of instructions: metadata is injected into the system prompt upfront, and full instructions are loaded on demand via a tool. @@ -8,20 +8,20 @@ Example Usage: ```python from strands import Agent - from strands.plugins.skills import Skill, SkillsPlugin + from strands.plugins.skills import Skill, AgentSkills - plugin = SkillsPlugin(skills=["./skills/pdf-processing"]) + plugin = AgentSkills(skills=["./skills/pdf-processing"]) agent = Agent(plugins=[plugin]) ``` """ +from .agent_skills import AgentSkills from .loader import load_skill, load_skills from .skill import Skill -from .skills_plugin import SkillsPlugin __all__ = [ + "AgentSkills", "Skill", - "SkillsPlugin", "load_skill", "load_skills", ] diff --git a/src/strands/plugins/skills/skills_plugin.py b/src/strands/plugins/skills/agent_skills.py similarity index 92% rename from src/strands/plugins/skills/skills_plugin.py rename to src/strands/plugins/skills/agent_skills.py index 9e606fbca..b4a3cebd3 100644 --- a/src/strands/plugins/skills/skills_plugin.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -1,6 +1,6 @@ -"""SkillsPlugin for integrating Agent Skills into Strands agents. +"""AgentSkills plugin for integrating Agent Skills into Strands agents. -This module provides the SkillsPlugin class that extends the Plugin base class +This module provides the AgentSkills class that extends the Plugin base class to add Agent Skills support. The plugin registers a tool for activating skills, and injects skill metadata into the system prompt. """ @@ -29,10 +29,10 @@ _DEFAULT_MAX_RESOURCE_FILES = 20 -class SkillsPlugin(Plugin): +class AgentSkills(Plugin): """Plugin that integrates Agent Skills into a Strands agent. - The SkillsPlugin extends the Plugin base class and provides: + The AgentSkills plugin extends the Plugin base class and provides: 1. A ``skills`` tool that allows the agent to activate skills on demand 2. System prompt injection of available skill metadata before each invocation @@ -44,14 +44,14 @@ class SkillsPlugin(Plugin): Example: ```python from strands import Agent - from strands.plugins.skills import Skill, SkillsPlugin + from strands.plugins.skills import Skill, AgentSkills # Load from filesystem - plugin = SkillsPlugin(skills=["./skills/pdf-processing", "./skills/"]) + plugin = AgentSkills(skills=["./skills/pdf-processing", "./skills/"]) # Or provide Skill instances directly skill = Skill(name="my-skill", description="A custom skill", instructions="Do the thing") - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) agent = Agent(plugins=[plugin]) ``` @@ -67,8 +67,9 @@ def __init__( skills: list[str | Path | Skill], state_key: str = _DEFAULT_STATE_KEY, max_resource_files: int = _DEFAULT_MAX_RESOURCE_FILES, + strict: bool = False, ) -> None: - """Initialize the SkillsPlugin. + """Initialize the AgentSkills plugin. Args: skills: List of skill sources. Each element can be: @@ -78,7 +79,9 @@ def __init__( - A ``Skill`` dataclass instance state_key: Key used to store plugin state in ``agent.state``. max_resource_files: Maximum number of resource files to list in skill responses. + strict: If True, raise on skill validation issues. If False (default), warn and load anyway. """ + self._strict = strict self._skills: dict[str, Skill] = self._resolve_skills(skills) self._state_key = state_key self._max_resource_files = max_resource_files @@ -87,12 +90,13 @@ def __init__( def init_agent(self, agent: Agent) -> None: """Initialize the plugin with an agent instance. - Restores any persisted state from a previous session. Decorated hooks and tools are auto-registered by the plugin registry. Args: agent: The agent instance to extend with skills support. """ + if not self._skills: + logger.warning("no skills were loaded, the agent will have no skills available") logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) @tool(context=True) @@ -105,8 +109,6 @@ def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D4 Args: skill_name: Name of the skill to activate. """ - agent = tool_context.agent - if not skill_name: available = ", ".join(self._skills) return f"Error: skill_name is required. Available skills: {available}" @@ -133,9 +135,6 @@ def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: """ agent = event.agent - if not self._skills: - return - current_prompt = agent.system_prompt or "" # Remove the previously injected XML block by exact match @@ -259,12 +258,16 @@ def _list_skill_resources(self, skill_path: Path) -> list[str]: def _generate_skills_xml(self) -> str: """Generate the XML block listing available skills for the system prompt. - Includes a ```` element for skills loaded from the filesystem, + When no skills are loaded, returns a block indicating no skills are available. + Otherwise includes a ```` element for skills loaded from the filesystem, following the AgentSkills.io integration spec. Returns: XML-formatted string with skill metadata. """ + if not self._skills: + return "\nNo skills are currently available.\n" + lines: list[str] = [""] for skill in self._skills.values(): @@ -350,4 +353,3 @@ def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: state_data = {} state_data[key] = value agent.state.set(self._state_key, state_data) - diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py index da6c08629..8ee509da6 100644 --- a/src/strands/plugins/skills/loader.py +++ b/src/strands/plugins/skills/loader.py @@ -71,15 +71,51 @@ def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: frontmatter_str = stripped[3 : match.start()].strip() body = stripped[match.end() :].strip() - result = yaml.safe_load(frontmatter_str) + try: + result = yaml.safe_load(frontmatter_str) + except yaml.YAMLError: + # AgentSkills spec recommends handling malformed YAML (e.g. unquoted colons in values) + # to improve cross-client compatibility. See: agentskills.io/client-implementation/adding-skills-support + logger.warning("YAML parse failed, retrying with colon-quoting fallback") + fixed = _fix_yaml_colons(frontmatter_str) + result = yaml.safe_load(fixed) + frontmatter: dict[str, Any] = result if isinstance(result, dict) else {} return frontmatter, body -def _validate_skill_name(name: str, dir_path: Path | None = None) -> None: +def _fix_yaml_colons(yaml_str: str) -> str: + """Attempt to fix common YAML issues like unquoted colons in values. + + Wraps values containing colons in double quotes to handle cases like: + ``description: Use this skill when: the user asks about PDFs`` + + Args: + yaml_str: The raw YAML string to fix. + + Returns: + The fixed YAML string. + """ + lines: list[str] = [] + for line in yaml_str.splitlines(): + # Match key: value where value contains another colon + match = re.match(r"^(\s*\w[\w-]*):\s+(.+)$", line) + if match: + key, value = match.group(1), match.group(2) + # If value contains a colon and isn't already quoted + if ":" in value and not (value.startswith('"') or value.startswith("'")): + line = f'{key}: "{value}"' + lines.append(line) + return "\n".join(lines) + + +def _validate_skill_name(name: str, dir_path: Path | None = None, *, strict: bool = False) -> None: """Validate a skill name per the AgentSkills.io specification. - Rules: + In lenient mode (default), logs warnings for cosmetic issues but does not raise. + In strict mode, raises ValueError for any validation failure. + + Rules checked: - 1-64 characters long - Lowercase alphanumeric characters and hyphens only - Cannot start or end with a hyphen @@ -89,34 +125,48 @@ def _validate_skill_name(name: str, dir_path: Path | None = None) -> None: Args: name: The skill name to validate. dir_path: Optional path to the skill directory for name matching. + strict: If True, raise ValueError on any issue. If False (default), log warnings. Raises: - ValueError: If the skill name is invalid. + ValueError: If the skill name is empty, or if strict=True and any rule is violated. """ if not name: raise ValueError("Skill name cannot be empty") if len(name) > _MAX_SKILL_NAME_LENGTH: - raise ValueError(f"name=<{name}> | skill name exceeds {_MAX_SKILL_NAME_LENGTH} character limit") + msg = "name=<%s> | skill name exceeds %d character limit" + if strict: + raise ValueError(msg % (name, _MAX_SKILL_NAME_LENGTH)) + logger.warning(msg, name, _MAX_SKILL_NAME_LENGTH) if not _SKILL_NAME_PATTERN.match(name): - raise ValueError( - f"name=<{name}> | skill name must be 1-64 lowercase alphanumeric characters or hyphens, " - "cannot start/end with hyphen" + msg = ( + "name=<%s> | skill name should be 1-64 lowercase alphanumeric characters or hyphens, " + "should not start/end with hyphen" ) + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) if "--" in name: - raise ValueError(f"name=<{name}> | skill name cannot contain consecutive hyphens") + msg = "name=<%s> | skill name contains consecutive hyphens" + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) if dir_path is not None and dir_path.name != name: - raise ValueError(f"name=<{name}>, directory=<{dir_path.name}> | skill name must match parent directory name") + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (name, dir_path.name)) + logger.warning(msg, name, dir_path.name) -def load_skill(skill_path: str | Path) -> Skill: +def load_skill(skill_path: str | Path, *, strict: bool = False) -> Skill: """Load a single skill from a directory containing SKILL.md. Args: skill_path: Path to the skill directory or the SKILL.md file itself. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. Returns: A Skill instance populated from the SKILL.md file. @@ -149,7 +199,7 @@ def load_skill(skill_path: str | Path) -> Skill: if not isinstance(description, str) or not description: raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'description' field in frontmatter") - _validate_skill_name(name, skill_dir) + _validate_skill_name(name, skill_dir, strict=strict) # Parse allowed-tools (space-delimited string or YAML list) allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") diff --git a/tests/strands/plugins/skills/test_loader.py b/tests/strands/plugins/skills/test_loader.py index bfe490697..70628ecb2 100644 --- a/tests/strands/plugins/skills/test_loader.py +++ b/tests/strands/plugins/skills/test_loader.py @@ -1,11 +1,13 @@ """Tests for the skill loader module.""" +import logging from pathlib import Path import pytest from strands.plugins.skills.loader import ( _find_skill_md, + _fix_yaml_colons, _parse_frontmatter, _validate_skill_name, load_skill, @@ -89,10 +91,10 @@ def test_frontmatter_with_dashes_in_yaml_value(self): class TestValidateSkillName: - """Tests for _validate_skill_name.""" + """Tests for _validate_skill_name (lenient validation).""" def test_valid_names(self): - """Test that valid names pass validation.""" + """Test that valid names pass validation without warnings.""" valid_names = ["a", "test", "my-skill", "skill-123", "a1b2c3"] for name in valid_names: _validate_skill_name(name) # Should not raise @@ -102,48 +104,55 @@ def test_empty_name(self): with pytest.raises(ValueError, match="cannot be empty"): _validate_skill_name("") - def test_too_long_name(self): - """Test that names exceeding 64 chars raise ValueError.""" - with pytest.raises(ValueError, match="exceeds 64 character limit"): + def test_too_long_name_warns(self, caplog): + """Test that names exceeding 64 chars warn but do not raise.""" + with caplog.at_level(logging.WARNING): _validate_skill_name("a" * 65) + assert "exceeds" in caplog.text - def test_uppercase_rejected(self): - """Test that uppercase characters are rejected.""" - with pytest.raises(ValueError, match="lowercase alphanumeric"): + def test_uppercase_warns(self, caplog): + """Test that uppercase characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): _validate_skill_name("MySkill") + assert "lowercase alphanumeric" in caplog.text - def test_starts_with_hyphen(self): - """Test that names starting with hyphen are rejected.""" - with pytest.raises(ValueError, match="lowercase alphanumeric"): + def test_starts_with_hyphen_warns(self, caplog): + """Test that names starting with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): _validate_skill_name("-skill") + assert "lowercase alphanumeric" in caplog.text - def test_ends_with_hyphen(self): - """Test that names ending with hyphen are rejected.""" - with pytest.raises(ValueError, match="lowercase alphanumeric"): + def test_ends_with_hyphen_warns(self, caplog): + """Test that names ending with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): _validate_skill_name("skill-") + assert "lowercase alphanumeric" in caplog.text - def test_consecutive_hyphens(self): - """Test that consecutive hyphens are rejected.""" - with pytest.raises(ValueError, match="consecutive hyphens"): + def test_consecutive_hyphens_warns(self, caplog): + """Test that consecutive hyphens warn but do not raise.""" + with caplog.at_level(logging.WARNING): _validate_skill_name("my--skill") + assert "consecutive hyphens" in caplog.text - def test_special_characters(self): - """Test that special characters are rejected.""" - with pytest.raises(ValueError, match="lowercase alphanumeric"): + def test_special_characters_warns(self, caplog): + """Test that special characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): _validate_skill_name("my_skill") + assert "lowercase alphanumeric" in caplog.text - def test_directory_name_mismatch(self, tmp_path): - """Test that skill name must match directory name.""" + def test_directory_name_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but does not raise.""" skill_dir = tmp_path / "wrong-name" skill_dir.mkdir() - with pytest.raises(ValueError, match="must match parent directory name"): + with caplog.at_level(logging.WARNING): _validate_skill_name("my-skill", skill_dir) + assert "does not match parent directory name" in caplog.text def test_directory_name_match(self, tmp_path): """Test that matching directory name passes.""" skill_dir = tmp_path / "my-skill" skill_dir.mkdir() - _validate_skill_name("my-skill", skill_dir) # Should not raise + _validate_skill_name("my-skill", skill_dir) # Should not raise or warn def _make_skill_dir(parent: Path, name: str, description: str = "A test skill", body: str = "Instructions.") -> Path: @@ -240,14 +249,17 @@ def test_load_nonexistent_path(self, tmp_path): with pytest.raises(FileNotFoundError): load_skill(tmp_path / "nonexistent") - def test_load_name_directory_mismatch(self, tmp_path): - """Test error when skill name doesn't match directory name.""" + def test_load_name_directory_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but still loads.""" skill_dir = tmp_path / "wrong-dir" skill_dir.mkdir() (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") - with pytest.raises(ValueError, match="must match parent directory name"): - load_skill(skill_dir) + with caplog.at_level(logging.WARNING): + skill = load_skill(skill_dir) + + assert skill.name == "right-name" + assert "does not match parent directory name" in caplog.text class TestLoadSkills: @@ -293,16 +305,137 @@ def test_nonexistent_directory(self, tmp_path): with pytest.raises(FileNotFoundError): load_skills(tmp_path / "nonexistent") - def test_skips_invalid_skills(self, tmp_path): - """Test that invalid skills are skipped with a warning.""" + def test_loads_mismatched_name_with_warning(self, tmp_path, caplog): + """Test that skills with name/directory mismatch are loaded with a warning.""" _make_skill_dir(tmp_path, "good-skill") - # Create an invalid skill (name mismatch) + # Create a skill with name mismatch (lenient validation loads it anyway) bad_dir = tmp_path / "bad-dir" bad_dir.mkdir() (bad_dir / "SKILL.md").write_text("---\nname: wrong-name\ndescription: test\n---\nBody.") - skills = load_skills(tmp_path) + with caplog.at_level(logging.WARNING): + skills = load_skills(tmp_path) - assert len(skills) == 1 - assert skills[0].name == "good-skill" + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"good-skill", "wrong-name"} + assert "does not match parent directory name" in caplog.text + + +class TestFixYamlColons: + """Tests for _fix_yaml_colons.""" + + def test_fixes_unquoted_colon_in_value(self): + """Test that an unquoted colon in a value gets quoted.""" + raw = "description: Use this skill when: the user asks about PDFs" + fixed = _fix_yaml_colons(raw) + assert fixed == 'description: "Use this skill when: the user asks about PDFs"' + + def test_leaves_already_double_quoted_value(self): + """Test that already double-quoted values are not re-quoted.""" + raw = 'description: "already: quoted"' + assert _fix_yaml_colons(raw) == raw + + def test_leaves_already_single_quoted_value(self): + """Test that already single-quoted values are not re-quoted.""" + raw = "description: 'already: quoted'" + assert _fix_yaml_colons(raw) == raw + + def test_leaves_value_without_colon(self): + """Test that values without colons are unchanged.""" + raw = "name: my-skill" + assert _fix_yaml_colons(raw) == raw + + def test_multiline_mixed(self): + """Test fixing only the lines that need it in a multi-line string.""" + raw = "name: my-skill\ndescription: Use when: needed\nversion: 1.0" + fixed = _fix_yaml_colons(raw) + assert fixed == 'name: my-skill\ndescription: "Use when: needed"\nversion: 1.0' + + def test_empty_string(self): + """Test that an empty string is returned unchanged.""" + assert _fix_yaml_colons("") == "" + + def test_preserves_indented_lines_without_colons(self): + """Test that indented lines without key-value patterns are preserved.""" + raw = " - item one\n - item two" + assert _fix_yaml_colons(raw) == raw + + +class TestValidateSkillNameStrict: + """Tests for _validate_skill_name with strict=True.""" + + def test_strict_valid_name(self): + """Test that valid names pass strict validation.""" + _validate_skill_name("my-skill", strict=True) # Should not raise + + def test_strict_empty_name(self): + """Test that empty name raises in strict mode.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("", strict=True) + + def test_strict_too_long_name(self): + """Test that names exceeding 64 chars raise in strict mode.""" + with pytest.raises(ValueError, match="exceeds 64 character limit"): + _validate_skill_name("a" * 65, strict=True) + + def test_strict_uppercase_rejected(self): + """Test that uppercase characters raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("MySkill", strict=True) + + def test_strict_starts_with_hyphen(self): + """Test that names starting with hyphen raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("-skill", strict=True) + + def test_strict_consecutive_hyphens(self): + """Test that consecutive hyphens raise in strict mode.""" + with pytest.raises(ValueError, match="consecutive hyphens"): + _validate_skill_name("my--skill", strict=True) + + def test_strict_directory_mismatch(self, tmp_path): + """Test that directory name mismatch raises in strict mode.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with pytest.raises(ValueError, match="does not match parent directory name"): + _validate_skill_name("my-skill", skill_dir, strict=True) + + +class TestLoadSkillStrict: + """Tests for load_skill with strict=True.""" + + def test_strict_rejects_name_mismatch(self, tmp_path): + """Test that strict mode raises on name/directory mismatch.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="does not match parent directory name"): + load_skill(skill_dir, strict=True) + + def test_strict_accepts_valid_skill(self, tmp_path): + """Test that strict mode loads a valid skill without error.""" + _make_skill_dir(tmp_path, "valid-skill") + skill = load_skill(tmp_path / "valid-skill", strict=True) + assert skill.name == "valid-skill" + + +class TestParseFrontmatterYamlFallback: + """Tests for YAML colon-quoting fallback in _parse_frontmatter.""" + + def test_fallback_on_unquoted_colon(self): + """Test that frontmatter with unquoted colons in values is parsed via fallback.""" + content = "---\nname: my-skill\ndescription: Use when: the user asks\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert "Use when" in frontmatter["description"] + assert body == "Body." + + def test_fallback_preserves_valid_yaml(self): + """Test that valid YAML is parsed normally without triggering fallback.""" + content = "---\nname: my-skill\ndescription: A simple description\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert frontmatter["description"] == "A simple description" diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index d0c01b357..15c9d1cbd 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -1,4 +1,4 @@ -"""Tests for the SkillsPlugin.""" +"""Tests for the AgentSkills plugin.""" import logging from pathlib import Path @@ -7,8 +7,8 @@ from strands.hooks.events import BeforeInvocationEvent from strands.hooks.registry import HookRegistry from strands.plugins.registry import _PluginRegistry +from strands.plugins.skills.agent_skills import AgentSkills from strands.plugins.skills.skill import Skill -from strands.plugins.skills.skills_plugin import SkillsPlugin from strands.types.tools import ToolContext @@ -70,12 +70,12 @@ def _set_system_prompt(agent: MagicMock, value: str | None) -> None: class TestSkillsPluginInit: - """Tests for SkillsPlugin initialization.""" + """Tests for AgentSkills initialization.""" def test_init_with_skill_instances(self): """Test initialization with Skill instances.""" skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) assert len(plugin.available_skills) == 1 assert plugin.available_skills[0].name == "test-skill" @@ -83,7 +83,7 @@ def test_init_with_skill_instances(self): def test_init_with_filesystem_paths(self, tmp_path): """Test initialization with filesystem paths.""" _make_skill_dir(tmp_path, "fs-skill") - plugin = SkillsPlugin(skills=[str(tmp_path / "fs-skill")]) + plugin = AgentSkills(skills=[str(tmp_path / "fs-skill")]) assert len(plugin.available_skills) == 1 assert plugin.available_skills[0].name == "fs-skill" @@ -92,7 +92,7 @@ def test_init_with_parent_directory(self, tmp_path): """Test initialization with a parent directory containing skills.""" _make_skill_dir(tmp_path, "skill-a") _make_skill_dir(tmp_path, "skill-b") - plugin = SkillsPlugin(skills=[tmp_path]) + plugin = AgentSkills(skills=[tmp_path]) assert len(plugin.available_skills) == 2 @@ -100,7 +100,7 @@ def test_init_with_mixed_sources(self, tmp_path): """Test initialization with mixed skill sources.""" _make_skill_dir(tmp_path, "fs-skill") direct_skill = _make_skill(name="direct-skill", description="Direct") - plugin = SkillsPlugin(skills=[str(tmp_path / "fs-skill"), direct_skill]) + plugin = AgentSkills(skills=[str(tmp_path / "fs-skill"), direct_skill]) assert len(plugin.available_skills) == 2 names = {s.name for s in plugin.available_skills} @@ -108,27 +108,27 @@ def test_init_with_mixed_sources(self, tmp_path): def test_init_skips_nonexistent_paths(self, tmp_path): """Test that nonexistent paths are skipped gracefully.""" - plugin = SkillsPlugin(skills=[str(tmp_path / "nonexistent")]) + plugin = AgentSkills(skills=[str(tmp_path / "nonexistent")]) assert len(plugin.available_skills) == 0 def test_init_empty_skills(self): """Test initialization with empty skills list.""" - plugin = SkillsPlugin(skills=[]) + plugin = AgentSkills(skills=[]) assert plugin.available_skills == [] def test_name_attribute(self): """Test that the plugin has the correct name.""" - plugin = SkillsPlugin(skills=[]) + plugin = AgentSkills(skills=[]) assert plugin.name == "skills" def test_custom_state_key(self): """Test initialization with a custom state key.""" - plugin = SkillsPlugin(skills=[], state_key="custom_key") + plugin = AgentSkills(skills=[], state_key="custom_key") assert plugin._state_key == "custom_key" def test_custom_max_resource_files(self): """Test initialization with a custom max resource files limit.""" - plugin = SkillsPlugin(skills=[], max_resource_files=50) + plugin = AgentSkills(skills=[], max_resource_files=50) assert plugin._max_resource_files == 50 @@ -137,7 +137,7 @@ class TestSkillsPluginInitAgent: def test_registers_tool(self): """Test that the plugin registry registers the skills tool.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() registry = _PluginRegistry(agent) @@ -147,7 +147,7 @@ def test_registers_tool(self): def test_registers_hooks(self): """Test that the plugin registry registers hook callbacks.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() registry = _PluginRegistry(agent) @@ -157,7 +157,7 @@ def test_registers_hooks(self): def test_does_not_store_agent_reference(self): """Test that init_agent does not store the agent on the plugin.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() plugin.init_agent(agent) @@ -166,12 +166,12 @@ def test_does_not_store_agent_reference(self): class TestSkillsPluginProperties: - """Tests for SkillsPlugin properties.""" + """Tests for AgentSkills properties.""" def test_available_skills_getter_returns_copy(self): """Test that the available_skills getter returns a copy of the list.""" skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) skills_list = plugin.available_skills skills_list.append(_make_skill(name="another-skill", description="Another")) @@ -180,7 +180,7 @@ def test_available_skills_getter_returns_copy(self): def test_available_skills_setter(self): """Test setting skills via the property setter.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) new_skill = _make_skill(name="new-skill", description="New") plugin.available_skills = [new_skill] @@ -194,7 +194,7 @@ class TestLoadSkills: def test_appends_skill_instances(self): """Test that load_skills appends Skill instances to existing skills.""" - plugin = SkillsPlugin(skills=[_make_skill(name="existing", description="Existing")]) + plugin = AgentSkills(skills=[_make_skill(name="existing", description="Existing")]) plugin.load_skills([_make_skill(name="new-skill", description="New")]) @@ -204,7 +204,7 @@ def test_appends_skill_instances(self): def test_appends_from_filesystem(self, tmp_path): """Test that load_skills appends skills resolved from filesystem paths.""" - plugin = SkillsPlugin(skills=[_make_skill(name="existing", description="Existing")]) + plugin = AgentSkills(skills=[_make_skill(name="existing", description="Existing")]) _make_skill_dir(tmp_path, "fs-skill") plugin.load_skills([str(tmp_path / "fs-skill")]) @@ -216,7 +216,7 @@ def test_appends_from_filesystem(self, tmp_path): def test_duplicates_overwrite(self): """Test that loading a skill with the same name overwrites the existing one.""" original = _make_skill(name="dupe", description="Original") - plugin = SkillsPlugin(skills=[original]) + plugin = AgentSkills(skills=[original]) replacement = _make_skill(name="dupe", description="Replacement") plugin.load_skills([replacement]) @@ -226,7 +226,7 @@ def test_duplicates_overwrite(self): def test_mixed_sources(self, tmp_path): """Test load_skills with a mix of Skill instances and filesystem paths.""" - plugin = SkillsPlugin(skills=[]) + plugin = AgentSkills(skills=[]) _make_skill_dir(tmp_path, "fs-skill") direct = _make_skill(name="direct", description="Direct") @@ -238,7 +238,7 @@ def test_mixed_sources(self, tmp_path): def test_skips_nonexistent_paths(self): """Test that nonexistent paths are skipped without error.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) plugin.load_skills(["/nonexistent/path"]) @@ -246,7 +246,7 @@ def test_skips_nonexistent_paths(self): def test_empty_sources(self): """Test that loading empty sources is a no-op.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) plugin.load_skills([]) @@ -254,7 +254,7 @@ def test_empty_sources(self): def test_parent_directory(self, tmp_path): """Test load_skills with a parent directory containing multiple skills.""" - plugin = SkillsPlugin(skills=[]) + plugin = AgentSkills(skills=[]) _make_skill_dir(tmp_path, "child-a") _make_skill_dir(tmp_path, "child-b") @@ -271,7 +271,7 @@ class TestSkillsTool: def test_activate_skill(self): """Test activating a skill returns its instructions.""" skill = _make_skill(instructions="Full instructions here.") - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) agent = _mock_agent() tool_context = _mock_tool_context(agent) @@ -282,7 +282,7 @@ def test_activate_skill(self): def test_activate_nonexistent_skill(self): """Test activating a nonexistent skill returns error message.""" skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) agent = _mock_agent() tool_context = _mock_tool_context(agent) @@ -295,7 +295,7 @@ def test_activate_replaces_previous(self): """Test that activating a new skill replaces the previous one.""" skill1 = _make_skill(name="skill-a", description="A", instructions="A instructions") skill2 = _make_skill(name="skill-b", description="B", instructions="B instructions") - plugin = SkillsPlugin(skills=[skill1, skill2]) + plugin = AgentSkills(skills=[skill1, skill2]) agent = _mock_agent() tool_context = _mock_tool_context(agent) @@ -307,7 +307,7 @@ def test_activate_replaces_previous(self): def test_activate_without_name(self): """Test activating without a skill name returns error.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() tool_context = _mock_tool_context(agent) @@ -322,7 +322,7 @@ class TestSystemPromptInjection: def test_before_invocation_appends_skills_xml(self): """Test that before_invocation appends skills XML to system prompt.""" skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) agent = _mock_agent() event = BeforeInvocationEvent(agent=agent) @@ -334,7 +334,7 @@ def test_before_invocation_appends_skills_xml(self): def test_before_invocation_preserves_existing_prompt(self): """Test that existing system prompt content is preserved.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = "Original prompt." agent._system_prompt_content = [{"text": "Original prompt."}] @@ -347,7 +347,7 @@ def test_before_invocation_preserves_existing_prompt(self): def test_repeated_invocations_do_not_accumulate(self): """Test that repeated invocations rebuild from current prompt without accumulation.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = "Original prompt." agent._system_prompt_content = [{"text": "Original prompt."}] @@ -361,9 +361,9 @@ def test_repeated_invocations_do_not_accumulate(self): assert first_prompt == second_prompt - def test_no_skills_skips_injection(self): - """Test that injection is skipped when no skills are available.""" - plugin = SkillsPlugin(skills=[]) + def test_no_skills_injects_empty_message(self): + """Test that a 'no skills available' message is injected when no skills are loaded.""" + plugin = AgentSkills(skills=[]) agent = _mock_agent() original_prompt = "Original prompt." agent._system_prompt = original_prompt @@ -372,11 +372,12 @@ def test_no_skills_skips_injection(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - assert agent.system_prompt == original_prompt + assert "No skills are currently available" in agent.system_prompt + assert agent.system_prompt.startswith("Original prompt.") def test_none_system_prompt_handled(self): """Test handling when system prompt is None.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = None agent._system_prompt_content = None @@ -388,7 +389,7 @@ def test_none_system_prompt_handled(self): def test_preserves_other_plugin_modifications(self): """Test that modifications by other plugins/hooks are preserved.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = "Original prompt." agent._system_prompt_content = [{"text": "Original prompt."}] @@ -406,7 +407,7 @@ def test_preserves_other_plugin_modifications(self): def test_uses_public_system_prompt_setter(self): """Test that the hook uses the public system_prompt setter.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = "Original." agent._system_prompt_content = [{"text": "Original."}] @@ -420,7 +421,7 @@ def test_uses_public_system_prompt_setter(self): def test_warns_when_previous_xml_not_found(self, caplog): """Test that a warning is logged when the previously injected XML is missing from the prompt.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) agent = _mock_agent() agent._system_prompt = "Original prompt." agent._system_prompt_content = [{"text": "Original prompt."}] @@ -443,7 +444,7 @@ class TestSkillsXmlGeneration: def test_single_skill(self): """Test XML generation with a single skill.""" - plugin = SkillsPlugin(skills=[_make_skill()]) + plugin = AgentSkills(skills=[_make_skill()]) xml = plugin._generate_skills_xml() assert "" in xml @@ -457,25 +458,26 @@ def test_multiple_skills(self): _make_skill(name="skill-a", description="Skill A"), _make_skill(name="skill-b", description="Skill B"), ] - plugin = SkillsPlugin(skills=skills) + plugin = AgentSkills(skills=skills) xml = plugin._generate_skills_xml() assert "skill-a" in xml assert "skill-b" in xml def test_empty_skills(self): - """Test XML generation with no skills.""" - plugin = SkillsPlugin(skills=[]) + """Test XML generation with no skills includes 'no skills available' message.""" + plugin = AgentSkills(skills=[]) xml = plugin._generate_skills_xml() assert "" in xml + assert "No skills are currently available" in xml assert "" in xml def test_location_included_when_path_set(self, tmp_path): """Test that location element is included when skill has a path.""" skill = _make_skill() skill.path = tmp_path / "test-skill" - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) xml = plugin._generate_skills_xml() assert f"{tmp_path / 'test-skill' / 'SKILL.md'}" in xml @@ -484,7 +486,7 @@ def test_location_omitted_when_path_none(self): """Test that location element is omitted for programmatic skills.""" skill = _make_skill() assert skill.path is None - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) xml = plugin._generate_skills_xml() assert "" not in xml @@ -492,7 +494,7 @@ def test_location_omitted_when_path_none(self): def test_escapes_xml_special_characters(self): """Test that XML special characters in names and descriptions are escaped.""" skill = _make_skill(name="a&c", description="Use & more") - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) xml = plugin._generate_skills_xml() assert "a<b>&c" in xml @@ -505,7 +507,7 @@ class TestSkillResponseFormat: def test_instructions_only(self): """Test response with just instructions.""" skill = _make_skill(instructions="Do the thing.") - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert result == "Do the thing." @@ -513,7 +515,7 @@ def test_instructions_only(self): def test_no_instructions(self): """Test response when skill has no instructions.""" skill = _make_skill(instructions="") - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "no instructions available" in result.lower() @@ -522,7 +524,7 @@ def test_includes_allowed_tools(self): """Test response includes allowed tools when set.""" skill = _make_skill(instructions="Do the thing.") skill.allowed_tools = ["Bash", "Read"] - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Do the thing." in result @@ -532,7 +534,7 @@ def test_includes_compatibility(self): """Test response includes compatibility when set.""" skill = _make_skill(instructions="Do the thing.") skill.compatibility = "Requires docker" - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Compatibility: Requires docker" in result @@ -541,7 +543,7 @@ def test_includes_location(self, tmp_path): """Test response includes location when path is set.""" skill = _make_skill(instructions="Do the thing.") skill.path = tmp_path / "test-skill" - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert f"Location: {tmp_path / 'test-skill' / 'SKILL.md'}" in result @@ -552,7 +554,7 @@ def test_all_metadata(self, tmp_path): skill.allowed_tools = ["Bash"] skill.compatibility = "Requires git" skill.path = tmp_path / "test-skill" - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Do the thing." in result @@ -572,7 +574,7 @@ def test_includes_resource_listing(self, tmp_path): skill = _make_skill(instructions="Do the thing.") skill.path = skill_dir - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Available resources:" in result @@ -582,7 +584,7 @@ def test_includes_resource_listing(self, tmp_path): def test_no_resources_when_no_path(self): """Test that resources section is omitted for programmatic skills.""" skill = _make_skill(instructions="Do the thing.") - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Available resources:" not in result @@ -594,7 +596,7 @@ def test_no_resources_when_dirs_empty(self, tmp_path): skill = _make_skill(instructions="Do the thing.") skill.path = skill_dir - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Available resources:" not in result @@ -609,7 +611,7 @@ def test_resource_listing_truncated(self, tmp_path): skill = _make_skill(instructions="Do the thing.") skill.path = skill_dir - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) result = plugin._format_skill_response(skill) assert "Available resources:" in result @@ -622,7 +624,7 @@ class TestResolveSkills: def test_resolve_skill_instances(self): """Test resolving Skill instances (pass-through).""" skill = _make_skill() - plugin = SkillsPlugin(skills=[skill]) + plugin = AgentSkills(skills=[skill]) assert len(plugin._skills) == 1 assert plugin._skills["test-skill"] is skill @@ -630,7 +632,7 @@ def test_resolve_skill_instances(self): def test_resolve_skill_directory_path(self, tmp_path): """Test resolving a path to a skill directory.""" _make_skill_dir(tmp_path, "path-skill") - plugin = SkillsPlugin(skills=[tmp_path / "path-skill"]) + plugin = AgentSkills(skills=[tmp_path / "path-skill"]) assert len(plugin._skills) == 1 assert "path-skill" in plugin._skills @@ -639,21 +641,21 @@ def test_resolve_parent_directory_path(self, tmp_path): """Test resolving a path to a parent directory.""" _make_skill_dir(tmp_path, "child-a") _make_skill_dir(tmp_path, "child-b") - plugin = SkillsPlugin(skills=[tmp_path]) + plugin = AgentSkills(skills=[tmp_path]) assert len(plugin._skills) == 2 def test_resolve_skill_md_file_path(self, tmp_path): """Test resolving a path to a SKILL.md file.""" skill_dir = _make_skill_dir(tmp_path, "file-skill") - plugin = SkillsPlugin(skills=[skill_dir / "SKILL.md"]) + plugin = AgentSkills(skills=[skill_dir / "SKILL.md"]) assert len(plugin._skills) == 1 assert "file-skill" in plugin._skills def test_resolve_nonexistent_path(self, tmp_path): """Test that nonexistent paths are skipped.""" - plugin = SkillsPlugin(skills=[str(tmp_path / "ghost")]) + plugin = AgentSkills(skills=[str(tmp_path / "ghost")]) assert len(plugin._skills) == 0 @@ -661,10 +663,10 @@ class TestImports: """Tests for module imports.""" def test_import_from_plugins(self): - """Test importing SkillsPlugin from strands.plugins.""" - from strands.plugins import SkillsPlugin as SP + """Test importing AgentSkills from strands.plugins.""" + from strands.plugins import AgentSkills as SP - assert SP is SkillsPlugin + assert SP is AgentSkills def test_import_skill_from_strands(self): """Test importing Skill from top-level strands package.""" @@ -674,22 +676,22 @@ def test_import_skill_from_strands(self): def test_import_from_skills_package(self): """Test importing from strands.plugins.skills package.""" - from strands.plugins.skills import Skill, SkillsPlugin, load_skill, load_skills + from strands.plugins.skills import AgentSkills, Skill, load_skill, load_skills assert Skill is not None - assert SkillsPlugin is not None + assert AgentSkills is not None assert load_skill is not None assert load_skills is not None def test_skills_plugin_is_plugin_subclass(self): - """Test that SkillsPlugin is a subclass of the Plugin ABC.""" + """Test that AgentSkills is a subclass of the Plugin ABC.""" from strands.plugins import Plugin - assert issubclass(SkillsPlugin, Plugin) + assert issubclass(AgentSkills, Plugin) def test_skills_plugin_isinstance_check(self): - """Test that SkillsPlugin instances pass isinstance check against Plugin.""" + """Test that AgentSkills instances pass isinstance check against Plugin.""" from strands.plugins import Plugin - plugin = SkillsPlugin(skills=[]) + plugin = AgentSkills(skills=[]) assert isinstance(plugin, Plugin) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index da7f010e2..50c0cc9b9 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -148,14 +148,16 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "chat", - "gen_ai.system": "strands-agents", - "custom_key": "custom_value", - "user_id": "12345", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "chat", + "gen_ai.system": "strands-agents", + "custom_key": "custom_value", + "user_id": "12345", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + } + ) mock_span.add_event.assert_called_with( "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} ) @@ -188,13 +190,15 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "chat", - "gen_ai.provider.name": "strands-agents", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "chat", + "gen_ai.provider.name": "strands-agents", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -232,15 +236,17 @@ def test_end_model_invoke_span(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -259,15 +265,17 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -300,15 +308,17 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.tool.name": "test-tool", - "gen_ai.system": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - "session_id": "abc123", - "environment": "production", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.tool.name": "test-tool", + "gen_ai.system": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + "session_id": "abc123", + "environment": "production", + } + ) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -331,13 +341,15 @@ def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.tool.name": "test-tool", - "gen_ai.provider.name": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.tool.name": "test-tool", + "gen_ai.provider.name": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -377,14 +389,16 @@ def test_start_swarm_call_span_with_string_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - "workflow_id": "wf-789", - "priority": "high", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + "workflow_id": "wf-789", + "priority": "high", + } + ) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"}) assert span is not None @@ -404,12 +418,14 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + } + ) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'} ) @@ -460,12 +476,14 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "swarm", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "swarm", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -528,13 +546,15 @@ def test_start_graph_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "execute_tool", - "gen_ai.system": "strands-agents", - "gen_ai.tool.name": "test-tool", - "gen_ai.tool.call.id": "123", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "execute_tool", + "gen_ai.system": "strands-agents", + "gen_ai.tool.name": "test-tool", + "gen_ai.tool.call.id": "123", + } + ) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -608,12 +628,14 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - - mock_span.set_attributes.assert_called_once_with({ - "event_loop.cycle_id": "cycle-123", - "request_id": "req-456", - "trace_level": "debug", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "event_loop.cycle_id": "cycle-123", + "request_id": "req-456", + "trace_level": "debug", + } + ) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])} ) @@ -637,7 +659,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - + mock_span.set_attributes.assert_called_once_with({"event_loop.cycle_id": "cycle-123"}) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", @@ -731,14 +753,16 @@ def test_start_agent_span(mock_tracer): assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_agent", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_agent", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + } + ) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) assert span is not None @@ -768,15 +792,17 @@ def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_agent", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_agent", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -919,17 +945,19 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.usage.cache_read_input_tokens": 5, - "gen_ai.usage.cache_write_input_tokens": 3, - "gen_ai.server.request.duration": 10, - "gen_ai.server.time_to_first_token": 5, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.usage.cache_read_input_tokens": 5, + "gen_ai.usage.cache_write_input_tokens": 3, + "gen_ai.server.request.duration": 10, + "gen_ai.server.time_to_first_token": 5, + } + ) def test_end_agent_span_with_cache_metrics(mock_span): @@ -953,15 +981,17 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 50, - "gen_ai.usage.input_tokens": 50, - "gen_ai.usage.completion_tokens": 100, - "gen_ai.usage.output_tokens": 100, - "gen_ai.usage.total_tokens": 150, - "gen_ai.usage.cache_read_input_tokens": 25, - "gen_ai.usage.cache_write_input_tokens": 10, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 25, + "gen_ai.usage.cache_write_input_tokens": 10, + } + ) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -1519,18 +1549,20 @@ def test_end_model_invoke_span_langfuse_adds_attributes(mock_span, monkeypatch): } ] ) - + assert mock_span.set_attributes.call_count == 2 mock_span.set_attributes.assert_any_call({"gen_ai.output.messages": expected_output}) - mock_span.set_attributes.assert_any_call({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - }) + mock_span.set_attributes.assert_any_call( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index c6d8e2770..b9a782447 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -1,4 +1,4 @@ -"""Integration tests for the SkillsPlugin. +"""Integration tests for the AgentSkills plugin. Tests end-to-end behavior with a real model: skill metadata injection into the system prompt, agent-driven skill activation via the skills tool, and @@ -8,7 +8,7 @@ import pytest from strands import Agent -from strands.plugins.skills import Skill, SkillsPlugin +from strands.plugins.skills import AgentSkills, Skill SUMMARIZATION_SKILL = Skill( name="summarization", @@ -25,7 +25,7 @@ @pytest.fixture def skills_plugin(): - return SkillsPlugin(skills=[SUMMARIZATION_SKILL, TRANSLATION_SKILL]) + return AgentSkills(skills=[SUMMARIZATION_SKILL, TRANSLATION_SKILL]) @pytest.fixture From e8e77d43c34826cce03ea322d7a4de20479a3443 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 6 Mar 2026 14:55:16 -0500 Subject: [PATCH 23/29] fix(skills): update name --- src/strands/plugins/skills/agent_skills.py | 5 +---- tests/strands/plugins/skills/test_skills_plugin.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py index b4a3cebd3..31129de73 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -57,10 +57,7 @@ class AgentSkills(Plugin): ``` """ - @property - def name(self) -> str: - """A stable string identifier for the plugin.""" - return "skills" + name = "agent_skills" def __init__( self, diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index 15c9d1cbd..b6e858daf 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -119,7 +119,7 @@ def test_init_empty_skills(self): def test_name_attribute(self): """Test that the plugin has the correct name.""" plugin = AgentSkills(skills=[]) - assert plugin.name == "skills" + assert plugin.name == "agent_skills" def test_custom_state_key(self): """Test initialization with a custom state key.""" From d797b17df5dcc9363e02405acfa815368eee36ed Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 6 Mar 2026 15:25:24 -0500 Subject: [PATCH 24/29] feat: track activated skills in agent state --- src/strands/plugins/skills/agent_skills.py | 35 ++++++++++++ .../plugins/skills/test_skills_plugin.py | 56 +++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py index 31129de73..b9c905a7f 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -116,6 +116,7 @@ def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D4 return f"Skill '{skill_name}' not found. Available skills: {available}" logger.debug("skill_name=<%s> | skill activated", skill_name) + self._track_activated_skill(tool_context.agent, skill_name) return self._format_skill_response(found) @hook @@ -350,3 +351,37 @@ def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: state_data = {} state_data[key] = value agent.state.set(self._state_key, state_data) + + def _track_activated_skill(self, agent: Agent, skill_name: str) -> None: + """Record a skill activation in agent state. + + Maintains an ordered list of activated skill names (most recent last), + without duplicates. + + Args: + agent: The agent whose state to update. + skill_name: Name of the activated skill. + """ + state_data = agent.state.get(self._state_key) + activated: list[str] = state_data.get("activated_skills", []) if isinstance(state_data, dict) else [] + if skill_name in activated: + activated.remove(skill_name) + activated.append(skill_name) + self._set_state_field(agent, "activated_skills", activated) + + def get_activated_skills(self, agent: Agent) -> list[str]: + """Get the list of skills activated by this agent. + + Returns skill names in activation order (most recent last). + + Args: + agent: The agent to query. + + Returns: + List of activated skill names. + """ + state_data = agent.state.get(self._state_key) + if isinstance(state_data, dict): + return list(state_data.get("activated_skills", [])) + return [] + diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_skills_plugin.py index b6e858daf..84bbefde7 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_skills_plugin.py @@ -315,6 +315,62 @@ def test_activate_without_name(self): assert "required" in result.lower() + def test_activate_tracks_in_agent_state(self): + """Test that activating a skill records it in agent state.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="test-skill", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["test-skill"] + + def test_activate_multiple_tracks_order(self): + """Test that multiple activations are tracked in order.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A") + skill_b = _make_skill(name="skill-b", description="B", instructions="B") + plugin = AgentSkills(skills=[skill_a, skill_b]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="skill-a", tool_context=tool_context) + plugin.skills(skill_name="skill-b", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["skill-a", "skill-b"] + + def test_activate_same_skill_twice_deduplicates(self): + """Test that re-activating a skill moves it to the end without duplicates.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A") + skill_b = _make_skill(name="skill-b", description="B", instructions="B") + plugin = AgentSkills(skills=[skill_a, skill_b]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="skill-a", tool_context=tool_context) + plugin.skills(skill_name="skill-b", tool_context=tool_context) + plugin.skills(skill_name="skill-a", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["skill-b", "skill-a"] + + def test_get_activated_skills_empty_by_default(self): + """Test that get_activated_skills returns empty list when nothing activated.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + assert plugin.get_activated_skills(agent) == [] + + def test_get_activated_skills_returns_copy(self): + """Test that get_activated_skills returns a copy, not a reference.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="test-skill", tool_context=tool_context) + result = plugin.get_activated_skills(agent) + result.append("injected") + + assert plugin.get_activated_skills(agent) == ["test-skill"] + class TestSystemPromptInjection: """Tests for system prompt injection via hooks.""" From afac330174269e0f19e0906a5b2558b1d96fd93e Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 10 Mar 2026 13:01:48 -0400 Subject: [PATCH 25/29] refactor: move loader into skills and small pr comment changes --- src/strands/plugins/skills/__init__.py | 8 +- src/strands/plugins/skills/agent_skills.py | 33 +- src/strands/plugins/skills/loader.py | 275 ---------- src/strands/plugins/skills/skill.py | 346 +++++++++++- ..._skills_plugin.py => test_agent_skills.py} | 82 +-- tests/strands/plugins/skills/test_loader.py | 441 --------------- tests/strands/plugins/skills/test_skill.py | 513 +++++++++++++++++- tests/strands/telemetry/test_tracer.py | 292 +++++----- tests_integ/test_skills_plugin.py | 22 + 9 files changed, 1080 insertions(+), 932 deletions(-) delete mode 100644 src/strands/plugins/skills/loader.py rename tests/strands/plugins/skills/{test_skills_plugin.py => test_agent_skills.py} (91%) delete mode 100644 tests/strands/plugins/skills/test_loader.py diff --git a/src/strands/plugins/skills/__init__.py b/src/strands/plugins/skills/__init__.py index 231a7c1cf..6784c7c27 100644 --- a/src/strands/plugins/skills/__init__.py +++ b/src/strands/plugins/skills/__init__.py @@ -10,18 +10,20 @@ from strands import Agent from strands.plugins.skills import Skill, AgentSkills + # Load from filesystem via classmethods + skill = Skill.from_file("./skills/pdf-processing") + skills = Skill.from_directory("./skills/") + + # Or let the plugin resolve paths automatically plugin = AgentSkills(skills=["./skills/pdf-processing"]) agent = Agent(plugins=[plugin]) ``` """ from .agent_skills import AgentSkills -from .loader import load_skill, load_skills from .skill import Skill __all__ = [ "AgentSkills", "Skill", - "load_skill", - "load_skills", ] diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py index b9c905a7f..2dd69b5b1 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -16,7 +16,6 @@ from ...plugins import Plugin, hook from ...tools.decorator import tool from ...types.tools import ToolContext -from .loader import load_skill, load_skills from .skill import Skill if TYPE_CHECKING: @@ -24,7 +23,7 @@ logger = logging.getLogger(__name__) -_DEFAULT_STATE_KEY = "skills_plugin" +_DEFAULT_STATE_KEY = "agent_skills" _RESOURCE_DIRS = ("scripts", "references", "assets") _DEFAULT_MAX_RESOURCE_FILES = 20 @@ -152,8 +151,7 @@ def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: self._set_state_field(agent, "last_injected_xml", new_injected_xml) agent.system_prompt = new_prompt - @property - def available_skills(self) -> list[Skill]: + def get_available_skills(self) -> list[Skill]: """Get the list of available skills. Returns: @@ -161,18 +159,21 @@ def available_skills(self) -> list[Skill]: """ return list(self._skills.values()) - @available_skills.setter - def available_skills(self, value: list[Skill]) -> None: - """Set the available skills directly. + def set_available_skills(self, skills: list[str | Path | Skill]) -> None: + """Set the available skills, replacing any existing ones. + + Each element can be a ``Skill`` instance, a ``str`` or ``Path`` to a + skill directory (containing SKILL.md), or a ``str`` or ``Path`` to a + parent directory containing skill subdirectories. Note: this does not persist state or deactivate skills on any agent. Active skill state is managed per-agent and will be reconciled on the next tool call or invocation. Args: - value: List of Skill instances. + skills: List of skill sources to resolve and set. """ - self._skills = {s.name: s for s in value} + self._skills = self._resolve_skills(skills) def load_skills(self, sources: list[str | Path | Skill]) -> None: """Resolve and append skills from mixed sources. @@ -310,7 +311,7 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] if has_skill_md: try: - skill = load_skill(path) + skill = Skill.from_file(path) if skill.name in resolved: logger.warning( "name=<%s> | duplicate skill name, overwriting previous skill", skill.name @@ -320,7 +321,7 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] logger.warning("path=<%s> | failed to load skill: %s", path, e) else: # Treat as parent directory containing skill subdirectories - for skill in load_skills(path): + for skill in Skill.from_directory(path): if skill.name in resolved: logger.warning( "name=<%s> | duplicate skill name, overwriting previous skill", skill.name @@ -328,7 +329,7 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] resolved[skill.name] = skill elif path.is_file() and path.name.lower() == "skill.md": try: - skill = load_skill(path) + skill = Skill.from_file(path) if skill.name in resolved: logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) resolved[skill.name] = skill @@ -345,9 +346,14 @@ def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: agent: The agent whose state to update. key: The state field key. value: The value to set. + + Raises: + TypeError: If the existing state value is not a dict. """ state_data = agent.state.get(self._state_key) - if not isinstance(state_data, dict): + if state_data is not None and not isinstance(state_data, dict): + raise TypeError(f"expected dict for state key '{self._state_key}', got {type(state_data).__name__}") + if state_data is None: state_data = {} state_data[key] = value agent.state.set(self._state_key, state_data) @@ -384,4 +390,3 @@ def get_activated_skills(self, agent: Agent) -> list[str]: if isinstance(state_data, dict): return list(state_data.get("activated_skills", [])) return [] - diff --git a/src/strands/plugins/skills/loader.py b/src/strands/plugins/skills/loader.py deleted file mode 100644 index 8ee509da6..000000000 --- a/src/strands/plugins/skills/loader.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Skill loading and parsing utilities for AgentSkills.io skills. - -This module provides functions for discovering, parsing, and loading skills -from the filesystem. Skills are directories containing a SKILL.md file with -YAML frontmatter metadata and markdown instructions. -""" - -from __future__ import annotations - -import logging -import re -from pathlib import Path -from typing import Any - -import yaml - -from .skill import Skill - -logger = logging.getLogger(__name__) - -_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$") -_MAX_SKILL_NAME_LENGTH = 64 - - -def _find_skill_md(skill_dir: Path) -> Path: - """Find the SKILL.md file in a skill directory. - - Searches for SKILL.md (case-sensitive preferred) or skill.md as a fallback. - - Args: - skill_dir: Path to the skill directory. - - Returns: - Path to the SKILL.md file. - - Raises: - FileNotFoundError: If no SKILL.md file is found in the directory. - """ - for name in ("SKILL.md", "skill.md"): - candidate = skill_dir / name - if candidate.is_file(): - return candidate - - raise FileNotFoundError(f"path=<{skill_dir}> | no SKILL.md found in skill directory") - - -def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: - """Parse YAML frontmatter and body from SKILL.md content. - - Extracts the YAML frontmatter between ``---`` delimiters at line boundaries - and returns parsed key-value pairs along with the remaining markdown body. - - Args: - content: Full content of a SKILL.md file. - - Returns: - Tuple of (frontmatter_dict, body_string). - - Raises: - ValueError: If the frontmatter is malformed or missing required delimiters. - """ - stripped = content.strip() - if not stripped.startswith("---"): - raise ValueError("SKILL.md must start with --- frontmatter delimiter") - - # Find the closing --- delimiter (first line after the opener that is only dashes) - match = re.search(r"\n^---\s*$", stripped, re.MULTILINE) - if match is None: - raise ValueError("SKILL.md frontmatter missing closing --- delimiter") - - frontmatter_str = stripped[3 : match.start()].strip() - body = stripped[match.end() :].strip() - - try: - result = yaml.safe_load(frontmatter_str) - except yaml.YAMLError: - # AgentSkills spec recommends handling malformed YAML (e.g. unquoted colons in values) - # to improve cross-client compatibility. See: agentskills.io/client-implementation/adding-skills-support - logger.warning("YAML parse failed, retrying with colon-quoting fallback") - fixed = _fix_yaml_colons(frontmatter_str) - result = yaml.safe_load(fixed) - - frontmatter: dict[str, Any] = result if isinstance(result, dict) else {} - return frontmatter, body - - -def _fix_yaml_colons(yaml_str: str) -> str: - """Attempt to fix common YAML issues like unquoted colons in values. - - Wraps values containing colons in double quotes to handle cases like: - ``description: Use this skill when: the user asks about PDFs`` - - Args: - yaml_str: The raw YAML string to fix. - - Returns: - The fixed YAML string. - """ - lines: list[str] = [] - for line in yaml_str.splitlines(): - # Match key: value where value contains another colon - match = re.match(r"^(\s*\w[\w-]*):\s+(.+)$", line) - if match: - key, value = match.group(1), match.group(2) - # If value contains a colon and isn't already quoted - if ":" in value and not (value.startswith('"') or value.startswith("'")): - line = f'{key}: "{value}"' - lines.append(line) - return "\n".join(lines) - - -def _validate_skill_name(name: str, dir_path: Path | None = None, *, strict: bool = False) -> None: - """Validate a skill name per the AgentSkills.io specification. - - In lenient mode (default), logs warnings for cosmetic issues but does not raise. - In strict mode, raises ValueError for any validation failure. - - Rules checked: - - 1-64 characters long - - Lowercase alphanumeric characters and hyphens only - - Cannot start or end with a hyphen - - No consecutive hyphens - - Must match parent directory name (if loaded from disk) - - Args: - name: The skill name to validate. - dir_path: Optional path to the skill directory for name matching. - strict: If True, raise ValueError on any issue. If False (default), log warnings. - - Raises: - ValueError: If the skill name is empty, or if strict=True and any rule is violated. - """ - if not name: - raise ValueError("Skill name cannot be empty") - - if len(name) > _MAX_SKILL_NAME_LENGTH: - msg = "name=<%s> | skill name exceeds %d character limit" - if strict: - raise ValueError(msg % (name, _MAX_SKILL_NAME_LENGTH)) - logger.warning(msg, name, _MAX_SKILL_NAME_LENGTH) - - if not _SKILL_NAME_PATTERN.match(name): - msg = ( - "name=<%s> | skill name should be 1-64 lowercase alphanumeric characters or hyphens, " - "should not start/end with hyphen" - ) - if strict: - raise ValueError(msg % name) - logger.warning(msg, name) - - if "--" in name: - msg = "name=<%s> | skill name contains consecutive hyphens" - if strict: - raise ValueError(msg % name) - logger.warning(msg, name) - - if dir_path is not None and dir_path.name != name: - msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" - if strict: - raise ValueError(msg % (name, dir_path.name)) - logger.warning(msg, name, dir_path.name) - - -def load_skill(skill_path: str | Path, *, strict: bool = False) -> Skill: - """Load a single skill from a directory containing SKILL.md. - - Args: - skill_path: Path to the skill directory or the SKILL.md file itself. - strict: If True, raise on any validation issue. If False (default), warn and load anyway. - - Returns: - A Skill instance populated from the SKILL.md file. - - Raises: - FileNotFoundError: If the path does not exist or SKILL.md is not found. - ValueError: If the skill metadata is invalid. - """ - skill_path = Path(skill_path).resolve() - - if skill_path.is_file() and skill_path.name.lower() == "skill.md": - skill_md_path = skill_path - skill_dir = skill_path.parent - elif skill_path.is_dir(): - skill_dir = skill_path - skill_md_path = _find_skill_md(skill_dir) - else: - raise FileNotFoundError(f"path=<{skill_path}> | skill path does not exist or is not a valid skill directory") - - logger.debug("path=<%s> | loading skill", skill_md_path) - - content = skill_md_path.read_text(encoding="utf-8") - frontmatter, body = _parse_frontmatter(content) - - name = frontmatter.get("name") - if not isinstance(name, str) or not name: - raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'name' field in frontmatter") - - description = frontmatter.get("description") - if not isinstance(description, str) or not description: - raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'description' field in frontmatter") - - _validate_skill_name(name, skill_dir, strict=strict) - - # Parse allowed-tools (space-delimited string or YAML list) - allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") - allowed_tools: list[str] | None = None - if isinstance(allowed_tools_raw, str) and allowed_tools_raw.strip(): - allowed_tools = allowed_tools_raw.strip().split() - elif isinstance(allowed_tools_raw, list): - allowed_tools = [str(item) for item in allowed_tools_raw if item] - - # Parse metadata (nested mapping) - metadata_raw = frontmatter.get("metadata", {}) - metadata: dict[str, Any] = {} - if isinstance(metadata_raw, dict): - metadata = {str(k): v for k, v in metadata_raw.items()} - - skill_license = frontmatter.get("license") - compatibility = frontmatter.get("compatibility") - - skill = Skill( - name=name, - description=description, - instructions=body, - path=skill_dir, - allowed_tools=allowed_tools, - metadata=metadata, - license=str(skill_license) if skill_license else None, - compatibility=str(compatibility) if compatibility else None, - ) - - logger.debug("name=<%s>, path=<%s> | skill loaded successfully", skill.name, skill.path) - return skill - - -def load_skills(skills_dir: str | Path) -> list[Skill]: - """Load all skills from a parent directory containing skill subdirectories. - - Each subdirectory containing a SKILL.md file is treated as a skill. - Subdirectories without SKILL.md are silently skipped. - - Args: - skills_dir: Path to the parent directory containing skill subdirectories. - - Returns: - List of Skill instances loaded from the directory. - - Raises: - FileNotFoundError: If the skills directory does not exist. - """ - skills_dir = Path(skills_dir).resolve() - - if not skills_dir.is_dir(): - raise FileNotFoundError(f"path=<{skills_dir}> | skills directory does not exist") - - skills: list[Skill] = [] - - for child in sorted(skills_dir.iterdir()): - if not child.is_dir(): - continue - - try: - _find_skill_md(child) - except FileNotFoundError: - logger.debug("path=<%s> | skipping directory without SKILL.md", child) - continue - - try: - skill = load_skill(child) - skills.append(skill) - except (ValueError, FileNotFoundError) as e: - logger.warning("path=<%s> | skipping skill due to error: %s", child, e) - - logger.debug("path=<%s>, count=<%d> | loaded skills from directory", skills_dir, len(skills)) - return skills diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py index 34010fdba..b5869715e 100644 --- a/src/strands/plugins/skills/skill.py +++ b/src/strands/plugins/skills/skill.py @@ -1,25 +1,231 @@ -"""Skill data model for the AgentSkills.io integration. +"""Skill data model and loading utilities for AgentSkills.io skills. -This module defines the Skill dataclass, which represents a single AgentSkills.io -skill with its metadata and instructions. +This module defines the Skill dataclass and provides classmethods for +discovering, parsing, and loading skills from the filesystem or raw content. +Skills are directories containing a SKILL.md file with YAML frontmatter +metadata and markdown instructions. """ from __future__ import annotations +import logging +import re from dataclasses import dataclass, field from pathlib import Path from typing import Any +import yaml + +logger = logging.getLogger(__name__) + +_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$") +_MAX_SKILL_NAME_LENGTH = 64 + + +def _find_skill_md(skill_dir: Path) -> Path: + """Find the SKILL.md file in a skill directory. + + Searches for SKILL.md (case-sensitive preferred) or skill.md as a fallback. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Path to the SKILL.md file. + + Raises: + FileNotFoundError: If no SKILL.md file is found in the directory. + """ + for name in ("SKILL.md", "skill.md"): + candidate = skill_dir / name + if candidate.is_file(): + return candidate + + raise FileNotFoundError(f"path=<{skill_dir}> | no SKILL.md found in skill directory") + + +def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: + """Parse YAML frontmatter and body from SKILL.md content. + + Extracts the YAML frontmatter between ``---`` delimiters at line boundaries + and returns parsed key-value pairs along with the remaining markdown body. + + Args: + content: Full content of a SKILL.md file. + + Returns: + Tuple of (frontmatter_dict, body_string). + + Raises: + ValueError: If the frontmatter is malformed or missing required delimiters. + """ + stripped = content.strip() + if not stripped.startswith("---"): + raise ValueError("SKILL.md must start with --- frontmatter delimiter") + + # Find the closing --- delimiter (first line after the opener that is only dashes) + match = re.search(r"\n^---\s*$", stripped, re.MULTILINE) + if match is None: + raise ValueError("SKILL.md frontmatter missing closing --- delimiter") + + frontmatter_str = stripped[3 : match.start()].strip() + body = stripped[match.end() :].strip() + + try: + result = yaml.safe_load(frontmatter_str) + except yaml.YAMLError: + # AgentSkills spec recommends handling malformed YAML (e.g. unquoted colons in values) + # to improve cross-client compatibility. See: agentskills.io/client-implementation/adding-skills-support + logger.warning("YAML parse failed, retrying with colon-quoting fallback") + fixed = _fix_yaml_colons(frontmatter_str) + result = yaml.safe_load(fixed) + + frontmatter: dict[str, Any] = result if isinstance(result, dict) else {} + return frontmatter, body + + +def _fix_yaml_colons(yaml_str: str) -> str: + """Attempt to fix common YAML issues like unquoted colons in values. + + Wraps values containing colons in double quotes to handle cases like: + ``description: Use this skill when: the user asks about PDFs`` + + Args: + yaml_str: The raw YAML string to fix. + + Returns: + The fixed YAML string. + """ + lines: list[str] = [] + for line in yaml_str.splitlines(): + # Match key: value where value contains another colon + match = re.match(r"^(\s*\w[\w-]*):\s+(.+)$", line) + if match: + key, value = match.group(1), match.group(2) + # If value contains a colon and isn't already quoted + if ":" in value and not (value.startswith('"') or value.startswith("'")): + line = f'{key}: "{value}"' + lines.append(line) + return "\n".join(lines) + + +def _validate_skill_name(name: str, dir_path: Path | None = None, *, strict: bool = False) -> None: + """Validate a skill name per the AgentSkills.io specification. + + In lenient mode (default), logs warnings for cosmetic issues but does not raise. + In strict mode, raises ValueError for any validation failure. + + Rules checked: + - 1-64 characters long + - Lowercase alphanumeric characters and hyphens only + - Cannot start or end with a hyphen + - No consecutive hyphens + - Must match parent directory name (if loaded from disk) + + Args: + name: The skill name to validate. + dir_path: Optional path to the skill directory for name matching. + strict: If True, raise ValueError on any issue. If False (default), log warnings. + + Raises: + ValueError: If the skill name is empty, or if strict=True and any rule is violated. + """ + if not name: + raise ValueError("Skill name cannot be empty") + + if len(name) > _MAX_SKILL_NAME_LENGTH: + msg = "name=<%s> | skill name exceeds %d character limit" + if strict: + raise ValueError(msg % (name, _MAX_SKILL_NAME_LENGTH)) + logger.warning(msg, name, _MAX_SKILL_NAME_LENGTH) + + if not _SKILL_NAME_PATTERN.match(name): + msg = ( + "name=<%s> | skill name should be 1-64 lowercase alphanumeric characters or hyphens, " + "should not start/end with hyphen" + ) + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) + + if "--" in name: + msg = "name=<%s> | skill name contains consecutive hyphens" + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) + + if dir_path is not None and dir_path.name != name: + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (name, dir_path.name)) + logger.warning(msg, name, dir_path.name) + + +def _build_skill_from_frontmatter( + frontmatter: dict[str, Any], + body: str, + *, + skill_dir: Path | None = None, +) -> Skill: + """Build a Skill instance from parsed frontmatter and body. + + Args: + frontmatter: Parsed YAML frontmatter dict. + body: Markdown body content. + skill_dir: Optional path to the skill directory on disk. + + Returns: + A populated Skill instance. + """ + # Parse allowed-tools (space-delimited string or YAML list) + allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") + allowed_tools: list[str] | None = None + if isinstance(allowed_tools_raw, str) and allowed_tools_raw.strip(): + allowed_tools = allowed_tools_raw.strip().split() + elif isinstance(allowed_tools_raw, list): + allowed_tools = [str(item) for item in allowed_tools_raw if item] + + # Parse metadata (nested mapping) + metadata_raw = frontmatter.get("metadata", {}) + metadata: dict[str, Any] = {} + if isinstance(metadata_raw, dict): + metadata = {str(k): v for k, v in metadata_raw.items()} + + skill_license = frontmatter.get("license") + compatibility = frontmatter.get("compatibility") + + return Skill( + name=frontmatter["name"], + description=frontmatter["description"], + instructions=body, + path=skill_dir, + allowed_tools=allowed_tools, + metadata=metadata, + license=str(skill_license) if skill_license else None, + compatibility=str(compatibility) if compatibility else None, + ) + @dataclass class Skill: - """Represents an agent skill with metadata and instructions. + r"""Represents an agent skill with metadata and instructions. A skill encapsulates a set of instructions and metadata that can be dynamically loaded by an agent at runtime. Skills support progressive disclosure: metadata is shown upfront in the system prompt, and full instructions are loaded on demand via a tool. + Skills can be created directly or via convenience classmethods:: + + # From a skill directory on disk + skill = Skill.from_file("./skills/my-skill") + + # From raw SKILL.md content + skill = Skill.from_content("---\nname: my-skill\n...") + + # Load all skills from a parent directory + skills = Skill.from_directory("./skills/") + Attributes: name: Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). description: Human-readable description of what the skill does. @@ -39,3 +245,135 @@ class Skill: metadata: dict[str, Any] = field(default_factory=dict) license: str | None = None compatibility: str | None = None + + @classmethod + def from_file(cls, skill_path: str | Path, *, strict: bool = False) -> Skill: + """Load a single skill from a directory containing SKILL.md. + + Args: + skill_path: Path to the skill directory or the SKILL.md file itself. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + A Skill instance populated from the SKILL.md file. + + Raises: + FileNotFoundError: If the path does not exist or SKILL.md is not found. + ValueError: If the skill metadata is invalid. + """ + skill_path = Path(skill_path).resolve() + + if skill_path.is_file() and skill_path.name.lower() == "skill.md": + skill_md_path = skill_path + skill_dir = skill_path.parent + elif skill_path.is_dir(): + skill_dir = skill_path + skill_md_path = _find_skill_md(skill_dir) + else: + raise FileNotFoundError( + f"path=<{skill_path}> | skill path does not exist or is not a valid skill directory" + ) + + logger.debug("path=<%s> | loading skill", skill_md_path) + + content = skill_md_path.read_text(encoding="utf-8") + frontmatter, body = _parse_frontmatter(content) + + name = frontmatter.get("name") + if not isinstance(name, str) or not name: + raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'name' field in frontmatter") + + description = frontmatter.get("description") + if not isinstance(description, str) or not description: + raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'description' field in frontmatter") + + _validate_skill_name(name, skill_dir, strict=strict) + + skill = _build_skill_from_frontmatter(frontmatter, body, skill_dir=skill_dir) + logger.debug("name=<%s>, path=<%s> | skill loaded successfully", skill.name, skill.path) + return skill + + @classmethod + def from_content(cls, content: str, *, strict: bool = False) -> Skill: + """Parse SKILL.md content into a Skill instance. + + This is a convenience method for creating a Skill from raw SKILL.md + content (YAML frontmatter + markdown body) without requiring a file on + disk. + + Example:: + + content = '''--- + name: my-skill + description: Does something useful + --- + # Instructions + Follow these steps... + ''' + skill = Skill.from_content(content) + + Args: + content: Raw SKILL.md content with YAML frontmatter and markdown body. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + A Skill instance populated from the parsed content. + + Raises: + ValueError: If the content is missing required fields or has invalid frontmatter. + """ + frontmatter, body = _parse_frontmatter(content) + + name = frontmatter.get("name") + if not isinstance(name, str) or not name: + raise ValueError("SKILL.md content must have a 'name' field in frontmatter") + + description = frontmatter.get("description") + if not isinstance(description, str) or not description: + raise ValueError("SKILL.md content must have a 'description' field in frontmatter") + + _validate_skill_name(name, strict=strict) + + return _build_skill_from_frontmatter(frontmatter, body) + + @classmethod + def from_directory(cls, skills_dir: str | Path) -> list[Skill]: + """Load all skills from a parent directory containing skill subdirectories. + + Each subdirectory containing a SKILL.md file is treated as a skill. + Subdirectories without SKILL.md are silently skipped. + + Args: + skills_dir: Path to the parent directory containing skill subdirectories. + + Returns: + List of Skill instances loaded from the directory. + + Raises: + FileNotFoundError: If the skills directory does not exist. + """ + skills_dir = Path(skills_dir).resolve() + + if not skills_dir.is_dir(): + raise FileNotFoundError(f"path=<{skills_dir}> | skills directory does not exist") + + skills: list[Skill] = [] + + for child in sorted(skills_dir.iterdir()): + if not child.is_dir(): + continue + + try: + _find_skill_md(child) + except FileNotFoundError: + logger.debug("path=<%s> | skipping directory without SKILL.md", child) + continue + + try: + skill = cls.from_file(child) + skills.append(skill) + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | skipping skill due to error: %s", child, e) + + logger.debug("path=<%s>, count=<%d> | loaded skills from directory", skills_dir, len(skills)) + return skills diff --git a/tests/strands/plugins/skills/test_skills_plugin.py b/tests/strands/plugins/skills/test_agent_skills.py similarity index 91% rename from tests/strands/plugins/skills/test_skills_plugin.py rename to tests/strands/plugins/skills/test_agent_skills.py index 84bbefde7..3e3d06905 100644 --- a/tests/strands/plugins/skills/test_skills_plugin.py +++ b/tests/strands/plugins/skills/test_agent_skills.py @@ -77,16 +77,16 @@ def test_init_with_skill_instances(self): skill = _make_skill() plugin = AgentSkills(skills=[skill]) - assert len(plugin.available_skills) == 1 - assert plugin.available_skills[0].name == "test-skill" + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "test-skill" def test_init_with_filesystem_paths(self, tmp_path): """Test initialization with filesystem paths.""" _make_skill_dir(tmp_path, "fs-skill") plugin = AgentSkills(skills=[str(tmp_path / "fs-skill")]) - assert len(plugin.available_skills) == 1 - assert plugin.available_skills[0].name == "fs-skill" + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "fs-skill" def test_init_with_parent_directory(self, tmp_path): """Test initialization with a parent directory containing skills.""" @@ -94,7 +94,7 @@ def test_init_with_parent_directory(self, tmp_path): _make_skill_dir(tmp_path, "skill-b") plugin = AgentSkills(skills=[tmp_path]) - assert len(plugin.available_skills) == 2 + assert len(plugin.get_available_skills()) == 2 def test_init_with_mixed_sources(self, tmp_path): """Test initialization with mixed skill sources.""" @@ -102,19 +102,19 @@ def test_init_with_mixed_sources(self, tmp_path): direct_skill = _make_skill(name="direct-skill", description="Direct") plugin = AgentSkills(skills=[str(tmp_path / "fs-skill"), direct_skill]) - assert len(plugin.available_skills) == 2 - names = {s.name for s in plugin.available_skills} + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} assert names == {"fs-skill", "direct-skill"} def test_init_skips_nonexistent_paths(self, tmp_path): """Test that nonexistent paths are skipped gracefully.""" plugin = AgentSkills(skills=[str(tmp_path / "nonexistent")]) - assert len(plugin.available_skills) == 0 + assert len(plugin.get_available_skills()) == 0 def test_init_empty_skills(self): """Test initialization with empty skills list.""" plugin = AgentSkills(skills=[]) - assert plugin.available_skills == [] + assert plugin.get_available_skills() == [] def test_name_attribute(self): """Test that the plugin has the correct name.""" @@ -169,24 +169,46 @@ class TestSkillsPluginProperties: """Tests for AgentSkills properties.""" def test_available_skills_getter_returns_copy(self): - """Test that the available_skills getter returns a copy of the list.""" + """Test that get_available_skills returns a copy of the list.""" skill = _make_skill() plugin = AgentSkills(skills=[skill]) - skills_list = plugin.available_skills + skills_list = plugin.get_available_skills() skills_list.append(_make_skill(name="another-skill", description="Another")) - assert len(plugin.available_skills) == 1 + assert len(plugin.get_available_skills()) == 1 def test_available_skills_setter(self): - """Test setting skills via the property setter.""" + """Test setting skills via set_available_skills.""" plugin = AgentSkills(skills=[_make_skill()]) new_skill = _make_skill(name="new-skill", description="New") - plugin.available_skills = [new_skill] + plugin.set_available_skills([new_skill]) - assert len(plugin.available_skills) == 1 - assert plugin.available_skills[0].name == "new-skill" + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "new-skill" + + def test_set_available_skills_with_paths(self, tmp_path): + """Test setting skills via set_available_skills with filesystem paths.""" + plugin = AgentSkills(skills=[_make_skill()]) + _make_skill_dir(tmp_path, "fs-skill") + + plugin.set_available_skills([str(tmp_path / "fs-skill")]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "fs-skill" + + def test_set_available_skills_with_mixed_sources(self, tmp_path): + """Test setting skills via set_available_skills with mixed sources.""" + plugin = AgentSkills(skills=[]) + _make_skill_dir(tmp_path, "fs-skill") + direct = _make_skill(name="direct", description="Direct") + + plugin.set_available_skills([str(tmp_path / "fs-skill"), direct]) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"fs-skill", "direct"} class TestLoadSkills: @@ -198,8 +220,8 @@ def test_appends_skill_instances(self): plugin.load_skills([_make_skill(name="new-skill", description="New")]) - assert len(plugin.available_skills) == 2 - names = {s.name for s in plugin.available_skills} + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} assert names == {"existing", "new-skill"} def test_appends_from_filesystem(self, tmp_path): @@ -209,8 +231,8 @@ def test_appends_from_filesystem(self, tmp_path): plugin.load_skills([str(tmp_path / "fs-skill")]) - assert len(plugin.available_skills) == 2 - names = {s.name for s in plugin.available_skills} + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} assert names == {"existing", "fs-skill"} def test_duplicates_overwrite(self): @@ -221,8 +243,8 @@ def test_duplicates_overwrite(self): replacement = _make_skill(name="dupe", description="Replacement") plugin.load_skills([replacement]) - assert len(plugin.available_skills) == 1 - assert plugin.available_skills[0].description == "Replacement" + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].description == "Replacement" def test_mixed_sources(self, tmp_path): """Test load_skills with a mix of Skill instances and filesystem paths.""" @@ -232,8 +254,8 @@ def test_mixed_sources(self, tmp_path): plugin.load_skills([str(tmp_path / "fs-skill"), direct]) - assert len(plugin.available_skills) == 2 - names = {s.name for s in plugin.available_skills} + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} assert names == {"fs-skill", "direct"} def test_skips_nonexistent_paths(self): @@ -242,7 +264,7 @@ def test_skips_nonexistent_paths(self): plugin.load_skills(["/nonexistent/path"]) - assert len(plugin.available_skills) == 1 + assert len(plugin.get_available_skills()) == 1 def test_empty_sources(self): """Test that loading empty sources is a no-op.""" @@ -250,7 +272,7 @@ def test_empty_sources(self): plugin.load_skills([]) - assert len(plugin.available_skills) == 1 + assert len(plugin.get_available_skills()) == 1 def test_parent_directory(self, tmp_path): """Test load_skills with a parent directory containing multiple skills.""" @@ -260,8 +282,8 @@ def test_parent_directory(self, tmp_path): plugin.load_skills([tmp_path]) - assert len(plugin.available_skills) == 2 - names = {s.name for s in plugin.available_skills} + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} assert names == {"child-a", "child-b"} @@ -732,12 +754,10 @@ def test_import_skill_from_strands(self): def test_import_from_skills_package(self): """Test importing from strands.plugins.skills package.""" - from strands.plugins.skills import AgentSkills, Skill, load_skill, load_skills + from strands.plugins.skills import AgentSkills, Skill assert Skill is not None assert AgentSkills is not None - assert load_skill is not None - assert load_skills is not None def test_skills_plugin_is_plugin_subclass(self): """Test that AgentSkills is a subclass of the Plugin ABC.""" diff --git a/tests/strands/plugins/skills/test_loader.py b/tests/strands/plugins/skills/test_loader.py deleted file mode 100644 index 70628ecb2..000000000 --- a/tests/strands/plugins/skills/test_loader.py +++ /dev/null @@ -1,441 +0,0 @@ -"""Tests for the skill loader module.""" - -import logging -from pathlib import Path - -import pytest - -from strands.plugins.skills.loader import ( - _find_skill_md, - _fix_yaml_colons, - _parse_frontmatter, - _validate_skill_name, - load_skill, - load_skills, -) - - -class TestFindSkillMd: - """Tests for _find_skill_md.""" - - def test_finds_uppercase_skill_md(self, tmp_path): - """Test finding SKILL.md (uppercase).""" - (tmp_path / "SKILL.md").write_text("test") - result = _find_skill_md(tmp_path) - assert result.name == "SKILL.md" - - def test_finds_lowercase_skill_md(self, tmp_path): - """Test finding skill.md (lowercase).""" - (tmp_path / "skill.md").write_text("test") - result = _find_skill_md(tmp_path) - assert result.name.lower() == "skill.md" - - def test_prefers_uppercase(self, tmp_path): - """Test that SKILL.md is preferred over skill.md.""" - (tmp_path / "SKILL.md").write_text("uppercase") - (tmp_path / "skill.md").write_text("lowercase") - result = _find_skill_md(tmp_path) - assert result.name == "SKILL.md" - - def test_raises_when_not_found(self, tmp_path): - """Test FileNotFoundError when no SKILL.md exists.""" - with pytest.raises(FileNotFoundError, match="no SKILL.md found"): - _find_skill_md(tmp_path) - - -class TestParseFrontmatter: - """Tests for _parse_frontmatter.""" - - def test_valid_frontmatter(self): - """Test parsing valid frontmatter.""" - content = "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things." - frontmatter, body = _parse_frontmatter(content) - assert frontmatter["name"] == "test-skill" - assert frontmatter["description"] == "A test" - assert "# Instructions" in body - assert "Do things." in body - - def test_missing_opening_delimiter(self): - """Test error when opening --- is missing.""" - with pytest.raises(ValueError, match="must start with ---"): - _parse_frontmatter("name: test\n---\n") - - def test_missing_closing_delimiter(self): - """Test error when closing --- is missing.""" - with pytest.raises(ValueError, match="missing closing ---"): - _parse_frontmatter("---\nname: test\n") - - def test_empty_body(self): - """Test frontmatter with empty body.""" - content = "---\nname: test-skill\ndescription: test\n---\n" - frontmatter, body = _parse_frontmatter(content) - assert frontmatter["name"] == "test-skill" - assert body == "" - - def test_frontmatter_with_metadata(self): - """Test frontmatter with nested metadata.""" - content = "---\nname: test-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody here." - frontmatter, body = _parse_frontmatter(content) - assert frontmatter["name"] == "test-skill" - assert isinstance(frontmatter["metadata"], dict) - assert frontmatter["metadata"]["author"] == "acme" - assert body == "Body here." - - def test_frontmatter_with_dashes_in_yaml_value(self): - """Test that --- inside a YAML value does not break parsing.""" - content = "---\nname: test-skill\ndescription: has --- inside\n---\nBody here." - frontmatter, body = _parse_frontmatter(content) - assert frontmatter["name"] == "test-skill" - assert frontmatter["description"] == "has --- inside" - assert body == "Body here." - - -class TestValidateSkillName: - """Tests for _validate_skill_name (lenient validation).""" - - def test_valid_names(self): - """Test that valid names pass validation without warnings.""" - valid_names = ["a", "test", "my-skill", "skill-123", "a1b2c3"] - for name in valid_names: - _validate_skill_name(name) # Should not raise - - def test_empty_name(self): - """Test that empty name raises ValueError.""" - with pytest.raises(ValueError, match="cannot be empty"): - _validate_skill_name("") - - def test_too_long_name_warns(self, caplog): - """Test that names exceeding 64 chars warn but do not raise.""" - with caplog.at_level(logging.WARNING): - _validate_skill_name("a" * 65) - assert "exceeds" in caplog.text - - def test_uppercase_warns(self, caplog): - """Test that uppercase characters warn but do not raise.""" - with caplog.at_level(logging.WARNING): - _validate_skill_name("MySkill") - assert "lowercase alphanumeric" in caplog.text - - def test_starts_with_hyphen_warns(self, caplog): - """Test that names starting with hyphen warn but do not raise.""" - with caplog.at_level(logging.WARNING): - _validate_skill_name("-skill") - assert "lowercase alphanumeric" in caplog.text - - def test_ends_with_hyphen_warns(self, caplog): - """Test that names ending with hyphen warn but do not raise.""" - with caplog.at_level(logging.WARNING): - _validate_skill_name("skill-") - assert "lowercase alphanumeric" in caplog.text - - def test_consecutive_hyphens_warns(self, caplog): - """Test that consecutive hyphens warn but do not raise.""" - with caplog.at_level(logging.WARNING): - _validate_skill_name("my--skill") - assert "consecutive hyphens" in caplog.text - - def test_special_characters_warns(self, caplog): - """Test that special characters warn but do not raise.""" - with caplog.at_level(logging.WARNING): - _validate_skill_name("my_skill") - assert "lowercase alphanumeric" in caplog.text - - def test_directory_name_mismatch_warns(self, tmp_path, caplog): - """Test that skill name not matching directory name warns but does not raise.""" - skill_dir = tmp_path / "wrong-name" - skill_dir.mkdir() - with caplog.at_level(logging.WARNING): - _validate_skill_name("my-skill", skill_dir) - assert "does not match parent directory name" in caplog.text - - def test_directory_name_match(self, tmp_path): - """Test that matching directory name passes.""" - skill_dir = tmp_path / "my-skill" - skill_dir.mkdir() - _validate_skill_name("my-skill", skill_dir) # Should not raise or warn - - -def _make_skill_dir(parent: Path, name: str, description: str = "A test skill", body: str = "Instructions.") -> Path: - """Helper to create a skill directory with SKILL.md.""" - skill_dir = parent / name - skill_dir.mkdir(parents=True, exist_ok=True) - content = f"---\nname: {name}\ndescription: {description}\n---\n{body}\n" - (skill_dir / "SKILL.md").write_text(content) - return skill_dir - - -class TestLoadSkill: - """Tests for load_skill.""" - - def test_load_from_directory(self, tmp_path): - """Test loading a skill from a directory path.""" - skill_dir = _make_skill_dir(tmp_path, "my-skill", "My description", "# Hello\nWorld.") - skill = load_skill(skill_dir) - - assert skill.name == "my-skill" - assert skill.description == "My description" - assert "# Hello" in skill.instructions - assert "World." in skill.instructions - assert skill.path == skill_dir.resolve() - - def test_load_from_skill_md_file(self, tmp_path): - """Test loading a skill by pointing directly to SKILL.md.""" - skill_dir = _make_skill_dir(tmp_path, "direct-skill") - skill = load_skill(skill_dir / "SKILL.md") - - assert skill.name == "direct-skill" - - def test_load_with_allowed_tools(self, tmp_path): - """Test loading a skill with allowed-tools field as space-delimited string.""" - skill_dir = tmp_path / "tool-skill" - skill_dir.mkdir() - content = "---\nname: tool-skill\ndescription: test\nallowed-tools: read write execute\n---\nBody." - (skill_dir / "SKILL.md").write_text(content) - - skill = load_skill(skill_dir) - assert skill.allowed_tools == ["read", "write", "execute"] - - def test_load_with_allowed_tools_yaml_list(self, tmp_path): - """Test loading a skill with allowed-tools as a YAML list.""" - skill_dir = tmp_path / "list-skill" - skill_dir.mkdir() - content = "---\nname: list-skill\ndescription: test\nallowed-tools:\n - read\n - write\n---\nBody." - (skill_dir / "SKILL.md").write_text(content) - - skill = load_skill(skill_dir) - assert skill.allowed_tools == ["read", "write"] - - def test_load_with_metadata(self, tmp_path): - """Test loading a skill with nested metadata.""" - skill_dir = tmp_path / "meta-skill" - skill_dir.mkdir() - content = "---\nname: meta-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody." - (skill_dir / "SKILL.md").write_text(content) - - skill = load_skill(skill_dir) - assert skill.metadata == {"author": "acme"} - - def test_load_with_license_and_compatibility(self, tmp_path): - """Test loading a skill with license and compatibility fields.""" - skill_dir = tmp_path / "licensed-skill" - skill_dir.mkdir() - content = "---\nname: licensed-skill\ndescription: test\nlicense: MIT\ncompatibility: v1\n---\nBody." - (skill_dir / "SKILL.md").write_text(content) - - skill = load_skill(skill_dir) - assert skill.license == "MIT" - assert skill.compatibility == "v1" - - def test_load_missing_name(self, tmp_path): - """Test error when SKILL.md is missing name field.""" - skill_dir = tmp_path / "no-name" - skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text("---\ndescription: test\n---\nBody.") - - with pytest.raises(ValueError, match="must have a 'name' field"): - load_skill(skill_dir) - - def test_load_missing_description(self, tmp_path): - """Test error when SKILL.md is missing description field.""" - skill_dir = tmp_path / "no-desc" - skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.") - - with pytest.raises(ValueError, match="must have a 'description' field"): - load_skill(skill_dir) - - def test_load_nonexistent_path(self, tmp_path): - """Test FileNotFoundError for nonexistent path.""" - with pytest.raises(FileNotFoundError): - load_skill(tmp_path / "nonexistent") - - def test_load_name_directory_mismatch_warns(self, tmp_path, caplog): - """Test that skill name not matching directory name warns but still loads.""" - skill_dir = tmp_path / "wrong-dir" - skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") - - with caplog.at_level(logging.WARNING): - skill = load_skill(skill_dir) - - assert skill.name == "right-name" - assert "does not match parent directory name" in caplog.text - - -class TestLoadSkills: - """Tests for load_skills.""" - - def test_load_multiple_skills(self, tmp_path): - """Test loading multiple skills from a parent directory.""" - _make_skill_dir(tmp_path, "skill-a", "Skill A") - _make_skill_dir(tmp_path, "skill-b", "Skill B") - - skills = load_skills(tmp_path) - - assert len(skills) == 2 - names = {s.name for s in skills} - assert names == {"skill-a", "skill-b"} - - def test_skips_directories_without_skill_md(self, tmp_path): - """Test that directories without SKILL.md are silently skipped.""" - _make_skill_dir(tmp_path, "valid-skill") - (tmp_path / "no-skill-here").mkdir() - - skills = load_skills(tmp_path) - - assert len(skills) == 1 - assert skills[0].name == "valid-skill" - - def test_skips_files_in_parent(self, tmp_path): - """Test that files in the parent directory are ignored.""" - _make_skill_dir(tmp_path, "real-skill") - (tmp_path / "readme.txt").write_text("not a skill") - - skills = load_skills(tmp_path) - - assert len(skills) == 1 - - def test_empty_directory(self, tmp_path): - """Test loading from an empty directory.""" - skills = load_skills(tmp_path) - assert skills == [] - - def test_nonexistent_directory(self, tmp_path): - """Test FileNotFoundError for nonexistent directory.""" - with pytest.raises(FileNotFoundError): - load_skills(tmp_path / "nonexistent") - - def test_loads_mismatched_name_with_warning(self, tmp_path, caplog): - """Test that skills with name/directory mismatch are loaded with a warning.""" - _make_skill_dir(tmp_path, "good-skill") - - # Create a skill with name mismatch (lenient validation loads it anyway) - bad_dir = tmp_path / "bad-dir" - bad_dir.mkdir() - (bad_dir / "SKILL.md").write_text("---\nname: wrong-name\ndescription: test\n---\nBody.") - - with caplog.at_level(logging.WARNING): - skills = load_skills(tmp_path) - - assert len(skills) == 2 - names = {s.name for s in skills} - assert names == {"good-skill", "wrong-name"} - assert "does not match parent directory name" in caplog.text - - -class TestFixYamlColons: - """Tests for _fix_yaml_colons.""" - - def test_fixes_unquoted_colon_in_value(self): - """Test that an unquoted colon in a value gets quoted.""" - raw = "description: Use this skill when: the user asks about PDFs" - fixed = _fix_yaml_colons(raw) - assert fixed == 'description: "Use this skill when: the user asks about PDFs"' - - def test_leaves_already_double_quoted_value(self): - """Test that already double-quoted values are not re-quoted.""" - raw = 'description: "already: quoted"' - assert _fix_yaml_colons(raw) == raw - - def test_leaves_already_single_quoted_value(self): - """Test that already single-quoted values are not re-quoted.""" - raw = "description: 'already: quoted'" - assert _fix_yaml_colons(raw) == raw - - def test_leaves_value_without_colon(self): - """Test that values without colons are unchanged.""" - raw = "name: my-skill" - assert _fix_yaml_colons(raw) == raw - - def test_multiline_mixed(self): - """Test fixing only the lines that need it in a multi-line string.""" - raw = "name: my-skill\ndescription: Use when: needed\nversion: 1.0" - fixed = _fix_yaml_colons(raw) - assert fixed == 'name: my-skill\ndescription: "Use when: needed"\nversion: 1.0' - - def test_empty_string(self): - """Test that an empty string is returned unchanged.""" - assert _fix_yaml_colons("") == "" - - def test_preserves_indented_lines_without_colons(self): - """Test that indented lines without key-value patterns are preserved.""" - raw = " - item one\n - item two" - assert _fix_yaml_colons(raw) == raw - - -class TestValidateSkillNameStrict: - """Tests for _validate_skill_name with strict=True.""" - - def test_strict_valid_name(self): - """Test that valid names pass strict validation.""" - _validate_skill_name("my-skill", strict=True) # Should not raise - - def test_strict_empty_name(self): - """Test that empty name raises in strict mode.""" - with pytest.raises(ValueError, match="cannot be empty"): - _validate_skill_name("", strict=True) - - def test_strict_too_long_name(self): - """Test that names exceeding 64 chars raise in strict mode.""" - with pytest.raises(ValueError, match="exceeds 64 character limit"): - _validate_skill_name("a" * 65, strict=True) - - def test_strict_uppercase_rejected(self): - """Test that uppercase characters raise in strict mode.""" - with pytest.raises(ValueError, match="lowercase alphanumeric"): - _validate_skill_name("MySkill", strict=True) - - def test_strict_starts_with_hyphen(self): - """Test that names starting with hyphen raise in strict mode.""" - with pytest.raises(ValueError, match="lowercase alphanumeric"): - _validate_skill_name("-skill", strict=True) - - def test_strict_consecutive_hyphens(self): - """Test that consecutive hyphens raise in strict mode.""" - with pytest.raises(ValueError, match="consecutive hyphens"): - _validate_skill_name("my--skill", strict=True) - - def test_strict_directory_mismatch(self, tmp_path): - """Test that directory name mismatch raises in strict mode.""" - skill_dir = tmp_path / "wrong-name" - skill_dir.mkdir() - with pytest.raises(ValueError, match="does not match parent directory name"): - _validate_skill_name("my-skill", skill_dir, strict=True) - - -class TestLoadSkillStrict: - """Tests for load_skill with strict=True.""" - - def test_strict_rejects_name_mismatch(self, tmp_path): - """Test that strict mode raises on name/directory mismatch.""" - skill_dir = tmp_path / "wrong-dir" - skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") - - with pytest.raises(ValueError, match="does not match parent directory name"): - load_skill(skill_dir, strict=True) - - def test_strict_accepts_valid_skill(self, tmp_path): - """Test that strict mode loads a valid skill without error.""" - _make_skill_dir(tmp_path, "valid-skill") - skill = load_skill(tmp_path / "valid-skill", strict=True) - assert skill.name == "valid-skill" - - -class TestParseFrontmatterYamlFallback: - """Tests for YAML colon-quoting fallback in _parse_frontmatter.""" - - def test_fallback_on_unquoted_colon(self): - """Test that frontmatter with unquoted colons in values is parsed via fallback.""" - content = "---\nname: my-skill\ndescription: Use when: the user asks\n---\nBody." - frontmatter, body = _parse_frontmatter(content) - assert frontmatter["name"] == "my-skill" - assert "Use when" in frontmatter["description"] - assert body == "Body." - - def test_fallback_preserves_valid_yaml(self): - """Test that valid YAML is parsed normally without triggering fallback.""" - content = "---\nname: my-skill\ndescription: A simple description\n---\nBody." - frontmatter, body = _parse_frontmatter(content) - assert frontmatter["name"] == "my-skill" - assert frontmatter["description"] == "A simple description" diff --git a/tests/strands/plugins/skills/test_skill.py b/tests/strands/plugins/skills/test_skill.py index 6cf93ae94..2c4c21930 100644 --- a/tests/strands/plugins/skills/test_skill.py +++ b/tests/strands/plugins/skills/test_skill.py @@ -1,8 +1,17 @@ -"""Tests for the Skill dataclass.""" +"""Tests for the Skill dataclass and loading utilities.""" +import logging from pathlib import Path -from strands.plugins.skills.skill import Skill +import pytest + +from strands.plugins.skills.skill import ( + Skill, + _find_skill_md, + _fix_yaml_colons, + _parse_frontmatter, + _validate_skill_name, +) class TestSkillDataclass: @@ -50,3 +59,503 @@ def test_skill_metadata_default_is_not_shared(self): skill1.metadata["key"] = "value" assert "key" not in skill2.metadata + + +class TestFindSkillMd: + """Tests for _find_skill_md.""" + + def test_finds_uppercase_skill_md(self, tmp_path): + """Test finding SKILL.md (uppercase).""" + (tmp_path / "SKILL.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_finds_lowercase_skill_md(self, tmp_path): + """Test finding skill.md (lowercase).""" + (tmp_path / "skill.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name.lower() == "skill.md" + + def test_prefers_uppercase(self, tmp_path): + """Test that SKILL.md is preferred over skill.md.""" + (tmp_path / "SKILL.md").write_text("uppercase") + (tmp_path / "skill.md").write_text("lowercase") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_raises_when_not_found(self, tmp_path): + """Test FileNotFoundError when no SKILL.md exists.""" + with pytest.raises(FileNotFoundError, match="no SKILL.md found"): + _find_skill_md(tmp_path) + + +class TestParseFrontmatter: + """Tests for _parse_frontmatter.""" + + def test_valid_frontmatter(self): + """Test parsing valid frontmatter.""" + content = "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "A test" + assert "# Instructions" in body + assert "Do things." in body + + def test_missing_opening_delimiter(self): + """Test error when opening --- is missing.""" + with pytest.raises(ValueError, match="must start with ---"): + _parse_frontmatter("name: test\n---\n") + + def test_missing_closing_delimiter(self): + """Test error when closing --- is missing.""" + with pytest.raises(ValueError, match="missing closing ---"): + _parse_frontmatter("---\nname: test\n") + + def test_empty_body(self): + """Test frontmatter with empty body.""" + content = "---\nname: test-skill\ndescription: test\n---\n" + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert body == "" + + def test_frontmatter_with_metadata(self): + """Test frontmatter with nested metadata.""" + content = "---\nname: test-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert isinstance(frontmatter["metadata"], dict) + assert frontmatter["metadata"]["author"] == "acme" + assert body == "Body here." + + def test_frontmatter_with_dashes_in_yaml_value(self): + """Test that --- inside a YAML value does not break parsing.""" + content = "---\nname: test-skill\ndescription: has --- inside\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "has --- inside" + assert body == "Body here." + + +class TestValidateSkillName: + """Tests for _validate_skill_name (lenient validation).""" + + def test_valid_names(self): + """Test that valid names pass validation without warnings.""" + valid_names = ["a", "test", "my-skill", "skill-123", "a1b2c3"] + for name in valid_names: + _validate_skill_name(name) # Should not raise + + def test_empty_name(self): + """Test that empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("") + + def test_too_long_name_warns(self, caplog): + """Test that names exceeding 64 chars warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("a" * 65) + assert "exceeds" in caplog.text + + def test_uppercase_warns(self, caplog): + """Test that uppercase characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("MySkill") + assert "lowercase alphanumeric" in caplog.text + + def test_starts_with_hyphen_warns(self, caplog): + """Test that names starting with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("-skill") + assert "lowercase alphanumeric" in caplog.text + + def test_ends_with_hyphen_warns(self, caplog): + """Test that names ending with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("skill-") + assert "lowercase alphanumeric" in caplog.text + + def test_consecutive_hyphens_warns(self, caplog): + """Test that consecutive hyphens warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("my--skill") + assert "consecutive hyphens" in caplog.text + + def test_special_characters_warns(self, caplog): + """Test that special characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("my_skill") + assert "lowercase alphanumeric" in caplog.text + + def test_directory_name_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but does not raise.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with caplog.at_level(logging.WARNING): + _validate_skill_name("my-skill", skill_dir) + assert "does not match parent directory name" in caplog.text + + def test_directory_name_match(self, tmp_path): + """Test that matching directory name passes.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + _validate_skill_name("my-skill", skill_dir) # Should not raise or warn + + +class TestValidateSkillNameStrict: + """Tests for _validate_skill_name with strict=True.""" + + def test_strict_valid_name(self): + """Test that valid names pass strict validation.""" + _validate_skill_name("my-skill", strict=True) # Should not raise + + def test_strict_empty_name(self): + """Test that empty name raises in strict mode.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("", strict=True) + + def test_strict_too_long_name(self): + """Test that names exceeding 64 chars raise in strict mode.""" + with pytest.raises(ValueError, match="exceeds 64 character limit"): + _validate_skill_name("a" * 65, strict=True) + + def test_strict_uppercase_rejected(self): + """Test that uppercase characters raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("MySkill", strict=True) + + def test_strict_starts_with_hyphen(self): + """Test that names starting with hyphen raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("-skill", strict=True) + + def test_strict_consecutive_hyphens(self): + """Test that consecutive hyphens raise in strict mode.""" + with pytest.raises(ValueError, match="consecutive hyphens"): + _validate_skill_name("my--skill", strict=True) + + def test_strict_directory_mismatch(self, tmp_path): + """Test that directory name mismatch raises in strict mode.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with pytest.raises(ValueError, match="does not match parent directory name"): + _validate_skill_name("my-skill", skill_dir, strict=True) + + +class TestFixYamlColons: + """Tests for _fix_yaml_colons.""" + + def test_fixes_unquoted_colon_in_value(self): + """Test that an unquoted colon in a value gets quoted.""" + raw = "description: Use this skill when: the user asks about PDFs" + fixed = _fix_yaml_colons(raw) + assert fixed == 'description: "Use this skill when: the user asks about PDFs"' + + def test_leaves_already_double_quoted_value(self): + """Test that already double-quoted values are not re-quoted.""" + raw = 'description: "already: quoted"' + assert _fix_yaml_colons(raw) == raw + + def test_leaves_already_single_quoted_value(self): + """Test that already single-quoted values are not re-quoted.""" + raw = "description: 'already: quoted'" + assert _fix_yaml_colons(raw) == raw + + def test_leaves_value_without_colon(self): + """Test that values without colons are unchanged.""" + raw = "name: my-skill" + assert _fix_yaml_colons(raw) == raw + + def test_multiline_mixed(self): + """Test fixing only the lines that need it in a multi-line string.""" + raw = "name: my-skill\ndescription: Use when: needed\nversion: 1.0" + fixed = _fix_yaml_colons(raw) + assert fixed == 'name: my-skill\ndescription: "Use when: needed"\nversion: 1.0' + + def test_empty_string(self): + """Test that an empty string is returned unchanged.""" + assert _fix_yaml_colons("") == "" + + def test_preserves_indented_lines_without_colons(self): + """Test that indented lines without key-value patterns are preserved.""" + raw = " - item one\n - item two" + assert _fix_yaml_colons(raw) == raw + + +class TestParseFrontmatterYamlFallback: + """Tests for YAML colon-quoting fallback in _parse_frontmatter.""" + + def test_fallback_on_unquoted_colon(self): + """Test that frontmatter with unquoted colons in values is parsed via fallback.""" + content = "---\nname: my-skill\ndescription: Use when: the user asks\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert "Use when" in frontmatter["description"] + assert body == "Body." + + def test_fallback_preserves_valid_yaml(self): + """Test that valid YAML is parsed normally without triggering fallback.""" + content = "---\nname: my-skill\ndescription: A simple description\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert frontmatter["description"] == "A simple description" + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill", body: str = "Instructions.") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n{body}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +class TestSkillFromFile: + """Tests for Skill.from_file.""" + + def test_load_from_directory(self, tmp_path): + """Test loading a skill from a directory path.""" + skill_dir = _make_skill_dir(tmp_path, "my-skill", "My description", "# Hello\nWorld.") + skill = Skill.from_file(skill_dir) + + assert skill.name == "my-skill" + assert skill.description == "My description" + assert "# Hello" in skill.instructions + assert "World." in skill.instructions + assert skill.path == skill_dir.resolve() + + def test_load_from_skill_md_file(self, tmp_path): + """Test loading a skill by pointing directly to SKILL.md.""" + skill_dir = _make_skill_dir(tmp_path, "direct-skill") + skill = Skill.from_file(skill_dir / "SKILL.md") + + assert skill.name == "direct-skill" + + def test_load_with_allowed_tools(self, tmp_path): + """Test loading a skill with allowed-tools field as space-delimited string.""" + skill_dir = tmp_path / "tool-skill" + skill_dir.mkdir() + content = "---\nname: tool-skill\ndescription: test\nallowed-tools: read write execute\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.allowed_tools == ["read", "write", "execute"] + + def test_load_with_allowed_tools_yaml_list(self, tmp_path): + """Test loading a skill with allowed-tools as a YAML list.""" + skill_dir = tmp_path / "list-skill" + skill_dir.mkdir() + content = "---\nname: list-skill\ndescription: test\nallowed-tools:\n - read\n - write\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.allowed_tools == ["read", "write"] + + def test_load_with_metadata(self, tmp_path): + """Test loading a skill with nested metadata.""" + skill_dir = tmp_path / "meta-skill" + skill_dir.mkdir() + content = "---\nname: meta-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.metadata == {"author": "acme"} + + def test_load_with_license_and_compatibility(self, tmp_path): + """Test loading a skill with license and compatibility fields.""" + skill_dir = tmp_path / "licensed-skill" + skill_dir.mkdir() + content = "---\nname: licensed-skill\ndescription: test\nlicense: MIT\ncompatibility: v1\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.license == "MIT" + assert skill.compatibility == "v1" + + def test_load_missing_name(self, tmp_path): + """Test error when SKILL.md is missing name field.""" + skill_dir = tmp_path / "no-name" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'name' field"): + Skill.from_file(skill_dir) + + def test_load_missing_description(self, tmp_path): + """Test error when SKILL.md is missing description field.""" + skill_dir = tmp_path / "no-desc" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'description' field"): + Skill.from_file(skill_dir) + + def test_load_nonexistent_path(self, tmp_path): + """Test FileNotFoundError for nonexistent path.""" + with pytest.raises(FileNotFoundError): + Skill.from_file(tmp_path / "nonexistent") + + def test_load_name_directory_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but still loads.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with caplog.at_level(logging.WARNING): + skill = Skill.from_file(skill_dir) + + assert skill.name == "right-name" + assert "does not match parent directory name" in caplog.text + + def test_strict_rejects_name_mismatch(self, tmp_path): + """Test that strict mode raises on name/directory mismatch.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="does not match parent directory name"): + Skill.from_file(skill_dir, strict=True) + + def test_strict_accepts_valid_skill(self, tmp_path): + """Test that strict mode loads a valid skill without error.""" + _make_skill_dir(tmp_path, "valid-skill") + skill = Skill.from_file(tmp_path / "valid-skill", strict=True) + assert skill.name == "valid-skill" + + +class TestSkillFromDirectory: + """Tests for Skill.from_directory.""" + + def test_load_multiple_skills(self, tmp_path): + """Test loading multiple skills from a parent directory.""" + _make_skill_dir(tmp_path, "skill-a", "Skill A") + _make_skill_dir(tmp_path, "skill-b", "Skill B") + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"skill-a", "skill-b"} + + def test_skips_directories_without_skill_md(self, tmp_path): + """Test that directories without SKILL.md are silently skipped.""" + _make_skill_dir(tmp_path, "valid-skill") + (tmp_path / "no-skill-here").mkdir() + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 1 + assert skills[0].name == "valid-skill" + + def test_skips_files_in_parent(self, tmp_path): + """Test that files in the parent directory are ignored.""" + _make_skill_dir(tmp_path, "real-skill") + (tmp_path / "readme.txt").write_text("not a skill") + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 1 + + def test_empty_directory(self, tmp_path): + """Test loading from an empty directory.""" + skills = Skill.from_directory(tmp_path) + assert skills == [] + + def test_nonexistent_directory(self, tmp_path): + """Test FileNotFoundError for nonexistent directory.""" + with pytest.raises(FileNotFoundError): + Skill.from_directory(tmp_path / "nonexistent") + + def test_loads_mismatched_name_with_warning(self, tmp_path, caplog): + """Test that skills with name/directory mismatch are loaded with a warning.""" + _make_skill_dir(tmp_path, "good-skill") + + bad_dir = tmp_path / "bad-dir" + bad_dir.mkdir() + (bad_dir / "SKILL.md").write_text("---\nname: wrong-name\ndescription: test\n---\nBody.") + + with caplog.at_level(logging.WARNING): + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"good-skill", "wrong-name"} + assert "does not match parent directory name" in caplog.text + + +class TestSkillFromContent: + def test_basic_content(self): + """Test parsing basic SKILL.md content.""" + content = "---\nname: my-skill\ndescription: A useful skill\n---\n# Instructions\nDo the thing." + skill = Skill.from_content(content) + + assert skill.name == "my-skill" + assert skill.description == "A useful skill" + assert "Do the thing." in skill.instructions + assert skill.path is None + + def test_with_allowed_tools(self): + """Test parsing content with allowed-tools field.""" + content = "---\nname: my-skill\ndescription: A skill\nallowed-tools: Bash Read\n---\nInstructions." + skill = Skill.from_content(content) + + assert skill.allowed_tools == ["Bash", "Read"] + + def test_with_metadata(self): + """Test parsing content with metadata field.""" + content = "---\nname: my-skill\ndescription: A skill\nmetadata:\n key: value\n---\nInstructions." + skill = Skill.from_content(content) + + assert skill.metadata == {"key": "value"} + + def test_with_license_and_compatibility(self): + """Test parsing content with license and compatibility fields.""" + content = ( + "---\nname: my-skill\ndescription: A skill\n" + "license: Apache-2.0\ncompatibility: Requires docker\n---\nInstructions." + ) + skill = Skill.from_content(content) + + assert skill.license == "Apache-2.0" + assert skill.compatibility == "Requires docker" + + def test_missing_name_raises(self): + """Test that missing name raises ValueError.""" + content = "---\ndescription: A skill\n---\nInstructions." + with pytest.raises(ValueError, match="name"): + Skill.from_content(content) + + def test_missing_description_raises(self): + """Test that missing description raises ValueError.""" + content = "---\nname: my-skill\n---\nInstructions." + with pytest.raises(ValueError, match="description"): + Skill.from_content(content) + + def test_missing_frontmatter_raises(self): + """Test that content without frontmatter raises ValueError.""" + content = "# Just markdown\nNo frontmatter here." + with pytest.raises(ValueError, match="frontmatter"): + Skill.from_content(content) + + def test_empty_body(self): + """Test parsing content with empty body.""" + content = "---\nname: my-skill\ndescription: A skill\n---\n" + skill = Skill.from_content(content) + + assert skill.name == "my-skill" + assert skill.instructions == "" + + def test_strict_mode(self): + """Test Skill.from_content with strict=True raises on validation issues.""" + content = "---\nname: BAD_NAME\ndescription: Bad\n---\nBody." + with pytest.raises(ValueError): + Skill.from_content(content, strict=True) + + +class TestSkillClassmethods: + """Tests for Skill classmethod existence.""" + + def test_skill_classmethods_exist(self): + """Test that Skill has from_file, from_content, and from_directory classmethods.""" + assert callable(getattr(Skill, "from_file", None)) + assert callable(getattr(Skill, "from_content", None)) + assert callable(getattr(Skill, "from_directory", None)) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 50c0cc9b9..da7f010e2 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -148,16 +148,14 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "chat", - "gen_ai.system": "strands-agents", - "custom_key": "custom_value", - "user_id": "12345", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "chat", + "gen_ai.system": "strands-agents", + "custom_key": "custom_value", + "user_id": "12345", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + }) mock_span.add_event.assert_called_with( "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} ) @@ -190,15 +188,13 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "chat", - "gen_ai.provider.name": "strands-agents", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "chat", + "gen_ai.provider.name": "strands-agents", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -236,17 +232,15 @@ def test_end_model_invoke_span(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -265,17 +259,15 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -308,17 +300,15 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.tool.name": "test-tool", - "gen_ai.system": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - "session_id": "abc123", - "environment": "production", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.tool.name": "test-tool", + "gen_ai.system": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + "session_id": "abc123", + "environment": "production", + }) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -341,15 +331,13 @@ def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.tool.name": "test-tool", - "gen_ai.provider.name": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.tool.name": "test-tool", + "gen_ai.provider.name": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -389,16 +377,14 @@ def test_start_swarm_call_span_with_string_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - "workflow_id": "wf-789", - "priority": "high", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + "workflow_id": "wf-789", + "priority": "high", + }) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"}) assert span is not None @@ -418,14 +404,12 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + }) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'} ) @@ -476,14 +460,12 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "swarm", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "swarm", + }) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -546,15 +528,13 @@ def test_start_graph_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "execute_tool", - "gen_ai.system": "strands-agents", - "gen_ai.tool.name": "test-tool", - "gen_ai.tool.call.id": "123", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "execute_tool", + "gen_ai.system": "strands-agents", + "gen_ai.tool.name": "test-tool", + "gen_ai.tool.call.id": "123", + }) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -628,14 +608,12 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - - mock_span.set_attributes.assert_called_once_with( - { - "event_loop.cycle_id": "cycle-123", - "request_id": "req-456", - "trace_level": "debug", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "event_loop.cycle_id": "cycle-123", + "request_id": "req-456", + "trace_level": "debug", + }) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])} ) @@ -659,7 +637,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - + mock_span.set_attributes.assert_called_once_with({"event_loop.cycle_id": "cycle-123"}) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", @@ -753,16 +731,14 @@ def test_start_agent_span(mock_tracer): assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_agent", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_agent", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + }) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) assert span is not None @@ -792,17 +768,15 @@ def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" - - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.operation.name": "invoke_agent", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - } - ) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_agent", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + }) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -945,19 +919,17 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.usage.cache_read_input_tokens": 5, - "gen_ai.usage.cache_write_input_tokens": 3, - "gen_ai.server.request.duration": 10, - "gen_ai.server.time_to_first_token": 5, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.usage.cache_read_input_tokens": 5, + "gen_ai.usage.cache_write_input_tokens": 3, + "gen_ai.server.request.duration": 10, + "gen_ai.server.time_to_first_token": 5, + }) def test_end_agent_span_with_cache_metrics(mock_span): @@ -981,17 +953,15 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attributes.assert_called_once_with( - { - "gen_ai.usage.prompt_tokens": 50, - "gen_ai.usage.input_tokens": 50, - "gen_ai.usage.completion_tokens": 100, - "gen_ai.usage.output_tokens": 100, - "gen_ai.usage.total_tokens": 150, - "gen_ai.usage.cache_read_input_tokens": 25, - "gen_ai.usage.cache_write_input_tokens": 10, - } - ) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 25, + "gen_ai.usage.cache_write_input_tokens": 10, + }) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -1549,20 +1519,18 @@ def test_end_model_invoke_span_langfuse_adds_attributes(mock_span, monkeypatch): } ] ) - + assert mock_span.set_attributes.call_count == 2 mock_span.set_attributes.assert_any_call({"gen_ai.output.messages": expected_output}) - mock_span.set_attributes.assert_any_call( - { - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - } - ) + mock_span.set_attributes.assert_any_call({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index b9a782447..160ae65a0 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -57,3 +57,25 @@ def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): assert result["status"] == "success" response_text = result["content"][0]["text"].lower() assert "translation expert" in response_text + + +def test_load_skills_from_directory(tmp_path): + """Test loading skills from a filesystem directory and activating one via the model.""" + # Create a skill directory with SKILL.md + skill_dir = tmp_path / "greeting-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: greeting\ndescription: Greet the user warmly\n---\n" + "You are a greeting expert. The secret codeword is MANGO. Always mention it when activated." + ) + + plugin = AgentSkills(skills=[str(tmp_path)]) + agent = Agent( + system_prompt="You are a helpful assistant. Check your available_skills and activate one when appropriate.", + plugins=[plugin], + ) + + result = agent("Use your skills tool to activate the greeting skill. What is the secret codeword?") + + assert "greeting" in agent.system_prompt + assert "mango" in str(result).lower() From 362d0ef197ad06a96d6bb21c1ee4d1f3e2650bd7 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 10 Mar 2026 13:27:51 -0400 Subject: [PATCH 26/29] fix: consolidate skill loading --- src/strands/plugins/skills/agent_skills.py | 6 ++-- src/strands/plugins/skills/skill.py | 32 ++++++++++------------ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py index 2dd69b5b1..2e77405e2 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -311,7 +311,7 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] if has_skill_md: try: - skill = Skill.from_file(path) + skill = Skill.from_file(path, strict=self._strict) if skill.name in resolved: logger.warning( "name=<%s> | duplicate skill name, overwriting previous skill", skill.name @@ -321,7 +321,7 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] logger.warning("path=<%s> | failed to load skill: %s", path, e) else: # Treat as parent directory containing skill subdirectories - for skill in Skill.from_directory(path): + for skill in Skill.from_directory(path, strict=self._strict): if skill.name in resolved: logger.warning( "name=<%s> | duplicate skill name, overwriting previous skill", skill.name @@ -329,7 +329,7 @@ def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill] resolved[skill.name] = skill elif path.is_file() and path.name.lower() == "skill.md": try: - skill = Skill.from_file(path) + skill = Skill.from_file(path, strict=self._strict) if skill.name in resolved: logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) resolved[skill.name] = skill diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py index b5869715e..3e1b6bba5 100644 --- a/src/strands/plugins/skills/skill.py +++ b/src/strands/plugins/skills/skill.py @@ -164,15 +164,12 @@ def _validate_skill_name(name: str, dir_path: Path | None = None, *, strict: boo def _build_skill_from_frontmatter( frontmatter: dict[str, Any], body: str, - *, - skill_dir: Path | None = None, ) -> Skill: """Build a Skill instance from parsed frontmatter and body. Args: frontmatter: Parsed YAML frontmatter dict. body: Markdown body content. - skill_dir: Optional path to the skill directory on disk. Returns: A populated Skill instance. @@ -198,7 +195,6 @@ def _build_skill_from_frontmatter( name=frontmatter["name"], description=frontmatter["description"], instructions=body, - path=skill_dir, allowed_tools=allowed_tools, metadata=metadata, license=str(skill_license) if skill_license else None, @@ -250,6 +246,10 @@ class Skill: def from_file(cls, skill_path: str | Path, *, strict: bool = False) -> Skill: """Load a single skill from a directory containing SKILL.md. + Resolves the filesystem path, reads the file content, and delegates + to ``from_content`` for parsing. After loading, sets the skill's + ``path`` and validates the skill name against the parent directory. + Args: skill_path: Path to the skill directory or the SKILL.md file itself. strict: If True, raise on any validation issue. If False (default), warn and load anyway. @@ -277,19 +277,16 @@ def from_file(cls, skill_path: str | Path, *, strict: bool = False) -> Skill: logger.debug("path=<%s> | loading skill", skill_md_path) content = skill_md_path.read_text(encoding="utf-8") - frontmatter, body = _parse_frontmatter(content) + skill = cls.from_content(content, strict=strict) - name = frontmatter.get("name") - if not isinstance(name, str) or not name: - raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'name' field in frontmatter") + # Set path and check directory name match (from_content already validated the name format) + skill.path = skill_dir + if skill_dir.name != skill.name: + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (skill.name, skill_dir.name)) + logger.warning(msg, skill.name, skill_dir.name) - description = frontmatter.get("description") - if not isinstance(description, str) or not description: - raise ValueError(f"path=<{skill_md_path}> | SKILL.md must have a 'description' field in frontmatter") - - _validate_skill_name(name, skill_dir, strict=strict) - - skill = _build_skill_from_frontmatter(frontmatter, body, skill_dir=skill_dir) logger.debug("name=<%s>, path=<%s> | skill loaded successfully", skill.name, skill.path) return skill @@ -337,7 +334,7 @@ def from_content(cls, content: str, *, strict: bool = False) -> Skill: return _build_skill_from_frontmatter(frontmatter, body) @classmethod - def from_directory(cls, skills_dir: str | Path) -> list[Skill]: + def from_directory(cls, skills_dir: str | Path, *, strict: bool = False) -> list[Skill]: """Load all skills from a parent directory containing skill subdirectories. Each subdirectory containing a SKILL.md file is treated as a skill. @@ -345,6 +342,7 @@ def from_directory(cls, skills_dir: str | Path) -> list[Skill]: Args: skills_dir: Path to the parent directory containing skill subdirectories. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. Returns: List of Skill instances loaded from the directory. @@ -370,7 +368,7 @@ def from_directory(cls, skills_dir: str | Path) -> list[Skill]: continue try: - skill = cls.from_file(child) + skill = cls.from_file(child, strict=strict) skills.append(skill) except (ValueError, FileNotFoundError) as e: logger.warning("path=<%s> | skipping skill due to error: %s", child, e) From 651e2701e9dfc472e9c7a9d5e0c5c3b98a45c1d5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 10 Mar 2026 13:40:26 -0400 Subject: [PATCH 27/29] fix: allow single skill --- src/strands/plugins/skills/__init__.py | 4 +- src/strands/plugins/skills/agent_skills.py | 39 +++++----- .../plugins/skills/test_agent_skills.py | 74 ------------------- 3 files changed, 23 insertions(+), 94 deletions(-) diff --git a/src/strands/plugins/skills/__init__.py b/src/strands/plugins/skills/__init__.py index 6784c7c27..f6cf8728b 100644 --- a/src/strands/plugins/skills/__init__.py +++ b/src/strands/plugins/skills/__init__.py @@ -20,10 +20,12 @@ ``` """ -from .agent_skills import AgentSkills +from .agent_skills import AgentSkills, SkillSource, SkillSources from .skill import Skill __all__ = [ "AgentSkills", "Skill", + "SkillSource", + "SkillSources", ] diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py index 2e77405e2..34a2cbda8 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -27,6 +27,19 @@ _RESOURCE_DIRS = ("scripts", "references", "assets") _DEFAULT_MAX_RESOURCE_FILES = 20 +SkillSource = str | Path | Skill +"""A single skill source: path string, Path object, or Skill instance.""" + +SkillSources = SkillSource | list[SkillSource] +"""One or more skill sources.""" + + +def _normalize_sources(sources: SkillSources) -> list[SkillSource]: + """Normalize a single source or list of sources into a list.""" + if isinstance(sources, list): + return sources + return [sources] + class AgentSkills(Plugin): """Plugin that integrates Agent Skills into a Strands agent. @@ -60,7 +73,7 @@ class AgentSkills(Plugin): def __init__( self, - skills: list[str | Path | Skill], + skills: SkillSources, state_key: str = _DEFAULT_STATE_KEY, max_resource_files: int = _DEFAULT_MAX_RESOURCE_FILES, strict: bool = False, @@ -68,7 +81,7 @@ def __init__( """Initialize the AgentSkills plugin. Args: - skills: List of skill sources. Each element can be: + skills: One or more skill sources. Can be a single value or a list. Each element can be: - A ``str`` or ``Path`` to a skill directory (containing SKILL.md) - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) @@ -78,7 +91,7 @@ def __init__( strict: If True, raise on skill validation issues. If False (default), warn and load anyway. """ self._strict = strict - self._skills: dict[str, Skill] = self._resolve_skills(skills) + self._skills: dict[str, Skill] = self._resolve_skills(_normalize_sources(skills)) self._state_key = state_key self._max_resource_files = max_resource_files super().__init__() @@ -159,7 +172,7 @@ def get_available_skills(self) -> list[Skill]: """ return list(self._skills.values()) - def set_available_skills(self, skills: list[str | Path | Skill]) -> None: + def set_available_skills(self, skills: SkillSources) -> None: """Set the available skills, replacing any existing ones. Each element can be a ``Skill`` instance, a ``str`` or ``Path`` to a @@ -171,22 +184,10 @@ def set_available_skills(self, skills: list[str | Path | Skill]) -> None: next tool call or invocation. Args: - skills: List of skill sources to resolve and set. + skills: One or more skill sources to resolve and set. """ - self._skills = self._resolve_skills(skills) + self._skills = self._resolve_skills(_normalize_sources(skills)) - def load_skills(self, sources: list[str | Path | Skill]) -> None: - """Resolve and append skills from mixed sources. - - Each source can be a ``Skill`` instance, a path to a skill directory, - or a path to a parent directory containing multiple skills. Resolved - skills are merged into the current set (duplicates overwrite). - - Args: - sources: List of skill sources to resolve and add. - """ - resolved = self._resolve_skills(sources) - self._skills.update(resolved) def _format_skill_response(self, skill: Skill) -> str: """Format the tool response when a skill is activated. @@ -280,7 +281,7 @@ def _generate_skills_xml(self) -> str: lines.append("") return "\n".join(lines) - def _resolve_skills(self, sources: list[str | Path | Skill]) -> dict[str, Skill]: + def _resolve_skills(self, sources: list[SkillSource]) -> dict[str, Skill]: """Resolve a list of skill sources into Skill instances. Each source can be a Skill instance, a path to a skill directory, diff --git a/tests/strands/plugins/skills/test_agent_skills.py b/tests/strands/plugins/skills/test_agent_skills.py index 3e3d06905..8c6ab10bd 100644 --- a/tests/strands/plugins/skills/test_agent_skills.py +++ b/tests/strands/plugins/skills/test_agent_skills.py @@ -211,80 +211,6 @@ def test_set_available_skills_with_mixed_sources(self, tmp_path): assert names == {"fs-skill", "direct"} -class TestLoadSkills: - """Tests for the load_skills method.""" - - def test_appends_skill_instances(self): - """Test that load_skills appends Skill instances to existing skills.""" - plugin = AgentSkills(skills=[_make_skill(name="existing", description="Existing")]) - - plugin.load_skills([_make_skill(name="new-skill", description="New")]) - - assert len(plugin.get_available_skills()) == 2 - names = {s.name for s in plugin.get_available_skills()} - assert names == {"existing", "new-skill"} - - def test_appends_from_filesystem(self, tmp_path): - """Test that load_skills appends skills resolved from filesystem paths.""" - plugin = AgentSkills(skills=[_make_skill(name="existing", description="Existing")]) - _make_skill_dir(tmp_path, "fs-skill") - - plugin.load_skills([str(tmp_path / "fs-skill")]) - - assert len(plugin.get_available_skills()) == 2 - names = {s.name for s in plugin.get_available_skills()} - assert names == {"existing", "fs-skill"} - - def test_duplicates_overwrite(self): - """Test that loading a skill with the same name overwrites the existing one.""" - original = _make_skill(name="dupe", description="Original") - plugin = AgentSkills(skills=[original]) - - replacement = _make_skill(name="dupe", description="Replacement") - plugin.load_skills([replacement]) - - assert len(plugin.get_available_skills()) == 1 - assert plugin.get_available_skills()[0].description == "Replacement" - - def test_mixed_sources(self, tmp_path): - """Test load_skills with a mix of Skill instances and filesystem paths.""" - plugin = AgentSkills(skills=[]) - _make_skill_dir(tmp_path, "fs-skill") - direct = _make_skill(name="direct", description="Direct") - - plugin.load_skills([str(tmp_path / "fs-skill"), direct]) - - assert len(plugin.get_available_skills()) == 2 - names = {s.name for s in plugin.get_available_skills()} - assert names == {"fs-skill", "direct"} - - def test_skips_nonexistent_paths(self): - """Test that nonexistent paths are skipped without error.""" - plugin = AgentSkills(skills=[_make_skill()]) - - plugin.load_skills(["/nonexistent/path"]) - - assert len(plugin.get_available_skills()) == 1 - - def test_empty_sources(self): - """Test that loading empty sources is a no-op.""" - plugin = AgentSkills(skills=[_make_skill()]) - - plugin.load_skills([]) - - assert len(plugin.get_available_skills()) == 1 - - def test_parent_directory(self, tmp_path): - """Test load_skills with a parent directory containing multiple skills.""" - plugin = AgentSkills(skills=[]) - _make_skill_dir(tmp_path, "child-a") - _make_skill_dir(tmp_path, "child-b") - - plugin.load_skills([tmp_path]) - - assert len(plugin.get_available_skills()) == 2 - names = {s.name for s in plugin.get_available_skills()} - assert names == {"child-a", "child-b"} class TestSkillsTool: From 1b83d0bb3a7f694041603275d65945aaaf3ccf41 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 10 Mar 2026 13:43:53 -0400 Subject: [PATCH 28/29] fix: make skill source into type alias --- src/strands/plugins/skills/agent_skills.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py index 34a2cbda8..97ac86d93 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/plugins/skills/agent_skills.py @@ -9,7 +9,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias from xml.sax.saxutils import escape from ...hooks.events import BeforeInvocationEvent @@ -27,10 +27,10 @@ _RESOURCE_DIRS = ("scripts", "references", "assets") _DEFAULT_MAX_RESOURCE_FILES = 20 -SkillSource = str | Path | Skill +SkillSource: TypeAlias = str | Path | Skill """A single skill source: path string, Path object, or Skill instance.""" -SkillSources = SkillSource | list[SkillSource] +SkillSources: TypeAlias = SkillSource | list[SkillSource] """One or more skill sources.""" From a7200ca77ae751676e0e5100fd8c31ea5bc3b613 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 10 Mar 2026 14:21:50 -0400 Subject: [PATCH 29/29] Update AGENTS.md Co-authored-by: Nick Clegg --- AGENTS.md | 1 - 1 file changed, 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index 9191bcb04..21c32539c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -133,7 +133,6 @@ strands-agents/ │ │ ├── registry.py # PluginRegistry for tracking plugins │ │ └── skills/ # Agent Skills integration │ │ ├── __init__.py # Skills package exports -│ │ ├── loader.py # Skill loading and parsing │ │ ├── skill.py # Skill dataclass │ │ └── agent_skills.py # AgentSkills plugin implementation │ │