Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
' backend. Please use Vertex AI backend.'
)
llm_request.live_connect_config.tools = llm_request.config.tools
if llm_request.config.thinking_config is not None:
llm_request.live_connect_config.thinking_config = (
llm_request.config.thinking_config
)
logger.debug('Connecting to live with llm_request:%s', llm_request)
logger.debug('Live connect config: %s', llm_request.live_connect_config)
async with self._live_api_client.aio.live.connect(
Expand Down
33 changes: 31 additions & 2 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import logging
import os
import sys
from typing import Optional
from unittest import mock
Expand Down Expand Up @@ -832,7 +831,7 @@ async def __aexit__(self, *args):
with mock.patch(
"google.adk.models.google_llm.GeminiLlmConnection"
) as MockGeminiLlmConnection:
async with gemini_llm.connect(llm_request) as connection:
async with gemini_llm.connect(llm_request):
# Verify that the connect method was called with the right config
mock_live_client.aio.live.connect.assert_called_once()
call_args = mock_live_client.aio.live.connect.call_args
Expand All @@ -852,6 +851,36 @@ async def __aexit__(self, *args):
)


@pytest.mark.asyncio
async def test_connect_forwards_thinking_config(gemini_llm, llm_request):
"""Test that live sessions keep the request thinking_config."""
thinking_config = types.ThinkingConfig(thinking_budget=128)
llm_request.config.thinking_config = thinking_config
llm_request.live_connect_config = types.LiveConnectConfig()

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:

class MockLiveConnect:

async def __aenter__(self):
return mock_live_session

async def __aexit__(self, *args):
pass

mock_live_client.aio.live.connect.return_value = MockLiveConnect()

async with gemini_llm.connect(llm_request) as connection:
mock_live_client.aio.live.connect.assert_called_once()
call_args = mock_live_client.aio.live.connect.call_args
config_arg = call_args.kwargs["config"]

assert config_arg.thinking_config == thinking_config
assert isinstance(connection, GeminiLlmConnection)


@pytest.mark.parametrize(
(
"api_backend, "
Expand Down