Skip to content

Commit 6c277e7

Browse files
TonyLee-AIclaude
andcommitted
fix(live): prevent zombie WebSocket session after LiveRequestQueue.close()
Calling `LiveRequestQueue.close()` is the documented way to shut down a live session from the client side. However, `run_live()`'s `while True:` reconnect loop had no awareness of this intentional shutdown: when the resulting APIError(1000) / ConnectionClosed event arrived it would either reconnect (if a session-resumption handle was present) or raise a spurious error (if no handle was present), in both cases creating a long-lived zombie WebSocket connection that Gemini eventually terminates after ~2 hours with a 1006 error. Fix --- * Add `is_closed: bool` property to `LiveRequestQueue` backed by a simple boolean flag that is set synchronously in `close()` *before* the sentinel is enqueued. The synchronous flag avoids any asyncio scheduling race: by the time any connection-close exception reaches `run_live()`'s handlers, the flag is already True. * In `run_live()`, check `live_request_queue.is_closed` in both the `ConnectionClosed` and `APIError(1000)` exception handlers. When the queue is closed, log an info message and `return` instead of reconnecting or raising. A trailing guard at the bottom of the loop body covers the (less common) case where the receive generator returns normally without raising. Behaviour after this fix ------------------------ | Scenario | Before | After | |--------------------------------------------|---------------|-----------| | `close()` called, no session handle | raises error | terminates cleanly | | `close()` called, session handle present | reconnects | terminates cleanly | | Network drop, session handle present | reconnects | reconnects (unchanged) | | Network drop, no session handle | raises | raises (unchanged) | Tests ----- * `test_is_closed_initially_false` — property starts False * `test_is_closed_true_after_close` — property becomes True after close() * `test_is_closed_not_affected_by_other_sends` — other sends don't set it * `test_run_live_no_reconnect_after_queue_close_api_error_1000` — APIError(1000) after close() → terminates, connect called once * `test_run_live_no_reconnect_after_queue_close_connection_closed` — same for ConnectionClosed variant * `test_run_live_still_reconnects_on_unintentional_drop_with_handle` — genuine network drop without close() still reconnects (regression guard) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent dcc485b commit 6c277e7

File tree

4 files changed

+213
-0
lines changed

4 files changed

+213
-0
lines changed

src/google/adk/agents/live_request_queue.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ class LiveRequestQueue:
6262

6363
def __init__(self):
6464
self._queue = asyncio.Queue()
65+
self._closed = False
66+
67+
@property
68+
def is_closed(self) -> bool:
69+
"""Returns True if close() has been called on this queue."""
70+
return self._closed
6571

6672
def close(self):
73+
self._closed = True
6774
self._queue.put_nowait(LiveRequest(close=True))
6875

