Skip to content

Commit 65fb051

Browse files
committed
cloud-shell-auth
1 parent 177695b commit 65fb051

6 files changed

Lines changed: 144 additions & 111 deletions

File tree

src/azure-cli-core/azure/cli/core/_profile.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -284,17 +284,16 @@ def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=N
284284

285285
def login_in_cloud_shell(self):
286286
import jwt
287-
from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper
287+
from .auth.msal_credentials import CloudShellCredential
288288

289-
msi_creds = MSIAuthenticationWrapper(resource=self.cli_ctx.cloud.endpoints.active_directory_resource_id)
290-
token_entry = msi_creds.token
291-
token = token_entry['access_token']
292-
logger.info('MSI: token was retrieved. Now trying to initialize local accounts...')
289+
cred = CloudShellCredential()
290+
token = cred.get_token(*self._arm_scope).token
291+
logger.info('Cloud Shell token was retrieved. Now trying to initialize local accounts...')
293292
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
294293
tenant = decode['tid']
295294

296295
subscription_finder = SubscriptionFinder(self.cli_ctx)
297-
subscriptions = subscription_finder.find_using_specific_tenant(tenant, msi_creds)
296+
subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred)
298297
if not subscriptions:
299298
raise CLIError('No subscriptions were found in the cloud shell')
300299
user = decode.get('unique_name', 'N/A')
@@ -351,11 +350,19 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N
351350

352351
managed_identity_type, managed_identity_id = Profile._try_parse_msi_account_name(account)
353352

354-
# Cloud Shell is just a system assignment managed identity
355353
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
356-
managed_identity_type = MsiAccountTypes.system_assigned
354+
# Cloud Shell
355+
from .auth.msal_credentials import CloudShellCredential
356+
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
357+
cs_cred = CloudShellCredential()
358+
# The cloud shell credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
359+
cred = CredentialAdaptor(cs_cred, resource=resource)
357360

358-
if managed_identity_type is None:
361+
elif managed_identity_type:
362+
# managed identity
363+
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource)
364+
365+
else:
359366
# user and service principal
360367
external_tenants = []
361368
if aux_tenants:
@@ -375,9 +382,7 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N
375382
cred = CredentialAdaptor(credential,
376383
auxiliary_credentials=external_credentials,
377384
resource=resource)
378-
else:
379-
# managed identity
380-
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource)
385+
381386
return (cred,
382387
str(account[_SUBSCRIPTION_ID]),
383388
str(account[_TENANT_ID]))
@@ -397,27 +402,27 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
397402

398403
account = self.get_subscription(subscription)
399404

400-
identity_type, identity_id = Profile._try_parse_msi_account_name(account)
401-
if identity_type:
405+
managed_identity_type, managed_identity_id = Profile._try_parse_msi_account_name(account)
406+
407+
if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
408+
# Cloud Shell
409+
if tenant:
410+
raise CLIError("Tenant shouldn't be specified for Cloud Shell account")
411+
from .auth.msal_credentials import CloudShellCredential
412+
cred = CloudShellCredential()
413+
414+
elif managed_identity_type:
402415
# managed identity
403416
if tenant:
404417
raise CLIError("Tenant shouldn't be specified for managed identity account")
405418
from .auth.util import scopes_to_resource
406-
msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id,
407-
scopes_to_resource(scopes))
408-
sdk_token = msi_creds.get_token(*scopes)
409-
elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
410-
# Cloud Shell, which is just a system-assigned managed identity.
411-
if tenant:
412-
raise CLIError("Tenant shouldn't be specified for Cloud Shell account")
413-
from .auth.util import scopes_to_resource
414-
msi_creds = MsiAccountTypes.msi_auth_factory(MsiAccountTypes.system_assigned, identity_id,
415-
scopes_to_resource(scopes))
416-
sdk_token = msi_creds.get_token(*scopes)
419+
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
420+
scopes_to_resource(scopes))
421+
417422
else:
418-
credential = self._create_credential(account, tenant)
419-
sdk_token = credential.get_token(*scopes)
423+
cred = self._create_credential(account, tenant)
420424

