Skip to content

Commit a8b92bf

Browse files
committed
Address review comments
1 parent a0f8f38 commit a8b92bf

2 files changed

Lines changed: 79 additions & 3 deletions

File tree

src/google/adk/sessions/base_session_service.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@ async def append_event(self, session: Session, event: Event) -> Event:
106106
"""Appends an event to a session object."""
107107
if event.partial:
108108
return event
109-
event = self._trim_temp_delta_state(event)
109+
# Update session state with ALL keys (including temp:) so they're accessible
110+
# during callbacks within the same invocation
110111
self._update_session_state(session, event)
112+
# Trim temp: keys from the event before persisting to avoid storing them
113+
event = self._trim_temp_delta_state(event)
111114
session.events.append(event)
112115
return event
113116

@@ -127,5 +130,4 @@ def _update_session_state(self, session: Session, event: Event) -> None:
127130
"""Updates the session state based on the event."""
128131
if not event.actions or not event.actions.state_delta:
129132
return
130-
for key, value in event.actions.state_delta.items():
131-
session.state.update({key: value})
133+
session.state.update(event.actions.state_delta)

tests/unittests/test_runners.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,5 +1322,79 @@ async def after_agent_callback(self, *, agent, callback_context):
13221322
assert "temp:test_key" not in event.actions.state_delta
13231323

13241324

1325+
@pytest.mark.asyncio
1326+
async def test_temp_state_from_state_delta_accessible_in_callbacks():
1327+
"""Tests that temp: state set via run_async state_delta parameter is
1328+
accessible during lifecycle callbacks but not persisted."""
1329+
1330+
# Track what state was seen during callbacks
1331+
state_seen_in_before_agent = {}
1332+
1333+
class StateAccessPlugin(BasePlugin):
1334+
"""Plugin that accesses state during callbacks."""
1335+
1336+
async def before_agent_callback(self, *, agent, callback_context):
1337+
# Check if temp state from state_delta is accessible
1338+
state_seen_in_before_agent["temp:from_run_async"] = (
1339+
callback_context.state.get("temp:from_run_async")
1340+
)
1341+
state_seen_in_before_agent["normal:from_run_async"] = (
1342+
callback_context.state.get("normal:from_run_async")
1343+
)
1344+
return None
1345+
1346+
# Setup
1347+
session_service = InMemorySessionService()
1348+
plugin = StateAccessPlugin(name="state_access")
1349+
1350+
agent = MockAgent(name="test_agent")
1351+
runner = Runner(
1352+
app_name=TEST_APP_ID,
1353+
agent=agent,
1354+
session_service=session_service,
1355+
plugins=[plugin],
1356+
auto_create_session=True,
1357+
)
1358+
1359+
# Run the agent with state_delta containing both temp and normal keys
1360+
events = []
1361+
async for event in runner.run_async(
1362+
user_id=TEST_USER_ID,
1363+
session_id=TEST_SESSION_ID,
1364+
new_message=types.Content(
1365+
role="user", parts=[types.Part(text="test message")]
1366+
),
1367+
state_delta={
1368+
"temp:from_run_async": "temp_value",
1369+
"normal:from_run_async": "normal_value",
1370+
},
1371+
):
1372+
events.append(event)
1373+
1374+
# Verify temp state from state_delta WAS accessible during callbacks
1375+
assert (
1376+
state_seen_in_before_agent["temp:from_run_async"] == "temp_value"
1377+
), "temp: state from state_delta should be accessible in callbacks"
1378+
assert state_seen_in_before_agent["normal:from_run_async"] == "normal_value"
1379+
1380+
# Verify temp state is NOT persisted in the session
1381+
session = await session_service.get_session(
1382+
app_name=TEST_APP_ID,
1383+
user_id=TEST_USER_ID,
1384+
session_id=TEST_SESSION_ID,
1385+
)
1386+
1387+
# Normal state should be persisted
1388+
assert session.state.get("normal:from_run_async") == "normal_value"
1389+
1390+
# Temp state should NOT be persisted
1391+
assert "temp:from_run_async" not in session.state
1392+
1393+
# Verify temp state is also not in any event's state_delta
1394+
for event in session.events:
1395+
if event.actions and event.actions.state_delta:
1396+
assert "temp:from_run_async" not in event.actions.state_delta
1397+
1398+
13251399
if __name__ == "__main__":
13261400
pytest.main([__file__])

0 commit comments

Comments
 (0)