diff --git a/core/src/stackit/core/auth_methods/key_auth.py b/core/src/stackit/core/auth_methods/key_auth.py index 575e76ac..8cd13547 100644 --- a/core/src/stackit/core/auth_methods/key_auth.py +++ b/core/src/stackit/core/auth_methods/key_auth.py @@ -42,6 +42,7 @@ class KeyAuth(AuthBase): DEFAULT_TOKEN_ENDPOINT = "https://service-account.api.stackit.cloud/token" # noqa S105 false positive TOKEN_EXPIRY_CHECK_INTERVAL = timedelta(seconds=60) EXPIRATION_LEEWAY = timedelta(minutes=5) + MAX_REFRESH_RETRIES = 3 timeout: Optional[int] = 30 initial_token: Optional[str] @@ -73,6 +74,8 @@ def __init__( def __call__(self, r: Request) -> Request: with self.lock: + if self.refresh_future is not None and self.refresh_future.done(): + self.refresh_future.result() if self.__is_token_expired(self.access_token): if self.refresh_future is None or self.refresh_future.done(): self.refresh_future = self.executor.submit(self.__refresh_token) @@ -108,13 +111,15 @@ def __fetch_token_from_endpoint(self) -> None: self.access_token = response_json["access_token"] self.refresh_token = response_json["refresh_token"] except requests.RequestException as e: - print(f"Initial token fetch failed: {e}") + raise requests.RequestException("Initial token fetch failed") from e def __start_token_refresh_task(self): def token_refresh_task(): while True: time.sleep(self.TOKEN_EXPIRY_CHECK_INTERVAL.total_seconds()) with self.lock: + if self.refresh_future is not None and self.refresh_future.done(): + self.refresh_future.result() if self.__is_token_expired(self.access_token) and ( self.refresh_future is None or self.refresh_future.done() ): @@ -135,16 +140,19 @@ def __refresh_token(self): "refresh_token": self.refresh_token, } - try: - response = requests.post(self.token_endpoint, data=body, timeout=self.timeout) - response.raise_for_status() - response_data = response.json() - new_token = response_data.get("access_token") - # with self.lock: - self.access_token = new_token - print("Token successfully refreshed!") - except requests.RequestException as e: - print(f"Token refresh failed: {e}") + last_exception = None + for _ in range(self.MAX_REFRESH_RETRIES): + try: + response = requests.post(self.token_endpoint, data=body, timeout=self.timeout) + response.raise_for_status() + response_data = response.json() + new_token = response_data.get("access_token") + self.access_token = new_token + return + except requests.RequestException as e: + last_exception = e + + raise requests.RequestException("Token refresh failed after retries") from last_exception def __is_token_expired(self, token: str) -> bool: try: diff --git a/core/tests/core/test_auth.py b/core/tests/core/test_auth.py index e6866f71..6557bdcd 100644 --- a/core/tests/core/test_auth.py +++ b/core/tests/core/test_auth.py @@ -2,11 +2,13 @@ import pytest import json +import jwt +import requests from unittest.mock import patch, mock_open, Mock from requests.auth import HTTPBasicAuth -from stackit.core.auth_methods.key_auth import KeyAuth +from stackit.core.auth_methods.key_auth import KeyAuth, ServiceAccountKey from stackit.core.auth_methods.token_auth import TokenAuth from stackit.core.authorization import Authorization from stackit.core.configuration import Configuration @@ -262,3 +264,32 @@ def test_service_account_keyfile_not_found_raises_exception(self): config = Configuration(service_account_key_path="/non/existent/path/to/file") with pytest.raises(FileNotFoundError): Authorization(config) + + def test_token_refresh_fails_after_retries(self, service_account_key_file_json): + service_account_key = ServiceAccountKey.model_validate_json(service_account_key_file_json) + service_account_key.credentials.private_key = "test-private-key" + + def set_initial_token(auth): + auth.initial_token = "test-initial-token" + + with patch.object(KeyAuth, "_KeyAuth__create_initial_token", new=set_initial_token), patch.object( + KeyAuth, "_KeyAuth__start_token_refresh_task", return_value=None + ), patch("requests.post") as mock_post: + init_response = Mock() + init_response.raise_for_status.return_value = None + init_response.json.return_value = { + "access_token": jwt.encode({"exp": 4102444800}, "secret", algorithm="HS256"), + "refresh_token": jwt.encode({"exp": 4102444800}, "secret", algorithm="HS256"), + } + mock_post.return_value = init_response + + auth = KeyAuth(service_account_key) + auth.refresh_token = jwt.encode({"exp": 4102444800}, "secret", algorithm="HS256") + + mock_post.reset_mock() + mock_post.side_effect = requests.RequestException("refresh failed") + + with pytest.raises(requests.RequestException): + auth._KeyAuth__refresh_token() + + assert mock_post.call_count == KeyAuth.MAX_REFRESH_RETRIES