6976
def send_content(self, content: types.Content):

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,14 @@ async def run_live(
617617
except asyncio.CancelledError:
618618
pass
619619
except (ConnectionClosed, ConnectionClosedOK) as e:
620+
# An intentional close via LiveRequestQueue.close() may surface as a
621+
# ConnectionClosed event. Do not reconnect in that case.
622+
if invocation_context.live_request_queue.is_closed:
623+
logger.info(
624+
'Live session for agent %s closed by client request.',
625+
invocation_context.agent.name,
626+
)
627+
return
620628
# If we have a session resumption handle, we attempt to reconnect.
621629
# This handle is updated dynamically during the session.
622630
if invocation_context.live_session_resumption_handle:
@@ -630,6 +638,15 @@ async def run_live(
630638
logger.error('Connection closed: %s.', e)
631639
raise
632640
except errors.APIError as e:
641+
# Error code 1000 indicates a normal (intentional) closure. If the
642+
# client called LiveRequestQueue.close(), do not treat this as an error
643+
# and do not attempt to reconnect regardless of session handle state.
644+
if e.code == 1000 and invocation_context.live_request_queue.is_closed:
645+
logger.info(
646+
'Live session for agent %s closed by client request.',
647+
invocation_context.agent.name,
648+
)
649+
return
633650
# Error code 1000 and 1006 indicates a recoverable connection drop.
634651
# In that case, we attempt to reconnect with session handle if available.
635652
if e.code in [1000, 1006]:
@@ -649,6 +666,15 @@ async def run_live(
649666
)
650667
raise
651668

669+
# If the client explicitly closed the queue and no exception was raised
670+
# (e.g. the receive generator returned normally), do not reconnect.
671+
if invocation_context.live_request_queue.is_closed:
672+
logger.info(
673+
'Live session for agent %s closed by client request.',
674+
invocation_context.agent.name,
675+
)
676+
return
677+
652678
async def _send_to_model(
653679
self,
654680
llm_connection: BaseLlmConnection,

tests/unittests/agents/test_live_request_queue.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ async def test_close_queue():
1717
mock_put_nowait.assert_called_once_with(LiveRequest(close=True))
1818

1919

20+
def test_is_closed_initially_false():
21+
queue = LiveRequestQueue()
22+
assert queue.is_closed is False
23+
24+
25+
def test_is_closed_true_after_close():
26+
queue = LiveRequestQueue()
27+
queue.close()
28+
assert queue.is_closed is True
29+
30+
31+
def test_is_closed_not_affected_by_other_sends():
32+
queue = LiveRequestQueue()
33+
queue.send_content(MagicMock(spec=types.Content))
34+
queue.send_realtime(MagicMock(spec=types.Blob))
35+
assert queue.is_closed is False
36+
37+
2038
def test_send_content():
2139
queue = LiveRequestQueue()
2240
content = MagicMock(spec=types.Content)

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,3 +893,165 @@ async def mock_receive():
893893
# We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts
894894
# Total calls = 2 + 5 = 7
895895
assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2
896+
897+
898+
@pytest.mark.asyncio
899+
async def test_run_live_no_reconnect_after_queue_close_api_error_1000():
900+
"""Test that run_live does not reconnect after LiveRequestQueue.close() (APIError 1000).
901+
902+
Calling LiveRequestQueue.close() signals an intentional client-side shutdown.
903+
When the resulting APIError(1000) arrives, run_live must terminate instead of
904+
reconnecting — even when a session resumption handle is present.
905+
"""
906+
from google.adk.agents.live_request_queue import LiveRequestQueue
907+
from google.genai.errors import APIError
908+
909+
real_model = Gemini()
910+
mock_connection = mock.AsyncMock()
911+
912+
async def mock_receive():
913+
# Simulate receiving a session resumption handle from the server.
914+
yield LlmResponse(
915+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
916+
new_handle='test_handle'
917+
)
918+
)
919+
# Simulate the normal-close APIError that arrives after llm_connection.close().
920+
raise APIError(1000, {})
921+
922+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
923+
924+
agent = Agent(name='test_agent', model=real_model)
925+
invocation_context = await testing_utils.create_invocation_context(
926+
agent=agent
927+
)
928+
invocation_context.live_request_queue = LiveRequestQueue()
929+
# Simulate what live_request_queue.close() does before the error arrives.
930+
invocation_context.live_request_queue.close()
931+
932+
flow = BaseLlmFlowForTesting()
933+
934+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
935+
with mock.patch(
936+
'google.adk.models.google_llm.Gemini.connect'
937+
) as mock_connect:
938+
mock_connect.return_value.__aenter__.return_value = mock_connection
939+
940+
events = []
941+
async for event in flow.run_live(invocation_context):
942+
events.append(event)
943+
944+
# run_live must terminate after the first connection — no reconnect.
945+
assert mock_connect.call_count == 1
946+
947+
948+
@pytest.mark.asyncio
949+
async def test_run_live_no_reconnect_after_queue_close_connection_closed():
950+
"""Test that run_live does not reconnect after LiveRequestQueue.close() (ConnectionClosed).
951+
952+
Same as the APIError(1000) case but the connection surfaces as ConnectionClosed,
953+
which can happen depending on the websockets library version or transport layer.
954+
"""
955+
from google.adk.agents.live_request_queue import LiveRequestQueue
956+
from websockets.exceptions import ConnectionClosed
957+
958+
real_model = Gemini()
959+
mock_connection = mock.AsyncMock()
960+
961+
async def mock_receive():
962+
yield LlmResponse(
963+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
964+
new_handle='test_handle'
965+
)
966+
)
967+
raise ConnectionClosed(None, None)
968+
969+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
970+
971+
agent = Agent(name='test_agent', model=real_model)
972+
invocation_context = await testing_utils.create_invocation_context(
973+
agent=agent
974+
)
975+
invocation_context.live_request_queue = LiveRequestQueue()
976+
invocation_context.live_request_queue.close()
977+
978+
flow = BaseLlmFlowForTesting()
979+
980+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
981+
with mock.patch(
982+
'google.adk.models.google_llm.Gemini.connect'
983+
) as mock_connect:
984+
mock_connect.return_value.__aenter__.return_value = mock_connection
985+
986+
events = []
987+
async for event in flow.run_live(invocation_context):
988+
events.append(event)
989+
990+
# run_live must terminate after the first connection — no reconnect.
991+
assert mock_connect.call_count == 1
992+
993+
994+
@pytest.mark.asyncio
995+
async def test_run_live_still_reconnects_on_unintentional_drop_with_handle():
996+
"""Test that session-resumption reconnection still works for genuine drops.
997+
998+
A genuine network drop (ConnectionClosed without queue.close()) with a session
999+
resumption handle must still trigger reconnection. The queue.close() fix
1000+
must not break this existing behaviour.
1001+
"""
1002+
from google.adk.agents.live_request_queue import LiveRequestQueue
1003+
from websockets.exceptions import ConnectionClosed
1004+
1005+
real_model = Gemini()
1006+
mock_connection = mock.AsyncMock()
1007+
1008+
async def mock_receive():
1009+
yield LlmResponse(
1010+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
1011+
new_handle='test_handle'
1012+
)
1013+
)
1014+
# Genuine network drop (queue was NOT closed).
1015+
raise ConnectionClosed(None, None)
1016+
1017+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
1018+
1019+
agent = Agent(name='test_agent', model=real_model)
1020+
invocation_context = await testing_utils.create_invocation_context(
1021+
agent=agent
1022+
)
1023+
invocation_context.live_request_queue = LiveRequestQueue()
1024+
# Note: queue.close() is NOT called — this is an unintentional drop.
1025+
1026+
flow = BaseLlmFlowForTesting()
1027+
1028+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
1029+
mock_connection_2 = mock.AsyncMock()
1030+
1031+
class NonRetryableError(Exception):
1032+
pass
1033+
1034+
async def mock_receive_2():
1035+
if False:
1036+
yield
1037+
raise NonRetryableError('stop')
1038+
1039+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
1040+
1041+
mock_aenter = mock.AsyncMock()
1042+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
1043+
1044+
with mock.patch(
1045+
'google.adk.models.google_llm.Gemini.connect'
1046+
) as mock_connect:
1047+
mock_connect.return_value.__aenter__ = mock_aenter
1048+
1049+
try:
1050+
async for _ in flow.run_live(invocation_context):
1051+
pass
1052+
except NonRetryableError:
1053+
pass
1054+
1055+
# Reconnection must have been attempted (2 connections).
1056+
assert mock_connect.call_count == 2
1057+
assert invocation_context.live_session_resumption_handle == 'test_handle'

0 commit comments

Comments
 (0)