Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
if text:
yield self.__build_full_text_response(text)
text = ''
else:
# this condition prevents duplicate interruption signals
if not content or not content.parts:
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.tool_call:
if text:
Expand Down
135 changes: 135 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,138 @@ async def test_send_history_filters_various_audio_mime_types(

# No content should be sent since the only part is audio
mock_gemini_session.send.assert_not_called()


@pytest.mark.asyncio
async def test_receive_interrupted_with_accumulated_text_and_empty_content(
gemini_connection, mock_gemini_session
):
"""Test interrupt with accumulated text and empty content yields both.

When interrupted signal arrives with accumulated partial text and no content,
the implementation should:
1. First yield the accumulated text as a full response
2. Then yield a separate interruption signal

This prevents losing accumulated text on interruption.
"""
message1 = mock.Mock()
message1.usage_metadata = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = types.Content(
role='model', parts=[types.Part.from_text(text='Hello ')]
)
message1.server_content.interrupted = False
message1.server_content.input_transcription = None
message1.server_content.output_transcription = None
message1.server_content.turn_complete = False
message1.server_content.generation_complete = False
message1.tool_call = None
message1.session_resumption_update = None

message2 = mock.Mock()
message2.usage_metadata = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = types.Content(
role='model', parts=[types.Part.from_text(text='world')]
)
message2.server_content.interrupted = False
message2.server_content.input_transcription = None
message2.server_content.output_transcription = None
message2.server_content.turn_complete = False
message2.server_content.generation_complete = False
message2.tool_call = None
message2.session_resumption_update = None

# Interruption with no content/parts (e.g., user interrupted during thinking)
message3 = mock.Mock()
message3.usage_metadata = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None # No content
message3.server_content.interrupted = True
message3.server_content.input_transcription = None
message3.server_content.output_transcription = None
message3.server_content.turn_complete = False
message3.server_content.generation_complete = False
message3.tool_call = None
message3.session_resumption_update = None

async def mock_receive_generator():
yield message1
yield message2
yield message3

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Should have:
# - 2 partial text responses (Hello, world)
# - 1 full text response (Hello world) when interrupted
# - 1 interruption signal
assert len(responses) == 4

# First two are partial text
assert responses[0].content.parts[0].text == 'Hello '
assert responses[0].partial is True
assert responses[1].content.parts[0].text == 'world'
assert responses[1].partial is True

# Third is the merged full text
assert responses[2].content.parts[0].text == 'Hello world'
assert not responses[2].partial
assert not responses[2].interrupted

# Fourth is the interruption signal
assert responses[3].interrupted is True
assert responses[3].content is None


@pytest.mark.asyncio
async def test_receive_interrupted_with_content(
gemini_connection, mock_gemini_session
):
"""Test interrupt with content yields partial then full response.

When interrupted signal arrives with content (no prior accumulated text),
the current implementation:
1. Yields the content as a partial response with interrupted flag
2. Yields the accumulated text as a full response

This is because the text gets accumulated before being flushed on interrupt.
"""
message = mock.Mock()
message.usage_metadata = None
message.server_content = mock.Mock()
message.server_content.model_turn = types.Content(
role='model', parts=[types.Part.from_text(text='Interrupted text')]
)
message.server_content.interrupted = True
message.server_content.input_transcription = None
message.server_content.output_transcription = None
message.server_content.turn_complete = False
message.server_content.generation_complete = False
message.tool_call = None
message.session_resumption_update = None

async def mock_receive_generator():
yield message

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Should have 2 responses:
# - 1 partial response with interrupted flag
# - 1 full text response (accumulated text flushed on interrupt)
assert len(responses) == 2
assert responses[0].content.parts[0].text == 'Interrupted text'
assert responses[0].interrupted is True
assert responses[0].partial is True

# Second response is the full accumulated text
assert responses[1].content.parts[0].text == 'Interrupted text'
assert responses[1].partial is not True # May be None or False
assert not responses[1].interrupted