diff --git a/docs/Guidance-for-Higher-Level-SDKs-to-Consume-MSAL.md b/docs/Guidance-for-Higher-Level-SDKs-to-Consume-MSAL.md new file mode 100644 index 00000000..e9f3f306 --- /dev/null +++ b/docs/Guidance-for-Higher-Level-SDKs-to-Consume-MSAL.md @@ -0,0 +1,351 @@ +# SDK Integration Guide — Building on MSAL Python mTLS PoP + +## Purpose + +This document explains how higher-level SDKs (e.g., `azure-identity`, `azure-sdk-for-python`) +can consume MSAL Python's MSI v2 mTLS Proof-of-Possession API to provide seamless mTLS +token-bound authentication to end users. + +## What MSAL Python PR #931 Exposes + +### Public API Surface + +```python +import msal +import requests + +# 1. Create ManagedIdentityClient (existing API) +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), +) + +# 2. Acquire mTLS-bound token (new parameters) +result = client.acquire_token_for_client( + resource="https://vault.azure.net", + mtls_proof_of_possession=True, + with_attestation_support=True, +) +``` + +### Return Value Contract + +```python +{ + "access_token": "eyJ0eXAi...", # mTLS-bound PoP token + "token_type": "mtls_pop", # Token type (use in auth header) + "expires_in": 86399, # Seconds until expiry + "binding_certificate": WindowsCertificate(...), # Opaque cert+key handle + "cert_thumbprint_sha256": "buc7x...", # Base64url SHA-256 thumbprint + "cert_pem": "-----BEGIN CERTIFICATE-----\n...", # Public cert (PEM) + "cert_der_b64": "MIIC...", # Public cert (Base64 DER) +} +``` + +### Key Objects + +#### `WindowsCertificate` + +Python equivalent of .NET's `X509Certificate2`. Wraps a non-exportable private key +(NCRYPT_KEY_HANDLE) and the associated X.509 certificate. + +```python +from msal import WindowsCertificate + +cert: WindowsCertificate = result["binding_certificate"] + +# Properties +cert.thumbprint_sha256 # str: base64url SHA-256 thumbprint +cert.pem # str: PEM-encoded public certificate +cert.subject # str: certificate subject + +# Methods +cert.create_cert_context() # -> PCCERT_CONTEXT (for WinHTTP/SChannel) +cert.close() # Free native handles (or use as context manager) +``` + +#### `MsiV2Error` + +```python +from msal import MsiV2Error + +try: + result = client.acquire_token_for_client(...) +except MsiV2Error as e: + # MSI v2 specific error (attestation failure, IMDS error, etc.) + pass +``` + +--- + +## Integration Pattern for Azure SDK + +### Architecture + +``` +┌──────────────────────────────────────────────────────────────┐ +│ Your SDK (e.g., azure-identity) │ +│ │ +│ ManagedIdentityCredential │ +│ ├── calls MSAL acquire_token_for_client(mtls_pop=True) │ +│ ├── receives binding_certificate + access_token │ +│ └── returns AccessToken + metadata to pipeline │ +├──────────────────────────────────────────────────────────────┤ +│ Transport Layer (e.g., azure-core) │ +│ │ +│ MtlsTransport (new, SChannel-based) │ +│ ├── receives WindowsCertificate from credential │ +│ ├── calls cert.create_cert_context() for TLS handshake │ +│ └── presents cert during client-auth in mTLS │ +├──────────────────────────────────────────────────────────────┤ +│ Resource SDK (e.g., azure-keyvault-secrets) │ +│ └── unaware of mTLS — just uses credential + transport │ +└──────────────────────────────────────────────────────────────┘ +``` + +### Step-by-Step Integration + +#### Step 1: Token Acquisition (in your Credential class) + +```python +# Inside azure-identity ManagedIdentityCredential +import msal + +class ManagedIdentityCredential: + def __init__(self, *, mtls_pop: bool = False, **kwargs): + self._mtls_pop = mtls_pop + self._msal_client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=self._http_client, + ) + self._binding_certificate = None + + def get_token(self, *scopes, **kwargs) -> AccessToken: + if self._mtls_pop: + result = self._msal_client.acquire_token_for_client( + resource=scopes[0].rstrip("/.default"), + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + # Store the binding certificate for the transport layer + self._binding_certificate = result["binding_certificate"] + + return AccessToken( + token=result["access_token"], + expires_on=int(time.time()) + result["expires_in"], + token_type=result["token_type"], # "mtls_pop" + ) + else: + # Standard MSI v1 path + ... +``` + +#### Step 2: Authorization Header Construction + +The token type dictates the auth header format: + +```python +# In your BearerTokenPolicy or equivalent: +def on_request(self, request): + token = self._credential.get_token(self._scopes) + + # token_type comes from MSAL — use it directly + request.headers["Authorization"] = f"{token.token_type} {token.token}" + + # For mTLS PoP to token-binding-aware services: + if token.token_type == "mtls_pop": + request.headers["x-ms-tokenboundauth"] = "true" +``` + +#### Step 3: mTLS Transport (presenting the certificate) + +Standard Python `requests` **cannot** use non-exportable keys. You need a +SChannel/WinHTTP transport: + +```python +# Option A: Use msal-schannel-transport directly +from msal_schannel_transport import SchannelSession + +session = SchannelSession() +response = session.get( + url, + headers=headers, + client_certificate=credential._binding_certificate, +) +``` + +```python +# Option B: Build your own azure-core HttpTransport +from azure.core.pipeline.transport import HttpTransport + +class SchannelTransport(HttpTransport): + """WinHTTP/SChannel transport for mTLS with non-exportable keys.""" + + def __init__(self): + self._winhttp = None # Lazy-load WinHTTP bindings + + def send(self, request, *, binding_certificate=None, **kwargs): + if binding_certificate: + # Use WinHTTP with the certificate's CERT_CONTEXT + cert_ctx = binding_certificate.create_cert_context() + # ... WinHTTP call with WINHTTP_OPTION_CLIENT_CERT_CONTEXT ... + else: + # Fall back to standard requests transport + ... +``` + +#### Step 4: End-User Experience (the goal) + +```python +from azure.identity import ManagedIdentityCredential +from azure.keyvault.secrets import SecretClient + +# One line to enable mTLS PoP +credential = ManagedIdentityCredential(mtls_pop=True) + +# Standard SDK usage — no mTLS awareness needed by the developer +client = SecretClient( + vault_url="https://tokenbinding.vault.azure.net", + credential=credential, +) +secret = client.get_secret("boundsecret") +print(secret.value) # "secretme" +``` + +--- + +## Key Design Decisions + +### 1. MSAL Only Acquires Tokens + +MSAL never makes downstream API calls. It returns the `binding_certificate` object +and the access token. The higher-level SDK is responsible for: +- Constructing the authorization header +- Presenting the certificate during TLS handshake +- Managing the transport layer + +### 2. `WindowsCertificate` is the Bridge + +The `WindowsCertificate` object is the contract between MSAL (token acquisition) +and the transport layer (mTLS presentation). It: +- Holds the live NCRYPT_KEY_HANDLE (non-exportable private key) +- Creates thread-safe CERT_CONTEXT instances for concurrent requests +- Manages handle lifecycle (reference counting, cleanup) + +### 3. Transport is Pluggable + +Higher-level SDKs can choose: +- **`msal-schannel-transport`** — ready-to-use WinHTTP session (ships with this PR) +- **Custom transport** — build your own using `cert.create_cert_context()` +- **Future: OpenSSL provider** — when available, standard `requests` will work + +### 4. Certificate Lifecycle + +```python +# WindowsCertificate is valid for the lifetime of the issued certificate +# (typically 8 hours for IMDS-issued certs). MSAL caches internally. + +# Pattern: acquire once, reuse for multiple requests +result = client.acquire_token_for_client(...) +cert = result["binding_certificate"] + +# Make many requests with the same cert +for url in urls: + session.get(url, client_certificate=cert) + +# When done (optional — GC handles this too): +cert.close() +``` + +--- + +## Why Standard `requests` Doesn't Work + +```python +# This FAILS with non-exportable keys: +requests.get(url, cert=("cert.pem", "key.pem")) +# ^^^^^^^^ +# KeyGuard/TPM keys cannot be exported to a file! +``` + +Python's `ssl` module → OpenSSL → requires raw private key bytes. +KeyGuard/VBS keys are hardware-isolated and never leave the security boundary. + +**Solution:** Use WinHTTP/SChannel which integrates natively with Windows +certificate stores and NCRYPT_KEY_HANDLEs. + +--- + +## Comparison with .NET + +| Concept | .NET | Python (this PR) | +|---------|------|-------------------| +| Token acquisition | `MsalMtlsPopProvider.GetTokenAsync()` | `client.acquire_token_for_client(mtls_pop=True)` | +| Certificate object | `X509Certificate2` | `WindowsCertificate` | +| Downstream mTLS | `HttpClientHandler.ClientCertificates` | `SchannelSession` or custom transport | +| Key isolation | Automatic (SChannel) | Automatic (WinHTTP/SChannel) | +| Auth header | `$"{tokenType} {token}"` | `f"{result['token_type']} {result['access_token']}"` | + +--- + +## Minimum Integration Example + +For SDK authors who want the fastest path to integration: + +```python +"""Minimal integration — 20 lines to full mTLS PoP.""" +import msal +from msal_schannel_transport import SchannelSession + +# Acquire token +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=__import__("requests").Session(), +) +result = client.acquire_token_for_client( + resource="https://vault.azure.net", + mtls_proof_of_possession=True, + with_attestation_support=True, +) + +# Downstream call +session = SchannelSession() +response = session.get( + "https://tokenbinding.vault.azure.net/secrets/boundsecret/?api-version=2015-06-01", + headers={ + "Authorization": f"{result['token_type']} {result['access_token']}", + "x-ms-tokenboundauth": "true", + }, + client_certificate=result["binding_certificate"], +) +print(response.status_code, response.text) +``` + +--- + +## Package Dependencies + +| Package | Role | Required? | +|---------|------|-----------| +| `msal` (with PR #931) | Token acquisition + `WindowsCertificate` | Yes | +| `msal-key-attestation` | MAA attestation (KeyGuard proof) | Yes for `with_attestation_support=True` | +| `msal-schannel-transport` | WinHTTP-based downstream mTLS | Yes (or build your own) | +| `requests` | HTTP client for MSAL's IMDS calls | Yes | + +--- + +## Future: OpenSSL 3 CNG Provider (Strategic Path) + +When a Microsoft-supported OpenSSL 3 CNG Provider becomes available, the +transport layer simplifies to: + +```python +from azure_identity_mtls import create_mtls_context +import requests + +ctx = create_mtls_context(thumbprint=result["cert_thumbprint_sha256"]) +response = requests.get(url, headers=headers) # standard requests, no WinHTTP needed +``` + +This is a future investment. The current WinHTTP approach is production-ready +and provides identical security guarantees. diff --git a/docs/MSI_V2_API.md b/docs/MSI_V2_API.md new file mode 100644 index 00000000..281bfa5a --- /dev/null +++ b/docs/MSI_V2_API.md @@ -0,0 +1,202 @@ +# MSI v2 — mTLS Proof-of-Possession API Reference + +## Overview + +MSI v2 enables Managed Identity token acquisition with mTLS Proof-of-Possession +on Windows Azure VMs with Credential Guard (KeyGuard). Tokens are +cryptographically bound to a per-boot hardware-isolated key. + +## Requirements + +- Windows Azure VM with Credential Guard / KeyGuard enabled +- `AttestationClientLib.dll` available on the system +- Python 3.9+ + +## Installation + +```bash +pip install msal msal-key-attestation requests +``` + +## Public API + +### Token Acquisition + +```python +import msal +import requests + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), +) + +result = client.acquire_token_for_client( + resource="https://graph.microsoft.com", + mtls_proof_of_possession=True, # Enable MSI v2 mTLS PoP + with_attestation_support=True, # Require msal-key-attestation +) +``` + +#### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `resource` | str | Yes | Resource URI (e.g., `https://graph.microsoft.com`) | +| `mtls_proof_of_possession` | bool | No | Enable mTLS PoP token binding | +| `with_attestation_support` | bool | No | Require MAA attestation (needs `msal-key-attestation`) | + +#### Return Value + +On success, returns a dict: + +```python +{ + "access_token": "eyJ0eXAi...", # The mTLS-bound access token + "expires_in": 3599, # Token lifetime in seconds + "token_type": "mtls_pop", # Always "mtls_pop" for this flow + "binding_certificate": WindowsCertificate(...), # Use with SchannelSession + "cert_der_b64": "MIIC...", # Base64-encoded DER certificate + "cert_pem": "-----BEGIN CERT...", # PEM-encoded public certificate only + "cert_thumbprint_sha256": "buc7x...", # Base64url SHA-256 thumbprint +} +``` + +On failure, raises `MsiV2Error`. + +### User-Assigned Managed Identity + +```python +# By client_id +client = msal.ManagedIdentityClient( + msal.UserAssignedManagedIdentity(client_id="11111111-..."), + http_client=requests.Session(), +) + +# By object_id +client = msal.ManagedIdentityClient( + msal.UserAssignedManagedIdentity(object_id="22222222-..."), + http_client=requests.Session(), +) + +# By resource_id +client = msal.ManagedIdentityClient( + msal.UserAssignedManagedIdentity( + resource_id="/subscriptions/.../providers/Microsoft.ManagedIdentity/..."), + http_client=requests.Session(), +) +``` + +### Downstream mTLS Call + +After acquiring the token, use the returned `binding_certificate` for +downstream mTLS API calls. The `cert_pem` in the result is the public +certificate bound to the token's `cnf.x5t#S256` claim, but it is not enough +to perform downstream mTLS by itself because the private key remains +platform-backed and non-exportable. + +> **Note:** Standard Python `requests` cannot present a KeyGuard-bound +> certificate (the private key is non-exportable). Use WinHTTP/SChannel +> for production mTLS calls. Use `SchannelSession` with the returned +> `binding_certificate` to present the certificate during the TLS handshake. + +```python +# Authorization header format: +headers = { + "Authorization": f"mtls_pop {result['access_token']}" +} +``` + +### Helper Functions + +```python +from msal.msi_v2 import get_cert_thumbprint_sha256, verify_cnf_binding + +# Compute base64url SHA-256 thumbprint of a PEM certificate +thumbprint = get_cert_thumbprint_sha256(cert_pem) + +# Verify that a JWT's cnf.x5t#S256 matches the certificate +is_bound = verify_cnf_binding(access_token, cert_pem) +``` + +## Error Handling + +```python +from msal import MsiV2Error, ManagedIdentityError + +try: + result = client.acquire_token_for_client(...) +except MsiV2Error as e: + # MSI v2 specific failure (no v1 fallback) + print(f"MSI v2 failed: {e}") +except ManagedIdentityError as e: + # General managed identity error + print(f"MI error: {e}") +``` + +### Error Hierarchy + +``` +ValueError +└── ManagedIdentityError + └── MsiV2Error +``` + +### Common Errors + +| Error Message | Cause | +|---------------|-------| +| `KeyGuard + mTLS PoP is Windows-only` | Running on non-Windows | +| `with_attestation_support=True requires...` | `msal-key-attestation` not installed | +| `attestation_requires_pop` | `with_attestation_support=True` without `mtls_proof_of_possession=True` | +| `getplatformmetadata missing required fields` | VM doesn't support IMDS v2 | +| `attestationEndpoint missing` | VM doesn't have MAA configured | + +## msal-key-attestation Package + +Separate pip package providing Windows `AttestationClientLib.dll` bindings. + +### API + +```python +from msal_key_attestation import create_attestation_provider + +provider = create_attestation_provider() +# provider(endpoint, key_handle, client_id, cache_key) -> str (JWT) +``` + +The provider is automatically discovered by MSAL when +`with_attestation_support=True` is set. + +### Environment Variables + +| Variable | Description | +|----------|-------------| +| `ATTESTATION_CLIENTLIB_PATH` | Absolute path to `AttestationClientLib.dll` | +| `MSAL_MSI_V2_ATTESTATION_CACHE` | Set to `"0"` to disable MAA token cache | + +## Environment Variables (Core) + +| Variable | Description | +|----------|-------------| +| `AZURE_POD_IDENTITY_AUTHORITY_HOST` | Override IMDS base URL (default: `http://169.254.169.254`) | +| `MSAL_MSI_V2_KEY_NAME` | Override per-boot key name | + +## Flow Diagram + +``` +1. NCrypt → Create/open KeyGuard RSA key (VBS-isolated, per-boot) +2. IMDS → GET /getplatformmetadata → clientId, tenantId, cuId, attestationEndpoint +3. CSR → Build PKCS#10 (RSA-PSS/SHA256 + cuId OID attribute) +4. Attestation → AttestationClientLib.dll → MAA JWT (proves key is KeyGuard-protected) +5. IMDS → POST /issuecredential {csr, attestation_token} → X.509 certificate +6. Crypt32 → Bind certificate to NCrypt key handle (SChannel-ready) +7. WinHTTP → POST /oauth2/v2.0/token via mTLS → mtls_pop access token +``` + +## Caching + +- **Certificate cache**: In-memory, process-local. Evicts when remaining + lifetime < 1 hour. Keyed by managed identity + attestation mode. +- **MAA token cache** (in `msal-key-attestation`): Refreshes at 90% of + JWT lifetime, 10-second absolute guard before expiry. diff --git a/docs/mTLS-PoP-Architecture-Decision.md b/docs/mTLS-PoP-Architecture-Decision.md new file mode 100644 index 00000000..5b1def08 --- /dev/null +++ b/docs/mTLS-PoP-Architecture-Decision.md @@ -0,0 +1,131 @@ +# MSAL Python — mTLS PoP with Non-Exportable Keys: Architecture Decision + +## Problem + +Python applications on Azure VMs with Credential Guard / KeyGuard need to: +1. Acquire an `mtls_pop` token bound to a non-exportable private key +2. Make downstream API calls presenting that same key via mTLS (TLS client certificate) + +The private key lives inside Windows VBS (Virtualization-Based Security) and **cannot be exported as raw bytes**. This is by design — it's the entire security model of KeyGuard. + +## Why Python Has a Gap + +Every mainstream Python HTTP library uses OpenSSL for TLS: + +| Library | TLS Backend | Can use non-exportable keys? | +|---------|-------------|------------------------------| +| `requests` / `urllib3` | OpenSSL | ❌ Requires `key_bytes` | +| `httpx` | OpenSSL | ❌ Requires `key_bytes` | +| `aiohttp` | OpenSSL | ❌ Requires `key_bytes` | +| `ssl` stdlib | OpenSSL | ❌ `SSLContext.load_cert_chain(keyfile=...)` needs a file/PEM | + +OpenSSL needs the private key as raw bytes (PEM/DER) to perform the TLS `CertificateVerify` signature. A non-exportable CNG key handle (`NCRYPT_KEY_HANDLE`) cannot provide this. + +**.NET does not have this problem** — `HttpClientHandler` uses Windows SChannel natively, which accepts `X509Certificate2` objects backed by non-exportable keys. The OS TLS stack performs the signature internally. + +## Research: Could We Use an OpenSSL CNG Provider? + +### RTI OpenSSL CNG Engine (2021) + +- **Repo:** https://github.com/rticommunity/openssl-cng-engine +- **Docs:** https://openssl-cng-engine.readthedocs.io/en/latest/about.html +- **Author:** RTI (Real-Time Innovations), a DDS middleware company +- **License:** Apache 2.0 + +RTI built an OpenSSL **ENGINE** that bridges OpenSSL → Windows CNG. It provides: +- `engine-bcrypt.dll` — Crypto primitives (AES, RSA, ECDSA) +- `engine-ncrypt.dll` — Key Store access (reads certs/keys from Windows cert store) + +### Why It Doesn't Work for Us + +| Issue | Detail | +|-------|--------| +| **OpenSSL 1.1.1 only** | Python 3.12+ ships OpenSSL 3.x. The ENGINE API was deprecated. | +| **Not ported to OpenSSL 3.x PROVIDER** | OpenSSL 3.x replaced ENGINEs with PROVIDERs (completely different API). No port exists. | +| **No KeyGuard/VBS testing** | Built for regular CNG keys, not VBS-isolated keys. | +| **Python `ssl` module** | Doesn't expose `ENGINE_load()` or `OSSL_PROVIDER_load()` to user code. | +| **RSA-PSS limitation** | ENGINE API can't support RSA-PSS (needed for modern TLS 1.3). | +| **Unmaintained** | Last meaningful activity ~2021. | + +### OpenSSL Mailing List Confirmation (July 2021) + +From David von Oheimb (OpenSSL project): +> "Porting this to the new OpenSSL crypto **provider** interface will likely lift the limitation regarding RSA-PSS support, which lacks just due to the engine interface." + +Source: https://mta.openssl.org/pipermail/openssl-users/2021-July/013944.html + +**No one has built this provider.** Not RTI, not Microsoft, not the OpenSSL community. + +## Our Solution: WinHTTP/SChannel via `msal-schannel-transport` + +Instead of trying to make OpenSSL work with non-exportable keys, we bypass OpenSSL entirely and use Windows' native TLS stack (SChannel) via WinHTTP — the same approach .NET uses internally. + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ App Developer Code │ +│ │ +│ result = client.acquire_token_for_client( │ +│ resource="https://vault.azure.net", │ +│ mtls_proof_of_possession=True, │ +│ with_attestation_support=True, │ +│ ) │ +│ │ +│ binding_cert = result["binding_certificate"] # WindowsCert │ +│ auth_header = f"{result['token_type']} {result['access_token']}│ +│ │ +│ with SchannelSession(client_certificate=binding_cert) as s: │ +│ response = s.get(url, headers={"Authorization": auth_header})│ +└──────────────┬──────────────────────────────────┬────────────────┘ + │ │ + ┌──────────▼──────────┐ ┌───────────▼────────────┐ + │ msal (pip package) │ │ msal-schannel-transport │ + │ │ │ (pip package) │ + │ • Token acquisition │ │ • WinHTTP/SChannel │ + │ • KeyGuard key mgmt │ │ • Uses NCRYPT_KEY_HANDLE│ + │ • Attestation (MAA) │ │ • No OpenSSL dependency │ + │ • Returns cert handle│ │ • mTLS with non-export │ + └──────────────────────┘ └─────────────────────────┘ +``` + +### Why Separate Packages? + +| Principle | Implementation | +|-----------|---------------| +| MSAL only acquires tokens | `msal` never makes downstream calls | +| App developer owns HTTP | `msal-schannel-transport` is a helper, not MSAL | +| Same pattern as .NET | .NET: MSAL → HttpClient. Python: MSAL → SchannelSession | +| Clear dependency boundary | Teams that don't need downstream mTLS don't install it | + +### Comparison with Other Languages + +| Language | Token Library | Downstream mTLS Transport | Notes | +|----------|--------------|---------------------------|-------| +| **.NET** | MSAL.NET | `HttpClientHandler` (built-in) | SChannel native, just works | +| **Go** | azure-sdk-for-go | `crypto/tls` + CNG bridge | Go has pluggable TLS | +| **Rust** | azure-sdk-for-rust | `schannel` crate | Native Windows TLS | +| **Python** | MSAL Python | `msal-schannel-transport` (ours) | OpenSSL can't, so we provide it | + +## E2E Proof (June 2026) + +Tested on `MSIV2` Azure VM (Windows, Credential Guard enabled): + +``` +✓ Token acquired: mtls_pop, 86399s expiry +✓ Binding certificate: WindowsCertificate (non-exportable, KeyGuard-backed) +✓ cnf.x5t#S256 binding: MATCH +✓ Downstream mTLS call: HTTP 200 from tokenbinding.vault.azure.net + Secret value returned: "secretme" +``` + +## Future Alternatives (if landscape changes) + +| If this happens... | We could... | +|---|---| +| Microsoft ships OpenSSL 3.x CNG Provider | Use `requests` normally with provider loaded | +| Python `ssl` exposes `OSSL_PROVIDER_load()` | Load CNG provider from Python | +| Python adds native SChannel support | Use stdlib directly | +| RTI ports to OpenSSL 3.x + supports KeyGuard | Evaluate as alternative | + +**None of these exist today.** Our WinHTTP approach is the only working solution for Python + non-exportable keys + mTLS. diff --git a/msal-key-attestation/LICENSE b/msal-key-attestation/LICENSE new file mode 100644 index 00000000..22aed37e --- /dev/null +++ b/msal-key-attestation/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/msal-key-attestation/MANIFEST.in b/msal-key-attestation/MANIFEST.in new file mode 100644 index 00000000..ece055d8 --- /dev/null +++ b/msal-key-attestation/MANIFEST.in @@ -0,0 +1,3 @@ +include LICENSE +include README.md +recursive-exclude tests * diff --git a/msal-key-attestation/msal_key_attestation/attestation.py b/msal-key-attestation/msal_key_attestation/attestation.py new file mode 100644 index 00000000..d0f1c280 --- /dev/null +++ b/msal-key-attestation/msal_key_attestation/attestation.py @@ -0,0 +1,398 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +Windows attestation for MSI v2 KeyGuard keys using AttestationClientLib.dll. + +This module calls into AttestationClientLib.dll to mint an attestation JWT for +a KeyGuard key handle. It also provides a small in-memory cache to reuse the +attestation JWT until ~90% of its lifetime. + +Caching notes: + - Cache is process-local (in-memory). Does not persist across process + restarts. + - Cache is keyed by (attestation_endpoint, client_id, cache_key). + - Provide a stable cache_key (e.g., the named per-boot key name) to + maximize hits. + - If the token cannot be parsed or has no ``exp`` claim, it is not cached. + +Env vars: + - ATTESTATION_CLIENTLIB_PATH: absolute path to AttestationClientLib.dll + - MSAL_MSI_V2_ATTESTATION_CACHE: "0" disables caching (default enabled) +""" + +from __future__ import annotations + +import base64 +import ctypes +import json +import logging +import os +import sys +import threading +import time +from ctypes import POINTER, Structure, c_char_p, c_int, c_void_p +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Native callback type — prevent GC of the delegate +# --------------------------------------------------------------------------- + +_NATIVE_LOG_CB = None + +# void LogFunc(void* ctx, const char* tag, int lvl, const char* func, +# int line, const char* msg); +_LogFunc = ctypes.CFUNCTYPE( + None, c_void_p, c_char_p, c_int, c_char_p, c_int, c_char_p) + + +class AttestationLogInfo(Structure): + _fields_ = [("Log", c_void_p), ("Ctx", c_void_p)] + + +def _default_logger(ctx, tag, lvl, func, line, msg): + try: + tag_s = tag.decode("utf-8", errors="replace") if tag else "" + func_s = func.decode("utf-8", errors="replace") if func else "" + msg_s = msg.decode("utf-8", errors="replace") if msg else "" + logger.debug("[Native:%s:%s] %s:%s - %s", + tag_s, lvl, func_s, line, msg_s) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Env helpers +# --------------------------------------------------------------------------- + +def _truthy_env(name: str, default: str = "1") -> bool: + val = os.getenv(name, default) + return (val or "").strip().lower() in ("1", "true", "yes", "y", "on") + + +def _maybe_add_dll_dirs(): + """Make DLL resolution more reliable (especially for packaged apps). + + Only adds the Python executable directory and the package directory. + Does NOT add os.getcwd() to avoid DLL preloading/hijacking risk. + Use ATTESTATION_CLIENTLIB_PATH env var for custom locations. + """ + if sys.platform != "win32": + return + add_dir = getattr(os, "add_dll_directory", None) + if not add_dir: + return + for p in (os.path.dirname(sys.executable), + os.path.dirname(__file__)): + try: + if p and os.path.isdir(p): + add_dir(p) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# DLL loading +# --------------------------------------------------------------------------- + +def _load_lib(): + if sys.platform != "win32": + raise RuntimeError( + "[msal_key_attestation] AttestationClientLib is Windows-only.") + + _maybe_add_dll_dirs() + + explicit = os.getenv("ATTESTATION_CLIENTLIB_PATH") + try: + if explicit: + return ctypes.CDLL(explicit) + # Try bundled DLL next to this module first + bundled = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "AttestationClientLib.dll") + if os.path.isfile(bundled): + return ctypes.CDLL(bundled) + return ctypes.CDLL("AttestationClientLib.dll") + except OSError as exc: + raise RuntimeError( + "[msal_key_attestation] Unable to load AttestationClientLib.dll. " + "Install msal-key-attestation package or set ATTESTATION_CLIENTLIB_PATH." + ) from exc + + +# --------------------------------------------------------------------------- +# JWT parsing (for cache lifetime) +# --------------------------------------------------------------------------- + +def _b64url_decode(s: str) -> bytes: + s = (s or "").strip() + s += "=" * ((4 - len(s) % 4) % 4) + return base64.urlsafe_b64decode(s.encode("ascii")) + + +def _try_extract_exp_iat(jwt: str) -> Tuple[Optional[int], Optional[int]]: + """Extract exp and iat (Unix seconds) from a JWT without validation.""" + try: + parts = jwt.split(".") + if len(parts) < 2: + return None, None + payload = json.loads( + _b64url_decode(parts[1]).decode("utf-8", errors="replace")) + if not isinstance(payload, dict): + return None, None + + def _to_int(v): + if isinstance(v, bool): + return None + if isinstance(v, int): + return v + if isinstance(v, float): + return int(v) + if isinstance(v, str) and v.strip().isdigit(): + return int(v.strip()) + return None + + return _to_int(payload.get("exp")), _to_int(payload.get("iat")) + except Exception: + return None, None + + +# --------------------------------------------------------------------------- +# MAA token cache (in-memory, process-local) +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class _CacheKey: + attestation_endpoint: str + client_id: str + cache_key: str + auth_token: str + client_payload: str + + +@dataclass +class _CacheEntry: + jwt: str + exp: int + refresh_after: float # epoch seconds + + +_CACHE_LOCK = threading.Lock() +_CACHE: dict = {} + + +def _cache_lookup(key: _CacheKey) -> Optional[str]: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return None + now = time.time() + with _CACHE_LOCK: + entry = _CACHE.get(key) + if not entry: + return None + if now >= entry.refresh_after or now >= entry.exp - 5: + return None + logger.debug("[msal_key_attestation] MAA cache HIT") + return entry.jwt + + +def _cache_store(key: _CacheKey, jwt: str) -> None: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return + exp, iat = _try_extract_exp_iat(jwt) + if exp is None: + return + now = int(time.time()) + issued_at = iat if iat is not None else now + lifetime = exp - issued_at + if lifetime <= 0: + return + # Refresh at 90% of lifetime with small absolute guard + refresh_after = issued_at + (0.90 * lifetime) + refresh_after = min(refresh_after, exp - 10) + with _CACHE_LOCK: + _CACHE[key] = _CacheEntry( + jwt=jwt, exp=exp, refresh_after=float(refresh_after)) + logger.debug("[msal_key_attestation] MAA cache SET") + + +def _cache_clear() -> None: + """Clear cache (for testing).""" + with _CACHE_LOCK: + _CACHE.clear() + + +# --------------------------------------------------------------------------- +# Core attestation call +# --------------------------------------------------------------------------- + +def get_attestation_jwt( + *, + attestation_endpoint: str, + client_id: str, + key_handle: int, + auth_token: str = "", + client_payload: str = "{}", + cache_key: Optional[str] = None, +) -> str: + """ + Get attestation JWT from AttestationClientLib.dll for a KeyGuard key. + + Args: + attestation_endpoint: MAA endpoint URL. + client_id: Client ID for attestation. + key_handle: NCrypt key handle (integer). + auth_token: Optional auth token for attestation service. + client_payload: Optional JSON payload. + cache_key: Stable identifier for caching (recommended: key name). + + Returns: + Attestation JWT string. + + Raises: + RuntimeError: on DLL load or attestation failure. + """ + if not attestation_endpoint: + raise ValueError( + "[msal_key_attestation] attestation_endpoint must be non-empty") + if not client_id: + raise ValueError( + "[msal_key_attestation] client_id must be non-empty") + if not key_handle: + raise ValueError( + "[msal_key_attestation] key_handle must be non-zero") + + stable = cache_key if cache_key is not None else f"handle:{key_handle}" + ck = _CacheKey( + attestation_endpoint=str(attestation_endpoint), + client_id=str(client_id), + cache_key=str(stable), + auth_token=str(auth_token or ""), + client_payload=str(client_payload or "{}"), + ) + + cached = _cache_lookup(ck) + if cached: + return cached + + lib = _load_lib() + + lib.InitAttestationLib.argtypes = [POINTER(AttestationLogInfo)] + lib.InitAttestationLib.restype = c_int + + lib.AttestKeyGuardImportKey.argtypes = [ + c_char_p, # endpoint + c_char_p, # authToken + c_char_p, # clientPayload + c_void_p, # keyHandle (NCRYPT_KEY_HANDLE) + POINTER(c_void_p), # out token (char*) + c_char_p, # clientId + ] + lib.AttestKeyGuardImportKey.restype = c_int + + lib.FreeAttestationToken.argtypes = [c_void_p] + lib.FreeAttestationToken.restype = None + + lib.UninitAttestationLib.argtypes = [] + lib.UninitAttestationLib.restype = None + + global _NATIVE_LOG_CB # pylint: disable=global-statement + _NATIVE_LOG_CB = _LogFunc(_default_logger) + + info = AttestationLogInfo() + info.Log = ctypes.cast(_NATIVE_LOG_CB, c_void_p).value + info.Ctx = c_void_p(0) + + rc = lib.InitAttestationLib(ctypes.byref(info)) + if rc != 0: + raise RuntimeError( + f"[msal_key_attestation] InitAttestationLib failed: {rc}") + + token_ptr = c_void_p() + try: + rc = lib.AttestKeyGuardImportKey( + attestation_endpoint.encode("utf-8"), + (auth_token or "").encode("utf-8"), + (client_payload or "{}").encode("utf-8"), + c_void_p(int(key_handle)), + ctypes.byref(token_ptr), + client_id.encode("utf-8"), + ) + if rc != 0: + raise RuntimeError( + f"[msal_key_attestation] AttestKeyGuardImportKey failed: {rc}") + if not token_ptr.value: + raise RuntimeError( + "[msal_key_attestation] Attestation token pointer is NULL") + + token = ctypes.string_at(token_ptr.value).decode( + "utf-8", errors="replace") + if not token or "." not in token: + raise RuntimeError( + "[msal_key_attestation] Attestation token looks malformed") + + _cache_store(ck, token) + return token + finally: + try: + if token_ptr.value: + lib.FreeAttestationToken(token_ptr) + finally: + try: + lib.UninitAttestationLib() + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Public factory — matches the callback signature MSAL expects: +# (endpoint: str, key_handle: int, client_id: str, cache_key: str) -> str +# --------------------------------------------------------------------------- + +def create_attestation_provider() -> Callable[[str, int, str, str], str]: + """ + Create an attestation token provider callable for MSAL MSI v2. + + The returned callable has signature:: + + provider(attestation_endpoint: str, key_handle: int, + client_id: str, cache_key: str) -> str + + ``cache_key`` should be the stable per-boot key name. Using the key + name (rather than the numeric handle) maximizes MAA-token cache hits + across key re-opens. + + It wraps :func:`get_attestation_jwt` with caching support. + + Usage:: + + from msal_key_attestation import create_attestation_provider + provider = create_attestation_provider() + + # MSAL auto-discovers this when with_attestation_support=True. + # Or pass explicitly: + from msal.msi_v2 import obtain_token + result = obtain_token( + http_client, managed_identity, resource, + attestation_token_provider=provider, + ) + + Returns: + Callable suitable for ``attestation_token_provider`` parameter. + """ + def _provider( + attestation_endpoint: str, + key_handle: int, + client_id: str, + cache_key: str = "", + ) -> str: + return get_attestation_jwt( + attestation_endpoint=attestation_endpoint, + client_id=client_id, + key_handle=key_handle, + cache_key=cache_key or None, + ) + return _provider diff --git a/msal-key-attestation/pyproject.toml b/msal-key-attestation/pyproject.toml new file mode 100644 index 00000000..1d129a9b --- /dev/null +++ b/msal-key-attestation/pyproject.toml @@ -0,0 +1,53 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "msal-key-attestation" +dynamic = ["version"] +description = "KeyGuard attestation support for MSAL Python MSI v2 (mTLS PoP). Provides Python bindings for AttestationClientLib.dll (Windows Credential Guard key attestation)." +readme = "README.md" +license = "MIT" +requires-python = ">=3.9" +authors = [ + {name = "Microsoft Corporation", email = "nugetaad@microsoft.com"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "License :: OSI Approved :: MIT License", + "Operating System :: Microsoft :: Windows", +] +dependencies = [ + "msal>=1.32.0", +] + +[project.urls] +Homepage = "https://github.com/AzureAD/microsoft-authentication-library-for-python" +Repository = "https://github.com/AzureAD/microsoft-authentication-library-for-python" + +[tool.setuptools.dynamic] +version = {attr = "msal_key_attestation.__version__"} + +[tool.setuptools.packages.find] +exclude = ["tests", "tests.*"] + +# Native dependency (not a pip package): +# AttestationClientLib.dll is sourced from NuGet: +# Microsoft.Azure.Security.KeyGuardAttestation v1.1.5 +# Path: runtimes/win-x64/native/AttestationClientLib.dll +# +# For development/testing: +# nuget install Microsoft.Azure.Security.KeyGuardAttestation -Version 1.1.5 +# Copy runtimes/win-x64/native/AttestationClientLib.dll to msal_key_attestation/ +# +# For production packaging: +# The DLL should be bundled into a platform wheel (msal_key_attestation-*-win_amd64.whl) +# built by CI/CD from the NuGet source. diff --git a/msal-schannel-transport/msal_schannel_transport/__init__.py b/msal-schannel-transport/msal_schannel_transport/__init__.py new file mode 100644 index 00000000..b300beda --- /dev/null +++ b/msal-schannel-transport/msal_schannel_transport/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +msal-schannel-transport — Windows SChannel/WinHTTP-backed HTTP transport +for downstream mTLS calls using platform-backed certificates. + +This package provides an HTTP session that uses WinHTTP + SChannel for TLS, +allowing app developers to make downstream API calls with certificates that +have non-exportable private keys (TPM/KeyGuard/VBS). + +This is the Python equivalent of using HttpClient + X509Certificate2 in .NET +for downstream mTLS API calls. + +Usage: + from msal import WindowsCertificate + from msal_schannel_transport import SchannelSession + + # Get certificate reference from MSAL auth result + cert = result["binding_certificate"] # WindowsCertificate object + + # Make downstream mTLS call + with SchannelSession(client_certificate=cert) as session: + response = session.get( + "https://my-vault.vault.azure.net/secrets/foo?api-version=7.5", + headers={ + "Authorization": ( + f"{result['token_type']} {result['access_token']}" + ), + }, + ) + print(response.status_code, response.json()) +""" + +from .session import SchannelSession, SchannelResponse + +__all__ = ["SchannelSession", "SchannelResponse"] +__version__ = "0.1.0" diff --git a/msal-schannel-transport/msal_schannel_transport/session.py b/msal-schannel-transport/msal_schannel_transport/session.py new file mode 100644 index 00000000..02c1c548 --- /dev/null +++ b/msal-schannel-transport/msal_schannel_transport/session.py @@ -0,0 +1,431 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +SchannelSession — WinHTTP/SChannel-backed HTTP client for mTLS. + +This module provides a requests-like HTTP session that uses Windows native +WinHTTP for TLS, enabling mTLS with non-exportable private keys. + +Design: + - App developer creates SchannelSession with a WindowsCertificate + - session.get() / session.post() use WinHTTP under the hood + - The private key never leaves Windows CNG — SChannel performs the + TLS CertificateVerify signature via the NCRYPT_KEY_HANDLE + - This is NOT part of MSAL — it's a separate transport for downstream calls +""" + +from __future__ import annotations + +import json +import sys +from typing import Any, Dict, Optional +from urllib.parse import urlparse, urlencode + + +class SchannelResponse: + """Response from a WinHTTP request (requests.Response-like interface).""" + + def __init__(self, status_code: int, body: bytes, headers: Optional[Dict[str, str]] = None): + self.status_code = status_code + self._body = body + self.headers = headers or {} + + @property + def content(self) -> bytes: + """Raw response body bytes.""" + return self._body + + @property + def text(self) -> str: + """Response body as UTF-8 text.""" + return self._body.decode("utf-8", errors="replace") + + def json(self) -> Any: + """Parse response body as JSON.""" + return json.loads(self._body) + + def raise_for_status(self) -> None: + """Raise an exception for 4xx/5xx status codes.""" + if 400 <= self.status_code < 600: + raise SchannelHttpError( + f"HTTP {self.status_code}: {self.text[:200]}", + status_code=self.status_code, + body=self._body, + ) + + def __repr__(self) -> str: + return f"" + + +class SchannelHttpError(Exception): + """HTTP error from a SChannel transport call.""" + + def __init__(self, message: str, status_code: int = 0, body: bytes = b""): + super().__init__(message) + self.status_code = status_code + self.body = body + + +class SchannelSession: + """ + WinHTTP/SChannel-backed HTTP session for mTLS downstream calls. + + This is the Python equivalent of .NET HttpClient configured with an + X509Certificate2 for client certificate authentication. + + Args: + client_certificate: A WindowsCertificate object (from MSAL auth result + or WindowsCertificate.from_store()). The certificate's private key + handle is used by SChannel for the TLS handshake. + + Example: + from msal_schannel_transport import SchannelSession + + with SchannelSession(client_certificate=cert) as session: + resp = session.get( + "https://api.example.com/resource", + headers={"Authorization": "MTLS_POP eyJ..."} + ) + print(resp.json()) + """ + + def __init__(self, client_certificate: Any): + """ + Create a session for mTLS downstream calls. + + Args: + client_certificate: A WindowsCertificate object. Must remain open + (not closed) for the lifetime of this session. + + Raises: + OSError: if not on Windows. + TypeError: if client_certificate lacks required interface. + ValueError: if certificate has no private key. + """ + if sys.platform != "win32": + raise OSError("SchannelSession is only supported on Windows") + + self._certificate = client_certificate + self._cert_ctx: Any = None + self._win32: Optional[dict] = None + self._closed = False + + if not hasattr(client_certificate, "create_cert_context"): + raise TypeError( + "client_certificate must be a WindowsCertificate instance " + "(or compatible object with create_cert_context method)") + + if not client_certificate.has_private_key: + raise ValueError( + "client_certificate has no private key — cannot perform mTLS") + + self._setup() + + def _setup(self) -> None: + """Initialize WinHTTP bindings and create CERT_CONTEXT.""" + self._win32 = self._load_winhttp() + self._cert_ctx = self._certificate.create_cert_context() + + @staticmethod + def _load_winhttp() -> dict: + """Load minimal WinHTTP bindings (independent of MSAL internals).""" + import ctypes + from ctypes import wintypes + + winhttp = ctypes.WinDLL("winhttp.dll", use_last_error=True) + crypt32 = ctypes.WinDLL("crypt32.dll", use_last_error=True) + + class CERT_CONTEXT(ctypes.Structure): + _fields_ = [ + ("dwCertEncodingType", wintypes.DWORD), + ("pbCertEncoded", ctypes.POINTER(ctypes.c_ubyte)), + ("cbCertEncoded", wintypes.DWORD), + ("pCertInfo", ctypes.c_void_p), + ("hCertStore", ctypes.c_void_p), + ] + + PCCERT_CONTEXT = ctypes.POINTER(CERT_CONTEXT) + + winhttp.WinHttpOpen.argtypes = [ + ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_wchar_p, + ctypes.c_wchar_p, wintypes.DWORD] + winhttp.WinHttpOpen.restype = ctypes.c_void_p + + winhttp.WinHttpConnect.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, wintypes.WORD, wintypes.DWORD] + winhttp.WinHttpConnect.restype = ctypes.c_void_p + + winhttp.WinHttpOpenRequest.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_wchar_p, + ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpOpenRequest.restype = ctypes.c_void_p + + winhttp.WinHttpSetOption.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpSetOption.restype = wintypes.BOOL + + winhttp.WinHttpSendRequest.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_void_p, + wintypes.DWORD, wintypes.DWORD, ctypes.c_size_t] + winhttp.WinHttpSendRequest.restype = wintypes.BOOL + + winhttp.WinHttpReceiveResponse.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p] + winhttp.WinHttpReceiveResponse.restype = wintypes.BOOL + + winhttp.WinHttpQueryHeaders.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_void_p, + ctypes.POINTER(wintypes.DWORD), ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryHeaders.restype = wintypes.BOOL + + winhttp.WinHttpQueryDataAvailable.argtypes = [ + ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryDataAvailable.restype = wintypes.BOOL + + winhttp.WinHttpReadData.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpReadData.restype = wintypes.BOOL + + winhttp.WinHttpCloseHandle.argtypes = [ctypes.c_void_p] + winhttp.WinHttpCloseHandle.restype = wintypes.BOOL + + crypt32.CertFreeCertificateContext.argtypes = [PCCERT_CONTEXT] + crypt32.CertFreeCertificateContext.restype = wintypes.BOOL + + return { + "ctypes": ctypes, "wintypes": wintypes, + "winhttp": winhttp, "crypt32": crypt32, + "CERT_CONTEXT": CERT_CONTEXT, + "WINHTTP_ACCESS_TYPE_DEFAULT_PROXY": 0, + "WINHTTP_FLAG_SECURE": 0x00800000, + "WINHTTP_OPTION_CLIENT_CERT_CONTEXT": 47, + "WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT": 161, + "WINHTTP_QUERY_STATUS_CODE": 19, + "WINHTTP_QUERY_FLAG_NUMBER": 0x20000000, + } + + def get( + self, + url: str, + *, + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + ) -> SchannelResponse: + """HTTP GET with mTLS client certificate.""" + if params: + sep = "&" if "?" in url else "?" + url = url + sep + urlencode(params) + return self._request("GET", url, headers=headers or {}) + + def post( + self, + url: str, + *, + headers: Optional[Dict[str, str]] = None, + data: Optional[bytes] = None, + json_body: Optional[Any] = None, + params: Optional[Dict[str, str]] = None, + ) -> SchannelResponse: + """HTTP POST with mTLS client certificate.""" + if params: + sep = "&" if "?" in url else "?" + url = url + sep + urlencode(params) + + if json_body is not None: + data = json.dumps(json_body).encode("utf-8") + headers = dict(headers or {}) + headers.setdefault("Content-Type", "application/json") + + return self._request("POST", url, headers=headers or {}, body=data) + + def _request( + self, method: str, url: str, headers: Dict[str, str], + body: Optional[bytes] = None, + ) -> SchannelResponse: + """Execute an HTTP request over WinHTTP with mTLS.""" + if self._closed: + raise RuntimeError("SchannelSession has been closed") + + # Validate headers against CRLF injection + for k, v in headers.items(): + if "\r" in k or "\n" in k or ":" in k: + raise ValueError( + f"Invalid header name (contains CR/LF/colon): {k!r}") + if "\r" in v or "\n" in v: + raise ValueError( + f"Invalid header value for '{k}' (contains CR/LF)") + + ctypes_mod = self._win32["ctypes"] + wintypes = self._win32["wintypes"] + winhttp = self._win32["winhttp"] + + u = urlparse(url) + if u.scheme.lower() != "https": + raise ValueError(f"SchannelSession requires https, got: {url!r}") + if not u.hostname: + raise ValueError(f"Invalid URL: {url!r}") + + host = u.hostname + port = u.port or 443 + path = u.path or "/" + if u.query: + path += "?" + u.query + + h_session = winhttp.WinHttpOpen( + "msal-schannel-transport/0.1", + self._win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], + None, None, 0) + if not h_session: + self._raise_last_error("WinHttpOpen failed") + + h_connect = None + h_request = None + try: + # Best-effort HTTP/2 + client cert + enable = wintypes.DWORD(1) + try: + winhttp.WinHttpSetOption( + h_session, + self._win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], + ctypes_mod.byref(enable), ctypes_mod.sizeof(enable)) + except Exception: + pass + + h_connect = winhttp.WinHttpConnect(h_session, host, int(port), 0) + if not h_connect: + self._raise_last_error("WinHttpConnect failed") + + h_request = winhttp.WinHttpOpenRequest( + h_connect, method, path, None, None, None, + self._win32["WINHTTP_FLAG_SECURE"]) + if not h_request: + self._raise_last_error("WinHttpOpenRequest failed") + + # Attach client certificate for mTLS + ok = winhttp.WinHttpSetOption( + h_request, + self._win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], + self._cert_ctx, + ctypes_mod.sizeof(self._win32["CERT_CONTEXT"])) + if not ok: + self._raise_last_error("WinHttpSetOption(CLIENT_CERT) failed") + + # Format headers + header_lines = "".join( + f"{k}: {v}\r\n" for k, v in headers.items()) + + # Send request + if body: + body_buf = ctypes_mod.create_string_buffer(body) + ok = winhttp.WinHttpSendRequest( + h_request, header_lines, 0xFFFFFFFF, + body_buf, len(body), len(body), 0) + else: + ok = winhttp.WinHttpSendRequest( + h_request, header_lines, 0xFFFFFFFF, + None, 0, 0, 0) + if not ok: + self._raise_last_error("WinHttpSendRequest failed") + + ok = winhttp.WinHttpReceiveResponse(h_request, None) + if not ok: + self._raise_last_error("WinHttpReceiveResponse failed") + + # Read status code + status = wintypes.DWORD(0) + status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) + index = wintypes.DWORD(0) + + ok = winhttp.WinHttpQueryHeaders( + h_request, + self._win32["WINHTTP_QUERY_STATUS_CODE"] + | self._win32["WINHTTP_QUERY_FLAG_NUMBER"], + None, ctypes_mod.byref(status), + ctypes_mod.byref(status_size), ctypes_mod.byref(index)) + if not ok: + self._raise_last_error("WinHttpQueryHeaders(STATUS_CODE) failed") + + # Read response body + chunks = [] + while True: + avail = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryDataAvailable( + h_request, ctypes_mod.byref(avail)) + if not ok: + self._raise_last_error("WinHttpQueryDataAvailable failed") + if avail.value == 0: + break + buf = (ctypes_mod.c_ubyte * avail.value)() + read = wintypes.DWORD(0) + ok = winhttp.WinHttpReadData( + h_request, buf, avail.value, ctypes_mod.byref(read)) + if not ok: + self._raise_last_error("WinHttpReadData failed") + if read.value: + chunks.append(bytes(buf[:read.value])) + if read.value == 0: + break + + return SchannelResponse( + status_code=int(status.value), + body=b"".join(chunks), + ) + finally: + self._close_handle(h_request) + self._close_handle(h_connect) + self._close_handle(h_session) + + def _close_handle(self, h: Any) -> None: + """Close a WinHTTP handle safely.""" + try: + if h: + self._win32["winhttp"].WinHttpCloseHandle(h) + except Exception: + pass + + def _raise_last_error(self, context: str) -> None: + """Raise with Win32 last error.""" + ctypes_mod = self._win32["ctypes"] + err = ctypes_mod.get_last_error() + detail = "" + try: + detail = ctypes_mod.FormatError(err).strip() + except Exception: + pass + raise OSError( + f"[SchannelSession] {context} (winerror={err} {detail})" + if detail else + f"[SchannelSession] {context} (winerror={err})") + + def close(self) -> None: + """Release the CERT_CONTEXT. Safe to call multiple times.""" + if self._closed: + return + self._closed = True + + if self._cert_ctx and self._win32: + try: + self._win32["crypt32"].CertFreeCertificateContext( + self._cert_ctx) + except Exception: + pass + self._cert_ctx = None + + def __enter__(self) -> "SchannelSession": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + def __repr__(self) -> str: + state = "closed" if self._closed else "open" + return f"" diff --git a/msal-schannel-transport/pyproject.toml b/msal-schannel-transport/pyproject.toml new file mode 100644 index 00000000..cb6f1bfa --- /dev/null +++ b/msal-schannel-transport/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "msal-schannel-transport" +version = "0.1.0" +description = "Windows SChannel/WinHTTP HTTP transport for mTLS with non-exportable keys" +readme = "README.md" +license = {text = "MIT"} +requires-python = ">=3.9" +authors = [ + {name = "Microsoft", email = "opencode@microsoft.com"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: MIT License", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Security", + "Topic :: Internet :: WWW/HTTP", +] +dependencies = [] + +[project.urls] +Homepage = "https://github.com/AzureAD/microsoft-authentication-library-for-python" +Issues = "https://github.com/AzureAD/microsoft-authentication-library-for-python/issues" + +[tool.setuptools.packages.find] +include = ["msal_schannel_transport*"] diff --git a/msal/__init__.py b/msal/__init__.py index ea681317..cca82ea7 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -39,11 +39,12 @@ SystemAssignedManagedIdentity, UserAssignedManagedIdentity, ManagedIdentityClient, ManagedIdentityError, + MsiV2Error, ArcPlatformNotSupportedError, ) +from .windows_certificate import WindowsCertificate # Putting module-level exceptions into the package namespace, to make them # 1. officially part of the MSAL public API, and # 2. can still be caught by the user code even if we change the module structure. from .oauth2cli.oauth2 import BrowserInteractionTimeoutError - diff --git a/msal/managed_identity.py b/msal/managed_identity.py index b2fc446c..0f06978f 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -26,6 +26,11 @@ class ManagedIdentityError(ValueError): pass +class MsiV2Error(ManagedIdentityError): + """Raised when MSI v2 (mTLS PoP) flow fails.""" + pass + + class ManagedIdentity(UserDict): """Feed an instance of this class to :class:`msal.ManagedIdentityClient` to acquire token for the specified managed identity. @@ -261,12 +266,17 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, + mtls_proof_of_possession: bool = False, + with_attestation_support: bool = False, ): """Acquire token for the managed identity. - The result will be automatically cached. + For the standard (MSI v1) flow, the result will be automatically cached. Subsequent calls will automatically search from cache first. + For the MSI v2 (mTLS PoP) flow, the certificate is cached internally + but each call acquires a fresh token from ESTS via mTLS. + :param resource: The resource for which the token is acquired. :param claims_challenge: @@ -280,6 +290,25 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. + :param bool mtls_proof_of_possession: (optional) + When True **and** ``with_attestation_support`` is also True, + use the MSI v2 (mTLS Proof-of-Possession) flow to acquire an + ``mtls_pop`` token bound to a short-lived mTLS certificate issued + by the IMDS ``/issuecredential`` endpoint. + + Requires Windows with Credential Guard / KeyGuard active. + Without ``with_attestation_support``, this flag alone falls + through to the legacy IMDS v1 flow. Defaults to False. + + :param bool with_attestation_support: (optional) + When True (and ``mtls_proof_of_possession`` is also True), + perform KeyGuard / platform attestation before credential + issuance. This requires the **msal-key-attestation** package + (``pip install msal-key-attestation``). + + Setting this to True without ``mtls_proof_of_possession`` + raises :class:`ManagedIdentityError`. Defaults to False. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -294,6 +323,46 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() + + # --- MSI v2 gate --- + # MSI v2 is opt-in: both mtls_proof_of_possession AND + # with_attestation_support must be True. + # No auto-fallback: if v2 fails, MsiV2Error is raised. + use_msi_v2 = bool(mtls_proof_of_possession and with_attestation_support) + + if with_attestation_support and not mtls_proof_of_possession: + raise ManagedIdentityError( + "attestation_requires_pop: with_attestation_support=True " + "requires mtls_proof_of_possession=True (mTLS PoP).") + + if use_msi_v2: + # Auto-discover attestation provider from msal-key-attestation + attestation_token_provider = None + try: + from msal_key_attestation import create_attestation_provider + attestation_token_provider = create_attestation_provider() + except ImportError as exc: + raise MsiV2Error( + "[msi_v2] with_attestation_support=True requires the " + "msal-key-attestation package. " + "Install it with: pip install msal-key-attestation") from exc + + from .msi_v2 import obtain_token as _obtain_token_v2 + try: + result = _obtain_token_v2( + self._http_client, self._managed_identity, resource, + attestation_enabled=True, + attestation_token_provider=attestation_token_provider, + ) + except MsiV2Error: + raise + except Exception as exc: + raise MsiV2Error( + f"[msi_v2] Unexpected failure: {exc}") from exc + if "access_token" in result and "error" not in result: + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP + return result + if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( diff --git a/msal/msi_v2.py b/msal/msi_v2.py new file mode 100644 index 00000000..eda8ced6 --- /dev/null +++ b/msal/msi_v2.py @@ -0,0 +1,1501 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +MSI v2 (IMDSv2) Managed Identity flow — Windows KeyGuard + SChannel mTLS PoP. + +This module implements the MSI v2 token acquisition path using Windows native APIs +via ctypes: + - CNG/NCrypt: create/open a KeyGuard-protected per-boot RSA key (non-exportable) + - Minimal DER/PKCS#10: build a CSR signed with RSA-PSS/SHA256 + - IMDS: call getplatformmetadata + issuecredential + - Crypt32: bind the issued certificate to the CNG private key + - WinHTTP/SChannel: acquire access token over mTLS (token_type=mtls_pop) + +Key behavior: + - Uses a *named per-boot key*: opens the key if it already exists for this boot; + otherwise creates it. + - No MSI v1 fallback: any MSI v2 failure raises MsiV2Error. + - Production-ready handle management: all WinHTTP / Crypt32 / NCrypt handles are + released in finally blocks. + - Certificate cache: in-memory with lifetime-based eviction (like .NET + InMemoryCertificateCache). + - Returns certificate with token for mTLS with resource. + +Environment variables (optional): + - AZURE_POD_IDENTITY_AUTHORITY_HOST: override IMDS base URL + (default http://169.254.169.254) + - MSAL_MSI_V2_KEY_NAME: override the per-boot key name (otherwise derived from + metadata clientId) +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import struct +import sys +import threading +import time +import uuid +from typing import Any, Callable, Dict, List, Optional, Tuple +from urllib.parse import urlparse, urlencode + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# IMDS constants +# --------------------------------------------------------------------------- + +_IMDS_DEFAULT_BASE = "http://169.254.169.254" +_IMDS_BASE_ENVVAR = "AZURE_POD_IDENTITY_AUTHORITY_HOST" + +_API_VERSION_QUERY_PARAM = "cred-api-version" +_IMDS_V2_API_VERSION = "2.0" + +_CSR_METADATA_PATH = "/metadata/identity/getplatformmetadata" +_ISSUE_CREDENTIAL_PATH = "/metadata/identity/issuecredential" +_ACQUIRE_ENTRA_TOKEN_PATH = "/oauth2/v2.0/token" + +_CU_ID_OID_STR = "1.3.6.1.4.1.311.90.2.10" + +# --------------------------------------------------------------------------- +# NCrypt/CNG flags +# --------------------------------------------------------------------------- + +_NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000 +_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 + +_RSA_KEY_SIZE = 2048 + +_AT_SIGNATURE = 2 + +_NCRYPT_SILENT_FLAG = 0x40 + +_KEY_NAME_ENVVAR = "MSAL_MSI_V2_KEY_NAME" + +# NCrypt "not found" status codes +_NTE_BAD_KEYSET = 0x80090016 +_NTE_NO_KEY = 0x8009000D +_NTE_NOT_FOUND = 0x80090011 +_NTE_KEY_DOES_NOT_EXIST = 0x8009003A # KeyGuard/VBS provider uses this +_NTE_EXISTS = 0x8009000F + +# Lazy-loaded Win32 API cache +_WIN32: Optional[Dict[str, Any]] = None + + +# --------------------------------------------------------------------------- +# Certificate cache (in-memory, process-local) +# --------------------------------------------------------------------------- + +class _CertCacheEntry: + """Cached mTLS certificate + metadata.""" + __slots__ = ("cert_der", "cert_pem", "token_endpoint", "client_id", + "not_after", "created_at") + + # Minimum remaining cert lifetime to consider cache valid. + # IMDS certs are typically ~8 hours; evict when <1h remains. + MIN_REMAINING_LIFETIME_SEC = 1 * 3600 + + def __init__(self, cert_der: bytes, cert_pem: str, + token_endpoint: str, client_id: str, + not_after: float): + self.cert_der = cert_der + self.cert_pem = cert_pem + self.token_endpoint = token_endpoint + self.client_id = client_id + self.not_after = not_after + self.created_at = time.time() + + def is_expired(self, now: Optional[float] = None) -> bool: + now = now or time.time() + return now >= self.not_after - self.MIN_REMAINING_LIFETIME_SEC + + +_CERT_CACHE_LOCK = threading.Lock() +_CERT_CACHE: Dict[str, _CertCacheEntry] = {} + + +def _cert_cache_key(managed_identity: Optional[Any], + attested: bool) -> str: + """Build a cache key from managed identity + identifier type + attestation flag.""" + mi_id_type = "SYSTEM_ASSIGNED" + mi_id = "SYSTEM_ASSIGNED" + getter = getattr(managed_identity, "get", None) + if callable(getter): + mi_id_type = str(getter("ManagedIdentityIdType") or "SYSTEM_ASSIGNED") + mi_id = str(getter("Id") or "SYSTEM_ASSIGNED") + tag = "#att=1" if attested else "#att=0" + return mi_id_type + ":" + mi_id + tag + + +def _cert_cache_get(key: str) -> Optional[_CertCacheEntry]: + """Return cached entry or None if missing/expired.""" + now = time.time() + with _CERT_CACHE_LOCK: + entry = _CERT_CACHE.get(key) + if entry is None: + return None + if entry.is_expired(now): + del _CERT_CACHE[key] + logger.debug("[msi_v2] Cert cache EVICT (expired) key=%s", key[:20]) + return None + logger.debug("[msi_v2] Cert cache HIT key=%s", key[:20]) + return entry + + +def _cert_cache_set(key: str, entry: _CertCacheEntry) -> None: + """Store entry if it has sufficient remaining lifetime.""" + now = time.time() + if entry.not_after <= now + _CertCacheEntry.MIN_REMAINING_LIFETIME_SEC: + logger.debug("[msi_v2] Cert cache SKIP (insufficient lifetime) key=%s", + key[:20]) + return + with _CERT_CACHE_LOCK: + _CERT_CACHE[key] = entry + logger.debug("[msi_v2] Cert cache SET key=%s", key[:20]) + + +def _cert_cache_remove(key: str) -> None: + """Remove entry (e.g., on SChannel failure).""" + with _CERT_CACHE_LOCK: + _CERT_CACHE.pop(key, None) + + +def _cert_cache_clear() -> None: + """Clear all entries (for testing).""" + with _CERT_CACHE_LOCK: + _CERT_CACHE.clear() + + +# --------------------------------------------------------------------------- +# Compatibility helpers (tests + cross-language parity) +# --------------------------------------------------------------------------- + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + """ + Return base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 + comparisons. Accepts a PEM certificate string. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + cert = x509.load_pem_x509_certificate( + cert_pem.encode("utf-8"), default_backend()) + der = cert.public_bytes(serialization.Encoding.DER) + digest = hashlib.sha256(der).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + except Exception: + return "" + + +def verify_cnf_binding(token: str, cert_pem: str) -> bool: + """ + Verify that JWT payload contains cnf.x5t#S256 matching the cert + thumbprint. + """ + try: + parts = token.split(".") + if len(parts) != 3: + return False + + payload_b64 = parts[1] + payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) + claims = json.loads( + base64.urlsafe_b64decode(payload_b64.encode("ascii"))) + + cnf = claims.get("cnf", {}) if isinstance(claims, dict) else {} + if not isinstance(cnf, dict): + return False + token_x5t = cnf.get("x5t#S256") + if not token_x5t: + return False + + cert_x5t = get_cert_thumbprint_sha256(cert_pem) + if not cert_x5t: + return False + + return token_x5t == cert_x5t + except Exception: + return False + + +def _der_to_pem(der_bytes: bytes) -> str: + """Convert DER certificate bytes to PEM string format.""" + b64 = base64.b64encode(der_bytes).decode("ascii") + lines = [b64[i:i + 64] for i in range(0, len(b64), 64)] + return ("-----BEGIN CERTIFICATE-----\n" + + "\n".join(lines) + + "\n-----END CERTIFICATE-----") + + +def _try_parse_cert_not_after(der_bytes: bytes) -> float: + """ + Best-effort extraction of notAfter from a DER X.509 certificate. + Returns epoch seconds. Falls back to now + 8 hours on any failure. + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + cert = x509.load_der_x509_certificate(der_bytes, default_backend()) + na = cert.not_valid_after_utc if hasattr( + cert, "not_valid_after_utc") else cert.not_valid_after + if na.tzinfo is None: + import calendar + return float(calendar.timegm(na.timetuple())) + return na.timestamp() + except Exception: + # Default: assume 8-hour cert lifetime (IMDS typical) + return time.time() + 8 * 3600 + + +# --------------------------------------------------------------------------- +# IMDS helpers +# --------------------------------------------------------------------------- + +def _imds_base() -> str: + base = os.getenv(_IMDS_BASE_ENVVAR) + if base is None: + return _IMDS_DEFAULT_BASE.rstrip("/") + base = base.strip().rstrip("/") + return base or _IMDS_DEFAULT_BASE.rstrip("/") + + +def _new_correlation_id() -> str: + return str(uuid.uuid4()) + + +def _imds_headers(correlation_id: Optional[str] = None) -> Dict[str, str]: + return { + "Metadata": "true", + "x-ms-client-request-id": correlation_id or _new_correlation_id(), + } + + +def _resource_to_scope(resource_or_scope: str) -> str: + """Normalize resource to scope format (append /.default if needed).""" + s = (resource_or_scope or "").strip() + if not s: + raise ValueError("resource must be non-empty") + if s.endswith("/.default"): + return s + return s.rstrip("/") + "/.default" + + +def _der_utf8string(value: str) -> bytes: + """DER UTF8String encoder (tag 0x0C).""" + raw = value.encode("utf-8") + n = len(raw) + if n < 0x80: + len_bytes = bytes([n]) + else: + tmp = bytearray() + m = n + while m > 0: + tmp.insert(0, m & 0xFF) + m >>= 8 + len_bytes = bytes([0x80 | len(tmp)]) + bytes(tmp) + return bytes([0x0C]) + len_bytes + raw + + +def _json_loads(text: str, what: str) -> Dict[str, Any]: + """Parse JSON with error context.""" + from .managed_identity import MsiV2Error + try: + obj = json.loads(text) + if not isinstance(obj, dict): + raise TypeError("expected JSON object") + return obj + except Exception as exc: + raise MsiV2Error( + f"[msi_v2] Invalid JSON from {what}: {text!r}") from exc + + +def _get_first(obj: Dict[str, Any], *names: str) -> Optional[str]: + """Get first non-empty value from object by multiple name variants.""" + for n in names: + if n in obj and obj[n] is not None and str(obj[n]).strip() != "": + return str(obj[n]) + lower = {str(k).lower(): k for k in obj.keys()} + for n in names: + k = lower.get(n.lower()) + if k and obj[k] is not None and str(obj[k]).strip() != "": + return str(obj[k]) + return None + + +def _mi_query_params( + managed_identity: Optional[Any], +) -> Dict[str, str]: + """Build IMDS query params: cred-api-version=2.0 + optional UAMI selector.""" + params: Dict[str, str] = {_API_VERSION_QUERY_PARAM: _IMDS_V2_API_VERSION} + getter = getattr(managed_identity, "get", None) + if not callable(getter): + return params + id_type = getter("ManagedIdentityIdType") + identifier = getter("Id") + mapping = {"ClientId": "client_id", "ObjectId": "object_id", + "ResourceId": "msi_res_id"} + wire = mapping.get(id_type) + if wire and identifier: + params[wire] = str(identifier) + return params + + +def _imds_get_json( + http_client, url: str, params: Dict[str, str], + headers: Dict[str, str], +) -> Dict[str, Any]: + """GET request to IMDS with server header verification.""" + from .managed_identity import MsiV2Error + resp = http_client.get(url, params=params, headers=headers) + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error( + f"[msi_v2] IMDS server header check failed. " + f"server={server!r} url={url}") + if resp.status_code != 200: + raise MsiV2Error( + f"[msi_v2] IMDSv2 GET {url} failed: " + f"HTTP {resp.status_code}: {resp.text}") + return _json_loads(resp.text, f"GET {url}") + + +def _imds_post_json( + http_client, url: str, params: Dict[str, str], + headers: Dict[str, str], body: Dict[str, Any], +) -> Dict[str, Any]: + """POST request to IMDS with server header verification.""" + from .managed_identity import MsiV2Error + resp = http_client.post( + url, params=params, headers=headers, + data=json.dumps(body, separators=(",", ":"))) + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error( + f"[msi_v2] IMDS server header check failed. " + f"server={server!r} url={url}") + if resp.status_code != 200: + raise MsiV2Error( + f"[msi_v2] IMDSv2 POST {url} failed: " + f"HTTP {resp.status_code}: {resp.text}") + return _json_loads(resp.text, f"POST {url}") + + +def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: + """ + Extract token endpoint from issuecredential response. + Prefers explicit token_endpoint, falls back to + mtls_authentication_endpoint + tenant_id. + """ + token_endpoint = _get_first(cred, "token_endpoint", "tokenEndpoint") + if token_endpoint: + return token_endpoint + + mtls_auth = _get_first( + cred, "mtls_authentication_endpoint", + "mtlsAuthenticationEndpoint", "mtls_authenticationEndpoint") + tenant_id = _get_first(cred, "tenant_id", "tenantId") + if not mtls_auth or not tenant_id: + from .managed_identity import MsiV2Error + raise MsiV2Error( + f"[msi_v2] issuecredential missing " + f"mtls_authentication_endpoint/tenant_id: {cred}") + + base = mtls_auth.rstrip("/") + "/" + tenant_id.strip("/") + return base + _ACQUIRE_ENTRA_TOKEN_PATH + + +# --------------------------------------------------------------------------- +# Win32 primitives (ctypes) — lazy loaded +# --------------------------------------------------------------------------- + +def _load_win32() -> Dict[str, Any]: + """Lazy-load Win32 APIs via ctypes (safe to import on non-Windows).""" + global _WIN32 + from .managed_identity import MsiV2Error + + if _WIN32 is not None: + return _WIN32 + if sys.platform != "win32": + raise MsiV2Error("[msi_v2] KeyGuard + mTLS PoP is Windows-only.") + + import ctypes + from ctypes import wintypes + + ncrypt = ctypes.WinDLL("ncrypt.dll") + crypt32 = ctypes.WinDLL("crypt32.dll", use_last_error=True) + winhttp = ctypes.WinDLL("winhttp.dll", use_last_error=True) + + NCRYPT_PROV_HANDLE = ctypes.c_void_p + NCRYPT_KEY_HANDLE = ctypes.c_void_p + SECURITY_STATUS = ctypes.c_long + + class CERT_CONTEXT(ctypes.Structure): + _fields_ = [ + ("dwCertEncodingType", wintypes.DWORD), + ("pbCertEncoded", ctypes.POINTER(ctypes.c_ubyte)), + ("cbCertEncoded", wintypes.DWORD), + ("pCertInfo", ctypes.c_void_p), + ("hCertStore", ctypes.c_void_p), + ] + + PCCERT_CONTEXT = ctypes.POINTER(CERT_CONTEXT) + + class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): + _fields_ = [ + ("pszAlgId", ctypes.c_wchar_p), + ("cbSalt", wintypes.ULONG), + ] + + # NCrypt prototypes + ncrypt.NCryptOpenStorageProvider.argtypes = [ + ctypes.POINTER(NCRYPT_PROV_HANDLE), ctypes.c_wchar_p, wintypes.DWORD] + ncrypt.NCryptOpenStorageProvider.restype = SECURITY_STATUS + + ncrypt.NCryptOpenKey.argtypes = [ + NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptOpenKey.restype = SECURITY_STATUS + + ncrypt.NCryptCreatePersistedKey.argtypes = [ + NCRYPT_PROV_HANDLE, ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, ctypes.c_wchar_p, wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptCreatePersistedKey.restype = SECURITY_STATUS + + ncrypt.NCryptSetProperty.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, + wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptSetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptFinalizeKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] + ncrypt.NCryptFinalizeKey.restype = SECURITY_STATUS + + ncrypt.NCryptGetProperty.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptGetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptExportKey.argtypes = [ + NCRYPT_KEY_HANDLE, NCRYPT_KEY_HANDLE, ctypes.c_wchar_p, + ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptExportKey.restype = SECURITY_STATUS + + ncrypt.NCryptSignHash.argtypes = [ + NCRYPT_KEY_HANDLE, ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), + wintypes.DWORD] + ncrypt.NCryptSignHash.restype = SECURITY_STATUS + + ncrypt.NCryptFreeObject.argtypes = [ctypes.c_void_p] + ncrypt.NCryptFreeObject.restype = SECURITY_STATUS + + # Crypt32 prototypes + crypt32.CertCreateCertificateContext.argtypes = [ + wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + crypt32.CertCreateCertificateContext.restype = PCCERT_CONTEXT + + crypt32.CertSetCertificateContextProperty.argtypes = [ + PCCERT_CONTEXT, wintypes.DWORD, wintypes.DWORD, ctypes.c_void_p] + crypt32.CertSetCertificateContextProperty.restype = wintypes.BOOL + + crypt32.CertFreeCertificateContext.argtypes = [PCCERT_CONTEXT] + crypt32.CertFreeCertificateContext.restype = wintypes.BOOL + + # Crypt32 — certificate store APIs (for WindowsCertificate.from_store) + crypt32.CertOpenStore.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, + wintypes.DWORD, ctypes.c_void_p] + crypt32.CertOpenStore.restype = ctypes.c_void_p + + crypt32.CertCloseStore.argtypes = [ctypes.c_void_p, wintypes.DWORD] + crypt32.CertCloseStore.restype = wintypes.BOOL + + crypt32.CertFindCertificateInStore.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, wintypes.DWORD, + wintypes.DWORD, ctypes.c_void_p, ctypes.c_void_p] + crypt32.CertFindCertificateInStore.restype = PCCERT_CONTEXT + + crypt32.CryptAcquireCertificatePrivateKey.argtypes = [ + PCCERT_CONTEXT, wintypes.DWORD, ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_ulong), + ctypes.POINTER(ctypes.c_int)] + crypt32.CryptAcquireCertificatePrivateKey.restype = wintypes.BOOL + + class CRYPT_HASH_BLOB(ctypes.Structure): + _fields_ = [ + ("cbData", wintypes.DWORD), + ("pbData", ctypes.c_void_p), + ] + + # WinHTTP prototypes + winhttp.WinHttpOpen.argtypes = [ + ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_wchar_p, + ctypes.c_wchar_p, wintypes.DWORD] + winhttp.WinHttpOpen.restype = ctypes.c_void_p + + winhttp.WinHttpConnect.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, wintypes.WORD, wintypes.DWORD] + winhttp.WinHttpConnect.restype = ctypes.c_void_p + + winhttp.WinHttpOpenRequest.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_wchar_p, + ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpOpenRequest.restype = ctypes.c_void_p + + winhttp.WinHttpSetOption.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpSetOption.restype = wintypes.BOOL + + winhttp.WinHttpSendRequest.argtypes = [ + ctypes.c_void_p, ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_void_p, + wintypes.DWORD, wintypes.DWORD, ctypes.c_size_t] + winhttp.WinHttpSendRequest.restype = wintypes.BOOL + + winhttp.WinHttpReceiveResponse.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p] + winhttp.WinHttpReceiveResponse.restype = wintypes.BOOL + + winhttp.WinHttpQueryHeaders.argtypes = [ + ctypes.c_void_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_void_p, + ctypes.POINTER(wintypes.DWORD), ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryHeaders.restype = wintypes.BOOL + + winhttp.WinHttpQueryDataAvailable.argtypes = [ + ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryDataAvailable.restype = wintypes.BOOL + + winhttp.WinHttpReadData.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, + ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpReadData.restype = wintypes.BOOL + + winhttp.WinHttpCloseHandle.argtypes = [ctypes.c_void_p] + winhttp.WinHttpCloseHandle.restype = wintypes.BOOL + + _WIN32 = { + "ctypes": ctypes, "wintypes": wintypes, + "ncrypt": ncrypt, "crypt32": crypt32, "winhttp": winhttp, + "NCRYPT_PROV_HANDLE": NCRYPT_PROV_HANDLE, + "NCRYPT_KEY_HANDLE": NCRYPT_KEY_HANDLE, + "SECURITY_STATUS": SECURITY_STATUS, + "CERT_CONTEXT": CERT_CONTEXT, + "PCCERT_CONTEXT": PCCERT_CONTEXT, + "BCRYPT_PSS_PADDING_INFO": BCRYPT_PSS_PADDING_INFO, + "CRYPT_HASH_BLOB": CRYPT_HASH_BLOB, + "ERROR_SUCCESS": 0, + "NCRYPT_OVERWRITE_KEY_FLAG": 0x00000080, + "NCRYPT_LENGTH_PROPERTY": "Length", + "NCRYPT_EXPORT_POLICY_PROPERTY": "Export Policy", + "NCRYPT_KEY_USAGE_PROPERTY": "Key Usage", + "NCRYPT_ALLOW_SIGNING_FLAG": 0x00000002, + "NCRYPT_ALLOW_DECRYPT_FLAG": 0x00000001, + "BCRYPT_PAD_PSS": 0x00000008, + "BCRYPT_SHA256_ALGORITHM": "SHA256", + "BCRYPT_RSA_ALGORITHM": "RSA", + "BCRYPT_RSAPUBLIC_BLOB": "RSAPUBLICBLOB", + "BCRYPT_RSAPUBLIC_MAGIC": 0x31415352, + "X509_ASN_ENCODING": 0x00000001, + "PKCS_7_ASN_ENCODING": 0x00010000, + "CERT_NCRYPT_KEY_HANDLE_PROP_ID": 78, + "CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG": 0x40000000, + "WINHTTP_ACCESS_TYPE_DEFAULT_PROXY": 0, + "WINHTTP_FLAG_SECURE": 0x00800000, + "WINHTTP_OPTION_CLIENT_CERT_CONTEXT": 47, + "WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT": 161, + "WINHTTP_QUERY_STATUS_CODE": 19, + "WINHTTP_QUERY_FLAG_NUMBER": 0x20000000, + } + return _WIN32 + + +# --------------------------------------------------------------------------- +# Win32 error helpers +# --------------------------------------------------------------------------- + +def _raise_win32_last_error(msg: str) -> None: + from .managed_identity import MsiV2Error + win32 = _load_win32() + ctypes_mod = win32["ctypes"] + err = ctypes_mod.get_last_error() + detail = "" + try: + detail = ctypes_mod.FormatError(err).strip() + except Exception: + pass + raise MsiV2Error(f"{msg} (winerror={err} {detail})" if detail + else f"{msg} (winerror={err})") + + +def _check_security_status(status: int, what: str) -> None: + from .managed_identity import MsiV2Error + if int(status) != 0: + code_u32 = int(status) & 0xFFFFFFFF + raise MsiV2Error(f"[msi_v2] {what} failed: status=0x{code_u32:08X}") + + +def _status_u32(status: int) -> int: + return int(status) & 0xFFFFFFFF + + +def _is_key_not_found(status: int) -> bool: + return _status_u32(status) in ( + _NTE_BAD_KEYSET, _NTE_NO_KEY, _NTE_NOT_FOUND, _NTE_KEY_DOES_NOT_EXIST) + + +# --------------------------------------------------------------------------- +# DER helpers (minimal PKCS#10 CSR builder) +# --------------------------------------------------------------------------- + +def _der_len(n: int) -> bytes: + if n < 0: + raise ValueError("DER length cannot be negative") + if n < 0x80: + return bytes([n]) + out = bytearray() + m = n + while m > 0: + out.insert(0, m & 0xFF) + m >>= 8 + return bytes([0x80 | len(out)]) + bytes(out) + + +def _der(tag: int, content: bytes) -> bytes: + return bytes([tag]) + _der_len(len(content)) + content + + +def _der_null() -> bytes: + return b"\x05\x00" + + +def _der_integer(value: int) -> bytes: + if value < 0: + raise ValueError("Only non-negative INTEGER supported") + if value == 0: + raw = b"\x00" + else: + raw = value.to_bytes((value.bit_length() + 7) // 8, "big") + if raw[0] & 0x80: + raw = b"\x00" + raw + return _der(0x02, raw) + + +def _der_oid(oid: str) -> bytes: + parts = [int(x) for x in oid.split(".")] + if len(parts) < 2 or parts[0] > 2 or (parts[0] < 2 and parts[1] >= 40): + raise ValueError(f"Invalid OID: {oid}") + first = 40 * parts[0] + parts[1] + out = bytearray([first]) + for p in parts[2:]: + if p < 0: + raise ValueError(f"Invalid OID component: {oid}") + stack = bytearray() + if p == 0: + stack.append(0) + else: + m = p + while m > 0: + stack.insert(0, m & 0x7F) + m >>= 7 + for i in range(len(stack) - 1): + stack[i] |= 0x80 + out.extend(stack) + return _der(0x06, bytes(out)) + + +def _der_sequence(*items: bytes) -> bytes: + return _der(0x30, b"".join(items)) + + +def _der_set(*items: bytes) -> bytes: + enc = sorted(items) + return _der(0x31, b"".join(enc)) + + +def _der_bitstring(data: bytes) -> bytes: + return _der(0x03, b"\x00" + data) + + +def _der_ia5string(value: str) -> bytes: + return _der(0x16, value.encode("ascii")) + + +def _der_context_explicit(tagnum: int, inner: bytes) -> bytes: + return _der(0xA0 + tagnum, inner) + + +def _der_context_implicit_constructed(tagnum: int, inner_content: bytes) -> bytes: + return _der(0xA0 + tagnum, inner_content) + + +def _der_name_cn_dc(cn: str, dc: str) -> bytes: + cn_atv = _der_sequence(_der_oid("2.5.4.3"), _der_utf8string(cn)) + cn_rdn = _der_set(cn_atv) + try: + dc_value = _der_ia5string(dc) + except Exception: + dc_value = _der_utf8string(dc) + dc_atv = _der_sequence( + _der_oid("0.9.2342.19200300.100.1.25"), dc_value) + dc_rdn = _der_set(dc_atv) + return _der_sequence(cn_rdn, dc_rdn) + + +def _der_subject_public_key_info_rsa(modulus: int, exponent: int) -> bytes: + rsa_pub = _der_sequence(_der_integer(modulus), _der_integer(exponent)) + alg = _der_sequence( + _der_oid("1.2.840.113549.1.1.1"), _der_null()) # rsaEncryption + return _der_sequence(alg, _der_bitstring(rsa_pub)) + + +def _der_algid_rsapss_sha256() -> bytes: + """AlgorithmIdentifier for RSASSA-PSS with SHA-256, MGF1(SHA-256), + saltLength=32. trailerField omitted (DEFAULT=1, per .NET).""" + sha256 = _der_sequence( + _der_oid("2.16.840.1.101.3.4.2.1"), _der_null()) + mgf1 = _der_sequence(_der_oid("1.2.840.113549.1.1.8"), sha256) + salt_len = _der_integer(32) + params = _der_sequence( + _der_context_explicit(0, sha256), + _der_context_explicit(1, mgf1), + _der_context_explicit(2, salt_len), + # trailerField [3] omitted — DEFAULT trailerFieldBC(1) + ) + return _der_sequence(_der_oid("1.2.840.113549.1.1.10"), params) + + +# --------------------------------------------------------------------------- +# CNG/NCrypt wrappers +# --------------------------------------------------------------------------- + +def _ncrypt_get_property(win32: Dict[str, Any], h: Any, name: str) -> bytes: + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + cb = wintypes.DWORD(0) + status = ncrypt.NCryptGetProperty(h, name, None, 0, + ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, f"NCryptGetProperty({name})") + if cb.value == 0: + return b"" + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptGetProperty(h, name, buf, cb.value, + ctypes_mod.byref(cb), 0) + _check_security_status(status, f"NCryptGetProperty({name})") + return bytes(buf[:cb.value]) + + +def _stable_key_name(client_id: str) -> str: + base = (client_id or "").strip() + safe = [] + for ch in base: + if ch.isalnum() or ch in ("-", "_"): + safe.append(ch) + else: + safe.append("_") + return "MsalMsiV2Key_" + "".join(safe)[:90] + + +def _open_or_create_keyguard_rsa_key( + win32: Dict[str, Any], *, key_name: str, +) -> Tuple[Any, Any, str, bool]: + """ + Open a named per-boot KeyGuard RSA key if it exists; otherwise create it. + Returns: (prov_handle, key_handle, key_name, opened_existing) + """ + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + prov = win32["NCRYPT_PROV_HANDLE"]() + status = ncrypt.NCryptOpenStorageProvider( + ctypes_mod.byref(prov), + "Microsoft Software Key Storage Provider", 0) + _check_security_status(status, "NCryptOpenStorageProvider") + + key = win32["NCRYPT_KEY_HANDLE"]() + + # 1) Try open first + status = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), + str(key_name), _AT_SIGNATURE, 0) + if int(status) == 0: + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error( + "[msi_v2] Virtual Iso property missing/invalid; " + "Credential Guard likely not active.") + return prov, key, str(key_name), True + + if not _is_key_not_found(status): + _check_security_status(status, f"NCryptOpenKey({key_name})") + + # 2) Create if missing + flags = (win32["NCRYPT_OVERWRITE_KEY_FLAG"] + | _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG + | _NCRYPT_USE_PER_BOOT_KEY_FLAG) + + status = ncrypt.NCryptCreatePersistedKey( + prov, ctypes_mod.byref(key), win32["BCRYPT_RSA_ALGORITHM"], + str(key_name), _AT_SIGNATURE, flags) + + if _status_u32(status) == _NTE_EXISTS: + # Race: another thread/process created it + status2 = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), + str(key_name), _AT_SIGNATURE, 0) + _check_security_status(status2, + f"NCryptOpenKey({key_name}) after exists") + return prov, key, str(key_name), True + + _check_security_status(status, "NCryptCreatePersistedKey") + + # Set key properties + length = wintypes.DWORD(int(_RSA_KEY_SIZE)) + status = ncrypt.NCryptSetProperty( + key, win32["NCRYPT_LENGTH_PROPERTY"], + ctypes_mod.byref(length), ctypes_mod.sizeof(length), 0) + _check_security_status(status, "NCryptSetProperty(Length)") + + usage = wintypes.DWORD( + win32["NCRYPT_ALLOW_SIGNING_FLAG"] + | win32["NCRYPT_ALLOW_DECRYPT_FLAG"]) + status = ncrypt.NCryptSetProperty( + key, win32["NCRYPT_KEY_USAGE_PROPERTY"], + ctypes_mod.byref(usage), ctypes_mod.sizeof(usage), 0) + _check_security_status(status, "NCryptSetProperty(Key Usage)") + + export_policy = wintypes.DWORD(0) # non-exportable + status = ncrypt.NCryptSetProperty( + key, win32["NCRYPT_EXPORT_POLICY_PROPERTY"], + ctypes_mod.byref(export_policy), ctypes_mod.sizeof(export_policy), 0) + _check_security_status(status, "NCryptSetProperty(Export Policy)") + + status = ncrypt.NCryptFinalizeKey(key, 0) + _check_security_status(status, "NCryptFinalizeKey") + + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error( + "[msi_v2] Virtual Iso property not available; " + "Credential Guard likely not active.") + + return prov, key, str(key_name), False + + +def _ncrypt_export_rsa_public( + win32: Dict[str, Any], key: Any, +) -> Tuple[int, int]: + """Export RSA public key (modulus, exponent) from an NCrypt key handle.""" + from .managed_identity import MsiV2Error + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + cb = wintypes.DWORD(0) + status = ncrypt.NCryptExportKey( + key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, None, 0, + ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, "NCryptExportKey(size)") + if cb.value == 0: + raise MsiV2Error("[msi_v2] NCryptExportKey returned empty blob size") + + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptExportKey( + key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, + buf, cb.value, ctypes_mod.byref(cb), 0) + _check_security_status(status, "NCryptExportKey(RSAPUBLICBLOB)") + blob = bytes(buf[:cb.value]) + + if len(blob) < 24: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB too small") + + magic, bitlen, cb_exp, cb_mod, cb_p1, cb_p2 = struct.unpack( + "<6I", blob[:24]) + if magic != win32["BCRYPT_RSAPUBLIC_MAGIC"]: + raise MsiV2Error( + f"[msi_v2] RSAPUBLICBLOB magic mismatch: 0x{magic:08X}") + + offset = 24 + if len(blob) < offset + cb_exp + cb_mod: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB truncated") + + exp_bytes = blob[offset:offset + cb_exp] + offset += cb_exp + mod_bytes = blob[offset:offset + cb_mod] + + exponent = int.from_bytes(exp_bytes, "big") + modulus = int.from_bytes(mod_bytes, "big") + return modulus, exponent + + +def _ncrypt_sign_pss_sha256( + win32: Dict[str, Any], key: Any, digest: bytes, +) -> bytes: + """Sign a SHA-256 digest using RSA-PSS via NCryptSignHash.""" + from .managed_identity import MsiV2Error + if len(digest) != 32: + raise MsiV2Error("[msi_v2] Expected SHA-256 digest (32 bytes)") + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + pad = win32["BCRYPT_PSS_PADDING_INFO"]( + win32["BCRYPT_SHA256_ALGORITHM"], 32) + hash_buf = (ctypes_mod.c_ubyte * len(digest)).from_buffer_copy(digest) + + cb_sig = wintypes.DWORD(0) + status = ncrypt.NCryptSignHash( + key, ctypes_mod.byref(pad), hash_buf, len(digest), + None, 0, ctypes_mod.byref(cb_sig), win32["BCRYPT_PAD_PSS"]) + if int(status) != 0 and cb_sig.value == 0: + _check_security_status(status, "NCryptSignHash(size)") + if cb_sig.value == 0: + raise MsiV2Error("[msi_v2] NCryptSignHash returned empty sig size") + + sig_buf = (ctypes_mod.c_ubyte * cb_sig.value)() + status = ncrypt.NCryptSignHash( + key, ctypes_mod.byref(pad), hash_buf, len(digest), + sig_buf, cb_sig.value, ctypes_mod.byref(cb_sig), + win32["BCRYPT_PAD_PSS"]) + _check_security_status(status, "NCryptSignHash") + return bytes(sig_buf[:cb_sig.value]) + + +# --------------------------------------------------------------------------- +# CSR builder +# --------------------------------------------------------------------------- + +def _build_csr_b64( + win32: Dict[str, Any], key: Any, + client_id: str, tenant_id: str, cu_id: Any, +) -> str: + """Build CSR signed by KeyGuard key (RSA-PSS SHA256), with cuId OID + attribute.""" + modulus, exponent = _ncrypt_export_rsa_public(win32, key) + subject = _der_name_cn_dc(client_id, tenant_id) + spki = _der_subject_public_key_info_rsa(modulus, exponent) + + cuid_json = json.dumps(cu_id, separators=(",", ":"), ensure_ascii=False) + cuid_val = _der_utf8string(cuid_json) + + attr = _der_sequence(_der_oid(_CU_ID_OID_STR), _der_set(cuid_val)) + attrs_content = b"".join(sorted([attr])) + attrs = _der_context_implicit_constructed(0, attrs_content) + + cri = _der_sequence(_der_integer(0), subject, spki, attrs) + digest = hashlib.sha256(cri).digest() + signature = _ncrypt_sign_pss_sha256(win32, key, digest) + + csr = _der_sequence(cri, _der_algid_rsapss_sha256(), + _der_bitstring(signature)) + return base64.b64encode(csr).decode("ascii") + + +# --------------------------------------------------------------------------- +# Certificate binding + WinHTTP mTLS +# --------------------------------------------------------------------------- + +def _create_cert_context_with_key( + win32: Dict[str, Any], cert_der: bytes, key: Any, key_name: str, + *, ksp_name: str = "Microsoft Software Key Storage Provider", +) -> Tuple[Any, Any, Tuple[Any, ...]]: + """Create a CERT_CONTEXT from DER bytes and associate it with a CNG + private key via multiple properties for SChannel compatibility.""" + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + crypt32 = win32["crypt32"] + + enc = win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"] + buf = ctypes_mod.create_string_buffer(cert_der) + ctx = crypt32.CertCreateCertificateContext(enc, buf, len(cert_der)) + if not ctx: + _raise_win32_last_error( + "[msi_v2] CertCreateCertificateContext failed") + + keepalive: List[Any] = [buf] + + try: + # (A) Direct NCrypt key handle + key_handle = ctypes_mod.c_void_p(int(key.value)) + keepalive.append(key_handle) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, win32["CERT_NCRYPT_KEY_HANDLE_PROP_ID"], + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_handle)) + if not ok: + _raise_win32_last_error( + "[msi_v2] CertSetCertificateContextProperty" + "(CERT_NCRYPT_KEY_HANDLE_PROP_ID) failed") + + # (B) CERT_KEY_CONTEXT_PROP_ID (best-effort) + CERT_KEY_CONTEXT_PROP_ID = 5 + CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF + + class CERT_KEY_CONTEXT(ctypes_mod.Structure): + _fields_ = [ + ("cbSize", wintypes.DWORD), + ("hCryptProvOrNCryptKey", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + key_ctx = CERT_KEY_CONTEXT( + ctypes_mod.sizeof(CERT_KEY_CONTEXT), key_handle, + wintypes.DWORD(CERT_NCRYPT_KEY_SPEC)) + keepalive.append(key_ctx) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, CERT_KEY_CONTEXT_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_ctx)) + if not ok: + logger.debug("[msi_v2] Failed to set CERT_KEY_CONTEXT_PROP_ID " + "(last_error=%s)", ctypes_mod.get_last_error()) + + # (C) CERT_KEY_PROV_INFO_PROP_ID (for SChannel reopen by name) + CERT_KEY_PROV_INFO_PROP_ID = 2 + + class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): + _fields_ = [ + ("pwszContainerName", wintypes.LPWSTR), + ("pwszProvName", wintypes.LPWSTR), + ("dwProvType", wintypes.DWORD), + ("dwFlags", wintypes.DWORD), + ("cProvParam", wintypes.DWORD), + ("rgProvParam", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + container_buf = ctypes_mod.create_unicode_buffer(str(key_name)) + provider_buf = ctypes_mod.create_unicode_buffer(str(ksp_name)) + keepalive.extend([container_buf, provider_buf]) + + prov_info = CRYPT_KEY_PROV_INFO( + ctypes_mod.cast(container_buf, wintypes.LPWSTR), + ctypes_mod.cast(provider_buf, wintypes.LPWSTR), + wintypes.DWORD(0), # CNG/KSP + wintypes.DWORD(_NCRYPT_SILENT_FLAG), + wintypes.DWORD(0), None, + wintypes.DWORD(_AT_SIGNATURE)) + keepalive.append(prov_info) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, CERT_KEY_PROV_INFO_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(prov_info)) + if not ok: + logger.debug("[msi_v2] Failed to set CERT_KEY_PROV_INFO_PROP_ID " + "(last_error=%s)", ctypes_mod.get_last_error()) + + return ctx, buf, tuple(keepalive) + + except Exception: + try: + crypt32.CertFreeCertificateContext(ctx) + except Exception: + pass + raise + + +def _winhttp_close(win32: Dict[str, Any], h: Any) -> None: + try: + if h: + win32["winhttp"].WinHttpCloseHandle(h) + except Exception: + pass + + +def _winhttp_post( + win32: Dict[str, Any], url: str, cert_ctx: Any, + body: bytes, headers: Dict[str, str], +) -> Tuple[int, bytes]: + """POST to https URL using WinHTTP + SChannel with client cert.""" + from .managed_identity import MsiV2Error + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + winhttp = win32["winhttp"] + + u = urlparse(url) + if u.scheme.lower() != "https": + raise MsiV2Error( + f"[msi_v2] Token endpoint must be https, got: {url!r}") + if not u.hostname: + raise MsiV2Error(f"[msi_v2] Invalid token endpoint: {url!r}") + + host = u.hostname + port = u.port or 443 + path = u.path or "/" + if u.query: + path += "?" + u.query + + h_session = winhttp.WinHttpOpen( + "msal-python-msi-v2", win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], + None, None, 0) + if not h_session: + _raise_win32_last_error("[msi_v2] WinHttpOpen failed") + + h_connect = None + h_request = None + try: + # Best-effort: HTTP/2 + client cert + enable = wintypes.DWORD(1) + try: + winhttp.WinHttpSetOption( + h_session, + win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], + ctypes_mod.byref(enable), ctypes_mod.sizeof(enable)) + except Exception: + pass + + h_connect = winhttp.WinHttpConnect(h_session, host, int(port), 0) + if not h_connect: + _raise_win32_last_error("[msi_v2] WinHttpConnect failed") + + h_request = winhttp.WinHttpOpenRequest( + h_connect, "POST", path, None, None, None, + win32["WINHTTP_FLAG_SECURE"]) + if not h_request: + _raise_win32_last_error("[msi_v2] WinHttpOpenRequest failed") + + # Attach cert for mTLS + ok = winhttp.WinHttpSetOption( + h_request, win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], + cert_ctx, ctypes_mod.sizeof(win32["CERT_CONTEXT"])) + if not ok: + _raise_win32_last_error( + "[msi_v2] WinHttpSetOption(CLIENT_CERT) failed") + + header_lines = "".join(f"{k}: {v}\r\n" for k, v in headers.items()) + body_buf = ctypes_mod.create_string_buffer(body) + + ok = winhttp.WinHttpSendRequest( + h_request, header_lines, 0xFFFFFFFF, + body_buf, len(body), len(body), 0) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSendRequest failed") + + ok = winhttp.WinHttpReceiveResponse(h_request, None) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReceiveResponse failed") + + # Read status code + status = wintypes.DWORD(0) + status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) + index = wintypes.DWORD(0) + + ok = winhttp.WinHttpQueryHeaders( + h_request, + win32["WINHTTP_QUERY_STATUS_CODE"] + | win32["WINHTTP_QUERY_FLAG_NUMBER"], + None, ctypes_mod.byref(status), + ctypes_mod.byref(status_size), ctypes_mod.byref(index)) + if not ok: + _raise_win32_last_error( + "[msi_v2] WinHttpQueryHeaders(STATUS_CODE) failed") + + # Read body + chunks: List[bytes] = [] + while True: + avail = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryDataAvailable( + h_request, ctypes_mod.byref(avail)) + if not ok: + _raise_win32_last_error( + "[msi_v2] WinHttpQueryDataAvailable failed") + if avail.value == 0: + break + buf = (ctypes_mod.c_ubyte * avail.value)() + read = wintypes.DWORD(0) + ok = winhttp.WinHttpReadData( + h_request, buf, avail.value, ctypes_mod.byref(read)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReadData failed") + if read.value: + chunks.append(bytes(buf[:read.value])) + if read.value == 0: + break + + return int(status.value), b"".join(chunks) + finally: + _winhttp_close(win32, h_request) + _winhttp_close(win32, h_connect) + _winhttp_close(win32, h_session) + + +def _acquire_token_mtls_schannel( + win32: Dict[str, Any], token_endpoint: str, cert_ctx: Any, + client_id: str, scope: str, +) -> Dict[str, Any]: + """Acquire an mtls_pop token from ESTS using WinHTTP/SChannel.""" + from .managed_identity import MsiV2Error + + form = urlencode({ + "grant_type": "client_credentials", + "client_id": client_id, + "scope": scope, + "token_type": "mtls_pop", + }).encode("utf-8") + + status, resp_body = _winhttp_post( + win32, token_endpoint, cert_ctx, form, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }) + + text = resp_body.decode("utf-8", errors="replace") + if status < 200 or status >= 300: + raise MsiV2Error( + f"[msi_v2] ESTS token request failed: " + f"HTTP {status} Body={text!r}") + return _json_loads(text, "ESTS token") + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +# Type alias for attestation provider callback. +# Signature: (endpoint, key_handle, client_id, cache_key) -> JWT string. +# cache_key is the stable per-boot key name for optimal caching. +AttestationTokenProvider = Callable[[str, int, str, str], str] + + +def obtain_token( + http_client, + managed_identity: Dict[str, Any], + resource: str, + *, + attestation_enabled: bool = True, + attestation_token_provider: Optional[AttestationTokenProvider] = None, +) -> Dict[str, Any]: + """ + Acquire mtls_pop token using Windows KeyGuard + optional MAA attestation. + + Flow: + 1. getplatformmetadata → client_id, tenant_id, cu_id, attestationEndpoint + 2. Open/create named per-boot KeyGuard RSA key (non-exportable) + 3. Build PKCS#10 CSR with cuId attribute, sign with RSA-PSS/SHA256 + 4. Get attestation JWT from MAA (if attestation_token_provider given) + 5. issuecredential → X.509 cert + 6. Create CERT_CONTEXT, bind to KeyGuard private key + 7. POST /oauth2/v2.0/token via WinHTTP/SChannel with mTLS + + Args: + http_client: HTTP client (e.g., requests.Session()) + managed_identity: MSAL managed identity dict + resource: Resource URI for token acquisition + attestation_enabled: Whether attestation is enabled + attestation_token_provider: Callback (endpoint, key_handle, + client_id, cache_key) -> JWT string. Provided by + msal-key-attestation package. cache_key is the stable + per-boot key name for optimal caching. None means + non-attested flow. + + Returns: + Token response dict with access_token, expires_in, token_type, + cert_pem, cert_der_b64, cert_thumbprint_sha256. + + Raises: + MsiV2Error: on any failure (no fallback to MSI v1) + """ + from .managed_identity import MsiV2Error + + win32 = _load_win32() + ncrypt = win32["ncrypt"] + crypt32 = win32["crypt32"] + + base = _imds_base() + params = _mi_query_params(managed_identity) + corr = _new_correlation_id() + + # Check certificate cache first. The cache key must reflect the + # effective attestation mode so a non-attested certificate is never + # reused as if it were attested (or vice versa). + attested = attestation_enabled and attestation_token_provider is not None + cache_key = _cert_cache_key(managed_identity, attested) + cached = _cert_cache_get(cache_key) + + prov = None + key = None + cert_ctx = None + cert_der = None + + try: + # 1) getplatformmetadata + meta_url = base + _CSR_METADATA_PATH + meta = _imds_get_json(http_client, meta_url, params, + _imds_headers(corr)) + + client_id = _get_first(meta, "clientId", "client_id") + tenant_id = _get_first(meta, "tenantId", "tenant_id") + cu_id = meta.get("cuId") if "cuId" in meta else meta.get("cu_id") + attestation_endpoint = _get_first( + meta, "attestationEndpoint", "attestation_endpoint") + + if not client_id or not tenant_id or cu_id is None: + raise MsiV2Error( + f"[msi_v2] getplatformmetadata missing required fields: " + f"{meta}") + + # 2) Open-or-create KeyGuard RSA key + key_name = (os.getenv(_KEY_NAME_ENVVAR) + or _stable_key_name(str(client_id))) + prov, key, key_name, opened = _open_or_create_keyguard_rsa_key( + win32, key_name=key_name) + logger.debug("[msi_v2] KeyGuard key=%s opened_existing=%s", + key_name, opened) + + # Use cached cert if available + if cached is not None: + cert_der = cached.cert_der + token_endpoint = cached.token_endpoint + canonical_client_id = cached.client_id + logger.debug("[msi_v2] Using cached certificate") + else: + # 3) Build CSR + csr_b64 = _build_csr_b64( + win32, key, str(client_id), str(tenant_id), cu_id) + + # 4) Attestation (if provider given) + att_jwt = "" + if attestation_enabled and attestation_token_provider is not None: + if not attestation_endpoint: + raise MsiV2Error( + "[msi_v2] attestationEndpoint missing from metadata.") + try: + att_jwt = attestation_token_provider( + str(attestation_endpoint), + int(key.value), + str(client_id), + str(key_name)) + except MsiV2Error: + raise + except Exception as exc: + raise MsiV2Error( + f"[msi_v2] Attestation provider failed: {exc}" + ) from exc + if not att_jwt or not str(att_jwt).strip(): + raise MsiV2Error( + "[msi_v2] Attestation provider returned empty JWT.") + + # 5) issuecredential + issue_url = base + _ISSUE_CREDENTIAL_PATH + issue_headers = _imds_headers(corr) + issue_headers["Content-Type"] = "application/json" + + cred = _imds_post_json( + http_client, issue_url, params, issue_headers, + {"csr": csr_b64, "attestation_token": att_jwt}) + + cert_b64 = _get_first(cred, "certificate", "Certificate") + if not cert_b64: + raise MsiV2Error( + f"[msi_v2] issuecredential missing certificate: {cred}") + + try: + cert_der = base64.b64decode(cert_b64) + except Exception as exc: + raise MsiV2Error( + "[msi_v2] issuecredential returned invalid base64 " + "certificate") from exc + + canonical_client_id = (_get_first(cred, "client_id", "clientId") + or str(client_id)) + token_endpoint = _token_endpoint_from_credential(cred) + + # Cache the cert + not_after = _try_parse_cert_not_after(cert_der) + _cert_cache_set(cache_key, _CertCacheEntry( + cert_der=cert_der, + cert_pem=_der_to_pem(cert_der), + token_endpoint=token_endpoint, + client_id=canonical_client_id, + not_after=not_after or (time.time() + 8 * 3600), + )) + + # 6) Create CERT_CONTEXT, bind to KeyGuard private key + cert_ctx, _, _ = _create_cert_context_with_key( + win32, cert_der, key, key_name) + scope = _resource_to_scope(resource) + + # 7) POST token via WinHTTP/SChannel mTLS + token_json = _acquire_token_mtls_schannel( + win32, token_endpoint, cert_ctx, canonical_client_id, scope) + + if token_json.get("access_token") and token_json.get("expires_in"): + cert_pem = _der_to_pem(cert_der) + cert_thumbprint = get_cert_thumbprint_sha256(cert_pem) + + token_type = token_json.get("token_type") or "mtls_pop" + access_token = token_json["access_token"] + + result = { + "access_token": access_token, + "expires_in": int(token_json["expires_in"]), + "token_type": token_type, + "resource": token_json.get("resource"), + # Legacy fields (kept for backward compat) + "cert_pem": cert_pem, + "cert_der_b64": base64.b64encode( + cert_der).decode("ascii"), + "cert_thumbprint_sha256": cert_thumbprint, + } + + # binding_certificate is only present for mTLS PoP tokens + # on Windows. For non-mTLS or non-Windows flows it is None. + if (sys.platform == "win32" + and token_type.lower() in ("mtls_pop", "pop")): + from .windows_certificate import WindowsCertificate + + # Create WindowsCertificate — transfers key/prov ownership + binding_cert = WindowsCertificate._from_handles( + win32, cert_der, key, prov, key_name) + # Ownership transferred — don't free in finally + key = None + prov = None + + result["binding_certificate"] = binding_cert + result["binding_certificate_metadata"] = ( + binding_cert.to_metadata_dict()) + else: + result["binding_certificate"] = None + result["binding_certificate_metadata"] = None + + return result + return token_json + + except Exception: + # On failure, evict cached cert (may be stale/bad) + _cert_cache_remove(cache_key) + raise + + finally: + try: + if cert_ctx: + crypt32.CertFreeCertificateContext(cert_ctx) + except Exception: + pass + # Only free if ownership was NOT transferred to WindowsCertificate + try: + if key: + ncrypt.NCryptFreeObject(key) + except Exception: + pass + try: + if prov: + ncrypt.NCryptFreeObject(prov) + except Exception: + pass diff --git a/msal/windows_certificate.py b/msal/windows_certificate.py new file mode 100644 index 00000000..ceb81455 --- /dev/null +++ b/msal/windows_certificate.py @@ -0,0 +1,435 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +WindowsCertificate — Python equivalent of .NET X509Certificate2 for +platform-backed certificates with non-exportable private keys. + +This class holds a reference to a certificate in the Windows certificate +store (or an in-memory CERT_CONTEXT) and the associated CNG key handle. +The private key is NEVER exported — all signing/TLS operations are +delegated to Windows CNG/SChannel via the native handle. + +Usage: + # From store (for app developers) + cert = WindowsCertificate.from_store( + store_path="CurrentUser/My", + thumbprint="7C0F1A2B3C4D5E6F7890ABCDEF1234567890ABCDE", + ) + + # From internal MSAL flow (returned in auth result) + cert = WindowsCertificate._from_handles(win32, cert_der, key_handle, ...) + + # App developer uses cert for downstream mTLS via compatible transport + session = SomeSchannelTransport(client_certificate=cert) + session.get(url, headers={"Authorization": f"{result['token_type']} {result['access_token']}"}) +""" + +from __future__ import annotations + +import base64 +import hashlib +import logging +import sys +import threading +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +class WindowsCertificate: + """ + Platform-backed certificate reference (Python equivalent of X509Certificate2). + + This object: + - Holds the public certificate metadata (thumbprint, subject, issuer, DER bytes) + - Maintains a live reference to the CNG private key handle + - Can create a CERT_CONTEXT bound to the key for SChannel/WinHTTP usage + - NEVER exposes private key bytes + + The object is safe to pass to an mTLS transport for downstream API calls. + It implements context manager protocol for deterministic handle cleanup. + """ + + def __init__(self): + """Use factory methods (from_store, _from_handles) instead.""" + self._cert_der: bytes = b"" + self._key_handle: Any = None # NCRYPT_KEY_HANDLE (c_void_p) + self._prov_handle: Any = None # NCRYPT_PROV_HANDLE (c_void_p) + self._key_name: str = "" + self._store_path: str = "" + self._win32: Optional[dict] = None + self._closed = False + self._owns_key_handle = True # Whether we should free the key handle + self._lock = threading.Lock() + + @classmethod + def from_store( + cls, + store_path: str = "CurrentUser/My", + *, + thumbprint: Optional[str] = None, + subject_name: Optional[str] = None, + ) -> "WindowsCertificate": + """ + Open a certificate from the Windows certificate store. + + Args: + store_path: Certificate store location (e.g., "CurrentUser/My", + "LocalMachine/My"). + thumbprint: SHA-1 thumbprint (hex, case-insensitive) to select + the certificate. Preferred selector. + subject_name: Subject CN to match. Less precise than thumbprint. + + Returns: + WindowsCertificate with the key handle opened. + + Raises: + OSError: if certificate not found or key not accessible. + ValueError: if arguments are invalid. + """ + if sys.platform != "win32": + raise OSError( + "WindowsCertificate.from_store() is only supported on Windows") + + if not thumbprint and not subject_name: + raise ValueError( + "Either 'thumbprint' or 'subject_name' must be provided") + + from .msi_v2 import _load_win32, _raise_win32_last_error + + win32 = _load_win32() + ctypes_mod = win32["ctypes"] + crypt32 = win32["crypt32"] + + # Parse store_path: "CurrentUser/My" -> (location_flag, "My") + parts = store_path.replace("\\", "/").split("/", 1) + if len(parts) != 2: + raise ValueError( + f"store_path must be 'Location/StoreName', got: {store_path!r}") + + location_str, store_name = parts + location_map = { + "currentuser": 0x00010000, # CERT_SYSTEM_STORE_CURRENT_USER + "localmachine": 0x00020000, # CERT_SYSTEM_STORE_LOCAL_MACHINE + } + location_flag = location_map.get(location_str.lower()) + if location_flag is None: + raise ValueError( + f"Unsupported store location: {location_str!r}. " + f"Use 'CurrentUser' or 'LocalMachine'.") + + # Use CERT_STORE_PROV_SYSTEM_W (numeric 10) with CERT_STORE_READONLY_FLAG + CERT_STORE_PROV_SYSTEM_W = ctypes_mod.c_void_p(10) + CERT_STORE_READONLY_FLAG = 0x00008000 + + h_store = crypt32.CertOpenStore( + CERT_STORE_PROV_SYSTEM_W, + 0, + None, + location_flag | CERT_STORE_READONLY_FLAG, + ctypes_mod.c_wchar_p(store_name), + ) + if not h_store: + _raise_win32_last_error( + f"[WindowsCertificate] CertOpenStore failed for {store_path}") + + cert_ctx = None + try: + cert_ctx = cls._find_cert_in_store( + win32, h_store, thumbprint=thumbprint, + subject_name=subject_name) + if not cert_ctx: + selector = thumbprint or subject_name + raise OSError( + f"[WindowsCertificate] Certificate not found in " + f"{store_path} with selector: {selector}") + + # Extract DER from context + cert_info = ctypes_mod.cast( + cert_ctx, + ctypes_mod.POINTER(win32["CERT_CONTEXT"]) + ).contents + cert_der = ctypes_mod.string_at( + cert_info.pbCertEncoded, cert_info.cbCertEncoded) + + # Acquire private key handle + key_handle = ctypes_mod.c_void_p() + key_spec = ctypes_mod.c_ulong() + caller_must_free = ctypes_mod.c_int() + + CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG = 0x00040000 + ok = crypt32.CryptAcquireCertificatePrivateKey( + cert_ctx, + CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, + None, + ctypes_mod.byref(key_handle), + ctypes_mod.byref(key_spec), + ctypes_mod.byref(caller_must_free), + ) + if not ok or not key_handle.value: + _raise_win32_last_error( + "[WindowsCertificate] CryptAcquireCertificatePrivateKey " + "failed — private key may not be accessible") + + # Verify we got an NCrypt key (CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF) + CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF + if key_spec.value != CERT_NCRYPT_KEY_SPEC: + # Got a legacy CryptoAPI handle, not CNG — not supported + raise OSError( + f"[WindowsCertificate] Certificate has a legacy CryptoAPI " + f"key (spec={key_spec.value}), CNG key required") + + # Build the object + obj = cls() + obj._cert_der = bytes(cert_der) + obj._key_handle = key_handle + obj._prov_handle = None + obj._key_name = "" + obj._store_path = store_path + obj._win32 = win32 + # Track whether we own the key handle + obj._owns_key_handle = bool(caller_must_free.value) + return obj + + finally: + # Free the CERT_CONTEXT from the store search (we extracted DER) + if cert_ctx: + try: + crypt32.CertFreeCertificateContext(cert_ctx) + except Exception: + pass + crypt32.CertCloseStore(h_store, 0) + + @classmethod + def _from_handles( + cls, + win32: dict, + cert_der: bytes, + key_handle: Any, + prov_handle: Any, + key_name: str, + ) -> "WindowsCertificate": + """ + Internal: create from existing NCrypt handles (used by obtain_token). + + The caller transfers ownership of key_handle and prov_handle to this + object. They will be freed when the WindowsCertificate is closed. + """ + obj = cls() + obj._cert_der = bytes(cert_der) + obj._key_handle = key_handle + obj._prov_handle = prov_handle + obj._key_name = key_name + obj._store_path = "" + obj._win32 = win32 + return obj + + # ------------------------------------------------------------------ + # Public properties (safe to access, no private key exposure) + # ------------------------------------------------------------------ + + @property + def thumbprint_sha1(self) -> str: + """SHA-1 thumbprint of the certificate (hex uppercase).""" + return hashlib.sha1(self._cert_der).hexdigest().upper() + + @property + def thumbprint_sha256(self) -> str: + """SHA-256 thumbprint of the certificate (hex uppercase).""" + return hashlib.sha256(self._cert_der).hexdigest().upper() + + @property + def x5t_s256(self) -> str: + """Base64url-encoded SHA-256 thumbprint (for cnf claim matching).""" + digest = hashlib.sha256(self._cert_der).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + + @property + def public_certificate_der(self) -> bytes: + """DER-encoded public certificate bytes.""" + return self._cert_der + + @property + def public_certificate_pem(self) -> str: + """PEM-encoded public certificate.""" + b64 = base64.b64encode(self._cert_der).decode("ascii") + lines = [b64[i:i+64] for i in range(0, len(b64), 64)] + return ( + "-----BEGIN CERTIFICATE-----\n" + + "\n".join(lines) + + "\n-----END CERTIFICATE-----\n" + ) + + @property + def has_private_key(self) -> bool: + """Whether a private key handle is available.""" + return self._key_handle is not None and not self._closed + + @property + def store_path(self) -> str: + """The store path this certificate was loaded from (if any).""" + return self._store_path + + @property + def key_name(self) -> str: + """The CNG key name (if known).""" + return self._key_name + + # ------------------------------------------------------------------ + # Native handle access (for transports that need CERT_CONTEXT) + # ------------------------------------------------------------------ + + def create_cert_context(self) -> Any: + """ + Create a new CERT_CONTEXT bound to this certificate's private key. + + The caller is responsible for freeing the returned CERT_CONTEXT + via CertFreeCertificateContext when done. + + The CERT_CONTEXT references (but does NOT own) the private key handle. + The WindowsCertificate must remain open while the CERT_CONTEXT is in use. + + Returns: + PCCERT_CONTEXT (ctypes pointer) with private key bound. + + Raises: + RuntimeError: if the certificate has been closed. + """ + with self._lock: + if self._closed: + raise RuntimeError( + "WindowsCertificate has been closed — cannot create " + "CERT_CONTEXT") + if not self._key_handle: + raise RuntimeError( + "WindowsCertificate has no private key handle") + + from .msi_v2 import _create_cert_context_with_key + + cert_ctx, _, _ = _create_cert_context_with_key( + self._win32, self._cert_der, self._key_handle, self._key_name) + return cert_ctx + + @property + def _native_key_handle(self) -> Any: + """Internal: raw NCRYPT_KEY_HANDLE for signing operations.""" + if self._closed: + raise RuntimeError("WindowsCertificate has been closed") + return self._key_handle + + # ------------------------------------------------------------------ + # Lifecycle management + # ------------------------------------------------------------------ + + def close(self) -> None: + """Release native handles. Safe to call multiple times.""" + with self._lock: + if self._closed: + return + self._closed = True + + if self._win32: + ncrypt = self._win32.get("ncrypt") + if ncrypt: + if self._key_handle and self._owns_key_handle: + try: + ncrypt.NCryptFreeObject(self._key_handle) + except Exception: + pass + if self._prov_handle: + try: + ncrypt.NCryptFreeObject(self._prov_handle) + except Exception: + pass + + self._key_handle = None + self._prov_handle = None + + def __enter__(self) -> "WindowsCertificate": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + def __repr__(self) -> str: + state = "closed" if self._closed else "open" + tp = self.thumbprint_sha1[:16] + "..." if self._cert_der else "empty" + return f"" + + # ------------------------------------------------------------------ + # Serialization helpers (for auth result metadata — no private key) + # ------------------------------------------------------------------ + + def to_metadata_dict(self) -> dict: + """ + Return JSON-safe metadata about this certificate. + Safe for logging and cross-process diagnostics. + """ + return { + "store_path": self._store_path, + "thumbprint_sha1": self.thumbprint_sha1, + "thumbprint_sha256": self.thumbprint_sha256, + "x5t#S256": self.x5t_s256, + "has_private_key": self.has_private_key, + "key_name": self._key_name or None, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _find_cert_in_store( + win32: dict, h_store: Any, *, + thumbprint: Optional[str] = None, + subject_name: Optional[str] = None, + ) -> Any: + """Find a certificate in an open store by thumbprint or subject.""" + ctypes_mod = win32["ctypes"] + crypt32 = win32["crypt32"] + + if thumbprint: + # Normalize: remove spaces, colons, dashes + normalized = thumbprint.replace(" ", "").replace(":", "").replace("-", "") + if len(normalized) != 40: + raise ValueError( + f"thumbprint must be a 40-character SHA-1 hex string, " + f"got {len(normalized)} characters after normalization") + thumb_bytes = bytes.fromhex(normalized) + blob = win32["CRYPT_HASH_BLOB"]() + buf = ctypes_mod.create_string_buffer(thumb_bytes) + blob.cbData = len(thumb_bytes) + blob.pbData = ctypes_mod.cast(buf, ctypes_mod.c_void_p) + + CERT_FIND_HASH = 0x10000 # CERT_FIND_SHA1_HASH + ctx = crypt32.CertFindCertificateInStore( + h_store, + win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"], + 0, + CERT_FIND_HASH, + ctypes_mod.byref(blob), + None, + ) + return ctx + + if subject_name: + CERT_FIND_SUBJECT_STR = 0x00080007 + ctx = crypt32.CertFindCertificateInStore( + h_store, + win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"], + 0, + CERT_FIND_SUBJECT_STR, + ctypes_mod.c_wchar_p(subject_name), + None, + ) + return ctx + + return None diff --git a/sample/MSI_V2_GUIDE.md b/sample/MSI_V2_GUIDE.md new file mode 100644 index 00000000..c726153e --- /dev/null +++ b/sample/MSI_V2_GUIDE.md @@ -0,0 +1,182 @@ +# MSI v2 (mTLS Proof-of-Possession) — Setup & Usage Guide + +## Overview + +MSI v2 enables Managed Identity token acquisition using mTLS Proof-of-Possession +on Windows Azure VMs with Credential Guard / KeyGuard. + +The implementation is split into two packages mirroring the MSAL .NET architecture: + +| Package | .NET Equivalent | What | +|---|---|---| +| `msal` | `Microsoft.Identity.Client` | Core mTLS PoP flow (KeyGuard key, CSR, IMDS, WinHTTP) | +| `msal-key-attestation` | `Microsoft.Identity.Client.KeyAttestation` | AttestationClientLib.dll native bindings | + +## Prerequisites + +1. **Windows Azure VM** with: + - Credential Guard / KeyGuard enabled (VBS) + - System-assigned or user-assigned managed identity + - Network access to IMDS (169.254.169.254) + +2. **AttestationClientLib.dll** — place in one of: + - Current working directory + - Same directory as `python.exe` + - Directory of the `msal_key_attestation` package + - Path specified by `ATTESTATION_CLIENTLIB_PATH` env var + +## Installation + +```bash +# Core MSAL package (includes MSI v2 flow) +pip install msal + +# Attestation support (loads AttestationClientLib.dll) +pip install msal-key-attestation +``` + +For development (from this repo): + +```bash +# Install msal in editable mode +pip install -e . + +# Install msal-key-attestation in editable mode +pip install -e msal-key-attestation/ +``` + +## Quick Start + +```python +import msal +import requests + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), +) + +result = client.acquire_token_for_client( + resource="https://graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True, +) + +if "access_token" in result: + print(f"Token type: {result['token_type']}") # mtls_pop + print(f"Expires in: {result['expires_in']}s") + print(f"Thumbprint: {result['cert_thumbprint_sha256']}") +else: + print(f"Error: {result}") +``` + +## API Reference + +### `acquire_token_for_client()` — New Parameters + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `mtls_proof_of_possession` | `bool` | `False` | Enable MSI v2 mTLS PoP flow | +| `with_attestation_support` | `bool` | `False` | Enable KeyGuard attestation (requires `msal-key-attestation`) | + +**Behavior matrix:** + +| `mtls_proof_of_possession` | `with_attestation_support` | Result | +|---|---|---| +| `False` | `False` | MSI v1 (default, unchanged) | +| `True` | `False` | MSI v1 fallthrough (PoP alone = no-op) | +| `False` | `True` | Raises `ManagedIdentityError` | +| `True` | `True` | **MSI v2** — KeyGuard + attestation + mTLS PoP | + +### Response + +When MSI v2 succeeds, the response dict includes extra fields: + +```python +{ + "access_token": "eyJ...", + "expires_in": 3600, + "token_type": "mtls_pop", + "resource": "https://graph.microsoft.com", + # Additional MSI v2 fields: + "cert_pem": "-----BEGIN CERTIFICATE-----\n...", + "cert_der_b64": "MIID...", + "cert_thumbprint_sha256": "abc123...", +} +``` + +### Errors + +| Error Class | When | +|---|---| +| `ManagedIdentityError` | `with_attestation_support=True` without `mtls_proof_of_possession` | +| `MsiV2Error` | Any MSI v2 flow failure (no fallback to v1) | +| `MsiV2Error` | `msal-key-attestation` package not installed | + +### Verification + +```python +from msal.msi_v2 import verify_cnf_binding + +bound = verify_cnf_binding(result["access_token"], result["cert_pem"]) +assert bound, "Token is not bound to the certificate" +``` + +## Flow Diagram + +``` +App MSAL Python IMDS ESTS (mTLS) + | | | | + |-- acquire_token ------>| | | + | (mtls_pop=True, | | | + | attestation=True) | | | + | | | | + | [1] NCrypt: KeyGuard key | | + | [2] GET /getplatformmetadata ->| | + | |<-- clientId, tenantId,| | + | | cuId, attestEP | | + | [3] Build CSR (RSA-PSS/SHA256) | | + | [4] AttestationClientLib.dll | | + | |--- MAA attest -------->| | + | |<-- attestation JWT | | + | [5] POST /issuecredential ---->| | + | |<-- certificate, endpoint | + | [6] Crypt32: bind cert to key | | + | [7] WinHTTP: POST /token ------|--------->| | + | | | | | + | |<-- mtls_pop token ----|---------| | + |<-- result -------------| | | +``` + +## Environment Variables + +| Variable | Description | +|---|---| +| `AZURE_POD_IDENTITY_AUTHORITY_HOST` | Override IMDS base URL | +| `MSAL_MSI_V2_KEY_NAME` | Override per-boot key name | +| `ATTESTATION_CLIENTLIB_PATH` | Full path to AttestationClientLib.dll | +| `MSAL_MSI_V2_ATTESTATION_CACHE` | `"0"` to disable MAA JWT caching | + +## Running the Sample + +```bash +cd sample/ +python msi_v2_sample.py + +# With verbose logging: +MSI_V2_VERBOSE=1 python msi_v2_sample.py + +# Custom resource: +RESOURCE=https://vault.azure.net python msi_v2_sample.py +``` + +## Running Tests + +```bash +# Core MSI v2 tests (no Windows/KeyGuard dependency) +pytest tests/test_msi_v2.py -v + +# Attestation package tests +cd msal-key-attestation/ +pytest tests/test_attestation.py -v +``` diff --git a/sample/devapp_msi_v2_mtls/app.py b/sample/devapp_msi_v2_mtls/app.py new file mode 100644 index 00000000..b6dd989d --- /dev/null +++ b/sample/devapp_msi_v2_mtls/app.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python +""" +MSI v2 mTLS PoP Dev App +======================== +Demonstrates the full mTLS PoP flow using ONLY public MSAL APIs: + + 1. ManagedIdentityClient.acquire_token_for_client() → token + binding_certificate + 2. App developer builds Authorization header from token_type + access_token + 3. App developer uses SchannelSession (separate package) for downstream mTLS call + +MSAL hands out the token and cert. The app owns the downstream call. +This is the Python equivalent of the .NET MsiV2DemoApp. + +Prerequisites: + - Windows Azure VM with Credential Guard / KeyGuard + - pip install requests + - msal-key-attestation package with AttestationClientLib.dll + - msal-schannel-transport package + +Usage: + python app.py + +Environment variables (optional): + RESOURCE - Token audience (default: https://graph.microsoft.com) + DOWNSTREAM_URL - URL to call over mTLS (default: Graph mTLS endpoint) + UAMI_CLIENT_ID - User-assigned MI client ID (default: system-assigned) +""" + +import json +import os +import sys +import base64 + +# Ensure local packages (msal/, msal-key-attestation/, msal-schannel-transport/) +# take precedence over system-installed versions. +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +for _subpkg in [ + _REPO_ROOT, + os.path.join(_REPO_ROOT, "msal-key-attestation"), + os.path.join(_REPO_ROOT, "msal-schannel-transport"), +]: + if _subpkg not in sys.path: + sys.path.insert(0, _subpkg) + + +def main(): + print("=" * 60) + print(" MSI v2 mTLS PoP — Dev App") + print(" (uses ONLY public MSAL APIs)") + print("=" * 60) + + if sys.platform != "win32": + print("ERROR: Requires Windows with KeyGuard.") + return 1 + + # ── Import MSAL public API only ────────────────────────── + import msal + import requests + + # ── Configuration ──────────────────────────────────────── + resource = os.environ.get("RESOURCE", "https://vault.azure.net") + downstream_url = os.environ.get( + "DOWNSTREAM_URL", + "https://tokenbinding.vault.azure.net/secrets/boundsecret/?api-version=2015-06-01") + uami_client_id = os.environ.get("UAMI_CLIENT_ID", "") + + print(f"\n Resource: {resource}") + print(f" Downstream: {downstream_url}") + if uami_client_id: + print(f" UAMI Client ID: {uami_client_id}") + else: + print(f" Identity: System-Assigned") + + # ── Step 1: Create MSAL client (public API) ───────────── + print(f"\n{'─' * 60}") + print(" Step 1: Create ManagedIdentityClient") + print(f"{'─' * 60}") + + if uami_client_id: + mi = msal.UserAssignedManagedIdentity(client_id=uami_client_id) + else: + mi = msal.SystemAssignedManagedIdentity() + + http_client = requests.Session() + + client = msal.ManagedIdentityClient( + mi, + http_client=http_client, + ) + print(" ✓ Client created") + + # ── Step 2: Acquire token (public API) ─────────────────── + print(f"\n{'─' * 60}") + print(" Step 2: acquire_token_for_client(mtls_proof_of_possession=True)") + print(f"{'─' * 60}") + + try: + result = client.acquire_token_for_client( + resource=resource, + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + except Exception as e: + print(f"\n ✗ FAILED: {type(e).__name__}: {e}") + return 1 + + if "error" in result: + print(f" ✗ Error: {result['error']}") + print(f" {result.get('error_description', '')}") + return 1 + + # ── Step 3: Inspect auth result ────────────────────────── + print(f"\n{'─' * 60}") + print(" Step 3: Auth result") + print(f"{'─' * 60}") + + access_token = result["access_token"] + token_type = result.get("token_type", "unknown") + expires_in = result.get("expires_in", 0) + + print(f" ✓ access_token: {access_token[:30]}...") + print(f" ✓ token_type: {token_type}") + print(f" ✓ expires_in: {expires_in}s") + + # Get binding_certificate — the WindowsCertificate object + # This is what MSAL hands out. The app developer uses it for downstream. + binding_cert = result.get("binding_certificate") + + if binding_cert is None: + print(f" ✗ binding_certificate is None!") + return 1 + + print(f" ✓ binding_certificate: {binding_cert}") + print(f" has_private_key: {binding_cert.has_private_key}") + print(f" thumbprint_sha1: {binding_cert.thumbprint_sha1}") + print(f" thumbprint_sha256: {binding_cert.thumbprint_sha256}") + print(f" x5t#S256: {binding_cert.x5t_s256}") + + # Verify token binding (cnf.x5t#S256 matches cert) + try: + parts = access_token.split(".") + payload_b64 = parts[1] + "=" * ((4 - len(parts[1]) % 4) % 4) + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + cnf = payload.get("cnf", {}) + token_x5t = cnf.get("x5t#S256", "NOT FOUND") + cert_x5t = binding_cert.x5t_s256 + match = "✓ MATCH" if token_x5t == cert_x5t else "✗ MISMATCH" + print(f"\n Token cnf.x5t#S256: {token_x5t}") + print(f" Cert x5t#S256: {cert_x5t}") + print(f" Binding: {match}") + if token_x5t != cert_x5t: + print(" ERROR: Token is NOT bound to the certificate!") + return 1 + except Exception as e: + print(f" ⚠ Could not verify binding: {e}") + + # ── Step 4: Downstream mTLS call (app developer's job) ─── + print(f"\n{'─' * 60}") + print(" Step 4: Downstream mTLS call (using SchannelSession)") + print(f"{'─' * 60}") + + try: + from msal_schannel_transport import SchannelSession + except ImportError: + print(" ✗ msal-schannel-transport not installed") + print(" pip install msal-schannel-transport") + binding_cert.close() + return 1 + + # App developer builds the auth header from token_type + access_token + auth_header = f"{token_type} {access_token}" + + print(f" URL: {downstream_url}") + print(f" Authorization: {token_type} ") + print(f" Client cert: {binding_cert.thumbprint_sha1[:16]}...") + + try: + with SchannelSession(client_certificate=binding_cert) as session: + response = session.get( + downstream_url, + headers={ + "Authorization": auth_header, + "x-ms-tokenboundauth": "true", + }, + ) + + print(f"\n Response: HTTP {response.status_code}") + + if response.status_code == 200: + body = response.json() + print(f" ✓ Success!") + for k in list(body.keys())[:5]: + v = body[k] + if isinstance(v, str) and len(v) > 80: + v = v[:80] + "..." + elif isinstance(v, list): + v = f"[{len(v)} items]" + print(f" {k}: {v}") + elif response.status_code == 403: + print(f" ⚠ 403 Forbidden — identity may lack permissions") + print(f" {response.text[:300]}") + elif response.status_code == 401: + print(f" ✗ 401 — mTLS binding may have failed") + print(f" {response.text[:300]}") + else: + print(f" ? HTTP {response.status_code}") + print(f" {response.text[:300]}") + + except Exception as e: + print(f"\n ✗ Downstream call FAILED: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + + # ── Cleanup ────────────────────────────────────────────── + binding_cert.close() + + print(f"\n{'─' * 60}") + print(" Done.") + print(f"{'─' * 60}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/sample/msi_v2_sample.py b/sample/msi_v2_sample.py new file mode 100644 index 00000000..f290281f --- /dev/null +++ b/sample/msi_v2_sample.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +""" +MSI v2 (mTLS PoP) Sample — Managed Identity with KeyGuard Attestation. + +This sample demonstrates acquiring an mTLS Proof-of-Possession token using +MSAL Python's MSI v2 flow on a Windows Azure VM with Credential Guard. + +Prerequisites: + - Windows Azure VM with Credential Guard / KeyGuard enabled + - AttestationClientLib.dll accessible (next to script or via env var) + - pip install msal msal-key-attestation requests + +Usage: + python msi_v2_sample.py + +Environment variables (optional): + RESOURCE - Resource URI (default: https://graph.microsoft.com) + RESOURCE_URL - URL to call with the token (default: Graph /applications) + MSI_V2_VERBOSE - Set to "1" for verbose logging +""" + +import logging +import os +import sys + +import requests +import msal + +# Optional: enable verbose logging +if os.getenv("MSI_V2_VERBOSE", "").strip() in ("1", "true"): + logging.basicConfig(level=logging.DEBUG) +else: + logging.basicConfig(level=logging.INFO) + +logger = logging.getLogger(__name__) + + +def main(): + # --- Configuration --- + resource = os.getenv( + "RESOURCE", "https://graph.microsoft.com") + resource_url = os.getenv( + "RESOURCE_URL", + "https://mtlstb.graph.microsoft.com/v1.0/applications?$top=5") + + logger.info("=" * 60) + logger.info("MSI v2 (mTLS PoP) Sample") + logger.info("=" * 60) + logger.info("Resource: %s", resource) + logger.info("Resource URL: %s", resource_url) + + # --- Create client --- + http_session = requests.Session() + + # Optionally add retry + from requests.adapters import HTTPAdapter + try: + from urllib3.util.retry import Retry + retries = Retry(total=3, backoff_factor=0.5, + status_forcelist=[429, 500, 502, 503, 504]) + http_session.mount("https://", HTTPAdapter(max_retries=retries)) + http_session.mount("http://", HTTPAdapter(max_retries=retries)) + except ImportError: + pass + + client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=http_session, + ) + + # --- Acquire token --- + logger.info("Acquiring mTLS PoP token...") + result = client.acquire_token_for_client( + resource=resource, + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + + if "access_token" not in result: + print("ERROR: Token acquisition failed. Check logs for details.") + sys.exit(1) + + token_type = result.get("token_type", "unknown") + expires_in = result.get("expires_in", 0) + thumbprint = result.get("cert_thumbprint_sha256", "") + + print("Token acquired successfully!") + print(f" token_type: {token_type}") + print(f" expires_in: {expires_in} seconds") + print(f" thumbprint: {thumbprint[:16]}..." if thumbprint else " thumbprint: (none)") + + if token_type != "mtls_pop": + print(f"WARNING: Expected token_type='mtls_pop' but got '{token_type}'.") + + # --- Verify binding --- + from msal.msi_v2 import verify_cnf_binding + cert_pem = result.get("cert_pem", "") + if cert_pem: + bound = verify_cnf_binding(result["access_token"], cert_pem) + print(f" cnf binding: {'VERIFIED' if bound else 'FAILED'}") + if not bound: + logger.error("Token is NOT bound to the certificate!") + sys.exit(1) + + # --- Call resource over mTLS (optional) --- + if resource_url: + logger.info("Calling resource: %s", resource_url) + + # Note: mTLS resource calls require presenting the same cert via + # WinHTTP/SChannel. The requests library cannot present a KeyGuard- + # bound cert. This demonstrates the auth header format only. + headers = { + "Authorization": f"{token_type} {result['access_token']}", + "Accept": "application/json", + } + + try: + resp = http_session.get(resource_url, headers=headers) + print(f" Response: HTTP {resp.status_code}") + if not resp.ok: + logger.warning(" Request failed with status %d", + resp.status_code) + except Exception as exc: + logger.warning(" Resource call failed: %s", type(exc).__name__) + print( + " Note: mTLS resource calls require WinHTTP/SChannel; " + "the requests library cannot present the mTLS cert.") + + logger.info("=" * 60) + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/test_e2e_mtls_pop.py b/tests/test_e2e_mtls_pop.py new file mode 100644 index 00000000..b505e6a5 --- /dev/null +++ b/tests/test_e2e_mtls_pop.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +E2E test: Acquire mTLS PoP token via MSAL → call Azure Key Vault over mTLS. + +This test demonstrates the correct architecture: + 1. MSAL acquires the token and returns: + - access_token + token_type + - app developer constructs the authorization header from token_type + + access_token + - binding_certificate (WindowsCertificate — platform key handle) + 2. App developer uses a SEPARATE mTLS transport (SchannelSession) + to call the downstream API with the binding_certificate. + +Requirements: + - Must run on a KeyGuard-enabled Azure VM (Windows) + - Must have MSI v2 configured with attestation + - msal-key-attestation package installed (for attestation) + - msal-schannel-transport package installed (for downstream call) + +Usage: + set RUN_E2E_TESTS=1 && python -m pytest tests/test_e2e_mtls_pop.py -v + OR + python tests/test_e2e_mtls_pop.py (standalone) +""" + +from __future__ import annotations + +import json +import os +import sys +import unittest + +# Skip entirely on non-Windows +if sys.platform != "win32": + raise unittest.SkipTest("E2E mTLS PoP tests require Windows") + + +# --------------------------------------------------------------------------- +# Configuration — from environment or defaults (KeyGuard Azure VM) +# --------------------------------------------------------------------------- + +# These come from the MSI v2 platform metadata — no manual config needed +# for system-assigned identity. For user-assigned, set these: +E2E_CLIENT_ID = os.environ.get("MSI_V2_CLIENT_ID", "") +E2E_RESOURCE = os.environ.get("MSI_V2_RESOURCE", "https://vault.azure.net") + +# Key Vault to test against +E2E_VAULT_URL = os.environ.get( + "MSI_V2_VAULT_URL", "https://msidlabvault.vault.azure.net") +E2E_SECRET_NAME = os.environ.get("MSI_V2_SECRET_NAME", "test-secret") + + +class TestMtlsPopE2E(unittest.TestCase): + """ + End-to-end: MSAL token acquisition → downstream mTLS call to Key Vault. + + Architecture: + MSAL (obtain_token) → auth result with binding_certificate + SchannelSession (separate package) → downstream mTLS GET to AKV + """ + + @unittest.skipUnless( + os.environ.get("RUN_E2E_TESTS"), + "Set RUN_E2E_TESTS=1 to run E2E tests (requires KeyGuard VM)") + def test_acquire_token_and_call_akv(self): + """ + Full flow: + 1. obtain_token() → gets mTLS PoP token + WindowsCertificate + 2. Verify auth result has binding_certificate + 3. Create SchannelSession with binding_certificate + 4. Call AKV GET /secrets/{name}?api-version=7.5 + 5. Verify 200 OK + """ + import requests + from msal.msi_v2 import obtain_token + from msal.managed_identity import SystemAssignedManagedIdentity + from msal.windows_certificate import WindowsCertificate + + # 1. Acquire token + http_client = requests.Session() + mi = SystemAssignedManagedIdentity() + if E2E_CLIENT_ID: + from msal.managed_identity import UserAssignedManagedIdentity + mi = UserAssignedManagedIdentity(client_id=E2E_CLIENT_ID) + + # Try with attestation if available + attestation_provider = None + try: + from msal_key_attestation import create_attestation_provider + attestation_provider = create_attestation_provider() + except ImportError: + pass + + result = obtain_token( + http_client=http_client, + managed_identity=mi, + resource=E2E_RESOURCE, + attestation_enabled=True, + attestation_token_provider=attestation_provider, + ) + + # 2. Verify auth result structure + self.assertIn("access_token", result) + self.assertIn("token_type", result) + self.assertIn("binding_certificate", result) + self.assertIn("binding_certificate_metadata", result) + + binding_cert = result["binding_certificate"] + self.assertIsInstance(binding_cert, WindowsCertificate) + self.assertTrue(binding_cert.has_private_key) + self.assertTrue(len(binding_cert.thumbprint_sha256) == 64) + + token_type = result["token_type"] + self.assertIn( + token_type.lower(), ("mtls_pop", "pop"), + f"Expected mtls_pop token type, got: {token_type}") + + print(f"\n✓ Token acquired successfully") + print(f" token_type: {token_type}") + print(f" cert thumbprint: {binding_cert.thumbprint_sha256[:16]}...") + print(f" x5t#S256: {binding_cert.x5t_s256}") + + # 3. Use SEPARATE transport for downstream mTLS call + from msal_schannel_transport import SchannelSession + + # App developer constructs auth header from token_type + access_token + auth_header = f"{token_type} {result['access_token']}" + + with SchannelSession(client_certificate=binding_cert) as session: + # 4. Call AKV + url = (f"{E2E_VAULT_URL}/secrets/{E2E_SECRET_NAME}" + f"?api-version=7.5") + + response = session.get( + url, + headers={"Authorization": auth_header}, + ) + + # 5. Verify response + print(f"\n AKV response: HTTP {response.status_code}") + if response.status_code == 200: + body = response.json() + print(f" Secret retrieved: {E2E_SECRET_NAME}") + self.assertIn("value", body) + elif response.status_code == 403: + # Expected if MSI doesn't have AKV access + print(f" 403 Forbidden — MSI may not have AKV access policy") + print(f" Body: {response.text[:200]}") + elif response.status_code == 401: + print(f" 401 Unauthorized — token binding mismatch?") + print(f" Body: {response.text[:200]}") + self.fail("401 from AKV — mTLS binding may have failed") + else: + print(f" Unexpected: {response.text[:200]}") + + # 6. Cleanup — WindowsCertificate freed by context manager + binding_cert.close() + + @unittest.skipUnless( + os.environ.get("RUN_E2E_TESTS"), + "Set RUN_E2E_TESTS=1 to run E2E tests (requires KeyGuard VM)") + def test_binding_certificate_matches_token(self): + """ + Verify that the binding_certificate's x5t#S256 matches the + token's cnf claim — the fundamental mTLS PoP security property. + """ + import base64 + import requests + from msal.msi_v2 import obtain_token + from msal.managed_identity import SystemAssignedManagedIdentity + + http_client = requests.Session() + mi = SystemAssignedManagedIdentity() + + attestation_provider = None + try: + from msal_key_attestation import create_attestation_provider + attestation_provider = create_attestation_provider() + except ImportError: + pass + + result = obtain_token( + http_client=http_client, + managed_identity=mi, + resource=E2E_RESOURCE, + attestation_enabled=True, + attestation_token_provider=attestation_provider, + ) + + self.assertIn("binding_certificate", result) + binding_cert = result["binding_certificate"] + + try: + # Decode JWT payload (without validation — we just check cnf) + token = result["access_token"] + parts = token.split(".") + self.assertEqual(len(parts), 3, "Token should be a JWT (3 parts)") + + # Pad base64url + payload_b64 = parts[1] + payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + + # Check cnf claim + self.assertIn("cnf", payload, + "mTLS PoP token must have 'cnf' claim") + cnf = payload["cnf"] + self.assertIn("x5t#S256", cnf, + "cnf claim must have 'x5t#S256'") + + token_thumbprint = cnf["x5t#S256"] + cert_thumbprint = binding_cert.x5t_s256 + + self.assertEqual( + token_thumbprint, cert_thumbprint, + f"Token binding mismatch!\n" + f" Token cnf.x5t#S256: {token_thumbprint}\n" + f" Cert x5t#S256: {cert_thumbprint}") + + print(f"\n✓ Token binding verified: x5t#S256 = {cert_thumbprint}") + finally: + binding_cert.close() + + @unittest.skipUnless( + os.environ.get("RUN_E2E_TESTS"), + "Set RUN_E2E_TESTS=1 to run E2E tests (requires KeyGuard VM)") + def test_certificate_from_store(self): + """ + Test WindowsCertificate.from_store() — loads cert by thumbprint + from the Windows cert store (separate from MSAL token flow). + """ + from msal.windows_certificate import WindowsCertificate + + # This test requires a cert thumbprint in env + thumbprint = os.environ.get("MSI_V2_CERT_THUMBPRINT") + if not thumbprint: + self.skipTest( + "Set MSI_V2_CERT_THUMBPRINT to test from_store()") + + cert = WindowsCertificate.from_store( + store_path="CurrentUser/My", + thumbprint=thumbprint, + ) + + try: + self.assertTrue(cert.has_private_key) + self.assertEqual( + cert.thumbprint_sha1.upper(), + thumbprint.upper()) + self.assertTrue(len(cert.public_certificate_der) > 100) + self.assertTrue( + cert.public_certificate_pem.startswith( + "-----BEGIN CERTIFICATE-----")) + print(f"\n✓ Cert loaded from store: {cert}") + print(f" SHA-256: {cert.thumbprint_sha256}") + print(f" x5t#S256: {cert.x5t_s256}") + finally: + cert.close() + + +if __name__ == "__main__": + # Allow running standalone + os.environ["RUN_E2E_TESTS"] = "1" + unittest.main(verbosity=2) diff --git a/tests/test_msi_v2.py b/tests/test_msi_v2.py new file mode 100644 index 00000000..c0c1168f --- /dev/null +++ b/tests/test_msi_v2.py @@ -0,0 +1,563 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +"""Tests for MSI v2 (mTLS PoP) implementation. + +Goals: +- Provide strong unit coverage without depending on KeyGuard / real IMDS. +- Validate: + * x5t#S256 helper correctness (local) + * verify_cnf_binding behavior (msal.msi_v2) + * Certificate cache behavior + * ManagedIdentityClient strict gating behavior + * IMDS wire-contract helpers +""" + +import base64 +import datetime +import hashlib +import json +import os +import time +import unittest + +try: + from unittest.mock import patch, MagicMock +except ImportError: + from mock import patch, MagicMock + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import msal +from msal import MsiV2Error +from msal.msi_v2 import ( + verify_cnf_binding, + _cert_cache_clear, + _cert_cache_get, + _cert_cache_set, + _cert_cache_key, + _cert_cache_remove, + _CertCacheEntry, + _mi_query_params, + _resource_to_scope, + _token_endpoint_from_credential, + _der_to_pem, +) + +from tests.http_client import MinimalResponse + + +# --------------------------------------------------------------------------- +# Local helpers +# --------------------------------------------------------------------------- + +def _make_self_signed_cert(private_key, common_name="test"): + subject = issuer = x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, common_name)]) + now = datetime.datetime.now(datetime.timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=1)) + .sign(private_key, hashes.SHA256(), default_backend()) + ) + return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + cert = x509.load_pem_x509_certificate( + cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + return (base64.urlsafe_b64encode(hashlib.sha256(cert_der).digest()) + .rstrip(b"=").decode("ascii")) + + +def _b64url(s: bytes) -> str: + return base64.urlsafe_b64encode(s).rstrip(b"=").decode("ascii") + + +def _make_jwt(payload_obj, header_obj=None) -> str: + header_obj = header_obj or {"alg": "RS256", "typ": "JWT"} + header = _b64url( + json.dumps(header_obj, separators=(",", ":")).encode("utf-8")) + payload = _b64url( + json.dumps(payload_obj, separators=(",", ":")).encode("utf-8")) + sig = _b64url(b"sig") + return f"{header}.{payload}.{sig}" + + +# --------------------------------------------------------------------------- +# Thumbprint helper +# --------------------------------------------------------------------------- + +class TestThumbprintHelper(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "thumbprint-test") + + def test_returns_base64url_no_padding(self): + thumb = get_cert_thumbprint_sha256(self.cert_pem) + self.assertIsInstance(thumb, str) + self.assertNotIn("=", thumb) + decoded = base64.urlsafe_b64decode(thumb + "==") + self.assertEqual(len(decoded), 32) + + def test_same_cert_same_thumbprint(self): + t1 = get_cert_thumbprint_sha256(self.cert_pem) + t2 = get_cert_thumbprint_sha256(self.cert_pem) + self.assertEqual(t1, t2) + + def test_different_certs_different_thumbprints(self): + key2 = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + cert2_pem = _make_self_signed_cert(key2, "thumbprint-test-2") + self.assertNotEqual( + get_cert_thumbprint_sha256(self.cert_pem), + get_cert_thumbprint_sha256(cert2_pem)) + + def test_matches_manual_sha256_der(self): + cert = x509.load_pem_x509_certificate( + self.cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + expected = (base64.urlsafe_b64encode( + hashlib.sha256(cert_der).digest()) + .rstrip(b"=").decode("ascii")) + self.assertEqual(get_cert_thumbprint_sha256(self.cert_pem), expected) + + +# --------------------------------------------------------------------------- +# verify_cnf_binding +# --------------------------------------------------------------------------- + +class TestVerifyCnfBinding(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "cnf-test") + self.thumbprint = get_cert_thumbprint_sha256(self.cert_pem) + + def test_valid_binding_true(self): + token = _make_jwt({"cnf": {"x5t#S256": self.thumbprint}}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + def test_wrong_thumbprint_false(self): + token = _make_jwt({"cnf": {"x5t#S256": "wrong"}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_cnf_false(self): + token = _make_jwt({"sub": "nobody"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_x5t_false(self): + token = _make_jwt({"cnf": {}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_cnf_not_object_false(self): + token = _make_jwt({"cnf": "not-an-object"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_not_a_jwt_false(self): + self.assertFalse(verify_cnf_binding("notajwt", self.cert_pem)) + + def test_two_part_jwt_false(self): + self.assertFalse(verify_cnf_binding("a.b", self.cert_pem)) + + def test_four_part_jwt_false(self): + self.assertFalse(verify_cnf_binding("a.b.c.d", self.cert_pem)) + + def test_malformed_payload_base64_false(self): + self.assertFalse(verify_cnf_binding("header.!!!.sig", self.cert_pem)) + + def test_payload_not_json_false(self): + header = _b64url(b'{"alg":"none"}') + payload = _b64url(b"not-json") + self.assertFalse( + verify_cnf_binding(f"{header}.{payload}.sig", self.cert_pem)) + + def test_payload_with_padding_works(self): + header = base64.urlsafe_b64encode( + b'{"alg":"RS256"}').decode("ascii") + payload = base64.urlsafe_b64encode(json.dumps( + {"cnf": {"x5t#S256": self.thumbprint}}).encode("utf-8") + ).decode("ascii") + self.assertTrue( + verify_cnf_binding(f"{header}.{payload}.sig", self.cert_pem)) + + def test_unicode_in_payload(self): + token = _make_jwt({ + "cnf": {"x5t#S256": self.thumbprint}, "msg": "こんにちは"}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + +# --------------------------------------------------------------------------- +# Certificate cache +# --------------------------------------------------------------------------- + +class TestCertificateCache(unittest.TestCase): + def setUp(self): + _cert_cache_clear() + + def tearDown(self): + _cert_cache_clear() + + def _make_entry(self, *, not_after=None): + return _CertCacheEntry( + cert_der=b"fake-der", + cert_pem="-----BEGIN CERTIFICATE-----\nfake\n" + "-----END CERTIFICATE-----", + token_endpoint="https://login.microsoftonline.com/t/oauth2/v2.0/token", + client_id="test-client-id", + not_after=not_after or (time.time() + 48 * 3600), + ) + + def test_set_and_get(self): + entry = self._make_entry() + _cert_cache_set("k1", entry) + got = _cert_cache_get("k1") + self.assertIsNotNone(got) + self.assertEqual(got.cert_der, b"fake-der") + self.assertEqual(got.client_id, "test-client-id") + + def test_miss_returns_none(self): + self.assertIsNone(_cert_cache_get("no-such-key")) + + def test_expired_entry_evicted(self): + entry = self._make_entry(not_after=time.time() + 100) + _cert_cache_set("k2", entry) + # Force it to look expired + entry.not_after = time.time() - 1 + self.assertIsNone(_cert_cache_get("k2")) + + def test_insufficient_lifetime_not_cached(self): + # Not enough remaining lifetime (< MIN_REMAINING_LIFETIME_SEC) + from msal.msi_v2 import _CertCacheEntry + half_threshold = _CertCacheEntry.MIN_REMAINING_LIFETIME_SEC // 2 + entry = self._make_entry(not_after=time.time() + half_threshold) + _cert_cache_set("k3", entry) + self.assertIsNone(_cert_cache_get("k3")) + + def test_remove(self): + entry = self._make_entry() + _cert_cache_set("k4", entry) + _cert_cache_remove("k4") + self.assertIsNone(_cert_cache_get("k4")) + + def test_clear(self): + _cert_cache_set("k5", self._make_entry()) + _cert_cache_set("k6", self._make_entry()) + _cert_cache_clear() + self.assertIsNone(_cert_cache_get("k5")) + self.assertIsNone(_cert_cache_get("k6")) + + def test_cache_key_generation(self): + mi_sys = {"ManagedIdentityIdType": "SystemAssigned", "Id": None} + mi_user = {"ManagedIdentityIdType": "ClientId", "Id": "abc"} + mi_obj = {"ManagedIdentityIdType": "ObjectId", "Id": "abc"} + k1 = _cert_cache_key(mi_sys, True) + k2 = _cert_cache_key(mi_sys, False) + k3 = _cert_cache_key(mi_user, True) + k4 = _cert_cache_key(mi_obj, True) + self.assertNotEqual(k1, k2) + self.assertNotEqual(k1, k3) + # Same Id but different IdType must produce different keys + self.assertNotEqual(k3, k4) + self.assertIn("#att=1", k1) + self.assertIn("#att=0", k2) + self.assertIn("ClientId:", k3) + self.assertIn("ObjectId:", k4) + + +# --------------------------------------------------------------------------- +# IMDS wire-contract helpers +# --------------------------------------------------------------------------- + +class TestImdsHelpers(unittest.TestCase): + def test_mi_query_params_system_assigned(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "SystemAssigned", "Id": None}) + self.assertEqual(p["cred-api-version"], "2.0") + self.assertNotIn("client_id", p) + + def test_mi_query_params_client_id(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "ClientId", "Id": "abc"}) + self.assertEqual(p["client_id"], "abc") + + def test_mi_query_params_object_id(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "ObjectId", "Id": "oid"}) + self.assertEqual(p["object_id"], "oid") + + def test_mi_query_params_resource_id(self): + p = _mi_query_params( + {"ManagedIdentityIdType": "ResourceId", "Id": "/sub/..."}) + self.assertEqual(p["msi_res_id"], "/sub/...") + + def test_resource_to_scope_appends_default(self): + self.assertEqual( + _resource_to_scope("https://graph.microsoft.com"), + "https://graph.microsoft.com/.default") + + def test_resource_to_scope_preserves_existing(self): + self.assertEqual( + _resource_to_scope("https://graph.microsoft.com/.default"), + "https://graph.microsoft.com/.default") + + def test_resource_to_scope_strips_trailing_slash(self): + self.assertEqual( + _resource_to_scope("https://graph.microsoft.com/"), + "https://graph.microsoft.com/.default") + + def test_resource_to_scope_raises_on_empty(self): + with self.assertRaises(ValueError): + _resource_to_scope("") + + def test_token_endpoint_prefers_explicit(self): + cred = {"token_endpoint": "https://explicit.com/token", + "mtls_authentication_endpoint": "https://other"} + self.assertEqual( + _token_endpoint_from_credential(cred), + "https://explicit.com/token") + + def test_token_endpoint_falls_back_to_mtls_auth(self): + cred = { + "mtls_authentication_endpoint": "https://login.example.com", + "tenant_id": "tid", + } + self.assertEqual( + _token_endpoint_from_credential(cred), + "https://login.example.com/tid/oauth2/v2.0/token") + + def test_token_endpoint_raises_on_missing(self): + with self.assertRaises(MsiV2Error): + _token_endpoint_from_credential({}) + + +# --------------------------------------------------------------------------- +# ManagedIdentityClient gating +# --------------------------------------------------------------------------- + +class TestManagedIdentityClientStrictGating(unittest.TestCase): + def _make_client(self): + import requests + return msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + ) + + def test_error_is_exported(self): + self.assertIs(msal.MsiV2Error, MsiV2Error) + + def test_error_is_subclass(self): + self.assertTrue(issubclass(MsiV2Error, msal.ManagedIdentityError)) + + @patch("msal.managed_identity._obtain_token") + def test_default_path_calls_v1(self, mock_v1): + mock_v1.return_value = { + "access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client(resource="R") + self.assertEqual(res["access_token"], "V1") + mock_v1.assert_called_once() + + def test_attestation_requires_pop(self): + client = self._make_client() + with self.assertRaises(msal.ManagedIdentityError): + client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=False, + with_attestation_support=True) + + @patch("msal.msi_v2.obtain_token") + @patch("msal.managed_identity._obtain_token") + def test_pop_without_attestation_does_not_call_v2( + self, mock_v1, mock_v2): + mock_v1.return_value = { + "access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=True, + with_attestation_support=False) + self.assertEqual(res["token_type"], "Bearer") + mock_v2.assert_not_called() + mock_v1.assert_called_once() + + @patch("msal.managed_identity.create_attestation_provider", + create=True) + @patch("msal.msi_v2.obtain_token") + def test_v2_called_when_both_flags_true(self, mock_v2, _): + mock_v2.return_value = { + "access_token": "V2", "expires_in": 3600, + "token_type": "mtls_pop"} + client = self._make_client() + + with patch.dict("sys.modules", { + "msal_key_attestation": MagicMock( + create_attestation_provider=MagicMock( + return_value=lambda ep, kh, ci, ck="": "fake.jwt")) + }): + res = client.acquire_token_for_client( + resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + + self.assertEqual(res["token_type"], "mtls_pop") + mock_v2.assert_called_once() + args, kwargs = mock_v2.call_args + self.assertTrue(len(args) >= 3) + self.assertEqual(args[2], "https://mtlstb.graph.microsoft.com") + self.assertTrue(kwargs["attestation_enabled"]) + + @patch("msal.msi_v2.obtain_token", side_effect=MsiV2Error("boom")) + @patch("msal.managed_identity._obtain_token") + def test_strict_v2_failure_raises_no_v1_fallback( + self, mock_v1, mock_v2): + client = self._make_client() + with patch.dict("sys.modules", { + "msal_key_attestation": MagicMock( + create_attestation_provider=MagicMock( + return_value=lambda ep, kh, ci, ck="": "fake.jwt")) + }): + with self.assertRaises(MsiV2Error): + client.acquire_token_for_client( + resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + mock_v1.assert_not_called() + + @patch("msal.msi_v2.obtain_token", + side_effect=RuntimeError("DLL load failed")) + @patch("msal.managed_identity._obtain_token") + def test_runtime_error_wrapped_as_msi_v2_error( + self, mock_v1, mock_v2): + """RuntimeError from provider/DLL must surface as MsiV2Error.""" + client = self._make_client() + with patch.dict("sys.modules", { + "msal_key_attestation": MagicMock( + create_attestation_provider=MagicMock( + return_value=lambda ep, kh, ci, ck="": "fake.jwt")) + }): + with self.assertRaises(MsiV2Error) as ctx: + client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=True, + with_attestation_support=True) + self.assertIn("DLL load failed", str(ctx.exception)) + mock_v1.assert_not_called() + + def test_missing_attestation_package_raises_clear_error(self): + client = self._make_client() + with patch.dict("sys.modules", {"msal_key_attestation": None}): + with self.assertRaises(MsiV2Error) as ctx: + client.acquire_token_for_client( + resource="R", + mtls_proof_of_possession=True, + with_attestation_support=True) + self.assertIn("pip install msal-key-attestation", + str(ctx.exception)) + + +# --------------------------------------------------------------------------- +# DER helpers +# --------------------------------------------------------------------------- + +class TestDerHelpers(unittest.TestCase): + def test_der_to_pem_roundtrip(self): + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048) + cert_pem = _make_self_signed_cert(key, "der-test") + cert = x509.load_pem_x509_certificate( + cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + + pem_out = _der_to_pem(cert_der) + self.assertIn("-----BEGIN CERTIFICATE-----", pem_out) + self.assertIn("-----END CERTIFICATE-----", pem_out) + + # Verify the PEM round-trips back to same DER + cert2 = x509.load_pem_x509_certificate( + pem_out.encode("utf-8"), default_backend()) + self.assertEqual( + cert2.public_bytes(serialization.Encoding.DER), cert_der) + + +# --------------------------------------------------------------------------- +# Tests with real ManagedIdentity objects (not plain dicts) +# --------------------------------------------------------------------------- + +class TestRealManagedIdentityObjects(unittest.TestCase): + """Verify that helpers work with actual ManagedIdentity (UserDict) instances, + not just plain dicts — fixing the isinstance(dict) bug.""" + + def test_mi_query_params_system_assigned_obj(self): + mi = msal.SystemAssignedManagedIdentity() + p = _mi_query_params(mi) + self.assertEqual(p["cred-api-version"], "2.0") + self.assertNotIn("client_id", p) + self.assertNotIn("object_id", p) + self.assertNotIn("msi_res_id", p) + + def test_mi_query_params_client_id_obj(self): + mi = msal.UserAssignedManagedIdentity(client_id="abc-123") + p = _mi_query_params(mi) + self.assertEqual(p["client_id"], "abc-123") + self.assertNotIn("object_id", p) + self.assertNotIn("msi_res_id", p) + + def test_mi_query_params_object_id_obj(self): + mi = msal.UserAssignedManagedIdentity(object_id="oid-456") + p = _mi_query_params(mi) + self.assertEqual(p["object_id"], "oid-456") + self.assertNotIn("client_id", p) + self.assertNotIn("msi_res_id", p) + + def test_mi_query_params_resource_id_obj(self): + mi = msal.UserAssignedManagedIdentity( + resource_id="/subscriptions/sub/resourceGroups/rg/providers/...") + p = _mi_query_params(mi) + self.assertEqual( + p["msi_res_id"], + "/subscriptions/sub/resourceGroups/rg/providers/...") + self.assertNotIn("client_id", p) + self.assertNotIn("object_id", p) + + def test_cert_cache_key_system_assigned_obj(self): + mi = msal.SystemAssignedManagedIdentity() + key = _cert_cache_key(mi, attested=True) + self.assertIn("SYSTEM_ASSIGNED", key) + self.assertIn("#att=1", key) + + def test_cert_cache_key_client_id_obj(self): + mi = msal.UserAssignedManagedIdentity(client_id="abc-123") + key = _cert_cache_key(mi, attested=False) + self.assertIn("ClientId", key) + self.assertIn("abc-123", key) + self.assertIn("#att=0", key) + + def test_cert_cache_key_object_id_obj(self): + mi = msal.UserAssignedManagedIdentity(object_id="oid-456") + key = _cert_cache_key(mi, attested=True) + self.assertIn("ObjectId", key) + self.assertIn("oid-456", key) + self.assertIn("#att=1", key) + + def test_cert_cache_key_attested_vs_nonattested_differ(self): + mi = msal.SystemAssignedManagedIdentity() + key_att = _cert_cache_key(mi, attested=True) + key_noatt = _cert_cache_key(mi, attested=False) + self.assertNotEqual(key_att, key_noatt) + + +if __name__ == "__main__": + unittest.main()