Skip to content
47 changes: 38 additions & 9 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,18 @@ def __init__(
gemini_session: live.AsyncSession,
api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI,
model_version: str | None = None,
live_config: types.LiveConnectConfig | None = None,
):
self._gemini_session = gemini_session
self._input_transcription_text: str = ''
self._output_transcription_text: str = ''
self._api_backend = api_backend
self._model_version = model_version

self._audio_active = False
if live_config and getattr(live_config, 'response_modalities', None):
self._audio_active = 'AUDIO' in live_config.response_modalities

async def send_history(self, history: list[types.Content]):
"""Sends the conversation history to the gemini model.

Expand Down Expand Up @@ -111,16 +116,23 @@ async def send_content(self, content: types.Content):
),
)
else:
logger.debug('Sending LLM new content %s', content)
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
self._model_version
)
if is_gemini_31 and len(content.parts) == 1 and content.parts[0].text:
logger.debug('Using send_realtime_input for Gemini 3.1 text input')
await self._gemini_session.send_realtime_input(
text=content.parts[0].text
is_gemini_api = self._api_backend == GoogleLLMVariant.GEMINI_API

# Route via send_realtime_input if audio is active OR if targeting 3.1 API
if (self._audio_active or (is_gemini_31 and is_gemini_api)) and all(
isinstance(part.text, str) for part in content.parts
):
logger.debug(
'Routing text via send_realtime_input %s',
content,
)
for part in content.parts:
await self._gemini_session.send_realtime_input(text=part.text)
else:
logger.debug('Sending LLM new content %s', content)
await self._gemini_session.send(
input=types.LiveClientContent(
turns=[content],
Expand All @@ -140,10 +152,18 @@ async def send_realtime(self, input: RealtimeInput):
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
self._model_version
)
if is_gemini_31:
if input.mime_type and input.mime_type.startswith('audio/'):
is_gemini_api = self._api_backend == GoogleLLMVariant.GEMINI_API

# As of now, Gemini 3.1 Flash Live is only available in Gemini API, not
# Vertex AI.
if is_gemini_31 and is_gemini_api:
if isinstance(input.mime_type, str) and input.mime_type.startswith(
'audio/'
):
await self._gemini_session.send_realtime_input(audio=input)
elif input.mime_type and input.mime_type.startswith('image/'):
elif isinstance(input.mime_type, str) and input.mime_type.startswith(
'image/'
):
await self._gemini_session.send_realtime_input(video=input)
else:
logger.warning(
Expand All @@ -152,7 +172,16 @@ async def send_realtime(self, input: RealtimeInput):
input.mime_type,
)
else:
await self._gemini_session.send_realtime_input(media=input)
if isinstance(input.mime_type, str) and input.mime_type.startswith(
'video/'
):
await self._gemini_session.send_realtime_input(video=input)
elif isinstance(input.mime_type, str) and input.mime_type.startswith(
'audio/'
):
await self._gemini_session.send_realtime_input(audio=input)
else:
await self._gemini_session.send_realtime_input(media=input)

elif isinstance(input, types.ActivityStart):
logger.debug('Sending LLM activity start signal.')
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
live_session,
api_backend=self._api_backend,
model_version=llm_request.model,
live_config=llm_request.live_connect_config,
)

async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None:
Expand Down
76 changes: 72 additions & 4 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,37 @@ def test_blob():
return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm')


@pytest.fixture
def test_fallback_blob():
"""Test blob for unknown media data."""
return types.Blob(data=b'\x01\x02', mime_type='application/pdf')


@pytest.mark.asyncio
async def test_send_realtime_default_behavior(
async def test_send_realtime_audio_routing(
gemini_connection, mock_gemini_session, test_blob
):
"""Test send_realtime with default automatic_activity_detection value (True)."""
"""Test send_realtime explicitly routing audio mimetypes to the audio parameter."""
await gemini_connection.send_realtime(test_blob)

# Should call send once
mock_gemini_session.send_realtime_input.assert_called_once_with(
media=test_blob
audio=test_blob
)
# Should not call .send function
mock_gemini_session.send.assert_not_called()


@pytest.mark.asyncio
async def test_send_realtime_media_fallback_routing(
gemini_connection, mock_gemini_session, test_fallback_blob
):
"""Test send_realtime falling back to media for non-audio/video mimetypes."""
await gemini_connection.send_realtime(test_fallback_blob)

# Should call send once
mock_gemini_session.send_realtime_input.assert_called_once_with(
media=test_fallback_blob
)
# Should not call .send function
mock_gemini_session.send.assert_not_called()
Expand All @@ -90,7 +111,12 @@ async def test_send_history(gemini_connection, mock_gemini_session):

@pytest.mark.asyncio
async def test_send_content_text(gemini_connection, mock_gemini_session):
"""Test send_content with text content."""
"""Test send_content with text content when audio is inactive.

Note: gemini_connection._audio_active is False by default.
"""
assert gemini_connection._audio_active is False

content = types.Content(
role='user', parts=[types.Part.from_text(text='Hello')]
)
Expand All @@ -104,6 +130,48 @@ async def test_send_content_text(gemini_connection, mock_gemini_session):
assert call_args['input'].turn_complete is True


@pytest.mark.asyncio
async def test_send_content_text_audio_active(
gemini_connection, mock_gemini_session
):
"""Test send_content routes to send_realtime_input when audio is active."""
gemini_connection._audio_active = True

content = types.Content(
role='user', parts=[types.Part.from_text(text='Hello')]
)

await gemini_connection.send_content(content)

mock_gemini_session.send_realtime_input.assert_called_once_with(text='Hello')
mock_gemini_session.send.assert_not_called()


@pytest.mark.asyncio
async def test_send_content_mixed_audio_active(
gemini_connection, mock_gemini_session, test_blob
):
"""Test send_content falls back to LiveClientContent for mixed modalities."""
gemini_connection._audio_active = True

content = types.Content(
role='user',
parts=[
types.Part.from_text(text='Hello'),
types.Part(inline_data=test_blob)
]
)

await gemini_connection.send_content(content)

mock_gemini_session.send.assert_called_once()
call_args = mock_gemini_session.send.call_args[1]
assert 'input' in call_args
assert call_args['input'].turns == [content]
assert call_args['input'].turn_complete is True
mock_gemini_session.send_realtime_input.assert_not_called()


@pytest.mark.asyncio
async def test_send_content_function_response(
gemini_connection, mock_gemini_session
Expand Down
1 change: 1 addition & 0 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ async def __aexit__(self, *args):
mock_live_session,
api_backend=gemini_llm._api_backend,
model_version=llm_request.model,
live_config=llm_request.live_connect_config,
)


Expand Down