From bbd48a0f1e66b20945d998f40cb3be26183e1af1 Mon Sep 17 00:00:00 2001 From: Tongzhou Jiang Date: Mon, 1 Jun 2026 14:18:23 -0700 Subject: [PATCH] fix: Update authorization for streaming_agent_run_with_events PiperOrigin-RevId: 924908903 --- agentplatform/agent_engines/templates/adk.py | 7 ++ .../test_agent_engine_templates_adk.py | 66 +++++++++++++++++++ vertexai/agent_engines/templates/adk.py | 7 ++ .../reasoning_engines/templates/adk.py | 7 ++ 4 files changed, 87 insertions(+) diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index c68cc1d45d..e7f6f39814 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -1332,6 +1332,7 @@ async def streaming_agent_run_with_events(self, request_json: str): self.set_up() # Try to get the session, if it doesn't exist, create a new one. + state_delta = None if request.session_id: session_service = self._tmpl_attrs.get("session_service") artifact_service = self._tmpl_attrs.get("artifact_service") @@ -1349,6 +1350,11 @@ async def streaming_agent_run_with_events(self, request_json: str): artifact_service=artifact_service, request=request, ) + if request.authorizations: + state_delta = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + state_delta[auth_id] = auth.access_token except ClientError: pass if not session: @@ -1380,6 +1386,7 @@ async def streaming_agent_run_with_events(self, request_json: str): user_id=request.user_id, session_id=session.id, new_message=message_for_agent, + state_delta=state_delta, ): converted_event = await self._convert_response_events( user_id=request.user_id, diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index ca4503a581..2e80315044 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -507,6 +507,72 @@ async def test_streaming_agent_run_with_events( events.append(event) assert len(events) == 1 + @pytest.mark.asyncio + async def test_streaming_agent_run_with_events_existing_session( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + + # Pre-create a session in the real in-memory session service + await app.async_create_session( + user_id=_TEST_USER_ID, session_id="test_session_id" + ) + + # Mock the main runner + runner_mock = mock.Mock() + + # Define an async generator for run_async mock return value + async def mock_run_async(*args, **kwargs): + from google.adk.events import event + yield event.Event( + **{ + "author": "currency_exchange_agent", + "content": { + "parts": [{"text": "Sweden"}], + "role": "model", + }, + "id": "9aaItGK9", + "invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7", + } + ) + + spy = mock.MagicMock(side_effect=mock_run_async) + runner_mock.run_async = spy + app._tmpl_attrs["runner"] = runner_mock + + request_json = json.dumps( + { + "authorizations": { + "test_user_id1": {"access_token": "test_access_token"}, + }, + "user_id": _TEST_USER_ID, + "session_id": "test_session_id", + "message": { + "parts": [{"text": "What is the exchange rate from USD to SEK?"}], + "role": "user", + }, + } + ) + + events = [] + async for event in app.streaming_agent_run_with_events( + request_json=request_json, + ): + events.append(event) + + assert len(events) == 1 + + # Assert that run_async was called with the expected state_delta! + spy.assert_called_once_with( + user_id=_TEST_USER_ID, + session_id="test_session_id", + new_message=mock.ANY, + state_delta={"test_user_id1": "test_access_token"}, + ) + @pytest.mark.asyncio @mock.patch.dict( os.environ, diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 16b477c573..43bc996bef 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -1332,6 +1332,7 @@ async def streaming_agent_run_with_events(self, request_json: str): self.set_up() # Try to get the session, if it doesn't exist, create a new one. + state_delta = None if request.session_id: session_service = self._tmpl_attrs.get("session_service") artifact_service = self._tmpl_attrs.get("artifact_service") @@ -1349,6 +1350,11 @@ async def streaming_agent_run_with_events(self, request_json: str): artifact_service=artifact_service, request=request, ) + if request.authorizations: + state_delta = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + state_delta[auth_id] = auth.access_token except ClientError: pass if not session: @@ -1380,6 +1386,7 @@ async def streaming_agent_run_with_events(self, request_json: str): user_id=request.user_id, session_id=session.id, new_message=message_for_agent, + state_delta=state_delta, ): converted_event = await self._convert_response_events( user_id=request.user_id, diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 469ecf71cc..0537a8f6bc 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -1060,6 +1060,7 @@ async def _invoke_agent_async(): ): self.set_up() # Try to get the session, if it doesn't exist, create a new one. + state_delta = None if request.session_id: session_service = self._tmpl_attrs.get("session_service") artifact_service = self._tmpl_attrs.get("artifact_service") @@ -1077,6 +1078,11 @@ async def _invoke_agent_async(): artifact_service=artifact_service, request=request, ) + if request.authorizations: + state_delta = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + state_delta[auth_id] = auth.access_token except ClientError: pass if not session: @@ -1107,6 +1113,7 @@ async def _invoke_agent_async(): user_id=request.user_id, session_id=session.id, new_message=message_for_agent, + state_delta=state_delta, ): converted_event = await self._convert_response_events( user_id=request.user_id,