diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index c653456a25..4bdaba398a 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -48,6 +48,12 @@ _USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata' +def _quote_filter_literal(value: str) -> str: + """Quotes filter values so embedded metacharacters stay inside the literal.""" + escaped_value = value.replace('\\', '\\\\').replace('"', '\\"') + return f'"{escaped_value}"' + + def _set_internal_custom_metadata( metadata_dict: dict[str, Any], *, key: str, value: dict[str, Any] ) -> None: @@ -228,7 +234,7 @@ async def list_sessions( sessions = [] config = {} if user_id is not None: - config['filter'] = f'user_id="{user_id}"' + config['filter'] = f'user_id={_quote_filter_literal(user_id)}' sessions_iterator = await api_client.agent_engines.sessions.list( name=f'reasoningEngines/{reasoning_engine_id}', config=config, diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 80a7c9c537..5f77f46a50 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -374,6 +374,7 @@ def __init__(self) -> None: self.agent_engines.sessions.events.list.side_effect = self._list_events self.agent_engines.sessions.events.append.side_effect = self._append_event self.last_create_session_config: dict[str, Any] = {} + self.last_list_sessions_config: dict[str, Any] = {} async def __aenter__(self): """Enters the asynchronous context.""" @@ -390,8 +391,9 @@ async def _get_session(self, name: str): raise api_core_exceptions.NotFound(f'Session not found: {session_id}') async def _list_sessions(self, name: str, config: dict[str, Any]): + self.last_list_sessions_config = config filter_val = config.get('filter', '') - user_id_match = re.search(r'user_id="([^"]+)"', filter_val) + user_id_match = re.search(r'user_id="((?:\\.|[^"])*)"', filter_val) if user_id_match: user_id = user_id_match.group(1) if user_id == 'user_with_pages': @@ -876,6 +878,29 @@ async def test_list_sessions_all_users(): } +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +@pytest.mark.parametrize( + ('payload', 'expected_filter'), + [ + ('attacker" OR user_id!=""', 'user_id="attacker\\" OR user_id!=\\"\\""'), + ('\\', 'user_id="\\\\"'), + ('', 'user_id=""'), + ], +) +async def test_list_sessions_quotes_user_id_filter( + mock_api_client_instance, payload, expected_filter +): + session_service = mock_vertex_ai_session_service() + + sessions = await session_service.list_sessions(app_name='123', user_id=payload) + + assert sessions.sessions == [] + assert mock_api_client_instance.last_list_sessions_config == { + 'filter': expected_filter + } + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_create_session():