diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 6e4f73351f..65592239b6 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -25,6 +25,7 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator class BaseModelWithConfig(BaseModel): @@ -149,6 +150,21 @@ class ServiceAccount(BaseModelWithConfig): service_account_credential: Optional[ServiceAccountCredential] = None scopes: List[str] use_default_credential: Optional[bool] = False + token_kind: Literal["access_token", "id_token"] = "access_token" + audience: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def _validate_before(cls, data: Any) -> Any: + if isinstance(data, dict): + token_kind = data.get("token_kind", "access_token") + audience = data.get("audience") + if token_kind == "id_token" and not audience: + raise ValueError( + "service_account.audience is required when" + " service_account.token_kind='id_token'" + ) + return data class AuthCredentialTypes(str, Enum): diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 1dbe0fe46a..0ef9b2aeb9 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -20,6 +20,7 @@ import google.auth from google.auth.transport.requests import Request +from google.oauth2 import id_token as google_id_token from google.oauth2 import service_account import google.oauth2.credentials @@ -73,27 +74,50 @@ def exchange_credential( ) try: - if auth_credential.service_account.use_default_credential: - credentials, project_id = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - quota_project_id = ( - getattr(credentials, "quota_project_id", None) or project_id - ) + config = auth_credential.service_account + token_kind = getattr(config, "token_kind", "access_token") + request = Request() + + quota_project_id = None + token = None + + if token_kind == "id_token": + audience = getattr(config, "audience", None) + if config.use_default_credential: + token = google_id_token.fetch_id_token(request, audience) + else: + id_creds = ( + service_account.IDTokenCredentials.from_service_account_info( + config.service_account_credential.model_dump(), + target_audience=audience, + ) + ) + id_creds.refresh(request) + token = id_creds.token else: - config = auth_credential.service_account - credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes - ) - quota_project_id = None - - credentials.refresh(Request()) + if auth_credential.service_account.use_default_credential: + credentials, project_id = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + quota_project_id = ( + getattr(credentials, "quota_project_id", None) or project_id + ) + else: + config = auth_credential.service_account + credentials = service_account.Credentials.from_service_account_info( + config.service_account_credential.model_dump(), + scopes=config.scopes, + ) + quota_project_id = None + + credentials.refresh(Request()) + token = credentials.token updated_credential = AuthCredential( auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token http=HttpAuth( scheme="bearer", - credentials=HttpCredentials(token=credentials.token), + credentials=HttpCredentials(token=token), additional_headers={ "x-goog-user-project": quota_project_id, } diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index 0ca9944423..5e606e2183 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -25,6 +25,7 @@ from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger import google.auth +import google.oauth2.id_token import pytest @@ -218,3 +219,123 @@ def test_exchange_credential_exchange_failure( service_account_exchanger.exchange_credential(auth_scheme, auth_credential) assert "Failed to exchange service account token" in str(exc_info.value) mock_from_service_account_info.assert_called_once() + + +def test_exchange_credential_use_default_credential_id_token_success( + service_account_exchanger, auth_scheme, monkeypatch +): + """Test successful exchange using ADC with an ID token (OIDC) for a target audience.""" + mock_google_auth_default = MagicMock() + monkeypatch.setattr(google.auth, "default", mock_google_auth_default) + + mock_fetch_id_token = MagicMock(return_value="mock_id_token") + monkeypatch.setattr( + google.oauth2.id_token, + "fetch_id_token", + mock_fetch_id_token, + ) + monkeypatch.setattr( + "google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger.google_id_token.fetch_id_token", + mock_fetch_id_token, + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=[ + "https://www.googleapis.com/auth/cloud-platform" + ], # unused in id_token mode, but required by model today + token_kind="id_token", + audience="https://my-service-abc.a.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_id_token" + assert not result.http.additional_headers + + mock_fetch_id_token.assert_called_once() + # Can we test this? + # mock_fetch_id_token.assert_called_once_with(ANY_REQUEST_OBJECT, "https://my-service-abc.a.run.app") + mock_google_auth_default.assert_not_called() + + +def test_exchange_credential_service_account_id_token_success( + service_account_exchanger, auth_scheme, monkeypatch +): + """Test successful exchange using SA JSON key with an ID token (OIDC) for a target audience.""" + mock_id_creds = MagicMock() + mock_id_creds.token = "mock_id_token" + mock_id_creds.refresh = MagicMock() + + mock_from_info = MagicMock(return_value=mock_id_creds) + + # Patch IDTokenCredentials factory (NOT Credentials.from_service_account_info) + target_path = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.IDTokenCredentials." + "from_service_account_info" + ) + monkeypatch.setattr(target_path, mock_from_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=ServiceAccountCredential( + type_="service_account", + project_id="your_project_id", + private_key_id="your_private_key_id", + private_key="-----BEGIN PRIVATE KEY-----...", + client_email="...@....iam.gserviceaccount.com", + client_id="your_client_id", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url=( + "https://www.googleapis.com/oauth2/v1/certs" + ), + client_x509_cert_url=( + "https://www.googleapis.com/robot/v1/metadata/x509/..." + ), + universe_domain="googleapis.com", + ), + scopes=[ + "https://www.googleapis.com/auth/cloud-platform" + ], # unused in id_token mode but required today + token_kind="id_token", + audience="https://my-service-abc.a.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_id_token" + assert not result.http.additional_headers + + # Verify we used the IDTokenCredentials path with the correct target_audience + mock_from_info.assert_called_once() + _, kwargs = mock_from_info.call_args + assert kwargs["target_audience"] == "https://my-service-abc.a.run.app" + + mock_id_creds.refresh.assert_called_once() + + +def test_service_account_id_token_requires_audience(): + """ServiceAccount validation: id_token requires audience.""" + with pytest.raises(ValueError) as exc_info: + ServiceAccount( + use_default_credential=True, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + token_kind="id_token", + audience=None, + ) + assert "audience" in str(exc_info.value)