From b2024e269edf4b86eebecf430b1c6a39f34d6384 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Thu, 9 Apr 2026 09:19:48 +0000 Subject: [PATCH] fix(live): prevent zombie WebSocket session after LiveRequestQueue.close() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Calling `LiveRequestQueue.close()` is the documented way to shut down a live session from the client side. However, `run_live()`'s `while True:` reconnect loop had no awareness of this intentional shutdown: when the resulting APIError(1000) / ConnectionClosed event arrived it would either reconnect (if a session-resumption handle was present) or raise a spurious error (if no handle was present), in both cases creating a long-lived zombie WebSocket connection that Gemini eventually terminates after ~2 hours with a 1006 error. Fix --- * Add `is_closed: bool` property to `LiveRequestQueue` backed by a simple boolean flag that is set synchronously in `close()` *before* the sentinel is enqueued. The synchronous flag avoids any asyncio scheduling race: by the time any connection-close exception reaches `run_live()`'s handlers, the flag is already True. * In `run_live()`, check `live_request_queue.is_closed` in both the `ConnectionClosed` and `APIError(1000)` exception handlers. When the queue is closed, log an info message and `return` instead of reconnecting or raising. A trailing guard at the bottom of the loop body covers the (less common) case where the receive generator returns normally without raising. Behaviour after this fix ------------------------ | Scenario | Before | After | |--------------------------------------------|---------------|-----------| | `close()` called, no session handle | raises error | terminates cleanly | | `close()` called, session handle present | reconnects | terminates cleanly | | Network drop, session handle present | reconnects | reconnects (unchanged) | | Network drop, no session handle | raises | raises (unchanged) | Tests ----- * `test_is_closed_initially_false` — property starts False * `test_is_closed_true_after_close` — property becomes True after close() * `test_is_closed_not_affected_by_other_sends` — other sends don't set it * `test_run_live_no_reconnect_after_queue_close_api_error_1000` — APIError(1000) after close() → terminates, connect called once * `test_run_live_no_reconnect_after_queue_close_connection_closed` — same for ConnectionClosed variant * `test_run_live_still_reconnects_on_unintentional_drop_with_handle` — genuine network drop without close() still reconnects (regression guard) --- src/google/adk/agents/live_request_queue.py | 7 + .../adk/flows/llm_flows/base_llm_flow.py | 26 +++ .../agents/test_live_request_queue.py | 18 ++ .../flows/llm_flows/test_base_llm_flow.py | 162 ++++++++++++++++++ 4 files changed, 213 insertions(+) diff --git a/src/google/adk/agents/live_request_queue.py b/src/google/adk/agents/live_request_queue.py index 9b698c81d6..cf12b30f4e 100644 --- a/src/google/adk/agents/live_request_queue.py +++ b/src/google/adk/agents/live_request_queue.py @@ -62,8 +62,15 @@ class LiveRequestQueue: def __init__(self): self._queue = asyncio.Queue() + self._closed = False + + @property + def is_closed(self) -> bool: + """Returns True if close() has been called on this queue.""" + return self._closed def close(self): + self._closed = True self._queue.put_nowait(LiveRequest(close=True)) def send_content(self, content: types.Content): diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 76c7b8e160..d216cf7dab 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -617,6 +617,14 @@ async def run_live( except asyncio.CancelledError: pass except (ConnectionClosed, ConnectionClosedOK) as e: + # An intentional close via LiveRequestQueue.close() may surface as a + # ConnectionClosed event. Do not reconnect in that case. + if invocation_context.live_request_queue.is_closed: + logger.info( + 'Live session for agent %s closed by client request.', + invocation_context.agent.name, + ) + return # If we have a session resumption handle, we attempt to reconnect. # This handle is updated dynamically during the session. if invocation_context.live_session_resumption_handle: @@ -630,6 +638,15 @@ async def run_live( logger.error('Connection closed: %s.', e) raise except errors.APIError as e: + # Error code 1000 indicates a normal (intentional) closure. If the + # client called LiveRequestQueue.close(), do not treat this as an error + # and do not attempt to reconnect regardless of session handle state. + if e.code == 1000 and invocation_context.live_request_queue.is_closed: + logger.info( + 'Live session for agent %s closed by client request.', + invocation_context.agent.name, + ) + return # Error code 1000 and 1006 indicates a recoverable connection drop. # In that case, we attempt to reconnect with session handle if available. if e.code in [1000, 1006]: @@ -649,6 +666,15 @@ async def run_live( ) raise + # If the client explicitly closed the queue and no exception was raised + # (e.g. the receive generator returned normally), do not reconnect. + if invocation_context.live_request_queue.is_closed: + logger.info( + 'Live session for agent %s closed by client request.', + invocation_context.agent.name, + ) + return + async def _send_to_model( self, llm_connection: BaseLlmConnection, diff --git a/tests/unittests/agents/test_live_request_queue.py b/tests/unittests/agents/test_live_request_queue.py index ab98894daf..1a17c5143e 100644 --- a/tests/unittests/agents/test_live_request_queue.py +++ b/tests/unittests/agents/test_live_request_queue.py @@ -17,6 +17,24 @@ async def test_close_queue(): mock_put_nowait.assert_called_once_with(LiveRequest(close=True)) +def test_is_closed_initially_false(): + queue = LiveRequestQueue() + assert queue.is_closed is False + + +def test_is_closed_true_after_close(): + queue = LiveRequestQueue() + queue.close() + assert queue.is_closed is True + + +def test_is_closed_not_affected_by_other_sends(): + queue = LiveRequestQueue() + queue.send_content(MagicMock(spec=types.Content)) + queue.send_realtime(MagicMock(spec=types.Blob)) + assert queue.is_closed is False + + def test_send_content(): queue = LiveRequestQueue() content = MagicMock(spec=types.Content) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index bdcaa9af5a..97679eefb6 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -893,3 +893,165 @@ async def mock_receive(): # We expect 2 successful attempts + DEFAULT_MAX_RECONNECT_ATTEMPTS failed attempts # Total calls = 2 + 5 = 7 assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2 + + +@pytest.mark.asyncio +async def test_run_live_no_reconnect_after_queue_close_api_error_1000(): + """Test that run_live does not reconnect after LiveRequestQueue.close() (APIError 1000). + + Calling LiveRequestQueue.close() signals an intentional client-side shutdown. + When the resulting APIError(1000) arrives, run_live must terminate instead of + reconnecting — even when a session resumption handle is present. + """ + from google.adk.agents.live_request_queue import LiveRequestQueue + from google.genai.errors import APIError + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + # Simulate receiving a session resumption handle from the server. + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + # Simulate the normal-close APIError that arrives after llm_connection.close(). + raise APIError(1000, {}) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + # Simulate what live_request_queue.close() does before the error arrives. + invocation_context.live_request_queue.close() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + events = [] + async for event in flow.run_live(invocation_context): + events.append(event) + + # run_live must terminate after the first connection — no reconnect. + assert mock_connect.call_count == 1 + + +@pytest.mark.asyncio +async def test_run_live_no_reconnect_after_queue_close_connection_closed(): + """Test that run_live does not reconnect after LiveRequestQueue.close() (ConnectionClosed). + + Same as the APIError(1000) case but the connection surfaces as ConnectionClosed, + which can happen depending on the websockets library version or transport layer. + """ + from google.adk.agents.live_request_queue import LiveRequestQueue + from websockets.exceptions import ConnectionClosed + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + invocation_context.live_request_queue.close() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + events = [] + async for event in flow.run_live(invocation_context): + events.append(event) + + # run_live must terminate after the first connection — no reconnect. + assert mock_connect.call_count == 1 + + +@pytest.mark.asyncio +async def test_run_live_still_reconnects_on_unintentional_drop_with_handle(): + """Test that session-resumption reconnection still works for genuine drops. + + A genuine network drop (ConnectionClosed without queue.close()) with a session + resumption handle must still trigger reconnection. The queue.close() fix + must not break this existing behaviour. + """ + from google.adk.agents.live_request_queue import LiveRequestQueue + from websockets.exceptions import ConnectionClosed + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + # Genuine network drop (queue was NOT closed). + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + # Note: queue.close() is NOT called — this is an unintentional drop. + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + mock_connection_2 = mock.AsyncMock() + + class NonRetryableError(Exception): + pass + + async def mock_receive_2(): + if False: + yield + raise NonRetryableError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except NonRetryableError: + pass + + # Reconnection must have been attempted (2 connections). + assert mock_connect.call_count == 2 + assert invocation_context.live_session_resumption_handle == 'test_handle'