Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/google/adk/agents/live_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,15 @@ class LiveRequestQueue:

def __init__(self):
self._queue = asyncio.Queue()
self._closed = False

@property
def is_closed(self) -> bool:
"""Returns True if close() has been called on this queue."""
return self._closed

def close(self):
self._closed = True
self._queue.put_nowait(LiveRequest(close=True))

def send_content(self, content: types.Content):
Expand Down
26 changes: 26 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ async def run_live(
except asyncio.CancelledError:
pass
except (ConnectionClosed, ConnectionClosedOK) as e:
# An intentional close via LiveRequestQueue.close() may surface as a
# ConnectionClosed event. Do not reconnect in that case.
if invocation_context.live_request_queue.is_closed:
logger.info(
'Live session for agent %s closed by client request.',
invocation_context.agent.name,
)
return
# If we have a session resumption handle, we attempt to reconnect.
# This handle is updated dynamically during the session.
if invocation_context.live_session_resumption_handle:
Expand All @@ -630,6 +638,15 @@ async def run_live(
logger.error('Connection closed: %s.', e)
raise
except errors.APIError as e:
# Error code 1000 indicates a normal (intentional) closure. If the
# client called LiveRequestQueue.close(), do not treat this as an error
# and do not attempt to reconnect regardless of session handle state.
if e.code == 1000 and invocation_context.live_request_queue.is_closed:
logger.info(
'Live session for agent %s closed by client request.',
invocation_context.agent.name,
)
return
# Error code 1000 and 1006 indicates a recoverable connection drop.
# In that case, we attempt to reconnect with session handle if available.
if e.code in [1000, 1006]:
Expand All @@ -649,6 +666,15 @@ async def run_live(
)
raise

# If the client explicitly closed the queue and no exception was raised
# (e.g. the receive generator returned normally), do not reconnect.
if invocation_context.live_request_queue.is_closed:
logger.info(
'Live session for agent %s closed by client request.',
invocation_context.agent.name,
)
return

async def _send_to_model(
self,
llm_connection: BaseLlmConnection,
Expand Down
18 changes: 18 additions & 0 deletions tests/unittests/agents/test_live_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ async def test_close_queue():
mock_put_nowait.assert_called_once_with(LiveRequest(close=True))


def test_is_closed_initially_false():
queue = LiveRequestQueue()
assert queue.is_closed is False


def test_is_closed_true_after_close():
queue = LiveRequestQueue()
queue.close()
assert queue.is_closed is True


def test_is_closed_not_affected_by_other_sends():
queue = LiveRequestQueue()
queue.send_content(MagicMock(spec=types.Content))
queue.send_realtime(MagicMock(spec=types.Blob))
assert queue.is_closed is False


def test_send_content():
queue = LiveRequestQueue()
content = MagicMock(spec=types.Content)
Expand Down
162 changes: 162 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,168 @@ async def mock_receive():
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2


@pytest.mark.asyncio
async def test_run_live_no_reconnect_after_queue_close_api_error_1000():
"""Test that run_live does not reconnect after LiveRequestQueue.close() (APIError 1000).

Calling LiveRequestQueue.close() signals an intentional client-side shutdown.
When the resulting APIError(1000) arrives, run_live must terminate instead of
reconnecting — even when a session resumption handle is present.
"""
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.genai.errors import APIError

real_model = Gemini()
mock_connection = mock.AsyncMock()

async def mock_receive():
# Simulate receiving a session resumption handle from the server.
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
)
)
# Simulate the normal-close APIError that arrives after llm_connection.close().
raise APIError(1000, {})

mock_connection.receive = mock.Mock(side_effect=mock_receive)

agent = Agent(name='test_agent', model=real_model)
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
invocation_context.live_request_queue = LiveRequestQueue()
# Simulate what live_request_queue.close() does before the error arrives.
invocation_context.live_request_queue.close()

flow = BaseLlmFlowForTesting()

with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__.return_value = mock_connection

events = []
async for event in flow.run_live(invocation_context):
events.append(event)

# run_live must terminate after the first connection — no reconnect.
assert mock_connect.call_count == 1


@pytest.mark.asyncio
async def test_run_live_no_reconnect_after_queue_close_connection_closed():
"""Test that run_live does not reconnect after LiveRequestQueue.close() (ConnectionClosed).

Same as the APIError(1000) case but the connection surfaces as ConnectionClosed,
which can happen depending on the websockets library version or transport layer.
"""
from google.adk.agents.live_request_queue import LiveRequestQueue
from websockets.exceptions import ConnectionClosed

real_model = Gemini()
mock_connection = mock.AsyncMock()

async def mock_receive():
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
)
)
raise ConnectionClosed(None, None)

mock_connection.receive = mock.Mock(side_effect=mock_receive)

agent = Agent(name='test_agent', model=real_model)
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
invocation_context.live_request_queue = LiveRequestQueue()
invocation_context.live_request_queue.close()

flow = BaseLlmFlowForTesting()

with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__.return_value = mock_connection

events = []
async for event in flow.run_live(invocation_context):
events.append(event)

# run_live must terminate after the first connection — no reconnect.
assert mock_connect.call_count == 1


@pytest.mark.asyncio
async def test_run_live_still_reconnects_on_unintentional_drop_with_handle():
"""Test that session-resumption reconnection still works for genuine drops.

A genuine network drop (ConnectionClosed without queue.close()) with a session
resumption handle must still trigger reconnection. The queue.close() fix
must not break this existing behaviour.
"""
from google.adk.agents.live_request_queue import LiveRequestQueue
from websockets.exceptions import ConnectionClosed

real_model = Gemini()
mock_connection = mock.AsyncMock()

async def mock_receive():
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
)
)
# Genuine network drop (queue was NOT closed).
raise ConnectionClosed(None, None)

mock_connection.receive = mock.Mock(side_effect=mock_receive)

agent = Agent(name='test_agent', model=real_model)
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
invocation_context.live_request_queue = LiveRequestQueue()
# Note: queue.close() is NOT called — this is an unintentional drop.

flow = BaseLlmFlowForTesting()

with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
mock_connection_2 = mock.AsyncMock()

class NonRetryableError(Exception):
pass

async def mock_receive_2():
if False:
yield
raise NonRetryableError('stop')

mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)

mock_aenter = mock.AsyncMock()
mock_aenter.side_effect = [mock_connection, mock_connection_2]

with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__ = mock_aenter

try:
async for _ in flow.run_live(invocation_context):
pass
except NonRetryableError:
pass

# Reconnection must have been attempted (2 connections).
assert mock_connect.call_count == 2
assert invocation_context.live_session_resumption_handle == 'test_handle'


@pytest.mark.asyncio
async def test_postprocess_live_session_resumption_update():
"""Test that _postprocess_live yields live_session_resumption_update."""
Expand Down