Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/google/adk/auth/auth_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator


class BaseModelWithConfig(BaseModel):
Expand Down Expand Up @@ -149,6 +150,21 @@ class ServiceAccount(BaseModelWithConfig):
service_account_credential: Optional[ServiceAccountCredential] = None
scopes: List[str]
use_default_credential: Optional[bool] = False
token_kind: Literal["access_token", "id_token"] = "access_token"
audience: Optional[str] = None

@model_validator(mode="before")
@classmethod
def _validate_before(cls, data: Any) -> Any:
if isinstance(data, dict):
token_kind = data.get("token_kind", "access_token")
audience = data.get("audience")
if token_kind == "id_token" and not audience:
raise ValueError(
"service_account.audience is required when"
" service_account.token_kind='id_token'"
)
return data


class AuthCredentialTypes(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import google.auth
from google.auth.transport.requests import Request
from google.oauth2 import id_token as google_id_token
from google.oauth2 import service_account
import google.oauth2.credentials

Expand Down Expand Up @@ -73,27 +74,50 @@ def exchange_credential(
)

try:
if auth_credential.service_account.use_default_credential:
credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
quota_project_id = (
getattr(credentials, "quota_project_id", None) or project_id
)
config = auth_credential.service_account
token_kind = getattr(config, "token_kind", "access_token")
request = Request()

quota_project_id = None
token = None

if token_kind == "id_token":
audience = getattr(config, "audience", None)
if config.use_default_credential:
token = google_id_token.fetch_id_token(request, audience)
else:
id_creds = (
service_account.IDTokenCredentials.from_service_account_info(
config.service_account_credential.model_dump(),
target_audience=audience,
)
)
id_creds.refresh(request)
token = id_creds.token
else:
config = auth_credential.service_account
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(), scopes=config.scopes
)
quota_project_id = None

credentials.refresh(Request())
if auth_credential.service_account.use_default_credential:
credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
quota_project_id = (
getattr(credentials, "quota_project_id", None) or project_id
)
else:
config = auth_credential.service_account
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(),
scopes=config.scopes,
)
quota_project_id = None

credentials.refresh(Request())
token = credentials.token

updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credentials.token),
credentials=HttpCredentials(token=token),
additional_headers={
"x-goog-user-project": quota_project_id,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import google.auth
import google.oauth2.id_token
import pytest


Expand Down Expand Up @@ -218,3 +219,123 @@ def test_exchange_credential_exchange_failure(
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account token" in str(exc_info.value)
mock_from_service_account_info.assert_called_once()


def test_exchange_credential_use_default_credential_id_token_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange using ADC with an ID token (OIDC) for a target audience."""
mock_google_auth_default = MagicMock()
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)

mock_fetch_id_token = MagicMock(return_value="mock_id_token")
monkeypatch.setattr(
google.oauth2.id_token,
"fetch_id_token",
mock_fetch_id_token,
)
monkeypatch.setattr(
"google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger.google_id_token.fetch_id_token",
mock_fetch_id_token,
)

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=[
"https://www.googleapis.com/auth/cloud-platform"
], # unused in id_token mode, but required by model today
token_kind="id_token",
audience="https://my-service-abc.a.run.app",
),
)

result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)

assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_id_token"
assert not result.http.additional_headers

mock_fetch_id_token.assert_called_once()
# Can we test this?
# mock_fetch_id_token.assert_called_once_with(ANY_REQUEST_OBJECT, "https://my-service-abc.a.run.app")
mock_google_auth_default.assert_not_called()


def test_exchange_credential_service_account_id_token_success(
service_account_exchanger, auth_scheme, monkeypatch
):
"""Test successful exchange using SA JSON key with an ID token (OIDC) for a target audience."""
mock_id_creds = MagicMock()
mock_id_creds.token = "mock_id_token"
mock_id_creds.refresh = MagicMock()

mock_from_info = MagicMock(return_value=mock_id_creds)

# Patch IDTokenCredentials factory (NOT Credentials.from_service_account_info)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.IDTokenCredentials."
"from_service_account_info"
)
monkeypatch.setattr(target_path, mock_from_info)

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=[
"https://www.googleapis.com/auth/cloud-platform"
], # unused in id_token mode but required today
token_kind="id_token",
audience="https://my-service-abc.a.run.app",
),
)

result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)

assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_id_token"
assert not result.http.additional_headers

# Verify we used the IDTokenCredentials path with the correct target_audience
mock_from_info.assert_called_once()
_, kwargs = mock_from_info.call_args
assert kwargs["target_audience"] == "https://my-service-abc.a.run.app"

mock_id_creds.refresh.assert_called_once()


def test_service_account_id_token_requires_audience():
"""ServiceAccount validation: id_token requires audience."""
with pytest.raises(ValueError) as exc_info:
ServiceAccount(
use_default_credential=True,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
token_kind="id_token",
audience=None,
)
assert "audience" in str(exc_info.value)
Loading