425+
sdk_token = cred.get_token(*scopes)
421426
# Convert epoch int 'expires_on' to datetime string 'expiresOn' for backward compatibility
422427
# WARNING: expiresOn is deprecated and will be removed in future release.
423428
import datetime
@@ -429,11 +434,11 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
429434
'expiresOn': expiresOn # datetime string, like "2020-11-12 13:50:47.114324"
430435
}
431436

432-
# (tokenType, accessToken, tokenEntry)
433-
creds = 'Bearer', sdk_token.token, token_entry
437+
# Build a tuple of (token_type, token, token_entry)
438+
token_tuple = 'Bearer', sdk_token.token, token_entry
434439

435-
# (cred, subscription, tenant)
436-
return (creds,
440+
# Return a tuple of (token_tuple, subscription, tenant)
441+
return (token_tuple,
437442
None if tenant else str(account[_SUBSCRIPTION_ID]),
438443
str(tenant if tenant else account[_TENANT_ID]))
439444

src/azure-cli-core/azure/cli/core/auth/adal_authentication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
2424
# Use MSAL to get VM SSH certificate
2525
import msal
2626
from .util import check_result, build_sdk_access_token
27-
from .identity import AZURE_CLI_CLIENT_ID
27+
from .constants import AZURE_CLI_CLIENT_ID
2828
app = msal.PublicClientApplication(
2929
AZURE_CLI_CLIENT_ID, # Use a real client_id, so that cache would work
3030
# TODO: This PoC does not currently maintain a token cache;
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'

src/azure-cli-core/azure/cli/core/auth/identity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
from knack.util import CLIError
1414
from msal import PublicClientApplication, ConfidentialClientApplication
1515

16+
from .constants import AZURE_CLI_CLIENT_ID
1617
from .msal_credentials import UserCredential, ServicePrincipalCredential
1718
from .persistence import load_persisted_token_cache, file_extensions, load_secret_store
1819
from .util import check_result
1920

20-
AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
21-
2221
# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters:
2322
# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow
2423
_TENANT = 'tenant'

src/azure-cli-core/azure/cli/core/auth/msal_credentials.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from knack.util import CLIError
2121
from msal import PublicClientApplication, ConfidentialClientApplication
2222

23+
from .constants import AZURE_CLI_CLIENT_ID
2324
from .util import check_result, build_sdk_access_token
2425

2526
logger = get_logger(__name__)
@@ -108,3 +109,25 @@ def get_token(self, *scopes, **kwargs):
108109
result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs)
109110
check_result(result)
110111
return build_sdk_access_token(result)
112+
113+
114+
class CloudShellCredential: # pylint: disable=too-few-public-methods
115+
# Cloud Shell acts as a "broker" to obtain access token for the user account, so even though it uses
116+
# managed identity protocol, it returns a user token.
117+
# That's why MSAL uses acquire_token_interactive to retrieve an access token in Cloud Shell.
118+
# See https://github.com/Azure/azure-cli/pull/29637
119+
120+
def __init__(self):
121+
self._msal_app = PublicClientApplication(
122+
AZURE_CLI_CLIENT_ID, # Use a real client_id, so that cache would work
123+
# TODO: We currently don't maintain an MSAL token cache as Cloud Shell already has its own token cache.
124+
# Ideally we should also use an MSAL token cache.
125+
# token_cache=...
126+
)
127+
128+
def get_token(self, *scopes, **kwargs):
129+
logger.debug("CloudShellCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
130+
# kwargs is already sanitized by CredentialAdaptor, so it can be safely passed to MSAL
131+
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
132+
check_result(result, scopes=scopes)
133+
return build_sdk_access_token(result)

0 commit comments

Comments
 (0)