Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,10 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
' backend. Please use Vertex AI backend.'
)
llm_request.live_connect_config.tools = llm_request.config.tools
if llm_request.config.thinking_config is not None:
llm_request.live_connect_config.thinking_config = (
llm_request.config.thinking_config
)
logger.debug('Connecting to live with llm_request:%s', llm_request)
logger.debug('Live connect config: %s', llm_request.live_connect_config)
async with self._live_api_client.aio.live.connect(
Expand Down
32 changes: 32 additions & 0 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,38 @@ async def __aexit__(self, *args):
)


@pytest.mark.asyncio
async def test_connect_copies_thinking_config_to_live_config(
gemini_llm, llm_request
):
"""Test that live connections preserve thinking_config from generate config."""
thinking_config = types.ThinkingConfig(
thinking_budget=10,
include_thoughts=True,
)
llm_request.config.thinking_config = thinking_config
llm_request.live_connect_config = types.LiveConnectConfig()

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:

class MockLiveConnect:

async def __aenter__(self):
return mock_live_session

async def __aexit__(self, *args):
pass

mock_live_client.aio.live.connect.return_value = MockLiveConnect()

async with gemini_llm.connect(llm_request):
mock_live_client.aio.live.connect.assert_called_once()
config_arg = mock_live_client.aio.live.connect.call_args.kwargs["config"]
assert config_arg.thinking_config == thinking_config


@pytest.mark.parametrize(
(
"api_backend, "
Expand Down