Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/a2a/server/request_handlers/jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ async def list_tasks(
response = await self.request_handler.on_list_tasks(
request, context
)
result = MessageToDict(response, preserving_proto_field_name=False)
result = MessageToDict(
response,
preserving_proto_field_name=False,
always_print_fields_with_no_presence=True,
)
return _build_success_response(request_id, result)
except A2AError as e:
return _build_error_response(request_id, e)
Expand Down
5 changes: 4 additions & 1 deletion src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,19 @@
from a2a.compat.v0_3 import types as types_v03
from a2a.server.context import ServerCallContext
from a2a.server.models import Base, TaskModel, create_task_model
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
from a2a.server.tasks.task_store import TaskStore
from a2a.types import a2a_pb2
from a2a.types.a2a_pb2 import Task
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.errors import InvalidParamsError
from a2a.utils.task import decode_page_token, encode_page_token


logger = logging.getLogger(__name__)


class DatabaseTaskStore(TaskStore):

Check notice on line 51 in src/a2a/server/tasks/database_task_store.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/tasks/inmemory_task_store.py (5-17)
"""SQLAlchemy-based implementation of TaskStore.

Stores task objects in a database supported by SQLAlchemy.
Expand Down Expand Up @@ -285,7 +286,9 @@
)
).scalar_one_or_none()
if not start_task:
raise ValueError(f'Invalid page token: {params.page_token}')
raise InvalidParamsError(
f'Invalid page token: {params.page_token}'
)

start_task_timestamp = start_task.last_updated
where_clauses = []
Expand Down
5 changes: 4 additions & 1 deletion src/a2a/server/tasks/inmemory_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
import logging

from a2a.server.context import ServerCallContext
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
from a2a.server.tasks.task_store import TaskStore
from a2a.types import a2a_pb2
from a2a.types.a2a_pb2 import Task
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.errors import InvalidParamsError
from a2a.utils.task import decode_page_token, encode_page_token


logger = logging.getLogger(__name__)


class InMemoryTaskStore(TaskStore):

Check notice on line 17 in src/a2a/server/tasks/inmemory_task_store.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/tasks/database_task_store.py (39-51)
"""In-memory implementation of TaskStore.

Stores task objects in a nested dictionary in memory, keyed by owner then task_id.
Expand Down Expand Up @@ -135,7 +136,9 @@
valid_token = True
break
if not valid_token:
raise ValueError(f'Invalid page token: {params.page_token}')
raise InvalidParamsError(
f'Invalid page token: {params.page_token}'
)
page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
end_idx = start_idx + page_size
next_page_token = (
Expand Down
5 changes: 3 additions & 2 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ def decode_page_token(page_token: str) -> str:
missing_padding = len(encoded_str) % 4
if missing_padding:
encoded_str += '=' * (4 - missing_padding)
print(f'input: {encoded_str}')
try:
decoded = b64decode(encoded_str.encode(_ENCODING)).decode(_ENCODING)
except (binascii.Error, UnicodeDecodeError) as e:
raise ValueError('Token is not a valid base64-encoded cursor.') from e
raise InvalidParamsError(
'Token is not a valid base64-encoded cursor.'
) from e
return decoded
25 changes: 25 additions & 0 deletions tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,31 @@ async def test_on_list_tasks_error(self) -> None:
self.assertTrue(is_error_response(response))
self.assertEqual(response['error']['message'], 'DB down')

async def test_on_list_tasks_empty(self) -> None:
request_handler = AsyncMock(spec=DefaultRequestHandler)
handler = JSONRPCHandler(self.mock_agent_card, request_handler)

mock_result = ListTasksResponse(page_size=10)
request_handler.on_list_tasks.return_value = mock_result
from a2a.types.a2a_pb2 import ListTasksRequest

request = ListTasksRequest(page_size=10)
call_context = ServerCallContext(state={'foo': 'bar'})

response = await handler.list_tasks(request, call_context)

request_handler.on_list_tasks.assert_awaited_once()
self.assertIsInstance(response, dict)
self.assertTrue(is_success_response(response))
self.assertIn('tasks', response['result'])
self.assertEqual(len(response['result']['tasks']), 0)
self.assertIn('nextPageToken', response['result'])
self.assertEqual(response['result']['nextPageToken'], '')
self.assertIn('pageSize', response['result'])
self.assertEqual(response['result']['pageSize'], 10)
self.assertIn('totalSize', response['result'])
self.assertEqual(response['result']['totalSize'], 0)

async def test_on_cancel_task_success(self) -> None:
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_task_store = AsyncMock(spec=TaskStore)
Expand Down
3 changes: 2 additions & 1 deletion tests/server/tasks/test_database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from a2a.auth.user import User
from a2a.server.context import ServerCallContext
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.errors import InvalidParamsError


class SampleUser(User):
Expand Down Expand Up @@ -380,7 +381,7 @@ async def test_list_tasks_fails(
for task in tasks_to_create:
await db_store_parameterized.save(task)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(InvalidParamsError) as excinfo:
await db_store_parameterized.list(params)

assert expected_error_message in str(excinfo.value)
Expand Down
3 changes: 2 additions & 1 deletion tests/server/tasks/test_inmemory_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from a2a.server.tasks import InMemoryTaskStore
from a2a.types.a2a_pb2 import Task, TaskState, TaskStatus, ListTasksRequest
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
from a2a.utils.errors import InvalidParamsError

from a2a.auth.user import User

Expand Down Expand Up @@ -239,7 +240,7 @@ async def test_list_tasks_fails(
for task in tasks_to_create:
await store.save(task)

with pytest.raises(ValueError) as excinfo:
with pytest.raises(InvalidParamsError) as excinfo:
await store.list(params)

assert expected_error_message in str(excinfo.value)
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
encode_page_token,
new_task,
)
from a2a.utils.errors import InvalidParamsError


class TestTask(unittest.TestCase):
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_decode_page_token_succeeds(self):
assert decode_page_token(self.encoded_page_token) == self.page_token

def test_decode_page_token_fails(self):
with pytest.raises(ValueError) as excinfo:
with pytest.raises(InvalidParamsError) as excinfo:
decode_page_token('invalid')

assert 'Token is not a valid base64-encoded cursor.' in str(
Expand Down
Loading