diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f1c1cce813..ce2ae72c0a 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -749,9 +749,16 @@ async def _postprocess_live( and not llm_response.input_transcription and not llm_response.output_transcription and not llm_response.usage_metadata + and not llm_response.setup_complete ): return + # Handle setup complete events + if llm_response.setup_complete: + model_response_event.setup_complete = llm_response.setup_complete + yield model_response_event + return + # Handle transcription events ONCE per llm_response, outside the event loop if llm_response.input_transcription: model_response_event.input_transcription = ( diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 63606b21b0..d798a43376 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -171,6 +171,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # partial content and emit responses as needed. async for message in agen: logger.debug('Got LLM Live message: %s', message) + if message.setup_complete: + yield LlmResponse(setup_complete=True) if message.usage_metadata: # Tracks token usage data per model. yield LlmResponse( diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index 754e5abcfb..27266da730 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -46,6 +46,8 @@ class LlmResponse(BaseModel): output_transcription: Audio transcription of model output. avg_logprobs: Average log probability of the generated tokens. logprobs_result: Detailed log probabilities for chosen and top candidate tokens. + setup_complete: Indicates whether the initial model setup is complete. + Only used for Gemini Live streaming mode. """ model_config = ConfigDict( @@ -80,6 +82,12 @@ class LlmResponse(BaseModel): Only used for streaming mode. """ + setup_complete: Optional[bool] = None + """Indicates whether the initial model setup is complete. + + Only used for Gemini Live streaming mode. + """ + finish_reason: Optional[types.FinishReason] = None """The finish reason of the response.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index d065661c69..5f06ef8322 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -252,6 +252,7 @@ async def test_receive_transcript_finished_on_interrupt( message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -266,6 +267,7 @@ async def test_receive_transcript_finished_on_interrupt( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -280,6 +282,7 @@ async def test_receive_transcript_finished_on_interrupt( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = True @@ -326,6 +329,7 @@ async def test_receive_transcript_finished_on_generation_complete( message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -340,6 +344,7 @@ async def test_receive_transcript_finished_on_generation_complete( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -354,6 +359,7 @@ async def test_receive_transcript_finished_on_generation_complete( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -399,6 +405,7 @@ async def test_receive_transcript_finished_on_turn_complete( message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -413,6 +420,7 @@ async def test_receive_transcript_finished_on_turn_complete( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -427,6 +435,7 @@ async def test_receive_transcript_finished_on_turn_complete( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -471,6 +480,7 @@ async def test_receive_handles_input_transcription_fragments( """Test receive handles input transcription fragments correctly.""" message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -485,6 +495,7 @@ async def test_receive_handles_input_transcription_fragments( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -499,6 +510,7 @@ async def test_receive_handles_input_transcription_fragments( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -540,6 +552,7 @@ async def test_receive_handles_output_transcription_fragments( """Test receive handles output transcription fragments correctly.""" message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -554,6 +567,7 @@ async def test_receive_handles_output_transcription_fragments( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -568,6 +582,7 @@ async def test_receive_handles_output_transcription_fragments( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -774,3 +789,27 @@ async def test_send_history_filters_various_audio_mime_types( # No content should be sent since the only part is audio mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_receive_setup_complete(gemini_connection, mock_gemini_session): + """Test receive handles setup_complete signal.""" + + # Create a mock message simulating BidiGenerateContentSetupComplete + message = mock.Mock() + message.setup_complete = True + message.usage_metadata = None + message.server_content = None + message.tool_call = None + message.session_resumption_update = None + + async def mock_receive_generator(): + yield message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + assert len(responses) == 1 + assert responses[0].setup_complete is True diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index 8c54502e96..134f0504a7 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -54,6 +54,41 @@ def test_streaming(): ), 'Expected at least one response, but got an empty list.' +def test_live_streaming_setup_complete(): + """Test live streaming with setup complete event.""" + # Create LLM responses: setup complete followed by turn completion + response1 = LlmResponse( + setup_complete=True, + ) + response2 = LlmResponse( + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[], + ) + + runner = testing_utils.InMemoryRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + res_events = runner.run_live(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + assert len(res_events) >= 1, 'Expected at least one event.' + + # Check that we got a setup complete event + setup_complete_found = False + for event in res_events: + if event.setup_complete: + setup_complete_found = True + break + + assert setup_complete_found, 'Expected a setup complete event.' + + def test_live_streaming_function_call_single(): """Test live streaming with a single function call response.""" # Create a function call response