Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,16 @@ async def _postprocess_live(
and not llm_response.input_transcription
and not llm_response.output_transcription
and not llm_response.usage_metadata
and not llm_response.setup_complete
):
return

# Handle setup complete events
if llm_response.setup_complete:
model_response_event.setup_complete = llm_response.setup_complete
yield model_response_event
return

# Handle transcription events ONCE per llm_response, outside the event loop
if llm_response.input_transcription:
model_response_event.input_transcription = (
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
# partial content and emit responses as needed.
async for message in agen:
logger.debug('Got LLM Live message: %s', message)
if message.setup_complete:
yield LlmResponse(setup_complete=True)
if message.usage_metadata:
# Tracks token usage data per model.
yield LlmResponse(
Expand Down
8 changes: 8 additions & 0 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class LlmResponse(BaseModel):
output_transcription: Audio transcription of model output.
avg_logprobs: Average log probability of the generated tokens.
logprobs_result: Detailed log probabilities for chosen and top candidate tokens.
setup_complete: Indicates whether the initial model setup is complete.
Only used for Gemini Live streaming mode.
"""

model_config = ConfigDict(
Expand Down Expand Up @@ -80,6 +82,12 @@ class LlmResponse(BaseModel):
Only used for streaming mode.
"""

setup_complete: Optional[bool] = None
"""Indicates whether the initial model setup is complete.

Only used for Gemini Live streaming mode.
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""

Expand Down
39 changes: 39 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def test_receive_transcript_finished_on_interrupt(

message1 = mock.Mock()
message1.usage_metadata = None
message1.setup_complete = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
Expand All @@ -266,6 +267,7 @@ async def test_receive_transcript_finished_on_interrupt(

message2 = mock.Mock()
message2.usage_metadata = None
message2.setup_complete = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
Expand All @@ -280,6 +282,7 @@ async def test_receive_transcript_finished_on_interrupt(

message3 = mock.Mock()
message3.usage_metadata = None
message3.setup_complete = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = True
Expand Down Expand Up @@ -326,6 +329,7 @@ async def test_receive_transcript_finished_on_generation_complete(

message1 = mock.Mock()
message1.usage_metadata = None
message1.setup_complete = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
Expand All @@ -340,6 +344,7 @@ async def test_receive_transcript_finished_on_generation_complete(

message2 = mock.Mock()
message2.usage_metadata = None
message2.setup_complete = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
Expand All @@ -354,6 +359,7 @@ async def test_receive_transcript_finished_on_generation_complete(

message3 = mock.Mock()
message3.usage_metadata = None
message3.setup_complete = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
Expand Down Expand Up @@ -399,6 +405,7 @@ async def test_receive_transcript_finished_on_turn_complete(

message1 = mock.Mock()
message1.usage_metadata = None
message1.setup_complete = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
Expand All @@ -413,6 +420,7 @@ async def test_receive_transcript_finished_on_turn_complete(

message2 = mock.Mock()
message2.usage_metadata = None
message2.setup_complete = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
Expand All @@ -427,6 +435,7 @@ async def test_receive_transcript_finished_on_turn_complete(

message3 = mock.Mock()
message3.usage_metadata = None
message3.setup_complete = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
Expand Down Expand Up @@ -471,6 +480,7 @@ async def test_receive_handles_input_transcription_fragments(
"""Test receive handles input transcription fragments correctly."""
message1 = mock.Mock()
message1.usage_metadata = None
message1.setup_complete = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
Expand All @@ -485,6 +495,7 @@ async def test_receive_handles_input_transcription_fragments(

message2 = mock.Mock()
message2.usage_metadata = None
message2.setup_complete = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
Expand All @@ -499,6 +510,7 @@ async def test_receive_handles_input_transcription_fragments(

message3 = mock.Mock()
message3.usage_metadata = None
message3.setup_complete = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
Expand Down Expand Up @@ -540,6 +552,7 @@ async def test_receive_handles_output_transcription_fragments(
"""Test receive handles output transcription fragments correctly."""
message1 = mock.Mock()
message1.usage_metadata = None
message1.setup_complete = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
Expand All @@ -554,6 +567,7 @@ async def test_receive_handles_output_transcription_fragments(

message2 = mock.Mock()
message2.usage_metadata = None
message2.setup_complete = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
Expand All @@ -568,6 +582,7 @@ async def test_receive_handles_output_transcription_fragments(

message3 = mock.Mock()
message3.usage_metadata = None
message3.setup_complete = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
Expand Down Expand Up @@ -774,3 +789,27 @@ async def test_send_history_filters_various_audio_mime_types(

# No content should be sent since the only part is audio
mock_gemini_session.send.assert_not_called()


@pytest.mark.asyncio
async def test_receive_setup_complete(gemini_connection, mock_gemini_session):
"""Test receive handles setup_complete signal."""

# Create a mock message simulating BidiGenerateContentSetupComplete
message = mock.Mock()
message.setup_complete = True
message.usage_metadata = None
message.server_content = None
message.tool_call = None
message.session_resumption_update = None

async def mock_receive_generator():
yield message

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

assert len(responses) == 1
assert responses[0].setup_complete is True
35 changes: 35 additions & 0 deletions tests/unittests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,41 @@ def test_streaming():
), 'Expected at least one response, but got an empty list.'


def test_live_streaming_setup_complete():
"""Test live streaming with setup complete event."""
# Create LLM responses: setup complete followed by turn completion
response1 = LlmResponse(
setup_complete=True,
)
response2 = LlmResponse(
turn_complete=True,
)

mock_model = testing_utils.MockModel.create([response1, response2])

root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[],
)

runner = testing_utils.InMemoryRunner(root_agent=root_agent)
live_request_queue = LiveRequestQueue()
res_events = runner.run_live(live_request_queue)

assert res_events is not None, 'Expected a list of events, got None.'
assert len(res_events) >= 1, 'Expected at least one event.'

# Check that we got a setup complete event
setup_complete_found = False
for event in res_events:
if event.setup_complete:
setup_complete_found = True
break

assert setup_complete_found, 'Expected a setup complete event.'


def test_live_streaming_function_call_single():
"""Test live streaming with a single function call response."""
# Create a function call response
Expand Down
Loading