diff --git a/agentex/src/api/authentication_cache.py b/agentex/src/api/authentication_cache.py index 51413ded..27838018 100644 --- a/agentex/src/api/authentication_cache.py +++ b/agentex/src/api/authentication_cache.py @@ -5,6 +5,7 @@ import json import time from collections import OrderedDict +from collections.abc import Iterable from typing import Any from src.utils.cache_metrics import record_cache_access, record_cache_eviction @@ -85,6 +86,22 @@ async def clear(self) -> None: async with self._lock: self.cache.clear() + async def delete(self, key: str) -> bool: + """Delete a single cache entry by exact key.""" + async with self._lock: + if key not in self.cache: + return False + del self.cache[key] + return True + + async def delete_by_prefix(self, prefix: str) -> int: + """Delete cache entries whose keys start with the given prefix.""" + async with self._lock: + keys = [key for key in self.cache if key.startswith(prefix)] + for key in keys: + del self.cache[key] + return len(keys) + async def remove_expired(self) -> None: """Remove all expired entries from cache.""" async with self._lock: @@ -241,7 +258,7 @@ async def set_auth_gateway_response( ) -> None: """Cache auth gateway response.""" if self._contains_api_key(principal_context): - logger.debug("Skipping auth gateway cache for API key principal") + logger.debug("Skipping auth gateway cache for API-key auth context") return cache_key = self._create_headers_cache_key(headers) await self.auth_gateway_cache.set(f"gateway:{cache_key}", principal_context) @@ -258,17 +275,20 @@ def _contains_api_key(principal_context: Any) -> bool: # Authorization Check Cache Methods (Async) @staticmethod - def _create_authorization_cache_key( + def _create_authorization_resource_cache_prefix( resource_type: str, resource_selector: str, - operation: str, - principal_context: Any, ) -> str: - """ - Create a cache key for authorization checks. + resource_key = AuthenticationCache._hash_dict( + { + "resource_type": resource_type, + "resource_selector": resource_selector, + } + ) + return f"authz:{resource_key}:" - Combines resource info, operation, and principal context into a unique key. - """ + @staticmethod + def _authorization_principal_key_data(principal_context: Any) -> dict[str, Any]: # Extract relevant fields from principal context for cache key principal_key_data = {} if principal_context: @@ -301,15 +321,48 @@ def _create_authorization_cache_key( "context_hash": AuthenticationCache._hash_dict(context_dict) } - # Create the cache key components - cache_data = { - "resource_type": resource_type, - "resource_selector": resource_selector, - "operation": operation, - "principal": principal_key_data, - } + return principal_key_data + + @staticmethod + def _create_authorization_principal_cache_key(principal_context: Any) -> str: + return AuthenticationCache._hash_dict( + AuthenticationCache._authorization_principal_key_data(principal_context) + ) + + @staticmethod + def _create_authorization_resource_principal_cache_prefix( + resource_type: str, + resource_selector: str, + principal_context: Any, + ) -> str: + return ( + AuthenticationCache._create_authorization_resource_cache_prefix( + resource_type, resource_selector + ) + + AuthenticationCache._create_authorization_principal_cache_key( + principal_context + ) + + ":" + ) + + @staticmethod + def _create_authorization_cache_key( + resource_type: str, + resource_selector: str, + operation: str, + principal_context: Any, + ) -> str: + """ + Create a cache key for authorization checks. - return f"authz:{AuthenticationCache._hash_dict(cache_data)}" + Combines resource info, operation, and principal context into a unique key. + """ + return ( + AuthenticationCache._create_authorization_resource_principal_cache_prefix( + resource_type, resource_selector, principal_context + ) + + AuthenticationCache._hash_dict({"operation": operation}) + ) async def get_authorization_check( self, @@ -319,6 +372,12 @@ async def get_authorization_check( principal_context: Any, ) -> bool | None: """Get cached authorization check result.""" + if self._contains_api_key(principal_context): + logger.debug( + "Skipping authorization check cache lookup for API-key auth context" + ) + return None + cache_key = self._create_authorization_cache_key( resource_type, resource_selector, operation, principal_context ) @@ -339,6 +398,12 @@ async def set_authorization_check( allowed: bool, ) -> None: """Cache authorization check result.""" + if self._contains_api_key(principal_context): + logger.debug( + "Skipping authorization check cache write for API-key auth context" + ) + return + cache_key = self._create_authorization_cache_key( resource_type, resource_selector, operation, principal_context ) @@ -358,6 +423,66 @@ async def clear_all(self) -> None: await self.authorization_check_cache.clear() logger.info("All authentication and authorization caches cleared") + async def clear_authorization_checks_for_resource_principal( + self, + resource_type: str, + resource_selector: str, + principal_context: Any, + ) -> None: + """Clear cached authorization check results for one resource/principal.""" + prefix = self._create_authorization_resource_principal_cache_prefix( + resource_type, resource_selector, principal_context + ) + deleted = await self.authorization_check_cache.delete_by_prefix(prefix) + logger.info( + "Authorization check cache cleared for %s:%s matched_entries=%d", + resource_type, + resource_selector, + deleted, + ) + + async def clear_authorization_checks_for_resource_principal_operations( + self, + resource_type: str, + resource_selector: str, + operations: Iterable[str], + principal_context: Any, + ) -> None: + """Clear cached authorization checks for selected operations.""" + operation_list = [str(operation) for operation in operations] + deleted = 0 + for operation in operation_list: + cache_key = self._create_authorization_cache_key( + resource_type, + resource_selector, + operation, + principal_context, + ) + deleted += int(await self.authorization_check_cache.delete(cache_key)) + + logger.info( + "Authorization check cache cleared for %s:%s operations=%s matched_entries=%d", + resource_type, + resource_selector, + ",".join(operation_list), + deleted, + ) + + async def clear_authorization_check_for_resource_principal_operation( + self, + resource_type: str, + resource_selector: str, + operation: str, + principal_context: Any, + ) -> None: + """Clear one cached authorization check for a resource/principal/operation.""" + await self.clear_authorization_checks_for_resource_principal_operations( + resource_type, + resource_selector, + [operation], + principal_context, + ) + async def cleanup_expired(self) -> None: """Remove expired entries from all caches.""" await self.agent_identity_cache.remove_expired() diff --git a/agentex/src/domain/services/authorization_service.py b/agentex/src/domain/services/authorization_service.py index 936492f2..82009d5d 100644 --- a/agentex/src/domain/services/authorization_service.py +++ b/agentex/src/domain/services/authorization_service.py @@ -39,6 +39,28 @@ def _bypass(self) -> bool: def is_enabled(self) -> bool: return self.enabled + async def _clear_authorization_cache( + self, + resource: AgentexResource, + principal_context, + operation: AuthorizedOperationType | None = None, + ) -> None: + auth_cache = await get_auth_cache() + if operation is None: + await auth_cache.clear_authorization_checks_for_resource_principal( + resource_type=str(resource.type), + resource_selector=resource.selector, + principal_context=principal_context, + ) + return + + await auth_cache.clear_authorization_check_for_resource_principal_operation( + resource_type=str(resource.type), + resource_selector=resource.selector, + operation=str(operation), + principal_context=principal_context, + ) + async def grant( self, resource: AgentexResource, *, commit: bool = True, principal_context=... ) -> None: @@ -54,11 +76,19 @@ async def grant( resource.type, resource.selector, ) - result = await self.gateway.grant( + effective_principal = ( principal_context if principal_context is not ... - else self.principal_context, + else self.principal_context + ) + result = await self.gateway.grant( + effective_principal, + resource, + AuthorizedOperationType.create, + ) + await self._clear_authorization_cache( resource, + effective_principal, AuthorizedOperationType.create, ) return result @@ -77,16 +107,24 @@ async def revoke( resource.selector, ) - result = await self.gateway.revoke( + effective_principal = ( principal_context if principal_context is not ... - else self.principal_context, + else self.principal_context + ) + result = await self.gateway.revoke( + effective_principal, resource, AuthorizedOperationType.delete, ) logger.info( f"Revoked {AuthorizedOperationType.delete} permission on {resource.type}:{resource.selector}" ) + await self._clear_authorization_cache( + resource, + effective_principal, + AuthorizedOperationType.delete, + ) return result async def check( @@ -214,6 +252,7 @@ async def register_resource( f"{parent.type}:{parent.selector}" if parent is not None else None, ) await self.gateway.register_resource(effective_principal, resource, parent) + await self._clear_authorization_cache(resource, effective_principal) async def deregister_resource( self, @@ -237,6 +276,7 @@ async def deregister_resource( resource.selector, ) await self.gateway.deregister_resource(effective_principal, resource) + await self._clear_authorization_cache(resource, effective_principal) DAuthorizationService = Annotated[AuthorizationService, Depends(AuthorizationService)] diff --git a/agentex/tests/unit/api/test_authentication_cache_metrics.py b/agentex/tests/unit/api/test_authentication_cache_metrics.py index 9397fe59..11fd7d93 100644 --- a/agentex/tests/unit/api/test_authentication_cache_metrics.py +++ b/agentex/tests/unit/api/test_authentication_cache_metrics.py @@ -102,3 +102,53 @@ async def test_auth_gateway_response_with_api_key_principal_is_not_cached(): await cache.set_auth_gateway_response(headers, principal) assert await cache.get_auth_gateway_response(headers) is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_authorization_check_with_api_key_principal_is_not_cached(): + cache = AuthenticationCache() + principal = {"user_id": "user-1", "account_id": "acct-1", "api_key": "secret-key"} + + await cache.set_authorization_check( + resource_type="agent", + resource_selector="agent-1", + operation="execute", + principal_context=principal, + allowed=True, + ) + + assert ( + await cache.get_authorization_check( + resource_type="agent", + resource_selector="agent-1", + operation="execute", + principal_context=principal, + ) + is None + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_authorization_check_without_api_key_principal_is_cached(): + cache = AuthenticationCache() + principal = {"user_id": "user-1", "account_id": "acct-1"} + + await cache.set_authorization_check( + resource_type="agent", + resource_selector="agent-1", + operation="read", + principal_context=principal, + allowed=True, + ) + + assert ( + await cache.get_authorization_check( + resource_type="agent", + resource_selector="agent-1", + operation="read", + principal_context=principal, + ) + is True + ) diff --git a/agentex/tests/unit/services/test_authorization_service_cache.py b/agentex/tests/unit/services/test_authorization_service_cache.py new file mode 100644 index 00000000..aa79d877 --- /dev/null +++ b/agentex/tests/unit/services/test_authorization_service_cache.py @@ -0,0 +1,180 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from src.api.authentication_cache import reset_auth_cache +from src.api.schemas.authorization_types import AgentexResource, AuthorizedOperationType +from src.domain.services.authorization_service import AuthorizationService + + +def _request_with_principal(principal_context): + return SimpleNamespace( + state=SimpleNamespace( + principal_context=principal_context, + agent_identity=None, + ) + ) + + +def _service(principal_context, gateway): + return AuthorizationService( + enabled=True, + gateway=gateway, + request=_request_with_principal(principal_context), + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_api_key_principal_authorization_check_calls_gateway_each_time(): + await reset_auth_cache() + try: + gateway = AsyncMock() + gateway.check.return_value = True + service = _service( + {"user_id": "user-1", "account_id": "acct-1", "api_key": "secret-key"}, + gateway, + ) + resource = AgentexResource.agent("agent-1") + + assert await service.check(resource, AuthorizedOperationType.execute) is True + assert await service.check(resource, AuthorizedOperationType.execute) is True + + assert gateway.check.await_count == 2 + finally: + await reset_auth_cache() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_non_api_key_principal_authorization_check_uses_cache(): + await reset_auth_cache() + try: + gateway = AsyncMock() + gateway.check.return_value = True + service = _service({"user_id": "user-1", "account_id": "acct-1"}, gateway) + resource = AgentexResource.agent("agent-1") + + assert await service.check(resource, AuthorizedOperationType.read) is True + assert await service.check(resource, AuthorizedOperationType.read) is True + + assert gateway.check.await_count == 1 + finally: + await reset_auth_cache() + + +@pytest.mark.unit +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("mutation", "operation"), + [ + ("grant", AuthorizedOperationType.create), + ("revoke", AuthorizedOperationType.delete), + ("register_resource", AuthorizedOperationType.read), + ("deregister_resource", AuthorizedOperationType.read), + ], +) +async def test_authorization_mutations_clear_cached_authorization_checks( + mutation, + operation, +): + await reset_auth_cache() + try: + gateway = AsyncMock() + gateway.check.return_value = True + service = _service({"user_id": "user-1", "account_id": "acct-1"}, gateway) + resource = AgentexResource.agent("agent-1") + + assert await service.check(resource, operation) is True + assert await service.check(resource, operation) is True + assert gateway.check.await_count == 1 + + await getattr(service, mutation)(resource) + + assert await service.check(resource, operation) is True + assert gateway.check.await_count == 2 + finally: + await reset_auth_cache() + + +@pytest.mark.unit +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("mutation", "operation"), + [ + ("grant", AuthorizedOperationType.create), + ("revoke", AuthorizedOperationType.delete), + ("register_resource", AuthorizedOperationType.read), + ("deregister_resource", AuthorizedOperationType.read), + ], +) +async def test_authorization_mutations_only_clear_checks_for_mutated_principal( + mutation, + operation, +): + await reset_auth_cache() + try: + gateway = AsyncMock() + gateway.check.return_value = True + service = _service({"user_id": "user-1", "account_id": "acct-1"}, gateway) + unchanged_principal_service = _service( + {"user_id": "user-2", "account_id": "acct-1"}, + gateway, + ) + resource = AgentexResource.agent("agent-1") + + assert await service.check(resource, operation) is True + assert ( + await unchanged_principal_service.check( + resource, operation + ) + is True + ) + assert gateway.check.await_count == 2 + + await getattr(service, mutation)(resource) + + assert await service.check(resource, operation) is True + assert ( + await unchanged_principal_service.check( + resource, operation + ) + is True + ) + assert gateway.check.await_count == 3 + finally: + await reset_auth_cache() + + +@pytest.mark.unit +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("mutation", "changed_operation", "unchanged_operation"), + [ + ("grant", AuthorizedOperationType.create, AuthorizedOperationType.read), + ("revoke", AuthorizedOperationType.delete, AuthorizedOperationType.read), + ], +) +async def test_grant_revoke_only_clear_changed_operation( + mutation, + changed_operation, + unchanged_operation, +): + await reset_auth_cache() + try: + gateway = AsyncMock() + gateway.check.return_value = True + service = _service({"user_id": "user-1", "account_id": "acct-1"}, gateway) + resource = AgentexResource.agent("agent-1") + + assert await service.check(resource, changed_operation) is True + assert await service.check(resource, unchanged_operation) is True + assert gateway.check.await_count == 2 + + await getattr(service, mutation)(resource) + + assert await service.check(resource, changed_operation) is True + assert await service.check(resource, unchanged_operation) is True + assert gateway.check.await_count == 3 + finally: + await reset_auth_cache()