From 1c6f136dcba8f14cbeba8478c416cffb1ae21138 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 12:43:08 +0000 Subject: [PATCH 1/2] Initial plan From 178d7b7c5abf1b140f7ff5b0e1a0bad89f10fb7c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:25:12 +0000 Subject: [PATCH 2/2] Implement MSI v2 mTLS PoP with KeyGuard attestation for Java Managed Identity Co-authored-by: gladjohn <90415114+gladjohn@users.noreply.github.com> --- msal4j-sdk/src/main/cpp/MsalJNIBridge.cpp | 336 ++++++++++++++++ msal4j-sdk/src/main/cpp/MsalJNIBridge.h | 90 +++++ ...AcquireTokenByManagedIdentitySupplier.java | 38 ++ .../microsoft/aad/msal4j/CsrGenerator.java | 296 ++++++++++++++ .../com/microsoft/aad/msal4j/CsrMetadata.java | 63 +++ .../aad/msal4j/IssueCertificateRequest.java | 36 ++ .../aad/msal4j/IssueCertificateResponse.java | 70 ++++ .../aad/msal4j/ManagedIdentityParameters.java | 59 ++- .../com/microsoft/aad/msal4j/MsalError.java | 15 + .../aad/msal4j/MsalErrorMessage.java | 17 + .../java/com/microsoft/aad/msal4j/MsiV2.java | 369 +++++++++++++++++ .../microsoft/aad/msal4j/MsiV2Exception.java | 37 ++ .../aad/msal4j/WindowsKeyGuardJNI.java | 129 ++++++ .../com/microsoft/aad/msal4j/MsiV2Tests.java | 373 ++++++++++++++++++ 14 files changed, 1924 insertions(+), 4 deletions(-) create mode 100644 msal4j-sdk/src/main/cpp/MsalJNIBridge.cpp create mode 100644 msal4j-sdk/src/main/cpp/MsalJNIBridge.h create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrGenerator.java create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrMetadata.java create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateRequest.java create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateResponse.java create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2.java create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2Exception.java create mode 100644 msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/WindowsKeyGuardJNI.java create mode 100644 msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MsiV2Tests.java diff --git a/msal4j-sdk/src/main/cpp/MsalJNIBridge.cpp b/msal4j-sdk/src/main/cpp/MsalJNIBridge.cpp new file mode 100644 index 00000000..00e60772 --- /dev/null +++ b/msal4j-sdk/src/main/cpp/MsalJNIBridge.cpp @@ -0,0 +1,336 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/** + * MsalJNIBridge.cpp - Native implementation for MSAL Java MSI v2 mTLS PoP support. + * + * This file implements the JNI bridge between MSAL Java and the Windows platform + * for MSI v2 (mTLS Proof-of-Possession with KeyGuard attestation). + * + * Implementation overview: + * - createKeyGuardRsaKeyNative: Creates a VBS-isolated per-boot RSA key via NCrypt. + * The key is non-exportable and protected by Virtualization Based Security (VBS). + * - getPublicKeyNative: Exports the public key in DER/SubjectPublicKeyInfo format. + * - signWithKeyGuardNative: Signs data using the hardware-protected key (RSA-PSS/SHA-256). + * - getAttestationTokenNative: Calls AttestationClientLib.dll to obtain an attestation JWT. + * - acquireMtlsTokenNative: Performs an mTLS HTTPS POST using the hardware key as client cert. + * - freeKeyHandleNative: Frees the native NCrypt key handle. + * + * Build Requirements: + * - Visual Studio 2019+ or MSVC toolchain + * - Windows SDK 10.0.19041.0+ + * - Links: ncrypt.lib, bcrypt.lib, winhttp.lib, crypt32.lib + * - JNI headers from the JDK include directory + * - AttestationClientLib.lib (for attestation support) + * + * NOTE: This file is a stub implementation. Full implementation requires + * the AttestationClientLib.dll and Windows VBS-enabled environment. + */ + +#include "MsalJNIBridge.h" +#include +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================ +// Helper utilities +// ============================================================================ + +/** + * Deserializes the opaque key handle byte array back to NCRYPT_KEY_HANDLE. + */ +static NCRYPT_KEY_HANDLE deserializeKeyHandle(JNIEnv* env, jbyteArray keyHandle) { + jsize len = env->GetArrayLength(keyHandle); + if (len != sizeof(NCRYPT_KEY_HANDLE)) { + return 0; + } + NCRYPT_KEY_HANDLE handle = 0; + env->GetByteArrayRegion(keyHandle, 0, len, reinterpret_cast(&handle)); + return handle; +} + +/** + * Serializes a NCRYPT_KEY_HANDLE to a Java byte array for passing as opaque handle. + */ +static jbyteArray serializeKeyHandle(JNIEnv* env, NCRYPT_KEY_HANDLE handle) { + jbyteArray result = env->NewByteArray(sizeof(NCRYPT_KEY_HANDLE)); + env->SetByteArrayRegion(result, 0, sizeof(NCRYPT_KEY_HANDLE), + reinterpret_cast(&handle)); + return result; +} + +/** + * Throws a Java MsiV2Exception from native code. + */ +static void throwMsiV2Exception(JNIEnv* env, const char* message, const char* errorCode) { + jclass exClass = env->FindClass("com/microsoft/aad/msal4j/MsiV2Exception"); + if (exClass != nullptr) { + // Find constructor: MsiV2Exception(String message, String errorCode) + jmethodID ctor = env->GetMethodID(exClass, "", + "(Ljava/lang/String;Ljava/lang/String;)V"); + if (ctor != nullptr) { + jstring jMessage = env->NewStringUTF(message); + jstring jErrorCode = env->NewStringUTF(errorCode); + jobject ex = env->NewObject(exClass, ctor, jMessage, jErrorCode); + env->Throw(static_cast(ex)); + return; + } + } + // Fallback: throw RuntimeException + jclass rtClass = env->FindClass("java/lang/RuntimeException"); + env->ThrowNew(rtClass, message); +} + +// ============================================================================ +// JNI Implementations +// ============================================================================ + +JNIEXPORT jbyteArray JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_createKeyGuardRsaKeyNative( + JNIEnv* env, jclass /*clazz*/, jstring keyName, jint keySizeBits) +{ + NCRYPT_PROV_HANDLE hProvider = 0; + NCRYPT_KEY_HANDLE hKey = 0; + SECURITY_STATUS status; + + // Open the Microsoft Software Key Storage Provider + status = NCryptOpenStorageProvider(&hProvider, MSAL_KEYGUARD_PROVIDER, 0); + if (FAILED(status)) { + throwMsiV2Exception(env, + "[MSI v2] Failed to open NCrypt storage provider. Ensure Windows CNG is available.", + "msi_v2_keyguard_unavailable"); + return nullptr; + } + + // Get the key name as a wide string + const jchar* keyNameChars = env->GetStringChars(keyName, nullptr); + std::wstring keyNameW(reinterpret_cast(keyNameChars)); + env->ReleaseStringChars(keyName, keyNameChars); + + // Delete any existing key with this name (per-boot key - always fresh). + // First, try to open an existing key with this name and delete it. + NCRYPT_KEY_HANDLE hExistingKey = 0; + SECURITY_STATUS deleteStatus = NCryptOpenKey(hProvider, &hExistingKey, + keyNameW.c_str(), AT_KEYEXCHANGE, 0); + if (SUCCEEDED(deleteStatus) && hExistingKey != 0) { + NCryptDeleteKey(hExistingKey, 0); // hExistingKey is freed by NCryptDeleteKey + } + + // Create a new persisted key + status = NCryptCreatePersistedKey(hProvider, &hKey, + NCRYPT_RSA_ALGORITHM_GROUP, + keyNameW.c_str(), + AT_KEYEXCHANGE, 0); + if (FAILED(status)) { + NCryptFreeObject(hProvider); + throwMsiV2Exception(env, + "[MSI v2] Failed to create NCrypt persisted key. VBS may not be available.", + "msi_v2_keyguard_unavailable"); + return nullptr; + } + + // Set key size + DWORD keySize = static_cast(keySizeBits); + status = NCryptSetProperty(hKey, NCRYPT_LENGTH_PROPERTY, + reinterpret_cast(&keySize), sizeof(DWORD), 0); + if (FAILED(status)) { + NCryptFreeObject(hKey); + NCryptFreeObject(hProvider); + throwMsiV2Exception(env, + "[MSI v2] Failed to set key size property.", + "msi_v2_error"); + return nullptr; + } + + // Set VBS (Virtual Isolation) and per-boot flags for KeyGuard protection + DWORD keyUsagePolicy = NCRYPT_USE_VIRTUAL_ISOLATION_FLAG | NCRYPT_USE_PER_BOOT_KEY_FLAG; + status = NCryptSetProperty(hKey, NCRYPT_KEY_USAGE_POLICY_PROPERTY, + reinterpret_cast(&keyUsagePolicy), sizeof(DWORD), 0); + if (FAILED(status)) { + NCryptFreeObject(hKey); + NCryptFreeObject(hProvider); + throwMsiV2Exception(env, + "[MSI v2] Failed to set VBS KeyGuard properties. " + "Virtualization Based Security (VBS) must be enabled.", + "msi_v2_keyguard_unavailable"); + return nullptr; + } + + // Finalize the key + status = NCryptFinalizeKey(hKey, NCRYPT_WRITE_KEY_TO_LEGACY_STORE_FLAG); + if (FAILED(status)) { + NCryptFreeObject(hKey); + NCryptFreeObject(hProvider); + throwMsiV2Exception(env, + "[MSI v2] Failed to finalize KeyGuard key. VBS attestation may not be available.", + "msi_v2_keyguard_unavailable"); + return nullptr; + } + + NCryptFreeObject(hProvider); + + // Return the key handle as an opaque byte array + return serializeKeyHandle(env, hKey); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_getPublicKeyNative( + JNIEnv* env, jclass /*clazz*/, jbyteArray keyHandle) +{ + NCRYPT_KEY_HANDLE hKey = deserializeKeyHandle(env, keyHandle); + if (hKey == 0) { + throwMsiV2Exception(env, "[MSI v2] Invalid key handle.", "msi_v2_error"); + return nullptr; + } + + // Export public key in BCRYPT_RSAPUBLIC_BLOB format + DWORD cbPublicKey = 0; + SECURITY_STATUS status = NCryptExportKey(hKey, 0, BCRYPT_RSAPUBLIC_BLOB, + nullptr, nullptr, 0, &cbPublicKey, 0); + if (FAILED(status) || cbPublicKey == 0) { + throwMsiV2Exception(env, "[MSI v2] Failed to get public key size.", "msi_v2_error"); + return nullptr; + } + + std::vector publicKeyBlob(cbPublicKey); + status = NCryptExportKey(hKey, 0, BCRYPT_RSAPUBLIC_BLOB, + nullptr, publicKeyBlob.data(), cbPublicKey, &cbPublicKey, 0); + if (FAILED(status)) { + throwMsiV2Exception(env, "[MSI v2] Failed to export public key.", "msi_v2_error"); + return nullptr; + } + + // Convert the BCRYPT_RSAKEY_BLOB to DER SubjectPublicKeyInfo (X.509 format). + // The blob contains the RSA key parameters; we need to construct a CERT_PUBLIC_KEY_INFO + // and encode it as DER ASN.1 using CryptEncodeObjectEx. + // + // NOTE: Full implementation requires constructing the CERT_PUBLIC_KEY_INFO from + // the BCRYPT_RSAKEY_BLOB, setting the algorithm OID to szOID_RSA_RSA + // (1.2.840.113549.1.1.1), and encoding with X509_PUBLIC_KEY_INFO. + // This stub returns an empty placeholder; the production implementation + // should use CryptImportPublicKeyInfoEx2 or BCryptExportKey with X509_PUBLIC_KEY_INFO. + + // TODO: Implement full SubjectPublicKeyInfo DER encoding for production use. + throwMsiV2Exception(env, + "[MSI v2] getPublicKeyNative: SubjectPublicKeyInfo encoding is not yet fully implemented.", + "msi_v2_error"); + NCryptFreeObject(hProvider); + return nullptr; +} + +JNIEXPORT jbyteArray JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_signWithKeyGuardNative( + JNIEnv* env, jclass /*clazz*/, jbyteArray keyHandle, jbyteArray dataToSign) +{ + NCRYPT_KEY_HANDLE hKey = deserializeKeyHandle(env, keyHandle); + if (hKey == 0) { + throwMsiV2Exception(env, "[MSI v2] Invalid key handle for signing.", "msi_v2_error"); + return nullptr; + } + + // Get input data + jsize dataLen = env->GetArrayLength(dataToSign); + std::vector data(dataLen); + env->GetByteArrayRegion(dataToSign, 0, dataLen, reinterpret_cast(data.data())); + + // Hash the data with SHA-256 + BCRYPT_HASH_HANDLE hHash = nullptr; + BCRYPT_ALG_HANDLE hHashAlg = nullptr; + BCryptOpenAlgorithmProvider(&hHashAlg, BCRYPT_SHA256_ALGORITHM, nullptr, 0); + BCryptCreateHash(hHashAlg, &hHash, nullptr, 0, nullptr, 0, 0); + BCryptHashData(hHash, data.data(), static_cast(data.size()), 0); + std::vector digest(32); + BCryptFinishHash(hHash, digest.data(), 32, 0); + BCryptDestroyHash(hHash); + BCryptCloseAlgorithmProvider(hHashAlg, 0); + + // Sign with RSA-PSS/SHA-256 + BCRYPT_PSS_PADDING_INFO paddingInfo = {BCRYPT_SHA256_ALGORITHM, 32}; + DWORD cbSignature = 0; + SECURITY_STATUS status = NCryptSignHash(hKey, + &paddingInfo, + digest.data(), static_cast(digest.size()), + nullptr, 0, &cbSignature, + BCRYPT_PAD_PSS); + if (FAILED(status) || cbSignature == 0) { + throwMsiV2Exception(env, "[MSI v2] Failed to compute signature size.", "msi_v2_error"); + return nullptr; + } + + std::vector signature(cbSignature); + status = NCryptSignHash(hKey, + &paddingInfo, + digest.data(), static_cast(digest.size()), + signature.data(), cbSignature, &cbSignature, + BCRYPT_PAD_PSS); + if (FAILED(status)) { + throwMsiV2Exception(env, "[MSI v2] Failed to sign data with KeyGuard key.", "msi_v2_error"); + return nullptr; + } + + jbyteArray result = env->NewByteArray(static_cast(cbSignature)); + env->SetByteArrayRegion(result, 0, static_cast(cbSignature), + reinterpret_cast(signature.data())); + return result; +} + +JNIEXPORT jstring JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_getAttestationTokenNative( + JNIEnv* env, jclass /*clazz*/, jstring attestationEndpoint, jbyteArray keyHandle) +{ + // NOTE: Full implementation requires AttestationClientLib.dll. + // This stub outlines the expected call pattern. + // + // NCRYPT_KEY_HANDLE hKey = deserializeKeyHandle(env, keyHandle); + // AttestationClient client; + // std::string jwt = client.GetAttestationToken( + // jstring_to_wstring(env, attestationEndpoint), hKey); + // return env->NewStringUTF(jwt.c_str()); + + throwMsiV2Exception(env, + "[MSI v2] AttestationClientLib.dll is not available. " + "This functionality requires the Windows Attestation Client Library.", + "msi_v2_error"); + return nullptr; +} + +JNIEXPORT jstring JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_acquireMtlsTokenNative( + JNIEnv* env, jclass /*clazz*/, + jbyteArray keyHandle, jbyteArray certDer, + jstring tokenEndpointUrl, jstring requestBody) +{ + // NOTE: Full implementation uses WinHTTP with client certificate authentication. + // The KeyGuard private key is bound to the X.509 certificate for TLS client auth. + // + // NCRYPT_KEY_HANDLE hKey = deserializeKeyHandle(env, keyHandle); + // PCCERT_CONTEXT pCert = /* create from certDer bytes */; + // CertSetCertificateContextProperty(pCert, CERT_NCRYPT_KEY_HANDLE_PROP_ID, 0, &hKey); + // + // HINTERNET hSession = WinHttpOpen(...); + // HINTERNET hRequest = WinHttpOpenRequest(...); + // WinHttpSetOption(hRequest, WINHTTP_OPTION_CLIENT_CERT_CONTEXT, pCert, ...); + // WinHttpSendRequest(...); + // /* read response body */ + // return env->NewStringUTF(responseBody.c_str()); + + throwMsiV2Exception(env, + "[MSI v2] mTLS token acquisition is not yet implemented in native code.", + "msi_v2_error"); + return nullptr; +} + +JNIEXPORT void JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_freeKeyHandleNative( + JNIEnv* env, jclass /*clazz*/, jbyteArray keyHandle) +{ + NCRYPT_KEY_HANDLE hKey = deserializeKeyHandle(env, keyHandle); + if (hKey != 0) { + NCryptFreeObject(hKey); + } +} diff --git a/msal4j-sdk/src/main/cpp/MsalJNIBridge.h b/msal4j-sdk/src/main/cpp/MsalJNIBridge.h new file mode 100644 index 00000000..d3838900 --- /dev/null +++ b/msal4j-sdk/src/main/cpp/MsalJNIBridge.h @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +/** + * MsalJNIBridge - Native bridge for MSAL Java MSI v2 mTLS PoP support. + * + * This module provides JNI implementations for: + * - Creating VBS-isolated, per-boot KeyGuard RSA keys via NCrypt + * - Signing data with the hardware-protected KeyGuard key (for PKCS#10 CSR) + * - Obtaining attestation JWTs via the Windows AttestationClientLib + * - Performing mTLS HTTPS requests using the hardware-bound private key + * + * Prerequisites: + * - Windows with Virtualization Based Security (VBS) enabled + * - AttestationClientLib.dll (Windows Attestation Client Library) + * - NCrypt.dll (Windows CNG key storage) + * + * Build: + * This file is intended to be compiled with Visual Studio or MSBuild for Windows x64. + * The output DLL (MsalJNIBridge.dll) must be placed in a location accessible by the JVM. + */ + +// NCrypt flags for KeyGuard keys (VBS-isolated, per-boot, non-exportable) +#ifndef NCRYPT_USE_VIRTUAL_ISOLATION_FLAG +#define NCRYPT_USE_VIRTUAL_ISOLATION_FLAG 0x00020000 +#endif +#ifndef NCRYPT_USE_PER_BOOT_KEY_FLAG +#define NCRYPT_USE_PER_BOOT_KEY_FLAG 0x00040000 +#endif + +#define MSAL_KEYGUARD_PROVIDER L"Microsoft Software Key Storage Provider" +#define MSAL_KEYGUARD_ALG BCRYPT_RSA_ALGORITHM +#define MSAL_KEYGUARD_KEY_SIZE 2048 + +/** + * Creates a VBS-isolated per-boot RSA key in the NCrypt key store. + * Returns an opaque key handle (serialized NCRYPT_KEY_HANDLE) as a byte array. + */ +JNIEXPORT jbyteArray JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_createKeyGuardRsaKeyNative( + JNIEnv* env, jclass clazz, jstring keyName, jint keySizeBits); + +/** + * Returns the DER-encoded RSA SubjectPublicKeyInfo for the given key handle. + */ +JNIEXPORT jbyteArray JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_getPublicKeyNative( + JNIEnv* env, jclass clazz, jbyteArray keyHandle); + +/** + * Signs the given data bytes with the KeyGuard RSA key using RSA-PSS/SHA-256. + * Returns the signature bytes. + */ +JNIEXPORT jbyteArray JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_signWithKeyGuardNative( + JNIEnv* env, jclass clazz, jbyteArray keyHandle, jbyteArray dataToSign); + +/** + * Obtains an attestation JWT from the Windows AttestationClientLib. + * Returns the attestation JWT as a Java String. + */ +JNIEXPORT jstring JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_getAttestationTokenNative( + JNIEnv* env, jclass clazz, jstring attestationEndpoint, jbyteArray keyHandle); + +/** + * Performs an mTLS HTTPS POST request to the token endpoint using the KeyGuard private key. + * The TLS client authentication uses the provided X.509 certificate and the hardware-bound key. + * Returns the HTTP response body as a Java String. + */ +JNIEXPORT jstring JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_acquireMtlsTokenNative( + JNIEnv* env, jclass clazz, + jbyteArray keyHandle, jbyteArray certDer, + jstring tokenEndpointUrl, jstring requestBody); + +/** + * Frees the native NCrypt key handle and associated resources. + */ +JNIEXPORT void JNICALL +Java_com_microsoft_aad_msal4j_WindowsKeyGuardJNI_freeKeyHandleNative( + JNIEnv* env, jclass clazz, jbyteArray keyHandle); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java index c6545cf7..2195adf9 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AcquireTokenByManagedIdentitySupplier.java @@ -32,6 +32,27 @@ AuthenticationResult execute() throws Exception { MsalErrorMessage.SCOPES_REQUIRED); } + // Validate MSI v2 flag combination: attestation requires mTLS PoP + if (managedIdentityParameters.withAttestationSupport + && !managedIdentityParameters.mtlsProofOfPossession) { + throw new MsalClientException( + MsalErrorMessage.MSI_V2_ATTESTATION_REQUIRES_POP, + MsalError.MSI_V2_ATTESTATION_REQUIRES_POP); + } + + // MSI v2 flow: use when BOTH mtlsProofOfPossession AND withAttestationSupport are set. + // MSI v2 bypasses the token cache and never falls back to MSI v1 on failure. + if (managedIdentityParameters.mtlsProofOfPossession + && managedIdentityParameters.withAttestationSupport) { + LOG.debug("[MSI v2] Both mtlsProofOfPossession and withAttestationSupport are set. " + + "Using MSI v2 mTLS PoP flow."); + TokenRequestExecutor tokenRequestExecutor = new TokenRequestExecutor( + clientApplication.authenticationAuthority, + msalRequest, + clientApplication.serviceBundle()); + return fetchNewTokenMsiV2(tokenRequestExecutor); + } + TokenRequestExecutor tokenRequestExecutor = new TokenRequestExecutor( clientApplication.authenticationAuthority, msalRequest, @@ -121,6 +142,23 @@ private AuthenticationResult fetchNewAccessTokenAndSaveToCache(TokenRequestExecu return authenticationResult; } + /** + * Executes the MSI v2 mTLS PoP flow. Unlike MSI v1, the result is NOT cached + * since MSI v2 tokens are short-lived and hardware-bound. + * Any failure will propagate as a {@link MsiV2Exception} without fallback to MSI v1. + */ + private AuthenticationResult fetchNewTokenMsiV2(TokenRequestExecutor tokenRequestExecutor) { + ManagedIdentityResponse managedIdentityResponse = MsiV2.obtainToken( + msalRequest, + tokenRequestExecutor.getServiceBundle(), + managedIdentityParameters.resource); + + AuthenticationResult authenticationResult = createFromManagedIdentityResponse(managedIdentityResponse); + authenticationResult.metadata().tokenSource(TokenSource.IDENTITY_PROVIDER); + authenticationResult.metadata().cacheRefreshReason(CacheRefreshReason.NOT_APPLICABLE); + return authenticationResult; + } + private AuthenticationResult createFromManagedIdentityResponse(ManagedIdentityResponse managedIdentityResponse) { long expiresOn = getExpiresOnFromManagedIdentityTimestamp(managedIdentityResponse.expiresOn); long refreshOn = calculateRefreshOn(expiresOn); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrGenerator.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrGenerator.java new file mode 100644 index 00000000..cf7d599b --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrGenerator.java @@ -0,0 +1,296 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +/** + * Generates PKCS#10 Certificate Signing Requests (CSRs) for the MSI v2 mTLS PoP flow. + *

