Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
_USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata'


def _quote_filter_literal(value: str) -> str:
"""Quotes filter values so embedded metacharacters stay inside the literal."""
return json.dumps(value)


def _set_internal_custom_metadata(
metadata_dict: dict[str, Any], *, key: str, value: dict[str, Any]
) -> None:
Expand Down Expand Up @@ -228,7 +233,7 @@ async def list_sessions(
sessions = []
config = {}
if user_id is not None:
config['filter'] = f'user_id="{user_id}"'
config['filter'] = f'user_id={_quote_filter_literal(user_id)}'
sessions_iterator = await api_client.agent_engines.sessions.list(
name=f'reasoningEngines/{reasoning_engine_id}',
config=config,
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import copy
import datetime
import json
import re
import types
from typing import Any
Expand Down Expand Up @@ -374,6 +375,7 @@ def __init__(self) -> None:
self.agent_engines.sessions.events.list.side_effect = self._list_events
self.agent_engines.sessions.events.append.side_effect = self._append_event
self.last_create_session_config: dict[str, Any] = {}
self.last_list_sessions_config: dict[str, Any] = {}

async def __aenter__(self):
"""Enters the asynchronous context."""
Expand All @@ -390,6 +392,7 @@ async def _get_session(self, name: str):
raise api_core_exceptions.NotFound(f'Session not found: {session_id}')

async def _list_sessions(self, name: str, config: dict[str, Any]):
self.last_list_sessions_config = config
filter_val = config.get('filter', '')
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
if user_id_match:
Expand Down Expand Up @@ -876,6 +879,20 @@ async def test_list_sessions_all_users():
}


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions_quotes_user_id_filter(mock_api_client_instance):
session_service = mock_vertex_ai_session_service()
payload = 'attacker" OR user_id!=""'

sessions = await session_service.list_sessions(app_name='123', user_id=payload)

assert sessions.sessions == []
assert mock_api_client_instance.last_list_sessions_config == {
'filter': f'user_id={json.dumps(payload)}'
}


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session():
Expand Down