diff --git a/src/strands/tools/watcher.py b/src/strands/tools/watcher.py index c7f50fccd..3a73b0237 100644 --- a/src/strands/tools/watcher.py +++ b/src/strands/tools/watcher.py @@ -49,8 +49,8 @@ def __init__(self, tool_registry: ToolRegistry) -> None: """ self.tool_registry = tool_registry - def on_modified(self, event: Any) -> None: - """Reload tool if file modification detected. + def _handle_tool_change(self, event: Any) -> None: + """Handle a tool file change event (created or modified). Args: event: The file system event that triggered this handler. @@ -66,6 +66,24 @@ def on_modified(self, event: Any) -> None: except Exception as e: logger.error("tool_name=<%s>, exception=<%s> | failed to reload tool", tool_name, str(e)) + def on_modified(self, event: Any) -> None: + """Reload tool if file modification detected. + + Args: + event: The file system event that triggered this handler. + """ + self._handle_tool_change(event) + + def on_created(self, event: Any) -> None: + """Reload tool if new file created. + + This enables hot-reload for the first tool added to an empty ./tools directory. + + Args: + event: The file system event that triggered this handler. + """ + self._handle_tool_change(event) + class MasterChangeHandler(FileSystemEventHandler): """Master handler that delegates to all registered handlers.""" @@ -77,11 +95,12 @@ def __init__(self, dir_path: str) -> None: """ self.dir_path = dir_path - def on_modified(self, event: Any) -> None: - """Delegate file modification events to all registered handlers. + def _delegate_event(self, event: Any, handler_method: str) -> None: + """Delegate file events to all registered handlers. Args: event: The file system event that triggered this handler. + handler_method: The method name to call on handlers ('on_modified' or 'on_created'). """ if event.src_path.endswith(".py"): tool_path = Path(event.src_path) @@ -91,10 +110,28 @@ def on_modified(self, event: Any) -> None: # Delegate to all registered handlers for this directory for handler in ToolWatcher._registry_handlers.get(self.dir_path, {}).values(): try: - handler.on_modified(event) + getattr(handler, handler_method)(event) except Exception as e: logger.error("exception=<%s> | handler error", str(e)) + def on_modified(self, event: Any) -> None: + """Delegate file modification events to all registered handlers. + + Args: + event: The file system event that triggered this handler. + """ + self._delegate_event(event, "on_modified") + + def on_created(self, event: Any) -> None: + """Delegate file creation events to all registered handlers. + + This enables hot-reload for the first tool added to an empty ./tools directory. + + Args: + event: The file system event that triggered this handler. + """ + self._delegate_event(event, "on_created") + def start(self) -> None: """Start watching all tools directories for changes.""" # Initialize shared observer if not already done diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py index 07ec8f2ba..44d9a626a 100644 --- a/tests/strands/tools/test_decorator_pep563.py +++ b/tests/strands/tools/test_decorator_pep563.py @@ -10,10 +10,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal import pytest -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from strands import tool diff --git a/tests/strands/tools/test_watcher.py b/tests/strands/tools/test_watcher.py index 75a5616fe..b0ac09838 100644 --- a/tests/strands/tools/test_watcher.py +++ b/tests/strands/tools/test_watcher.py @@ -96,3 +96,104 @@ def test_on_modified_error_handling(mock_reload_tool): # Verify that reload_tool was called mock_reload_tool.assert_called_once_with("test_tool") + + +@pytest.mark.parametrize( + "test_case", + [ + # Regular Python file - should reload + { + "description": "Python file created", + "src_path": "/path/to/new_tool.py", + "is_directory": False, + "should_reload": True, + "expected_tool_name": "new_tool", + }, + # Non-Python file - should not reload + { + "description": "Non-Python file created", + "src_path": "/path/to/new_tool.txt", + "is_directory": False, + "should_reload": False, + }, + # __init__.py file - should not reload + { + "description": "Init file created", + "src_path": "/path/to/__init__.py", + "is_directory": False, + "should_reload": False, + }, + ], +) +@patch.object(ToolRegistry, "reload_tool") +def test_on_created_cases(mock_reload_tool, test_case): + """Test that on_created handles new tool file creation. + + This is critical for hot-reloading the first tool added to an empty ./tools directory. + """ + tool_registry = ToolRegistry() + watcher = ToolWatcher(tool_registry) + + # Create a mock event with the specified properties + event = MagicMock() + event.src_path = test_case["src_path"] + if "is_directory" in test_case: + event.is_directory = test_case["is_directory"] + + # Call the on_created method + watcher.tool_change_handler.on_created(event) + + # Verify the expected behavior + if test_case["should_reload"]: + mock_reload_tool.assert_called_once_with(test_case["expected_tool_name"]) + else: + mock_reload_tool.assert_not_called() + + +@patch.object(ToolRegistry, "reload_tool", side_effect=Exception("Test error")) +def test_on_created_error_handling(mock_reload_tool): + """Test that on_created handles errors during tool reloading.""" + tool_registry = ToolRegistry() + watcher = ToolWatcher(tool_registry) + + # Create a mock event with a Python file path + event = MagicMock() + event.src_path = "/path/to/new_tool.py" + + # Call the on_created method - should not raise an exception + watcher.tool_change_handler.on_created(event) + + # Verify that reload_tool was called + mock_reload_tool.assert_called_once_with("new_tool") + + +@patch.object(ToolRegistry, "reload_tool") +def test_master_handler_on_created_delegates_to_handlers(mock_reload_tool): + """Test that MasterChangeHandler.on_created delegates to all registered handlers. + + This ensures that when a new tool file is created in a watched directory, + all registered ToolChangeHandlers are notified. + """ + tool_registry = ToolRegistry() + watcher = ToolWatcher(tool_registry) + + # Get the master handler for the tools directory + tools_dirs = tool_registry.get_tools_dirs() + if tools_dirs: + dir_str = str(tools_dirs[0]) + master_handler = ToolWatcher.MasterChangeHandler(dir_str) + + # Manually register our handler (normally done in start()) + if dir_str not in ToolWatcher._registry_handlers: + ToolWatcher._registry_handlers[dir_str] = {} + ToolWatcher._registry_handlers[dir_str][id(tool_registry)] = watcher.tool_change_handler + + # Create a mock event + event = MagicMock() + event.src_path = f"{dir_str}/new_tool.py" + + # Call on_created on master handler + master_handler.on_created(event) + + # Verify that reload_tool was called via the delegated handler + mock_reload_tool.assert_called_once_with("new_tool")