diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 8ae256f44c..b6b61fffe2 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -50,8 +50,8 @@ from ...telemetry.tracing import tracer from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext -from ...utils.context_utils import Aclosing from ...utils import model_name_utils +from ...utils.context_utils import Aclosing from .audio_cache_manager import AudioCacheManager from .functions import build_auth_request_event diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index da76cba189..d95c3013e1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -83,10 +83,15 @@ def _build_basic_request( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) - active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model + active_model_name = ( + getattr(getattr(agent, 'canonical_live_model', None), 'model', None) + or llm_request.model + ) is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name) llm_request.live_connect_config.enable_affective_dialog = ( - None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog + None + if is_gemini_31 + else invocation_context.run_config.enable_affective_dialog ) llm_request.live_connect_config.proactivity = ( None if is_gemini_31 else invocation_context.run_config.proactivity diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index cf74a5b9d4..bdddbe7068 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -88,11 +88,15 @@ async def send_history(self, history: list[types.Content]): # protocol error (invalid role mid-session), we consolidate previous multi-turn # interactions into a unified contextual preamble on a single user role turn. if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API: - collapsed_text = "Previous conversation history:\n" + collapsed_text = 'Previous conversation history:\n' for c in contents: - text_parts = "".join(p.text for p in c.parts if p.text) + text_parts = ''.join(p.text for p in c.parts if p.text) collapsed_text += f'[{c.role}]: {text_parts}\n' - contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])] + contents = [ + types.Content( + role='user', parts=[types.Part.from_text(text=collapsed_text)] + ) + ] logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send_client_content( @@ -281,7 +285,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: is_thought = current_is_thought llm_response.partial = True # don't yield the merged text event when receiving audio data - if text and not any(p.text for p in content.parts) and not has_inline_data: + if ( + text + and not any(p.text for p in content.parts) + and not has_inline_data + ): yield self.__build_full_text_response(text, is_thought) text = '' is_thought = False diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 6563e2db8b..9d58716687 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -24,7 +24,8 @@ from google.adk.events.event import Event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow -from google.adk.models.google_llm import Gemini, GoogleLLMVariant +from google.adk.models.google_llm import Gemini +from google.adk.models.google_llm import GoogleLLMVariant from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin @@ -1390,7 +1391,7 @@ async def mock_receive_2(): @pytest.mark.asyncio @pytest.mark.parametrize( - "api_backend", + 'api_backend', [ GoogleLLMVariant.GEMINI_API, GoogleLLMVariant.VERTEX_AI, @@ -1422,8 +1423,11 @@ async def mock_receive(): flow = BaseLlmFlowForTesting() with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + async def mock_preprocess(ctx, req): - req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] + req.contents = [ + types.Content(parts=[types.Part.from_text(text='history')]) + ] yield Event(id=Event.new_id(), author='test') with mock.patch.object( @@ -1467,7 +1471,9 @@ async def test_run_live_respects_explicit_initial_history_in_client_content_fals ) invocation_context.live_request_queue = LiveRequestQueue() run_config = RunConfig( - history_config=types.HistoryConfig(initial_history_in_client_content=False) + history_config=types.HistoryConfig( + initial_history_in_client_content=False + ) ) invocation_context.run_config = run_config @@ -1476,6 +1482,7 @@ async def test_run_live_respects_explicit_initial_history_in_client_content_fals async def mock_preprocess(ctx, req): req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] from google.adk.flows.llm_flows.basic import _build_basic_request + _build_basic_request(ctx, req) yield Event(id=Event.new_id(), author='test') @@ -1509,5 +1516,7 @@ async def mock_receive(): assert mock_connect.call_count == 1 call_req = mock_connect.call_args[0][0] assert call_req.live_connect_config.history_config is not None - assert call_req.live_connect_config.history_config.initial_history_in_client_content is False - + assert ( + call_req.live_connect_config.history_config.initial_history_in_client_content + is False + ) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index fae561331d..95ae692dab 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -1543,7 +1543,9 @@ async def mock_receive_generator(): @pytest.mark.asyncio -async def test_receive_multiplexed_parts(gemini_connection, mock_gemini_session): +async def test_receive_multiplexed_parts( + gemini_connection, mock_gemini_session +): """Test receive with multiplexed inline data and text content.""" mock_content = types.Content( role='model', @@ -1588,6 +1590,7 @@ async def mock_receive_generator(): async def test_send_history_gemini_31_turn_complete(mock_gemini_session): """Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True.""" from google.adk.models.google_llm import GoogleLLMVariant + conn = GeminiLlmConnection( mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API, @@ -1611,6 +1614,7 @@ async def test_send_history_gemini_31_turn_complete(mock_gemini_session): async def test_send_history_collapse_vertex_ai(mock_gemini_session): """Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend.""" from google.adk.models.google_llm import GoogleLLMVariant + conn = GeminiLlmConnection( mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI, @@ -1625,10 +1629,15 @@ async def test_send_history_collapse_vertex_ai(mock_gemini_session): await conn.send_history(mock_contents) assert mock_gemini_session.send_client_content.call_count == 1 - called_turns = mock_gemini_session.send_client_content.call_args.kwargs['turns'] + called_turns = mock_gemini_session.send_client_content.call_args.kwargs[ + 'turns' + ] assert len(called_turns) == 1 assert called_turns[0].role == 'user' assert 'Previous conversation history:' in called_turns[0].parts[0].text assert '[user]: hi' in called_turns[0].parts[0].text assert '[model]: hello' in called_turns[0].parts[0].text - assert mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] is True + assert ( + mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] + is True + )