Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/a2a/server/apps/jsonrpc/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging

from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any


if TYPE_CHECKING:
from fastapi import FastAPI
from fastapi.params import Depends

_package_fastapi_installed = True
else:
try:
from fastapi import FastAPI
from fastapi.params import Depends

_package_fastapi_installed = True
except ImportError:
Expand Down Expand Up @@ -121,6 +123,7 @@ def add_routes_to_app(
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
rpc_url: str = DEFAULT_RPC_URL,
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
dependencies: Sequence[Depends] | None = None,
) -> None:
"""Adds the routes to the FastAPI application.

Expand All @@ -129,7 +132,16 @@ def add_routes_to_app(
agent_card_url: The URL for the agent card endpoint.
rpc_url: The URL for the A2A JSON-RPC endpoint.
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
dependencies: Optional sequence of FastAPI dependencies (e.g.
`[Security(get_current_active_user, scopes=["a2a"])]`)
applied to the RPC endpoint and the authenticated extended
agent card endpoint. The public agent card endpoint is left
unprotected.
"""
route_deps: dict[str, Any] = {}
if dependencies:
route_deps['dependencies'] = list(dependencies)

app.post(
rpc_url,
openapi_extra={
Expand All @@ -145,6 +157,7 @@ def add_routes_to_app(
'description': 'A2ARequest',
}
},
**route_deps,
)(self._handle_requests)
app.get(agent_card_url)(self._handle_get_agent_card)

Expand All @@ -156,7 +169,7 @@ def add_routes_to_app(
)

