Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import textwrap
import threading
import time
import unittest.mock
import warnings
from typing import Any, AsyncGenerator
Expand Down Expand Up @@ -194,10 +193,20 @@ class User(BaseModel):


class SlowMockedModel(MockedModelProvider):
"""A model that has a configurable delay and optional signaling for concurrency tests."""

def __init__(self, responses, delay: float = 0.15, signal_event: threading.Event | None = None):
super().__init__(responses)
self._delay = delay
self._signal_event = signal_event

async def stream(
self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs
) -> AsyncGenerator[Any, None]:
await asyncio.sleep(0.15) # Add async delay to ensure concurrency
# Signal that we've started (lock is held at this point)
if self._signal_event:
self._signal_event.set()
await asyncio.sleep(self._delay) # Hold the lock during this delay
async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs):
yield event

Expand Down Expand Up @@ -2212,32 +2221,39 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator):

def test_agent_concurrent_call_raises_exception():
"""Test that concurrent __call__() calls raise ConcurrencyException."""
# Use an event to signal when the first thread has acquired the lock
lock_acquired = threading.Event()

model = SlowMockedModel(
[
{"role": "assistant", "content": [{"text": "hello"}]},
{"role": "assistant", "content": [{"text": "world"}]},
]
],
delay=0.2, # Long enough to ensure overlap
signal_event=lock_acquired,
)
agent = Agent(model=model)

results = []
errors = []
lock = threading.Lock()
results_lock = threading.Lock()

def invoke():
try:
result = agent("test")
with lock:
with results_lock:
results.append(result)
except ConcurrencyException as e:
with lock:
with results_lock:
errors.append(e)

# Create two threads that will try to invoke concurrently
t1 = threading.Thread(target=invoke)
t2 = threading.Thread(target=invoke)

t1.start()
# Wait for t1 to acquire the lock (signaled when model.stream starts)
lock_acquired.wait(timeout=2.0)
t2.start()
t1.join()
t2.join()
Expand All @@ -2254,33 +2270,39 @@ def test_agent_concurrent_structured_output_raises_exception():
Note: This test validates that the sync invocation path is protected.
The concurrent __call__() test already validates the core functionality.
"""
# Use an event to signal when the first thread has acquired the lock
lock_acquired = threading.Event()

model = SlowMockedModel(
[
{"role": "assistant", "content": [{"text": "response1"}]},
{"role": "assistant", "content": [{"text": "response2"}]},
]
],
delay=0.2, # Long enough to ensure overlap
signal_event=lock_acquired,
)
agent = Agent(model=model)

results = []
errors = []
lock = threading.Lock()
results_lock = threading.Lock()

def invoke():
try:
result = agent("test")
with lock:
with results_lock:
results.append(result)
except ConcurrencyException as e:
with lock:
with results_lock:
errors.append(e)

# Create two threads that will try to invoke concurrently
t1 = threading.Thread(target=invoke)
t2 = threading.Thread(target=invoke)

t1.start()
time.sleep(0.05) # Small delay to ensure first thread acquires lock
# Wait for t1 to acquire the lock (signaled when model.stream starts)
lock_acquired.wait(timeout=2.0)
t2.start()
t1.join()
t2.join()
Expand Down
Loading