+ * The CSR is signed with the KeyGuard RSA key (via JNI) and includes a Microsoft-specific + * OID attribute ({@code 1.3.6.1.4.1.311.90.2.10}) containing the compute unit ID (cuId) + * as a UTF8String JSON value. + *

+ * The generated CSR uses RSA-PSS with SHA-256 as the signature algorithm. + */ +class CsrGenerator { + + // Pre-encoded OID byte arrays (DER OID encoding: tag 0x06 + length + VLQ-encoded OID arcs) + // rsaEncryption: 1.2.840.113549.1.1.1 + private static final byte[] OID_RSA_ENCRYPTION = + {0x2A, (byte) 0x86, 0x48, (byte) 0x86, (byte) 0xF7, 0x0D, 0x01, 0x01, 0x01}; + // id-RSASSA-PSS: 1.2.840.113549.1.1.10 + private static final byte[] OID_RSASSA_PSS = + {0x2A, (byte) 0x86, 0x48, (byte) 0x86, (byte) 0xF7, 0x0D, 0x01, 0x01, 0x0A}; + // id-mgf1: 1.2.840.113549.1.1.8 + private static final byte[] OID_MGF1 = + {0x2A, (byte) 0x86, 0x48, (byte) 0x86, (byte) 0xF7, 0x0D, 0x01, 0x01, 0x08}; + // sha-256: 2.16.840.1.101.3.4.2.1 + private static final byte[] OID_SHA256 = + {0x60, (byte) 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01}; + // commonName: 2.5.4.3 + private static final byte[] OID_COMMON_NAME = {0x55, 0x04, 0x03}; + // Microsoft MSI v2 cuId OID: 1.3.6.1.4.1.311.90.2.10 + private static final byte[] OID_MSI_V2_CU_ID = + {0x2B, 0x06, 0x01, 0x04, 0x01, (byte) 0x82, 0x37, 0x5A, 0x02, 0x0A}; + + // RSA-PSS AlgorithmIdentifier DER bytes (pre-computed for SHA-256, MGF1-SHA256, saltLen=32) + // SEQUENCE { OID rsaPSS, SEQUENCE { [0] SHA256-AlgId, [1] MGF1-SHA256-AlgId, [2] saltLen=32 } } + private static final byte[] ALG_ID_RSASSA_PSS = buildRsaPssAlgorithmIdentifier(); + + /** + * Generates a PKCS#10 CSR and returns it in PEM format. + * + * @param publicKeyDer DER-encoded RSA SubjectPublicKeyInfo bytes (from the KeyGuard key) + * @param cuId the compute unit ID (vmId or vmssId) from IMDS platform metadata + * @param keyHandle the native KeyGuard key handle used for signing via JNI + * @return PEM-encoded PKCS#10 CSR string + * @throws MsiV2Exception if CSR construction or signing fails + */ + static String generate(byte[] publicKeyDer, String cuId, byte[] keyHandle) { + try { + // Build the CertificationRequestInfo (TBS) structure + byte[] certRequestInfo = buildCertificationRequestInfo(publicKeyDer, cuId); + + // Sign the TBS bytes using the KeyGuard key via JNI + byte[] signature = WindowsKeyGuardJNI.signWithKeyGuardNative(keyHandle, certRequestInfo); + + // Assemble the final PKCS#10 structure + byte[] pkcs10Der = buildPkcs10(certRequestInfo, signature); + + // Encode to PEM + return toPem(pkcs10Der); + } catch (IOException e) { + throw new MsiV2Exception("[MSI v2] Failed to generate CSR: " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + } + + /** + * Builds the DER-encoded CertificationRequestInfo structure (the TBS part of PKCS#10). + */ + private static byte[] buildCertificationRequestInfo(byte[] publicKeyDer, String cuId) + throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + + // version INTEGER 0 + out.write(encodeInteger(0)); + + // subject: SET { SEQUENCE { OID commonName, UTF8String "managed-identity-csr" } } + out.write(encodeSubject("managed-identity-csr")); + + // subjectPublicKeyInfo: already DER-encoded RSA public key + out.write(publicKeyDer); + + // attributes [0]: Microsoft OID with cuId as UTF8String JSON + out.write(encodeAttributes(cuId)); + + // Wrap in SEQUENCE to get CertificationRequestInfo + return encodeSequence(out.toByteArray()); + } + + /** + * Builds the full PKCS#10 DER structure from CertificationRequestInfo and signature. + */ + private static byte[] buildPkcs10(byte[] certRequestInfo, byte[] signature) + throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + + // CertificationRequestInfo + out.write(certRequestInfo); + + // signatureAlgorithm: id-RSASSA-PSS with SHA-256/MGF1/saltLen=32 + out.write(ALG_ID_RSASSA_PSS); + + // signature: BIT STRING with 0 unused bits + out.write(encodeBitString(signature)); + + return encodeSequence(out.toByteArray()); + } + + /** + * Encodes the subject as a Distinguished Name: RDN { AttributeTypeAndValue { CN, value } }. + */ + private static byte[] encodeSubject(String commonName) throws IOException { + // UTF8String for the CN value + byte[] cnValue = encodeUtf8String(commonName); + + // AttributeTypeAndValue: SEQUENCE { OID commonName, UTF8String value } + ByteArrayOutputStream atv = new ByteArrayOutputStream(); + atv.write(encodeOid(OID_COMMON_NAME)); + atv.write(cnValue); + byte[] atvSeq = encodeSequence(atv.toByteArray()); + + // RelativeDistinguishedName: SET { AttributeTypeAndValue } + byte[] rdn = encodeSet(atvSeq); + + // Name: SEQUENCE { RDN } + return encodeSequence(rdn); + } + + /** + * Encodes the PKCS#10 attributes containing the MSI v2 cuId OID attribute. + * Structure: [0] IMPLICIT SET { SEQUENCE { OID, SET { UTF8String cuIdJson } } } + */ + private static byte[] encodeAttributes(String cuId) throws IOException { + // The cuId JSON representation (per the MSI v2 spec). + // Escape any special JSON characters in the cuId value to produce valid JSON. + String cuIdJson = "\"" + cuId.replace("\\", "\\\\").replace("\"", "\\\"") + "\""; + + // Attribute value: SET { UTF8String cuIdJson } + byte[] attrValue = encodeSet(encodeUtf8String(cuIdJson)); + + // Attribute: SEQUENCE { OID msiV2CuId, attrValue } + ByteArrayOutputStream attr = new ByteArrayOutputStream(); + attr.write(encodeOid(OID_MSI_V2_CU_ID)); + attr.write(attrValue); + byte[] attrSeq = encodeSequence(attr.toByteArray()); + + // attributes [0] IMPLICIT (context tag 0, constructed) + return encodeContextTag(0, attrSeq); + } + + // ------------------------------------------------------------------------- + // DER encoding helpers + // ------------------------------------------------------------------------- + + private static byte[] encodeSequence(byte[] content) throws IOException { + return encodeTlv(0x30, content); + } + + private static byte[] encodeSet(byte[] content) throws IOException { + return encodeTlv(0x31, content); + } + + private static byte[] encodeOid(byte[] oidBytes) throws IOException { + return encodeTlv(0x06, oidBytes); + } + + private static byte[] encodeInteger(int value) throws IOException { + byte[] valueBytes; + if (value == 0) { + valueBytes = new byte[]{0x00}; + } else { + // Minimal positive integer encoding + ByteArrayOutputStream b = new ByteArrayOutputStream(); + while (value > 0) { + b.write(value & 0xFF); + value >>= 8; + } + byte[] raw = b.toByteArray(); + // Reverse (we wrote LSB first) + for (int i = 0, j = raw.length - 1; i < j; i++, j--) { + byte tmp = raw[i]; + raw[i] = raw[j]; + raw[j] = tmp; + } + // Prepend 0x00 if high bit is set (to keep it positive) + if ((raw[0] & 0x80) != 0) { + byte[] padded = new byte[raw.length + 1]; + System.arraycopy(raw, 0, padded, 1, raw.length); + valueBytes = padded; + } else { + valueBytes = raw; + } + } + return encodeTlv(0x02, valueBytes); + } + + private static byte[] encodeUtf8String(String value) throws IOException { + return encodeTlv(0x0C, value.getBytes(StandardCharsets.UTF_8)); + } + + private static byte[] encodeBitString(byte[] content) throws IOException { + // BIT STRING: prepend 0x00 (unused bits count) + byte[] withUnused = new byte[content.length + 1]; + withUnused[0] = 0x00; // no unused bits + System.arraycopy(content, 0, withUnused, 1, content.length); + return encodeTlv(0x03, withUnused); + } + + private static byte[] encodeContextTag(int tagNumber, byte[] content) throws IOException { + // Constructed context tag: 0xA0 | tagNumber + return encodeTlv(0xA0 | tagNumber, content); + } + + private static byte[] encodeTlv(int tag, byte[] value) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(tag); + writeLength(out, value.length); + out.write(value); + return out.toByteArray(); + } + + private static void writeLength(ByteArrayOutputStream out, int length) { + if (length < 128) { + out.write(length); + } else if (length < 256) { + out.write(0x81); + out.write(length); + } else { + out.write(0x82); + out.write((length >> 8) & 0xFF); + out.write(length & 0xFF); + } + } + + /** + * Builds the pre-encoded DER bytes for the RSA-PSS AlgorithmIdentifier with SHA-256, + * MGF1-SHA256, and saltLength=32. + */ + private static byte[] buildRsaPssAlgorithmIdentifier() { + try { + // sha-256 AlgorithmIdentifier: SEQUENCE { OID sha256 } (no NULL per RFC 4055) + byte[] sha256AlgId = encodeSequence(encodeOid(OID_SHA256)); + + // [0] hashAlgorithm + byte[] hashAlgField = encodeContextTag(0, sha256AlgId); + + // mgf1 AlgorithmIdentifier: SEQUENCE { OID mgf1, sha256AlgId } + ByteArrayOutputStream mgf1Inner = new ByteArrayOutputStream(); + mgf1Inner.write(encodeOid(OID_MGF1)); + mgf1Inner.write(sha256AlgId); + byte[] mgf1AlgId = encodeSequence(mgf1Inner.toByteArray()); + + // [1] maskGenAlgorithm + byte[] maskGenField = encodeContextTag(1, mgf1AlgId); + + // [2] saltLength: INTEGER 32 + byte[] saltLenField = encodeContextTag(2, encodeInteger(32)); + + // RSASSA-PSS-params SEQUENCE + ByteArrayOutputStream pssParams = new ByteArrayOutputStream(); + pssParams.write(hashAlgField); + pssParams.write(maskGenField); + pssParams.write(saltLenField); + byte[] pssParamsSeq = encodeSequence(pssParams.toByteArray()); + + // AlgorithmIdentifier: SEQUENCE { OID rsaPSS, pssParamsSeq } + ByteArrayOutputStream algId = new ByteArrayOutputStream(); + algId.write(encodeOid(OID_RSASSA_PSS)); + algId.write(pssParamsSeq); + return encodeSequence(algId.toByteArray()); + } catch (IOException e) { + // Should never happen with ByteArrayOutputStream + throw new RuntimeException("Failed to build RSA-PSS AlgorithmIdentifier", e); + } + } + + /** + * Converts DER-encoded PKCS#10 bytes to PEM format. + */ + private static String toPem(byte[] der) { + String base64 = Base64.getMimeEncoder(64, new byte[]{'\n'}).encodeToString(der); + return "-----BEGIN CERTIFICATE REQUEST-----\n" + + base64 + + "\n-----END CERTIFICATE REQUEST-----\n"; + } + + private CsrGenerator() { + // Utility class, not instantiable + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrMetadata.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrMetadata.java new file mode 100644 index 00000000..c65da2fa --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/CsrMetadata.java @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import com.azure.json.JsonReader; +import com.azure.json.JsonSerializable; +import com.azure.json.JsonToken; +import com.azure.json.JsonWriter; + +import java.io.IOException; + +/** + * DTO containing platform metadata returned by the IMDS + * {@code GET /metadata/identity/getPlatformMetadata} endpoint. + * Used in the MSI v2 mTLS PoP flow. + */ +class CsrMetadata implements JsonSerializable { + + String clientId; + String tenantId; + String cuId; + String attestationEndpoint; + + public static CsrMetadata fromJson(JsonReader jsonReader) throws IOException { + CsrMetadata metadata = new CsrMetadata(); + return jsonReader.readObject(reader -> { + while (reader.nextToken() != JsonToken.END_OBJECT) { + String fieldName = reader.getFieldName(); + reader.nextToken(); + switch (fieldName) { + case "client_id": + metadata.clientId = reader.getString(); + break; + case "tenant_id": + metadata.tenantId = reader.getString(); + break; + case "cu_id": + metadata.cuId = reader.getString(); + break; + case "attestation_endpoint": + metadata.attestationEndpoint = reader.getString(); + break; + default: + reader.skipChildren(); + break; + } + } + return metadata; + }); + } + + @Override + public JsonWriter toJson(JsonWriter jsonWriter) throws IOException { + jsonWriter.writeStartObject(); + jsonWriter.writeStringField("client_id", clientId); + jsonWriter.writeStringField("tenant_id", tenantId); + jsonWriter.writeStringField("cu_id", cuId); + jsonWriter.writeStringField("attestation_endpoint", attestationEndpoint); + jsonWriter.writeEndObject(); + return jsonWriter; + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateRequest.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateRequest.java new file mode 100644 index 00000000..6965becf --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateRequest.java @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import com.azure.json.JsonWriter; +import com.azure.json.JsonSerializable; + +import java.io.IOException; + +/** + * Request body for the IMDS {@code POST /metadata/identity/issuecredential} endpoint. + * Used in Step 5 of the MSI v2 mTLS PoP flow to obtain a short-lived mTLS client certificate. + */ +class IssueCertificateRequest implements JsonSerializable { + + /** Base64-encoded PKCS#10 CSR signed with the KeyGuard RSA key. */ + String csr; + + /** JWT attestation token from the KeyGuard attestation service. */ + String attestationToken; + + IssueCertificateRequest(String csr, String attestationToken) { + this.csr = csr; + this.attestationToken = attestationToken; + } + + @Override + public JsonWriter toJson(JsonWriter jsonWriter) throws IOException { + jsonWriter.writeStartObject(); + jsonWriter.writeStringField("csr", csr); + jsonWriter.writeStringField("attestation_token", attestationToken); + jsonWriter.writeEndObject(); + return jsonWriter; + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateResponse.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateResponse.java new file mode 100644 index 00000000..0845e306 --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/IssueCertificateResponse.java @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import com.azure.json.JsonReader; +import com.azure.json.JsonSerializable; +import com.azure.json.JsonToken; +import com.azure.json.JsonWriter; + +import java.io.IOException; + +/** + * Response DTO from the IMDS {@code POST /metadata/identity/issuecredential} endpoint. + * Contains the short-lived X.509 certificate to use for mTLS token acquisition. + * Used in Step 6 of the MSI v2 mTLS PoP flow. + */ +class IssueCertificateResponse implements JsonSerializable { + + /** Base64-encoded DER X.509 certificate issued by IMDS for mTLS. */ + String certificate; + + /** Regional ESTS mTLS endpoint URL to acquire the final PoP token from. */ + String mtlsAuthenticationEndpoint; + + /** Tenant ID associated with the managed identity. */ + String tenantId; + + /** Client ID of the managed identity. */ + String clientId; + + public static IssueCertificateResponse fromJson(JsonReader jsonReader) throws IOException { + IssueCertificateResponse response = new IssueCertificateResponse(); + return jsonReader.readObject(reader -> { + while (reader.nextToken() != JsonToken.END_OBJECT) { + String fieldName = reader.getFieldName(); + reader.nextToken(); + switch (fieldName) { + case "certificate": + response.certificate = reader.getString(); + break; + case "mtls_authentication_endpoint": + response.mtlsAuthenticationEndpoint = reader.getString(); + break; + case "tenant_id": + response.tenantId = reader.getString(); + break; + case "client_id": + response.clientId = reader.getString(); + break; + default: + reader.skipChildren(); + break; + } + } + return response; + }); + } + + @Override + public JsonWriter toJson(JsonWriter jsonWriter) throws IOException { + jsonWriter.writeStartObject(); + jsonWriter.writeStringField("certificate", certificate); + jsonWriter.writeStringField("mtls_authentication_endpoint", mtlsAuthenticationEndpoint); + jsonWriter.writeStringField("tenant_id", tenantId); + jsonWriter.writeStringField("client_id", clientId); + jsonWriter.writeEndObject(); + return jsonWriter; + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java index 21335802..f31b4c21 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityParameters.java @@ -16,11 +16,16 @@ public class ManagedIdentityParameters implements IAcquireTokenParameters { boolean forceRefresh; String claims; String revokedTokenHash; - - private ManagedIdentityParameters(String resource, boolean forceRefresh, String claims) { + boolean mtlsProofOfPossession; + boolean withAttestationSupport; + + private ManagedIdentityParameters(String resource, boolean forceRefresh, String claims, + boolean mtlsProofOfPossession, boolean withAttestationSupport) { this.resource = resource; this.forceRefresh = forceRefresh; this.claims = claims; + this.mtlsProofOfPossession = mtlsProofOfPossession; + this.withAttestationSupport = withAttestationSupport; } @Override @@ -83,10 +88,20 @@ public String revokedTokenHash() { return this.revokedTokenHash; } + public boolean mtlsProofOfPossession() { + return this.mtlsProofOfPossession; + } + + public boolean withAttestationSupport() { + return this.withAttestationSupport; + } + public static class ManagedIdentityParametersBuilder { private String resource; private boolean forceRefresh; private String claims; + private boolean mtlsProofOfPossession; + private boolean withAttestationSupport; ManagedIdentityParametersBuilder() { } @@ -118,12 +133,48 @@ public ManagedIdentityParametersBuilder claims(String claims) { return this; } + /** + * Requests an mTLS Proof-of-Possession (PoP) token instead of a standard Bearer token. + *

+ * When set to {@code true}, the MSI v2 flow will be used if {@code withAttestationSupport} + * is also set to {@code true}. This requires Virtualization Based Security (VBS) and + * KeyGuard to be available on the host (Windows only). + * + * @param mtlsProofOfPossession {@code true} to request an mTLS PoP token + * @return this builder instance + */ + public ManagedIdentityParametersBuilder mtlsProofOfPossession(boolean mtlsProofOfPossession) { + this.mtlsProofOfPossession = mtlsProofOfPossession; + return this; + } + + /** + * Enables KeyGuard attestation support for the MSI v2 mTLS PoP flow. + *

+ * When set to {@code true}, the SDK will use the Windows KeyGuard (VBS-backed) RSA key + * to obtain a hardware-attested certificate from IMDS and use it for mTLS token acquisition. + *

+ * Requires {@code mtlsProofOfPossession=true}. If this flag is set without + * {@code mtlsProofOfPossession=true}, an exception will be thrown. + * + * @param withAttestationSupport {@code true} to require KeyGuard attestation + * @return this builder instance + */ + public ManagedIdentityParametersBuilder withAttestationSupport(boolean withAttestationSupport) { + this.withAttestationSupport = withAttestationSupport; + return this; + } + public ManagedIdentityParameters build() { - return new ManagedIdentityParameters(this.resource, this.forceRefresh, this.claims); + return new ManagedIdentityParameters(this.resource, this.forceRefresh, this.claims, + this.mtlsProofOfPossession, this.withAttestationSupport); } public String toString() { - return "ManagedIdentityParameters.ManagedIdentityParametersBuilder(resource=" + this.resource + ", forceRefresh=" + this.forceRefresh + ")"; + return "ManagedIdentityParameters.ManagedIdentityParametersBuilder(resource=" + this.resource + + ", forceRefresh=" + this.forceRefresh + + ", mtlsProofOfPossession=" + this.mtlsProofOfPossession + + ", withAttestationSupport=" + this.withAttestationSupport + ")"; } } } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalError.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalError.java index fe594e4d..7d598529 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalError.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalError.java @@ -36,4 +36,19 @@ public class MsalError { public static final String MANAGED_IDENTITY_FILE_READ_ERROR = "managed_identity_file_read_error"; public static final String MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE = "managed_identity_response_parse_failure"; + + /** + * MSI v2: withAttestationSupport=true requires mtlsProofOfPossession=true. + */ + public static final String MSI_V2_ATTESTATION_REQUIRES_POP = "msi_v2_attestation_requires_pop"; + + /** + * MSI v2 token acquisition failed (no silent fallback to MSI v1). + */ + public static final String MSI_V2_ERROR = "msi_v2_error"; + + /** + * KeyGuard (VBS-backed hardware key) is not available on this platform. + */ + public static final String MSI_V2_KEYGUARD_UNAVAILABLE = "msi_v2_keyguard_unavailable"; } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalErrorMessage.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalErrorMessage.java index f168a336..3a8ebbce 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalErrorMessage.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsalErrorMessage.java @@ -28,4 +28,21 @@ class MsalErrorMessage { public static final String GATEWAY_ERROR = "[Managed Identity] Authentication unavailable. The request failed due to a gateway error."; public static final String MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE = "[Managed Identity] MSI returned %s, but the response could not be parsed: %s"; + + public static final String MSI_V2_ATTESTATION_REQUIRES_POP = + "[MSI v2] withAttestationSupport=true requires mtlsProofOfPossession=true. " + + "Both flags must be set to use the MSI v2 mTLS PoP flow."; + + public static final String MSI_V2_KEYGUARD_UNAVAILABLE = + "[MSI v2] KeyGuard is unavailable. Virtualization Based Security (VBS) must be enabled " + + "on this host for attestation-backed MSI v2 token acquisition."; + + public static final String MSI_V2_PLATFORM_METADATA_FAILED = + "[MSI v2] Failed to retrieve platform metadata from IMDS."; + + public static final String MSI_V2_ISSUECREDENTIAL_FAILED = + "[MSI v2] Failed to issue mTLS credential from IMDS."; + + public static final String MSI_V2_TOKEN_ACQUISITION_FAILED = + "[MSI v2] Failed to acquire mTLS PoP token from the regional ESTS endpoint."; } diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2.java new file mode 100644 index 00000000..c4010dc7 --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2.java @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import com.azure.json.JsonProviders; +import com.azure.json.JsonReader; +import com.azure.json.JsonToken; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; + +/** + * Implements the MSI v2 mTLS Proof-of-Possession (PoP) token acquisition flow using + * Windows KeyGuard attestation. + *

+ * The flow consists of 7 steps: + *

    + *
  1. Get platform metadata from IMDS ({@code /metadata/identity/getPlatformMetadata})
  2. + *
  3. Create a VBS-protected KeyGuard RSA key via JNI
  4. + *
  5. Build a PKCS#10 CSR signed with the KeyGuard key, including the cuId OID attribute
  6. + *
  7. Obtain an attestation JWT from the Windows AttestationClientLib via JNI
  8. + *
  9. Issue a short-lived mTLS credential from IMDS ({@code /metadata/identity/issuecredential})
  10. + *
  11. Parse the issued X.509 certificate from the IMDS response
  12. + *
  13. Acquire an {@code mtls_pop} token from the regional ESTS endpoint using mTLS
  14. + *
+ *

+ * No silent fallback: if MSI v2 is explicitly requested and fails at any step, + * a {@link MsiV2Exception} is thrown and MSI v1 is NOT attempted as a fallback. + *

+ * Platform requirement: Windows with Virtualization Based Security (VBS) enabled. + */ +class MsiV2 { + + private static final Logger LOG = LoggerFactory.getLogger(MsiV2.class); + + static final String IMDS_BASE_URL = "http://169.254.169.254"; + static final String PLATFORM_METADATA_PATH = "/metadata/identity/getPlatformMetadata"; + static final String ISSUECREDENTIAL_PATH = "/metadata/identity/issuecredential"; + static final String IMDS_API_VERSION = "2018-02-01"; + static final String KEYGUARD_KEY_NAME = "MsalKeyGuardKey"; + static final int KEYGUARD_RSA_KEY_SIZE = 2048; + static final String MTLS_TOKEN_TYPE = "mtls_pop"; + static final String OAUTH2_GRANT_TYPE = "client_credentials"; + + /** + * Executes the full MSI v2 7-step flow to obtain an mTLS PoP token. + * + * @param msalRequest the current MSAL request (provides HTTP helper and request context) + * @param serviceBundle the service bundle with HTTP client + * @param resource the Azure resource URI to acquire a token for + * @return a {@link ManagedIdentityResponse} containing the acquired mTLS PoP token + * @throws MsiV2Exception if any step fails (no fallback to MSI v1) + */ + static ManagedIdentityResponse obtainToken(MsalRequest msalRequest, + ServiceBundle serviceBundle, + String resource) { + // Check native library availability + if (!WindowsKeyGuardJNI.isNativeLibraryLoaded()) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_KEYGUARD_UNAVAILABLE, + MsalError.MSI_V2_KEYGUARD_UNAVAILABLE); + } + + byte[] keyHandle = null; + try { + // Step 1: Get platform metadata from IMDS + LOG.debug("[MSI v2] Step 1: Retrieving platform metadata from IMDS."); + CsrMetadata metadata = getPlatformMetadata(msalRequest, serviceBundle); + + // Step 2: Create KeyGuard RSA key (VBS-isolated, per-boot, non-exportable) + LOG.debug("[MSI v2] Step 2: Creating KeyGuard RSA-{} key.", KEYGUARD_RSA_KEY_SIZE); + keyHandle = createKeyGuardKey(); + byte[] publicKeyDer = WindowsKeyGuardJNI.getPublicKeyNative(keyHandle); + + // Step 3: Build PKCS#10 CSR with Microsoft cuId OID attribute + LOG.debug("[MSI v2] Step 3: Generating PKCS#10 CSR with cuId attribute."); + String csrPem = CsrGenerator.generate(publicKeyDer, metadata.cuId, keyHandle); + String csrBase64 = extractBase64FromPem(csrPem); + + // Step 4: Obtain attestation JWT from AttestationClientLib + LOG.debug("[MSI v2] Step 4: Obtaining attestation JWT from {}.", + metadata.attestationEndpoint); + String attestationToken = getAttestationToken(metadata.attestationEndpoint, keyHandle); + + // Step 5: Issue mTLS credential from IMDS + LOG.debug("[MSI v2] Step 5: Issuing mTLS credential from IMDS."); + IssueCertificateResponse certResponse = issueCredential( + msalRequest, serviceBundle, csrBase64, attestationToken); + + // Step 6: Parse the issued X.509 certificate + LOG.debug("[MSI v2] Step 6: Parsing issued X.509 certificate."); + byte[] certDer = Base64.getDecoder().decode(certResponse.certificate); + + // Step 7: Acquire mTLS PoP token from regional ESTS endpoint + LOG.debug("[MSI v2] Step 7: Acquiring mTLS PoP token from {}.", + certResponse.mtlsAuthenticationEndpoint); + return acquireMtlsToken(keyHandle, certDer, certResponse, resource); + + } finally { + // Free the native key handle to avoid resource leaks + if (keyHandle != null) { + try { + WindowsKeyGuardJNI.freeKeyHandleNative(keyHandle); + } catch (Exception e) { + LOG.warn("[MSI v2] Failed to free native key handle: {}", e.getMessage()); + } + } + } + } + + // ------------------------------------------------------------------------- + // Step 1: Get platform metadata + // ------------------------------------------------------------------------- + + static CsrMetadata getPlatformMetadata(MsalRequest msalRequest, ServiceBundle serviceBundle) { + String url = IMDS_BASE_URL + PLATFORM_METADATA_PATH + "?api-version=" + IMDS_API_VERSION; + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + + HttpRequest request = new HttpRequest(HttpMethod.GET, url, headers); + IHttpResponse response; + try { + response = serviceBundle.getHttpHelper().executeHttpRequest( + request, msalRequest.requestContext(), serviceBundle); + } catch (Exception e) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_PLATFORM_METADATA_FAILED + " " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + + if (response.statusCode() != 200) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_PLATFORM_METADATA_FAILED + + " HTTP " + response.statusCode() + ": " + response.body(), + MsalError.MSI_V2_ERROR); + } + + try (JsonReader reader = JsonProviders.createReader(response.body())) { + return CsrMetadata.fromJson(reader); + } catch (IOException e) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_PLATFORM_METADATA_FAILED + " JSON parse error: " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + } + + // ------------------------------------------------------------------------- + // Step 2: Create KeyGuard RSA key + // ------------------------------------------------------------------------- + + private static byte[] createKeyGuardKey() { + try { + byte[] keyHandle = WindowsKeyGuardJNI.createKeyGuardRsaKeyNative( + KEYGUARD_KEY_NAME, KEYGUARD_RSA_KEY_SIZE); + if (keyHandle == null || keyHandle.length == 0) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_KEYGUARD_UNAVAILABLE, + MsalError.MSI_V2_KEYGUARD_UNAVAILABLE); + } + return keyHandle; + } catch (MsiV2Exception e) { + throw e; + } catch (Exception e) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_KEYGUARD_UNAVAILABLE + " " + e.getMessage(), + MsalError.MSI_V2_KEYGUARD_UNAVAILABLE, e); + } + } + + // ------------------------------------------------------------------------- + // Step 4: Get attestation token + // ------------------------------------------------------------------------- + + private static String getAttestationToken(String attestationEndpoint, byte[] keyHandle) { + try { + String token = WindowsKeyGuardJNI.getAttestationTokenNative(attestationEndpoint, keyHandle); + if (StringHelper.isNullOrBlank(token)) { + throw new MsiV2Exception( + "[MSI v2] Attestation service returned an empty token.", + MsalError.MSI_V2_ERROR); + } + return token; + } catch (MsiV2Exception e) { + throw e; + } catch (Exception e) { + throw new MsiV2Exception( + "[MSI v2] Attestation failed: " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + } + + // ------------------------------------------------------------------------- + // Step 5: Issue mTLS credential from IMDS + // ------------------------------------------------------------------------- + + static IssueCertificateResponse issueCredential(MsalRequest msalRequest, + ServiceBundle serviceBundle, + String csrBase64, + String attestationToken) { + String url = IMDS_BASE_URL + ISSUECREDENTIAL_PATH + "?api-version=" + IMDS_API_VERSION; + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + headers.put("Content-Type", "application/json"); + + // Build JSON request body + String body = buildIssueCertificateRequestBody(csrBase64, attestationToken); + + HttpRequest request = new HttpRequest(HttpMethod.POST, url, headers, body); + IHttpResponse response; + try { + response = serviceBundle.getHttpHelper().executeHttpRequest( + request, msalRequest.requestContext(), serviceBundle); + } catch (Exception e) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_ISSUECREDENTIAL_FAILED + " " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + + if (response.statusCode() != 200) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_ISSUECREDENTIAL_FAILED + + " HTTP " + response.statusCode() + ": " + response.body(), + MsalError.MSI_V2_ERROR); + } + + try (JsonReader reader = JsonProviders.createReader(response.body())) { + return IssueCertificateResponse.fromJson(reader); + } catch (IOException e) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_ISSUECREDENTIAL_FAILED + " JSON parse error: " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + } + + private static String buildIssueCertificateRequestBody(String csrBase64, String attestationToken) { + // Simple JSON serialization without external library + return "{\"csr\":\"" + escapeJson(csrBase64) + "\"," + + "\"attestation_token\":\"" + escapeJson(attestationToken) + "\"}"; + } + + // ------------------------------------------------------------------------- + // Step 7: Acquire mTLS PoP token + // ------------------------------------------------------------------------- + + private static ManagedIdentityResponse acquireMtlsToken(byte[] keyHandle, + byte[] certDer, + IssueCertificateResponse certResponse, + String resource) { + // Build the OAuth2 token request body + String scope = resource.endsWith("/.default") ? resource : resource + "/.default"; + String requestBody = "grant_type=" + OAUTH2_GRANT_TYPE + + "&client_id=" + urlEncode(certResponse.clientId) + + "&scope=" + urlEncode(scope) + + "&token_type=" + MTLS_TOKEN_TYPE; + + // Build the token endpoint URL: {mtlsEndpoint}/{tenantId}/oauth2/v2.0/token + String tokenEndpointUrl = buildTokenEndpointUrl( + certResponse.mtlsAuthenticationEndpoint, certResponse.tenantId); + + // Use the native mTLS method to acquire the token + String responseBody = WindowsKeyGuardJNI.acquireMtlsTokenNative( + keyHandle, certDer, tokenEndpointUrl, requestBody); + + return parseMtlsTokenResponse(responseBody, resource); + } + + static String buildTokenEndpointUrl(String mtlsEndpoint, String tenantId) { + String base = mtlsEndpoint; + if (base.endsWith("/")) { + base = base.substring(0, base.length() - 1); + } + return base + "/" + tenantId + "/oauth2/v2.0/token"; + } + + private static ManagedIdentityResponse parseMtlsTokenResponse(String responseBody, String resource) { + try (JsonReader reader = JsonProviders.createReader(responseBody)) { + ManagedIdentityResponse response = new ManagedIdentityResponse(); + reader.readObject(r -> { + while (r.nextToken() != JsonToken.END_OBJECT) { + String fieldName = r.getFieldName(); + r.nextToken(); + switch (fieldName) { + case "access_token": + response.accessToken = r.getString(); + break; + case "expires_in": + // Convert expires_in (seconds from now) to expires_on (epoch seconds) + long expiresIn = r.getLong(); + response.expiresOn = String.valueOf( + (System.currentTimeMillis() / 1000) + expiresIn); + break; + case "expires_on": + response.expiresOn = r.getString(); + break; + case "token_type": + response.tokenType = r.getString(); + break; + case "client_id": + response.clientId = r.getString(); + break; + default: + r.skipChildren(); + break; + } + } + return response; + }); + response.resource = resource; + if (response.tokenType == null) { + response.tokenType = MTLS_TOKEN_TYPE; + } + return response; + } catch (IOException e) { + throw new MsiV2Exception( + MsalErrorMessage.MSI_V2_TOKEN_ACQUISITION_FAILED + " JSON parse error: " + e.getMessage(), + MsalError.MSI_V2_ERROR, e); + } + } + + // ------------------------------------------------------------------------- + // Utility helpers + // ------------------------------------------------------------------------- + + private static String extractBase64FromPem(String pem) { + return pem + .replace("-----BEGIN CERTIFICATE REQUEST-----", "") + .replace("-----END CERTIFICATE REQUEST-----", "") + .replaceAll("\\s", ""); + } + + private static String escapeJson(String value) { + if (value == null) { + return ""; + } + return value.replace("\\", "\\\\").replace("\"", "\\\""); + } + + private static String urlEncode(String value) { + if (value == null) { + return ""; + } + try { + return java.net.URLEncoder.encode(value, StandardCharsets.UTF_8.name()); + } catch (Exception e) { + return value; + } + } + + private MsiV2() { + // Utility class, not instantiable + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2Exception.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2Exception.java new file mode 100644 index 00000000..df7ad97c --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/MsiV2Exception.java @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +/** + * Exception type thrown when an error occurs during MSI v2 (mTLS Proof-of-Possession with + * KeyGuard attestation) token acquisition. + *

+ * MSI v2 errors are not recoverable by falling back to MSI v1. When MSI v2 is explicitly + * requested (both {@code mtlsProofOfPossession=true} and {@code withAttestationSupport=true}), + * any failure in the MSI v2 flow will result in this exception rather than a silent fallback. + */ +public class MsiV2Exception extends MsalException { + + /** + * Initializes a new instance of the exception class with a message and error code. + * + * @param message the error message that explains the reason for the exception + * @param errorCode a simplified error code for references in documentation + */ + public MsiV2Exception(final String message, final String errorCode) { + super(message, errorCode); + } + + /** + * Initializes a new instance of the exception class with a message, error code, and cause. + * + * @param message the error message that explains the reason for the exception + * @param errorCode a simplified error code for references in documentation + * @param cause the inner exception that is the cause of the current exception + */ + public MsiV2Exception(final String message, final String errorCode, final Throwable cause) { + super(message, errorCode); + initCause(cause); + } +} diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/WindowsKeyGuardJNI.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/WindowsKeyGuardJNI.java new file mode 100644 index 00000000..ed2d89c5 --- /dev/null +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/WindowsKeyGuardJNI.java @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * JNI bridge to the Windows KeyGuard native library ({@code MsalJNIBridge.dll}). + *

+ * Provides access to: + *

    + *
  • Creating hardware-protected RSA keys via NCrypt and VBS (Virtualization Based Security)
  • + *
  • Signing data with the KeyGuard key for PKCS#10 CSR generation
  • + *
  • Obtaining attestation JWTs via the Windows AttestationClientLib
  • + *
  • Performing mTLS HTTPS connections using the hardware-bound private key
  • + *
+ *

+ * Platform requirements: Windows with Virtualization Based Security (VBS) enabled. + * All native methods will throw {@link MsiV2Exception} with error code + * {@link MsalError#MSI_V2_KEYGUARD_UNAVAILABLE} when called on unsupported platforms. + */ +class WindowsKeyGuardJNI { + + private static final Logger LOG = LoggerFactory.getLogger(WindowsKeyGuardJNI.class); + + private static final boolean NATIVE_LIBRARY_LOADED; + + static { + boolean loaded = false; + try { + System.loadLibrary("MsalJNIBridge"); + loaded = true; + LOG.debug("[MSI v2] Native MsalJNIBridge library loaded successfully."); + } catch (UnsatisfiedLinkError e) { + LOG.debug("[MSI v2] Native MsalJNIBridge library is not available: {}", e.getMessage()); + } + NATIVE_LIBRARY_LOADED = loaded; + } + + /** + * Returns {@code true} if the native KeyGuard library is loaded and available. + * This does not guarantee VBS is enabled; call {@link #createKeyGuardRsaKeyNative} to determine + * actual availability at runtime. + * + * @return {@code true} if the native library was loaded + */ + static boolean isNativeLibraryLoaded() { + return NATIVE_LIBRARY_LOADED; + } + + /** + * Creates a VBS-isolated, per-boot RSA key in the Windows CNG key store via NCrypt. + * The key is non-exportable and hardware-protected by Virtualization Based Security (VBS). + * + * @param keyName the NCrypt key name (e.g., {@code "MsalKeyGuardKey"}) + * @param keySizeBits the RSA key size in bits (e.g., {@code 2048}) + * @return an opaque native key handle used for subsequent signing and attestation calls. + * The caller is responsible for freeing this handle via {@link #freeKeyHandleNative}. + * @throws MsiV2Exception if VBS is unavailable or key creation fails + */ + static native byte[] createKeyGuardRsaKeyNative(String keyName, int keySizeBits); + + /** + * Returns the DER-encoded RSA public key corresponding to the provided native key handle. + * + * @param keyHandle the native key handle returned by {@link #createKeyGuardRsaKeyNative} + * @return DER-encoded SubjectPublicKeyInfo (RSA public key) + * @throws MsiV2Exception if the key handle is invalid or the operation fails + */ + static native byte[] getPublicKeyNative(byte[] keyHandle); + + /** + * Signs the given data with the KeyGuard RSA private key using RSA-PSS/SHA-256. + * + * @param keyHandle the native key handle returned by {@link #createKeyGuardRsaKeyNative} + * @param dataToSign the raw bytes to sign (typically the DER-encoded CertificationRequestInfo TBS) + * @return the RSA-PSS/SHA-256 signature bytes + * @throws MsiV2Exception if signing fails + */ + static native byte[] signWithKeyGuardNative(byte[] keyHandle, byte[] dataToSign); + + /** + * Obtains a JWT attestation token from the Windows AttestationClientLib. + * The attestation proves that the key is VBS-protected and non-exportable. + *

+ * Note: This method requires {@code AttestationClientLib.dll} to be present. + * If the DLL is not available, a {@link MsiV2Exception} will be thrown. + * + * @param attestationEndpoint the URL of the attestation service (from IMDS platform metadata) + * @param keyHandle the native key handle to attest + * @return a JWT string representing the attestation token + * @throws MsiV2Exception if attestation fails or the attestation service is unavailable + */ + static native String getAttestationTokenNative(String attestationEndpoint, byte[] keyHandle); + + /** + * Performs an mTLS HTTPS POST request to the specified token endpoint using the KeyGuard + * private key as the client certificate private key. + *

+ * This method handles the TLS handshake natively, using the hardware-protected key for the + * client authentication step of the mTLS connection. + *

+ * Note: The native implementation requires Windows with VBS enabled and + * {@code MsalJNIBridge.dll} built and available in the library path. + * + * @param keyHandle the native key handle for the client certificate private key + * @param certDer the DER-encoded X.509 client certificate (from IMDS issuecredential) + * @param tokenEndpointUrl the mTLS token endpoint URL (e.g., regional ESTS endpoint) + * @param requestBody the URL-encoded OAuth2 token request body + * @return the HTTP response body as a UTF-8 string (JSON token response) + * @throws MsiV2Exception if the mTLS connection or token request fails + */ + static native String acquireMtlsTokenNative(byte[] keyHandle, byte[] certDer, + String tokenEndpointUrl, String requestBody); + + /** + * Frees the native memory and NCrypt key handle associated with a KeyGuard key. + * This should be called when the key handle is no longer needed to avoid resource leaks. + * + * @param keyHandle the native key handle to free + */ + static native void freeKeyHandleNative(byte[] keyHandle); + + private WindowsKeyGuardJNI() { + // Utility class, not instantiable + } +} diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MsiV2Tests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MsiV2Tests.java new file mode 100644 index 00000000..91c28e67 --- /dev/null +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/MsiV2Tests.java @@ -0,0 +1,373 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for the MSI v2 mTLS PoP flow, covering: + * - Parameter validation (attestation requires mTLS PoP) + * - Gating logic (both flags required for MSI v2) + * - No silent fallback to MSI v1 on MSI v2 failure + * - URL construction utilities + */ +@ExtendWith(MockitoExtension.class) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class MsiV2Tests { + + private static final String RESOURCE = ManagedIdentityTestConstants.RESOURCE; + + @BeforeAll + void setupRetryPolicies() { + // Set retry delays to 1ms for faster test execution + ManagedIdentityRetryPolicy.setRetryDelayMs(1); + IMDSRetryPolicy.setRetryDelayMs(1); + } + + @AfterAll + void resetRetryPolicies() { + ManagedIdentityRetryPolicy.resetToDefaults(); + IMDSRetryPolicy.resetToDefaults(); + } + + /** + * An executor that runs tasks synchronously on the calling thread. + * This ensures MockedStatic is active when the code executes. + */ + private static final ExecutorService CURRENT_THREAD_EXECUTOR = new AbstractExecutorService() { + @Override public void execute(Runnable command) { command.run(); } + @Override public void shutdown() {} + @Override public List shutdownNow() { return Collections.emptyList(); } + @Override public boolean isShutdown() { return false; } + @Override public boolean isTerminated() { return false; } + @Override public boolean awaitTermination(long timeout, TimeUnit unit) { return true; } + }; + + // ------------------------------------------------------------------------- + // 1. Parameter validation tests + // ------------------------------------------------------------------------- + + @Nested + class ParameterValidationTests { + + private ManagedIdentityApplication miApp; + private DefaultHttpClient httpClientMock; + + @BeforeEach + void setUp() throws Exception { + ManagedIdentityApplication.setEnvironmentVariables( + new EnvironmentVariablesHelper(ManagedIdentitySourceType.DEFAULT_TO_IMDS, null)); + // Mock HTTP client returns a generic 500 error to avoid real network calls + httpClientMock = mock(DefaultHttpClient.class); + HttpResponse errorResponse = new HttpResponse(); + errorResponse.statusCode(500); + errorResponse.body(ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_500); + org.mockito.Mockito.lenient().when(httpClientMock.send(any())).thenReturn(errorResponse); + + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .executorService(CURRENT_THREAD_EXECUTOR) + .build(); + miApp.tokenCache().accessTokens.clear(); + } + + @Test + void attestationWithoutPop_throwsMsalClientException() throws Exception { + // withAttestationSupport=true requires mtlsProofOfPossession=true + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .withAttestationSupport(true) + .mtlsProofOfPossession(false) + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertInstanceOf(MsalClientException.class, ex.getCause()); + MsalClientException msalEx = (MsalClientException) ex.getCause(); + assertEquals(MsalError.MSI_V2_ATTESTATION_REQUIRES_POP, msalEx.errorCode()); + assertTrue(msalEx.getMessage().contains("withAttestationSupport")); + } + + @Test + void attestationWithoutPopExplicitFalse_throwsMsalClientException() throws Exception { + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .withAttestationSupport(true) + // mtlsProofOfPossession defaults to false + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertInstanceOf(MsalClientException.class, ex.getCause()); + assertEquals(MsalError.MSI_V2_ATTESTATION_REQUIRES_POP, + ((MsalClientException) ex.getCause()).errorCode()); + } + + @Test + void popOnlyWithoutAttestation_doesNotThrowAttestationError() throws Exception { + // mtlsProofOfPossession=true alone (without attestation) should fall through to MSI v1. + // It should fail because IMDS is not available in tests, but NOT with MSI v2 attestation error. + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .withAttestationSupport(false) + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + // Must NOT be an attestation_requires_pop error + assertFalse(ex.getCause() instanceof MsalClientException + && MsalError.MSI_V2_ATTESTATION_REQUIRES_POP.equals( + ((MsalClientException) ex.getCause()).errorCode()), + "PoP-only flag should NOT trigger attestation-requires-PoP validation error"); + // Must NOT be a MsiV2Exception (MSI v2 should not be called without attestation) + assertFalse(ex.getCause() instanceof MsiV2Exception, + "PoP-only should not use the MSI v2 flow (no MsiV2Exception expected)"); + } + + @Test + void neitherFlagSet_doesNotThrowMsiV2Exception() throws Exception { + // No flags set - standard MSI v1 flow (no MSI v2) + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + // Standard MSI v1 error (network or IMDS not available), not a MsiV2Exception + assertFalse(ex.getCause() instanceof MsiV2Exception, + "Standard MSI v1 path should not throw MsiV2Exception"); + } + } + + // ------------------------------------------------------------------------- + // 2. MSI v2 gating tests (both flags set) + // ------------------------------------------------------------------------- + + @Nested + class MsiV2GatingTests { + + private ManagedIdentityApplication miApp; + private DefaultHttpClient httpClientMock; + + @BeforeEach + void setUp() throws Exception { + ManagedIdentityApplication.setEnvironmentVariables( + new EnvironmentVariablesHelper(ManagedIdentitySourceType.DEFAULT_TO_IMDS, null)); + // Mock HTTP client returns a generic 500 error for MSI v1 fallback tests + httpClientMock = mock(DefaultHttpClient.class); + HttpResponse errorResponse = new HttpResponse(); + errorResponse.statusCode(500); + errorResponse.body(ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_500); + org.mockito.Mockito.lenient().when(httpClientMock.send(any())).thenReturn(errorResponse); + + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .executorService(CURRENT_THREAD_EXECUTOR) + .build(); + miApp.tokenCache().accessTokens.clear(); + } + + @Test + void bothFlagsSet_triggersMsiV2Path() throws Exception { + // When both flags are set, the MSI v2 path is taken. + // Since the native KeyGuard library is unavailable in the test environment, + // MsiV2.obtainToken() will throw MsiV2Exception with KEYGUARD_UNAVAILABLE. + // This proves the MSI v2 code path was entered. + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .withAttestationSupport(true) + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + // The MSI v2 path was entered - confirmed by MsiV2Exception with KEYGUARD_UNAVAILABLE + assertInstanceOf(MsiV2Exception.class, ex.getCause(), + "Both flags set should enter MSI v2 path (evidenced by MsiV2Exception for unavailable KeyGuard)"); + assertEquals(MsalError.MSI_V2_KEYGUARD_UNAVAILABLE, + ((MsiV2Exception) ex.getCause()).errorCode()); + } + + @Test + void bothFlagsSet_msiV2Exception_propagatesWithoutFallback() throws Exception { + // Verify that a MsiV2Exception propagates without silently falling back to MSI v1. + // Since both flags are set and native library is unavailable, we expect MsiV2Exception. + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .withAttestationSupport(true) + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + // Must propagate as MsiV2Exception (no fallback to MSI v1 which would produce MsalServiceException) + assertInstanceOf(MsiV2Exception.class, ex.getCause(), + "MsiV2Exception must propagate without silent fallback to MSI v1"); + } + + @Test + void onlyPopWithoutAttestation_doesNotTriggerMsiV2Path() throws Exception { + // With PoP only (no attestation), MSI v1 is used. The error should NOT be MsiV2Exception. + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .withAttestationSupport(false) + .build(); + + CompletableFuture future = + miApp.acquireTokenForManagedIdentity(params); + + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertFalse(ex.getCause() instanceof MsiV2Exception, + "PoP-only should use MSI v1 path (no MsiV2Exception)"); + } + } + + // ------------------------------------------------------------------------- + // 3. ManagedIdentityParameters builder tests + // ------------------------------------------------------------------------- + + @Nested + class ManagedIdentityParametersBuilderTests { + + @Test + void defaultValues_bothFlagsFalse() { + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE).build(); + assertFalse(params.mtlsProofOfPossession()); + assertFalse(params.withAttestationSupport()); + } + + @Test + void mtlsPopOnlyFlag_setCorrectly() { + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .build(); + assertTrue(params.mtlsProofOfPossession()); + assertFalse(params.withAttestationSupport()); + } + + @Test + void bothFlags_setCorrectly() { + ManagedIdentityParameters params = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .withAttestationSupport(true) + .build(); + assertTrue(params.mtlsProofOfPossession()); + assertTrue(params.withAttestationSupport()); + } + + @Test + void builderToString_includesNewFields() { + String str = ManagedIdentityParameters.builder(RESOURCE) + .mtlsProofOfPossession(true) + .withAttestationSupport(true) + .toString(); + assertTrue(str.contains("mtlsProofOfPossession=true")); + assertTrue(str.contains("withAttestationSupport=true")); + } + } + + // ------------------------------------------------------------------------- + // 4. MsiV2 utility method tests + // ------------------------------------------------------------------------- + + @Nested + class MsiV2UtilityTests { + + @Test + void buildTokenEndpointUrl_noTrailingSlash() { + String result = MsiV2.buildTokenEndpointUrl( + "https://eastus.mtlsauth.microsoft.com", "tenant-id"); + assertEquals("https://eastus.mtlsauth.microsoft.com/tenant-id/oauth2/v2.0/token", + result); + } + + @Test + void buildTokenEndpointUrl_withTrailingSlash() { + String result = MsiV2.buildTokenEndpointUrl( + "https://eastus.mtlsauth.microsoft.com/", "tenant-id"); + assertEquals("https://eastus.mtlsauth.microsoft.com/tenant-id/oauth2/v2.0/token", + result); + } + + @Test + void buildTokenEndpointUrl_differentTenantId() { + String result = MsiV2.buildTokenEndpointUrl( + "https://westus2.mtlsauth.microsoft.com", "my-tenant-123"); + assertEquals("https://westus2.mtlsauth.microsoft.com/my-tenant-123/oauth2/v2.0/token", + result); + } + } + + // ------------------------------------------------------------------------- + // 5. MsiV2Exception tests + // ------------------------------------------------------------------------- + + @Nested + class MsiV2ExceptionTests { + + @Test + void msiV2Exception_hasCorrectMessageAndErrorCode() { + MsiV2Exception ex = new MsiV2Exception("test message", MsalError.MSI_V2_ERROR); + assertEquals("test message", ex.getMessage()); + assertEquals(MsalError.MSI_V2_ERROR, ex.errorCode()); + } + + @Test + void msiV2Exception_withCause_hasCause() { + RuntimeException cause = new RuntimeException("root cause"); + MsiV2Exception ex = new MsiV2Exception("msg", MsalError.MSI_V2_ERROR, cause); + assertEquals(cause, ex.getCause()); + } + + @Test + void msiV2Exception_isInstanceOfMsalException() { + MsiV2Exception ex = new MsiV2Exception("msg", MsalError.MSI_V2_ERROR); + assertInstanceOf(MsalException.class, ex); + } + } + + // ------------------------------------------------------------------------- + // 6. WindowsKeyGuardJNI availability test + // ------------------------------------------------------------------------- + + @Nested + class WindowsKeyGuardJNITests { + + @Test + void nativeLibraryNotLoaded_isNativeLibraryLoadedReturnsFalse() { + // On non-Windows or systems without the native library, it should be false + boolean result = WindowsKeyGuardJNI.isNativeLibraryLoaded(); + assertFalse(result, "Native library should not be loaded in test environment"); + } + } +}