diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 63606b21b0..80540064ab 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -281,7 +281,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if text: yield self.__build_full_text_response(text) text = '' - else: + # this condition prevents duplicate interruption signals + if not content or not content.parts: yield LlmResponse(interrupted=message.server_content.interrupted) if message.tool_call: if text: diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index d065661c69..f56b486bed 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -774,3 +774,138 @@ async def test_send_history_filters_various_audio_mime_types( # No content should be sent since the only part is audio mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_receive_interrupted_with_accumulated_text_and_empty_content( + gemini_connection, mock_gemini_session +): + """Test interrupt with accumulated text and empty content yields both. + + When interrupted signal arrives with accumulated partial text and no content, + the implementation should: + 1. First yield the accumulated text as a full response + 2. Then yield a separate interruption signal + + This prevents losing accumulated text on interruption. + """ + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = types.Content( + role='model', parts=[types.Part.from_text(text='Hello ')] + ) + message1.server_content.interrupted = False + message1.server_content.input_transcription = None + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = types.Content( + role='model', parts=[types.Part.from_text(text='world')] + ) + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = None + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + # Interruption with no content/parts (e.g., user interrupted during thinking) + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None # No content + message3.server_content.interrupted = True + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = False + message3.server_content.generation_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Should have: + # - 2 partial text responses (Hello, world) + # - 1 full text response (Hello world) when interrupted + # - 1 interruption signal + assert len(responses) == 4 + + # First two are partial text + assert responses[0].content.parts[0].text == 'Hello ' + assert responses[0].partial is True + assert responses[1].content.parts[0].text == 'world' + assert responses[1].partial is True + + # Third is the merged full text + assert responses[2].content.parts[0].text == 'Hello world' + assert not responses[2].partial + assert not responses[2].interrupted + + # Fourth is the interruption signal + assert responses[3].interrupted is True + assert responses[3].content is None + + +@pytest.mark.asyncio +async def test_receive_interrupted_with_content( + gemini_connection, mock_gemini_session +): + """Test interrupt with content yields partial then full response. + + When interrupted signal arrives with content (no prior accumulated text), + the current implementation: + 1. Yields the content as a partial response with interrupted flag + 2. Yields the accumulated text as a full response + + This is because the text gets accumulated before being flushed on interrupt. + """ + message = mock.Mock() + message.usage_metadata = None + message.server_content = mock.Mock() + message.server_content.model_turn = types.Content( + role='model', parts=[types.Part.from_text(text='Interrupted text')] + ) + message.server_content.interrupted = True + message.server_content.input_transcription = None + message.server_content.output_transcription = None + message.server_content.turn_complete = False + message.server_content.generation_complete = False + message.tool_call = None + message.session_resumption_update = None + + async def mock_receive_generator(): + yield message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Should have 2 responses: + # - 1 partial response with interrupted flag + # - 1 full text response (accumulated text flushed on interrupt) + assert len(responses) == 2 + assert responses[0].content.parts[0].text == 'Interrupted text' + assert responses[0].interrupted is True + assert responses[0].partial is True + + # Second response is the full accumulated text + assert responses[1].content.parts[0].text == 'Interrupted text' + assert responses[1].partial is not True # May be None or False + assert not responses[1].interrupted