diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 81ce65989..aef2ff060 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -5,7 +5,6 @@ import os import textwrap import threading -import time import unittest.mock import warnings from typing import Any, AsyncGenerator @@ -194,10 +193,20 @@ class User(BaseModel): class SlowMockedModel(MockedModelProvider): + """A model that has a configurable delay and optional signaling for concurrency tests.""" + + def __init__(self, responses, delay: float = 0.15, signal_event: threading.Event | None = None): + super().__init__(responses) + self._delay = delay + self._signal_event = signal_event + async def stream( self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs ) -> AsyncGenerator[Any, None]: - await asyncio.sleep(0.15) # Add async delay to ensure concurrency + # Signal that we've started (lock is held at this point) + if self._signal_event: + self._signal_event.set() + await asyncio.sleep(self._delay) # Hold the lock during this delay async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): yield event @@ -2212,25 +2221,30 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): def test_agent_concurrent_call_raises_exception(): """Test that concurrent __call__() calls raise ConcurrencyException.""" + # Use an event to signal when the first thread has acquired the lock + lock_acquired = threading.Event() + model = SlowMockedModel( [ {"role": "assistant", "content": [{"text": "hello"}]}, {"role": "assistant", "content": [{"text": "world"}]}, - ] + ], + delay=0.2, # Long enough to ensure overlap + signal_event=lock_acquired, ) agent = Agent(model=model) results = [] errors = [] - lock = threading.Lock() + results_lock = threading.Lock() def invoke(): try: result = agent("test") - with lock: + with results_lock: results.append(result) except ConcurrencyException as e: - with lock: + with results_lock: errors.append(e) # Create two threads that will try to invoke concurrently @@ -2238,6 +2252,8 @@ def invoke(): t2 = threading.Thread(target=invoke) t1.start() + # Wait for t1 to acquire the lock (signaled when model.stream starts) + lock_acquired.wait(timeout=2.0) t2.start() t1.join() t2.join() @@ -2254,25 +2270,30 @@ def test_agent_concurrent_structured_output_raises_exception(): Note: This test validates that the sync invocation path is protected. The concurrent __call__() test already validates the core functionality. """ + # Use an event to signal when the first thread has acquired the lock + lock_acquired = threading.Event() + model = SlowMockedModel( [ {"role": "assistant", "content": [{"text": "response1"}]}, {"role": "assistant", "content": [{"text": "response2"}]}, - ] + ], + delay=0.2, # Long enough to ensure overlap + signal_event=lock_acquired, ) agent = Agent(model=model) results = [] errors = [] - lock = threading.Lock() + results_lock = threading.Lock() def invoke(): try: result = agent("test") - with lock: + with results_lock: results.append(result) except ConcurrencyException as e: - with lock: + with results_lock: errors.append(e) # Create two threads that will try to invoke concurrently @@ -2280,7 +2301,8 @@ def invoke(): t2 = threading.Thread(target=invoke) t1.start() - time.sleep(0.05) # Small delay to ensure first thread acquires lock + # Wait for t1 to acquire the lock (signaled when model.stream starts) + lock_acquired.wait(timeout=2.0) t2.start() t1.join() t2.join()