if self.agent_card.supports_authenticated_extended_card:
app.get(extended_agent_card_url)(
app.get(extended_agent_card_url, **route_deps)(
self._handle_get_authenticated_extended_agent_card
)

Expand All @@ -165,6 +178,7 @@ def build(
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
rpc_url: str = DEFAULT_RPC_URL,
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
dependencies: Sequence[Depends] | None = None,
**kwargs: Any,
) -> FastAPI:
"""Builds and returns the FastAPI application instance.
Expand All @@ -173,6 +187,10 @@ def build(
agent_card_url: The URL for the agent card endpoint.
rpc_url: The URL for the A2A JSON-RPC endpoint.
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
dependencies: Optional sequence of FastAPI dependencies (e.g.
`[Security(get_current_active_user, scopes=["a2a"])]`)
applied to authenticated routes. See
:meth:`add_routes_to_app`.
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.

Returns:
Expand All @@ -181,7 +199,7 @@ def build(
app = A2AFastAPI(**kwargs)

self.add_routes_to_app(
app, agent_card_url, rpc_url, extended_agent_card_url
app, agent_card_url, rpc_url, extended_agent_card_url, dependencies
)

return app
20 changes: 18 additions & 2 deletions src/a2a/server/apps/jsonrpc/starlette_app.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import logging

from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any


if TYPE_CHECKING:
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Route

_package_starlette_installed = True

else:
try:
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Route

_package_starlette_installed = True
except ImportError:
Starlette = Any
Middleware = Any
Route = Any

_package_starlette_installed = False
Expand Down Expand Up @@ -102,23 +105,30 @@ def routes(
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
rpc_url: str = DEFAULT_RPC_URL,
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
middleware: Sequence[Middleware] | None = None,
) -> list[Route]:
"""Returns the Starlette Routes for handling A2A requests.

Args:
agent_card_url: The URL path for the agent card endpoint.
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
middleware: Optional sequence of Starlette Middleware (e.g.
`[Middleware(AuthenticationMiddleware)]`) applied to the RPC
endpoint and the authenticated extended agent card endpoint.

Returns:
A list of Starlette Route objects.
"""
route_mw = list(middleware) if middleware else None

app_routes = [
Route(
rpc_url,
self._handle_requests,
methods=['POST'],
name='a2a_handler',
middleware=route_mw,
),
Route(
agent_card_url,
Expand Down Expand Up @@ -148,6 +158,7 @@ def routes(
self._handle_get_authenticated_extended_agent_card,
methods=['GET'],
name='authenticated_extended_agent_card',
middleware=route_mw,
)
)
return app_routes
Expand All @@ -158,6 +169,7 @@ def add_routes_to_app(
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
rpc_url: str = DEFAULT_RPC_URL,
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
middleware: Sequence[Middleware] | None = None,
) -> None:
"""Adds the routes to the Starlette application.

Expand All @@ -166,11 +178,13 @@ def add_routes_to_app(
agent_card_url: The URL path for the agent card endpoint.
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
middleware: Optional sequence of Starlette Middleware.
"""
routes = self.routes(
agent_card_url=agent_card_url,
rpc_url=rpc_url,
extended_agent_card_url=extended_agent_card_url,
middleware=middleware,
)
app.routes.extend(routes)

Expand All @@ -179,6 +193,7 @@ def build(
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
rpc_url: str = DEFAULT_RPC_URL,
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
middleware: Sequence[Middleware] | None = None,
**kwargs: Any,
) -> Starlette:
"""Builds and returns the Starlette application instance.
Expand All @@ -187,6 +202,7 @@ def build(
agent_card_url: The URL path for the agent card endpoint.
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
middleware: Optional sequence of Starlette Middleware applied to authenticated routes.
**kwargs: Additional keyword arguments to pass to the Starlette constructor.

Returns:
Expand All @@ -195,7 +211,7 @@ def build(
app = Starlette(**kwargs)

self.add_routes_to_app(
app, agent_card_url, rpc_url, extended_agent_card_url
app, agent_card_url, rpc_url, extended_agent_card_url, middleware
)

return app
21 changes: 19 additions & 2 deletions src/a2a/server/apps/rest/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import logging

from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any


if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.params import Depends
from fastapi.responses import JSONResponse

_package_fastapi_installed = True
else:
try:
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.params import Depends
from fastapi.responses import JSONResponse

_package_fastapi_installed = True
Expand Down Expand Up @@ -92,6 +94,8 @@ def build(
self,
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
rpc_url: str = '',
extended_agent_card_url: str = '',
dependencies: Sequence[Depends] | None = None,
**kwargs: Any,
) -> FastAPI:
"""Builds and returns the FastAPI application instance.
Expand All @@ -100,16 +104,29 @@ def build(
agent_card_url: The URL for the agent card endpoint.
rpc_url: The URL for the A2A JSON-RPC endpoint.
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
dependencies: Optional sequence of FastAPI dependencies (e.g.
`[Security(get_current_active_user, scopes=["a2a"])]`)
applied to the RPC endpoint and the authenticated extended
agent card endpoint. The public agent card endpoint is left
unprotected.
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.

Returns:
A configured FastAPI application instance.
"""
app = FastAPI(**kwargs)

route_deps: dict[str, Any] = {}
if dependencies:
route_deps['dependencies'] = list(dependencies)

router = APIRouter()
for route, callback in self._adapter.routes().items():
router.add_api_route(
f'{rpc_url}{route[0]}', callback, methods=[route[1]]
f'{rpc_url}{route[0]}',
callback,
methods=[route[1]],
**route_deps,
)

@router.get(f'{rpc_url}{agent_card_url}')
Expand Down
20 changes: 20 additions & 0 deletions tests/server/apps/jsonrpc/test_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ def test_create_a2a_fastapi_app_with_present_deps_succeeds(
' A2AFastAPIApplication instance should not raise ImportError'
)

def test_build_a2a_fastapi_app_with_dependencies_succeeds(
self, mock_app_params: dict
):
from fastapi import Depends

def mock_dependency():
return 'mock'

app = A2AFastAPIApplication(**mock_app_params)
fastapi_app = app.build(dependencies=[Depends(mock_dependency)])

from fastapi.routing import APIRoute

# Check that routes have the dependency
for route in fastapi_app.routes:
if getattr(route, 'path', '') in ['/v1/message:send', '/v1/card']:
assert isinstance(route, APIRoute)
assert len(route.dependencies) == 1
assert route.dependencies[0].dependency == mock_dependency

def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror(
self,
mock_app_params: dict,
Expand Down
1 change: 1 addition & 0 deletions tests/server/apps/jsonrpc/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_starlette_agent_card_with_api_key_scheme_alias(

try:
parsed_card = AgentCard.model_validate(response_data)
assert parsed_card.security_schemes is not None
parsed_scheme_wrapper = parsed_card.security_schemes['api_key_auth']
assert isinstance(parsed_scheme_wrapper.root, APIKeySecurityScheme)
assert parsed_scheme_wrapper.root.in_ == In.header
Expand Down
24 changes: 24 additions & 0 deletions tests/server/apps/jsonrpc/test_starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@ def test_create_a2a_starlette_app_with_present_deps_succeeds(
' A2AStarletteApplication instance should not raise ImportError'
)

def test_build_a2a_starlette_app_with_middleware_succeeds(
self, mock_app_params: dict
):
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware

class MockMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
return await call_next(request)

app = A2AStarletteApplication(**mock_app_params)
starlette_app = app.build(middleware=[Middleware(MockMiddleware)])

from starlette.routing import Route

# Check that routes have the middleware
for route in starlette_app.routes:
if getattr(route, 'path', '') in [
'/',
'/agent/authenticatedExtendedCard',
]:
assert isinstance(route, Route)
assert isinstance(route.app, MockMiddleware)

def test_create_a2a_starlette_app_with_missing_deps_raises_importerror(
self,
mock_app_params: dict,
Expand Down
26 changes: 26 additions & 0 deletions tests/server/apps/rest/test_rest_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,32 @@ async def test_create_a2a_rest_fastapi_app_with_present_deps_succeeds(
)


@pytest.mark.anyio
async def test_build_a2a_rest_fastapi_app_with_dependencies_succeeds(
agent_card: AgentCard, request_handler: RequestHandler
):
from fastapi import Depends

def mock_dependency():
return 'mock'

app = A2ARESTFastAPIApplication(agent_card, request_handler)
fastapi_app = app.build(
agent_card_url='/well-known/agent.json',
rpc_url='',
dependencies=[Depends(mock_dependency)],
)

from fastapi.routing import APIRoute

# Check that routes have the dependency
for route in fastapi_app.routes:
if getattr(route, 'path', '') in ['/v1/message:send']:
assert isinstance(route, APIRoute)
assert len(route.dependencies) == 1
assert route.dependencies[0].dependency == mock_dependency


@pytest.mark.anyio
async def test_create_a2a_rest_fastapi_app_with_missing_deps_raises_importerror(
agent_card: AgentCard,
Expand Down
Loading