|
12 | 12 | from collections.abc import AsyncGenerator, Awaitable, Callable |
13 | 13 | from dataclasses import dataclass, field |
14 | 14 | from typing import Any, Protocol |
15 | | -from urllib.parse import quote, urlencode, urljoin, urlparse |
| 15 | +from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse |
16 | 16 |
|
17 | 17 | import anyio |
18 | 18 | import httpx |
@@ -69,6 +69,85 @@ def generate(cls) -> "PKCEParameters": |
69 | 69 | return cls(code_verifier=code_verifier, code_challenge=code_challenge) |
70 | 70 |
|
71 | 71 |
|
| 72 | +@dataclass(frozen=True) |
| 73 | +class OAuthAuthorizationRedirect: |
| 74 | + """Resumable OAuth authorization redirect state. |
| 75 | +
|
| 76 | + Proxy and server-side callers can persist this value, send the authorization |
| 77 | + URL to a user, and later resume token exchange with the returned code plus |
| 78 | + the stored state and code verifier. |
| 79 | + """ |
| 80 | + |
| 81 | + authorization_url: str |
| 82 | + state: str |
| 83 | + code_verifier: str = field(repr=False) |
| 84 | + |
| 85 | + |
| 86 | +def build_authorization_redirect( |
| 87 | + *, |
| 88 | + authorization_endpoint: str, |
| 89 | + client_info: OAuthClientInformationFull, |
| 90 | + client_metadata: OAuthClientMetadata, |
| 91 | + pkce_params: PKCEParameters | None = None, |
| 92 | + state: str | None = None, |
| 93 | + resource_url: str | None = None, |
| 94 | +) -> OAuthAuthorizationRedirect: |
| 95 | + """Build an OAuth authorization URL and resumable state. |
| 96 | +
|
| 97 | + Args: |
| 98 | + authorization_endpoint: Authorization endpoint URL. |
| 99 | + client_info: Registered OAuth client information. |
| 100 | + client_metadata: Client metadata containing redirect URIs and scopes. |
| 101 | + pkce_params: Optional PKCE parameters. Generated when omitted. |
| 102 | + state: Optional OAuth state value. Generated when omitted. |
| 103 | + resource_url: Optional RFC 8707 resource value. |
| 104 | +
|
| 105 | + Returns: |
| 106 | + Authorization URL plus the state and code verifier needed to resume. |
| 107 | +
|
| 108 | + Raises: |
| 109 | + OAuthFlowError: If no client ID or redirect URI is available. |
| 110 | + """ |
| 111 | + if client_info.client_id is None: |
| 112 | + raise OAuthFlowError("No client ID provided for authorization code grant") |
| 113 | + |
| 114 | + if client_metadata.redirect_uris is None: |
| 115 | + raise OAuthFlowError("No redirect URIs provided for authorization code grant") |
| 116 | + |
| 117 | + pkce_params = pkce_params or PKCEParameters.generate() |
| 118 | + state = state or secrets.token_urlsafe(32) |
| 119 | + |
| 120 | + auth_params = { |
| 121 | + "response_type": "code", |
| 122 | + "client_id": client_info.client_id, |
| 123 | + "redirect_uri": str(client_metadata.redirect_uris[0]), |
| 124 | + "state": state, |
| 125 | + "code_challenge": pkce_params.code_challenge, |
| 126 | + "code_challenge_method": "S256", |
| 127 | + } |
| 128 | + |
| 129 | + if resource_url: |
| 130 | + auth_params["resource"] = resource_url |
| 131 | + |
| 132 | + if client_metadata.scope: |
| 133 | + auth_params["scope"] = client_metadata.scope |
| 134 | + |
| 135 | + # OIDC requires prompt=consent when offline_access is requested |
| 136 | + # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess |
| 137 | + if "offline_access" in client_metadata.scope.split(): |
| 138 | + auth_params["prompt"] = "consent" |
| 139 | + |
| 140 | + parsed_endpoint = urlparse(authorization_endpoint) |
| 141 | + query_params = parse_qsl(parsed_endpoint.query, keep_blank_values=True) |
| 142 | + query_params.extend(auth_params.items()) |
| 143 | + authorization_url = urlunparse(parsed_endpoint._replace(query=urlencode(query_params))) |
| 144 | + return OAuthAuthorizationRedirect( |
| 145 | + authorization_url=authorization_url, |
| 146 | + state=state, |
| 147 | + code_verifier=pkce_params.code_verifier, |
| 148 | + ) |
| 149 | + |
| 150 | + |
72 | 151 | class TokenStorage(Protocol): |
73 | 152 | """Protocol for token storage implementations.""" |
74 | 153 |
|
@@ -327,45 +406,29 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: |
327 | 406 | if not self.context.client_info: |
328 | 407 | raise OAuthFlowError("No client info available for authorization") # pragma: no cover |
329 | 408 |
|
330 | | - # Generate PKCE parameters |
331 | | - pkce_params = PKCEParameters.generate() |
332 | | - state = secrets.token_urlsafe(32) |
333 | | - |
334 | | - auth_params = { |
335 | | - "response_type": "code", |
336 | | - "client_id": self.context.client_info.client_id, |
337 | | - "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), |
338 | | - "state": state, |
339 | | - "code_challenge": pkce_params.code_challenge, |
340 | | - "code_challenge_method": "S256", |
341 | | - } |
342 | | - |
343 | | - # Only include resource param if conditions are met |
| 409 | + resource_url = None |
344 | 410 | if self.context.should_include_resource_param(self.context.protocol_version): |
345 | | - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 |
| 411 | + resource_url = self.context.get_resource_url() # RFC 8707 |
346 | 412 |
|
347 | | - if self.context.client_metadata.scope: # pragma: no branch |
348 | | - auth_params["scope"] = self.context.client_metadata.scope |
349 | | - |
350 | | - # OIDC requires prompt=consent when offline_access is requested |
351 | | - # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess |
352 | | - if "offline_access" in self.context.client_metadata.scope.split(): |
353 | | - auth_params["prompt"] = "consent" |
354 | | - |
355 | | - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" |
356 | | - await self.context.redirect_handler(authorization_url) |
| 413 | + redirect = build_authorization_redirect( |
| 414 | + authorization_endpoint=auth_endpoint, |
| 415 | + client_info=self.context.client_info, |
| 416 | + client_metadata=self.context.client_metadata, |
| 417 | + resource_url=resource_url, |
| 418 | + ) |
| 419 | + await self.context.redirect_handler(redirect.authorization_url) |
357 | 420 |
|
358 | 421 | # Wait for callback |
359 | 422 | auth_code, returned_state = await self.context.callback_handler() |
360 | 423 |
|
361 | | - if returned_state is None or not secrets.compare_digest(returned_state, state): |
362 | | - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") |
| 424 | + if returned_state is None or not secrets.compare_digest(returned_state, redirect.state): |
| 425 | + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {redirect.state}") |
363 | 426 |
|
364 | 427 | if not auth_code: |
365 | 428 | raise OAuthFlowError("No authorization code received") |
366 | 429 |
|
367 | 430 | # Return auth code and code verifier for token exchange |
368 | | - return auth_code, pkce_params.code_verifier |
| 431 | + return auth_code, redirect.code_verifier |
369 | 432 |
|
370 | 433 | def _get_token_endpoint(self) -> str: |
371 | 434 | if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: |
|
0 commit comments