-
Notifications
You must be signed in to change notification settings - Fork 4
feat(sql_execution): Implement retry logic for userpod API #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
39baf42
aa99cc3
01be651
790be77
1f804b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -592,9 +592,9 @@ def test_all_dataframes_serialize_to_parquet(self, key, df): | |
| class TestFederatedAuth(unittest.TestCase): | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.requests.post") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") | ||
| def test_get_federated_auth_credentials_returns_validated_response( | ||
| self, mock_post, mock_get_url, mock_get_headers | ||
| self, mock_create_session, mock_get_url, mock_get_headers | ||
| ): | ||
| """Test that _get_federated_auth_credentials properly validates and returns response data.""" | ||
| from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials | ||
|
|
@@ -603,12 +603,14 @@ def test_get_federated_auth_credentials_returns_validated_response( | |
| mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-integration-id" | ||
| mock_get_headers.return_value = {"Authorization": "Bearer project-token"} | ||
|
|
||
| mock_session = mock.Mock() | ||
| mock_response = mock.Mock() | ||
| mock_response.json.return_value = { | ||
| "integrationType": "trino", | ||
| "accessToken": "test-access-token-123", | ||
| } | ||
| mock_post.return_value = mock_response | ||
| mock_session.post.return_value = mock_response | ||
| mock_create_session.return_value = mock_session | ||
|
|
||
| # Call the function | ||
| result = _get_federated_auth_credentials( | ||
|
|
@@ -621,7 +623,7 @@ def test_get_federated_auth_credentials_returns_validated_response( | |
| ) | ||
|
|
||
| # Verify headers include both project auth and user pod auth context token | ||
| mock_post.assert_called_once_with( | ||
| mock_session.post.assert_called_once_with( | ||
| "https://api.example.com/integrations/federated-auth-token/test-integration-id", | ||
| timeout=10, | ||
| headers={ | ||
|
|
@@ -1019,3 +1021,241 @@ def test_databricks_connector_dialect_alias_is_registered(self): | |
|
|
||
| self.assertEqual(url.drivername, "databricks+connector") | ||
| self.assertIsNotNone(dialect_cls) | ||
|
|
||
|
|
||
| class TestCreateRetrySession(unittest.TestCase): | ||
| """Tests that exercise the real urllib3 retry loop by mocking at the | ||
| connection level (``HTTPConnectionPool._make_request``) rather than | ||
| replacing ``_create_retry_session``. This lets the ``Retry`` adapter | ||
| actually fire retries on 5xx responses. | ||
| """ | ||
|
|
||
| def test_create_retry_session_configuration(self): | ||
| """Verify the retry session is wired with the expected parameters.""" | ||
| from deepnote_toolkit.sql.sql_execution import _create_retry_session | ||
|
|
||
| session = _create_retry_session() | ||
|
|
||
| for prefix in ("http://", "https://"): | ||
| adapter = session.get_adapter(f"{prefix}example.com") | ||
| retry = adapter.max_retries | ||
|
|
||
| self.assertEqual(retry.total, 3) | ||
| self.assertEqual(retry.backoff_factor, 0.5) | ||
| self.assertEqual(set(retry.status_forcelist), {500, 502, 503, 504}) | ||
| self.assertIn("POST", retry.allowed_methods) | ||
|
|
||
| # -- _generate_temporary_credentials ------------------------------------ | ||
|
|
||
| @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) | ||
| @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") | ||
| def test_generate_credentials_retries_on_5xx_then_succeeds( | ||
| self, | ||
| mock_get_url, | ||
| mock_get_headers, | ||
| mock_make_request, | ||
| mock_retry_sleep, | ||
| ): | ||
| """Two 5xx failures followed by a 200 - the retry loop must | ||
| transparently retry and ultimately return valid credentials.""" | ||
| from urllib3 import HTTPResponse as Urllib3Response | ||
|
|
||
| from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials | ||
|
|
||
| mock_get_url.return_value = ( | ||
| "https://api.example.com/integrations/credentials/test-id" | ||
| ) | ||
| mock_get_headers.return_value = {"Authorization": "Bearer token"} | ||
|
|
||
| success_body = json.dumps({"username": "user", "password": "pass"}).encode() | ||
| mock_make_request.side_effect = [ | ||
| Urllib3Response( | ||
| body=io.BytesIO(b"Internal Server Error"), | ||
| status=500, | ||
| headers={}, | ||
| preload_content=False, | ||
| ), | ||
| Urllib3Response( | ||
| body=io.BytesIO(b"Bad Gateway"), | ||
| status=502, | ||
| headers={}, | ||
| preload_content=False, | ||
| ), | ||
| Urllib3Response( | ||
| body=io.BytesIO(success_body), | ||
| status=200, | ||
| headers={"Content-Type": "application/json"}, | ||
| preload_content=False, | ||
| ), | ||
| ] | ||
|
|
||
| result = _generate_temporary_credentials("test-id") | ||
|
|
||
| self.assertEqual(result, ("user", "pass")) | ||
| self.assertEqual(mock_make_request.call_count, 3) | ||
| self.assertEqual(mock_retry_sleep.call_count, 2) | ||
|
|
||
| @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) | ||
| @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") | ||
| def test_generate_credentials_exhausts_retries_on_persistent_5xx( | ||
| self, | ||
| mock_get_url, | ||
| mock_get_headers, | ||
| mock_make_request, | ||
| mock_retry_sleep, | ||
| ): | ||
| """All 4 attempts (1 original + 3 retries) return 500 - | ||
| must raise ``RetryError``.""" | ||
| import requests | ||
| from urllib3 import HTTPResponse as Urllib3Response | ||
|
|
||
| from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials | ||
|
|
||
| mock_get_url.return_value = ( | ||
| "https://api.example.com/integrations/credentials/test-id" | ||
| ) | ||
| mock_get_headers.return_value = {"Authorization": "Bearer token"} | ||
|
|
||
| mock_make_request.side_effect = [ | ||
| Urllib3Response( | ||
| body=io.BytesIO(b"Server Error"), | ||
| status=500, | ||
| headers={}, | ||
| preload_content=False, | ||
| ) | ||
| for _ in range(4) | ||
| ] | ||
|
|
||
| with self.assertRaises(requests.exceptions.RetryError): | ||
| _generate_temporary_credentials("test-id") | ||
|
|
||
| self.assertEqual(mock_make_request.call_count, 4) | ||
| self.assertEqual(mock_retry_sleep.call_count, 3) | ||
|
|
||
| @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) | ||
| @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") | ||
| def test_generate_credentials_no_retry_on_4xx( | ||
| self, | ||
| mock_get_url, | ||
| mock_get_headers, | ||
| mock_make_request, | ||
| mock_retry_sleep, | ||
| ): | ||
| """A 400 is not in the retry status list - must fail immediately | ||
| without retrying.""" | ||
| import requests | ||
| from urllib3 import HTTPResponse as Urllib3Response | ||
|
|
||
| from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials | ||
|
|
||
| mock_get_url.return_value = ( | ||
| "https://api.example.com/integrations/credentials/test-id" | ||
| ) | ||
| mock_get_headers.return_value = {"Authorization": "Bearer token"} | ||
|
|
||
| mock_make_request.side_effect = [ | ||
| Urllib3Response( | ||
| body=io.BytesIO(b"Bad Request"), | ||
| status=400, | ||
| headers={}, | ||
| preload_content=False, | ||
| ), | ||
| ] | ||
|
|
||
| with self.assertRaises(requests.exceptions.HTTPError): | ||
| _generate_temporary_credentials("test-id") | ||
|
|
||
| self.assertEqual(mock_make_request.call_count, 1) | ||
| mock_retry_sleep.assert_not_called() | ||
|
|
||
| # -- _get_federated_auth_credentials ------------------------------------ | ||
|
|
||
| @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) | ||
| @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") | ||
| def test_federated_auth_retries_on_5xx_then_succeeds( | ||
| self, | ||
| mock_get_url, | ||
| mock_get_headers, | ||
| mock_make_request, | ||
| mock_retry_sleep, | ||
| ): | ||
| """A 503 followed by a 200 - retry loop must recover and return | ||
| valid ``FederatedAuthResponseData``.""" | ||
| from urllib3 import HTTPResponse as Urllib3Response | ||
|
|
||
| from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials | ||
|
|
||
| mock_get_url.return_value = ( | ||
| "https://api.example.com/integrations/federated-auth-token/test-id" | ||
| ) | ||
| mock_get_headers.return_value = {"Authorization": "Bearer token"} | ||
|
|
||
| success_body = json.dumps( | ||
| {"integrationType": "trino", "accessToken": "test-token"} | ||
| ).encode() | ||
| mock_make_request.side_effect = [ | ||
| Urllib3Response( | ||
| body=io.BytesIO(b"Service Unavailable"), | ||
| status=503, | ||
| headers={}, | ||
| preload_content=False, | ||
| ), | ||
| Urllib3Response( | ||
| body=io.BytesIO(success_body), | ||
| status=200, | ||
| headers={"Content-Type": "application/json"}, | ||
| preload_content=False, | ||
| ), | ||
| ] | ||
|
|
||
| result = _get_federated_auth_credentials("test-id", "auth-context-token") | ||
|
|
||
| self.assertEqual(result.integrationType, "trino") | ||
| self.assertEqual(result.accessToken, "test-token") | ||
| self.assertEqual(mock_make_request.call_count, 2) | ||
| self.assertEqual(mock_retry_sleep.call_count, 1) | ||
|
|
||
| @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) | ||
| @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") | ||
| @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") | ||
| def test_federated_auth_exhausts_retries_on_persistent_5xx( | ||
| self, | ||
| mock_get_url, | ||
| mock_get_headers, | ||
| mock_make_request, | ||
| mock_retry_sleep, | ||
| ): | ||
| """All 4 attempts return 504 - must raise ``RetryError``.""" | ||
| import requests | ||
| from urllib3 import HTTPResponse as Urllib3Response | ||
|
|
||
| from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials | ||
|
|
||
| mock_get_url.return_value = ( | ||
| "https://api.example.com/integrations/federated-auth-token/test-id" | ||
| ) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mock_get_headers.return_value = {"Authorization": "Bearer token"} | ||
|
|
||
| mock_make_request.side_effect = [ | ||
| Urllib3Response( | ||
| body=io.BytesIO(b"Gateway Timeout"), | ||
| status=504, | ||
| headers={}, | ||
| preload_content=False, | ||
| ) | ||
| for _ in range(4) | ||
| ] | ||
|
|
||
| with self.assertRaises(requests.exceptions.RetryError): | ||
| _get_federated_auth_credentials("test-id", "auth-context-token") | ||
|
|
||
| self.assertEqual(mock_make_request.call_count, 4) | ||
|
Comment on lines
+1226
to
+1261
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing assertion on
🔧 Proposed fix with self.assertRaises(requests.exceptions.RetryError):
_get_federated_auth_credentials("test-id", "auth-context-token")
self.assertEqual(mock_make_request.call_count, 4)
+ self.assertEqual(mock_retry_sleep.call_count, 3)🧰 Tools🪛 Ruff (0.15.9)[warning] 1235-1235: Unused method argument: (ARG002) [warning] 1258-1258: Use Replace (PT027) [warning] 1261-1261: Use a regular Replace (PT009) 🤖 Prompt for AI Agents |
||
Uh oh!
There was an error while loading. Please reload this page.