diff --git a/src/databricks/sql/auth/auth_utils.py b/src/databricks/sql/auth/auth_utils.py index 439aabc51..a21ce843b 100644 --- a/src/databricks/sql/auth/auth_utils.py +++ b/src/databricks/sql/auth/auth_utils.py @@ -7,23 +7,6 @@ logger = logging.getLogger(__name__) -def parse_hostname(hostname: str) -> str: - """ - Normalize the hostname to include scheme and trailing slash. - - Args: - hostname: The hostname to normalize - - Returns: - Normalized hostname with scheme and trailing slash - """ - if not hostname.startswith("http://") and not hostname.startswith("https://"): - hostname = f"https://{hostname}" - if not hostname.endswith("/"): - hostname = f"{hostname}/" - return hostname - - def decode_token(access_token: str) -> Optional[Dict]: """ Decode a JWT token without verification to extract claims. diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 7b62f6762..f75b904fb 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -6,10 +6,10 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.auth.auth_utils import ( - parse_hostname, decode_token, is_same_host, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -99,7 +99,7 @@ def __init__( if not http_client: raise ValueError("http_client is required for TokenFederationProvider") - self.hostname = parse_hostname(hostname) + self.hostname = normalize_host_with_protocol(hostname) self.external_provider = external_provider self.http_client = http_client self.identity_federation_client_id = identity_federation_client_id @@ -164,7 +164,7 @@ def _should_exchange_token(self, access_token: str) -> bool: def _exchange_token(self, access_token: str) -> Token: """Exchange the external token for a Databricks token.""" - token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}" + token_url = f"{self.hostname}{self.TOKEN_EXCHANGE_ENDPOINT}" data = { "grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE, diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index b47f2add2..caefe9929 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -18,6 +18,7 @@ from databricks.sql.common.http_utils import ( detect_and_parse_proxy, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol logger = logging.getLogger(__name__) @@ -66,8 +67,9 @@ def __init__( self.auth_provider = auth_provider self.ssl_options = ssl_options - # Build base URL - self.base_url = f"https://{server_hostname}:{self.port}" + # Build base URL using url_utils for consistent normalization + normalized_host = normalize_host_with_protocol(server_hostname) + self.base_url = f"{normalized_host}:{self.port}" # Parse URL for proxy handling parsed_url = urllib.parse.urlparse(self.base_url) diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 032701f63..36e4b8a02 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -6,6 +6,7 @@ from typing import Dict, Optional, List, Any, TYPE_CHECKING from databricks.sql.common.http import HttpMethod +from databricks.sql.common.url_utils import normalize_host_with_protocol if TYPE_CHECKING: from databricks.sql.client import Connection @@ -67,7 +68,8 @@ def __init__( endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__) self._feature_flag_endpoint = ( - f"https://{self._connection.session.host}{endpoint_suffix}" + normalize_host_with_protocol(self._connection.session.host) + + endpoint_suffix ) # Use the provided HTTP client diff --git a/src/databricks/sql/common/url_utils.py b/src/databricks/sql/common/url_utils.py new file mode 100644 index 000000000..1c4d10369 --- /dev/null +++ b/src/databricks/sql/common/url_utils.py @@ -0,0 +1,45 @@ +""" +URL utility functions for the Databricks SQL connector. +""" + + +def normalize_host_with_protocol(host: str) -> str: + """ + Normalize a connection hostname by ensuring it has a protocol. + + This is useful for handling cases where users may provide hostnames with or without protocols + (common with dbt-databricks users copying URLs from their browser). + + Args: + host: Connection hostname which may or may not include a protocol prefix (https:// or http://) + and may or may not have a trailing slash + + Returns: + Normalized hostname with protocol prefix and no trailing slashes + + Examples: + normalize_host_with_protocol("myserver.com") -> "https://myserver.com" + normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com" + normalize_host_with_protocol("HTTPS://myserver.com/") -> "https://myserver.com" + normalize_host_with_protocol("http://localhost:8080/") -> "http://localhost:8080" + + Raises: + ValueError: If host is None or empty string + """ + # Handle None or empty host + if not host or not host.strip(): + raise ValueError("Host cannot be None or empty") + + # Remove trailing slashes + host = host.rstrip("/") + + # Add protocol if not present (case-insensitive check) + host_lower = host.lower() + if not host_lower.startswith("https://") and not host_lower.startswith("http://"): + host = f"https://{host}" + elif host_lower.startswith("https://") or host_lower.startswith("http://"): + # Normalize protocol to lowercase + protocol_end = host.index("://") + 3 + host = host[:protocol_end].lower() + host[protocol_end:] + + return host diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 77d1a2f9c..bc1626a3d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -47,6 +47,7 @@ TelemetryPushClient, CircuitBreakerTelemetryPushClient, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol if TYPE_CHECKING: from databricks.sql.client import Connection @@ -278,7 +279,7 @@ def _send_telemetry(self, events): if self._auth_provider else self.TELEMETRY_UNAUTHENTICATED_PATH ) - url = f"https://{self._host_url}{path}" + url = normalize_host_with_protocol(self._host_url) + path headers = {"Accept": "application/json", "Content-Type": "application/json"} diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py index 45c494d19..f9df0f377 100644 --- a/tests/e2e/test_circuit_breaker.py +++ b/tests/e2e/test_circuit_breaker.py @@ -23,6 +23,34 @@ from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager +def wait_for_circuit_state(circuit_breaker, expected_states, timeout=5): + """ + Wait for circuit breaker to reach one of the expected states with polling. + + Args: + circuit_breaker: The circuit breaker instance to monitor + expected_states: List of acceptable states + (STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN) + timeout: Maximum time to wait in seconds + + Returns: + True if state reached, False if timeout + + Examples: + # Single state - pass list with one element + wait_for_circuit_state(cb, [STATE_OPEN]) + + # Multiple states + wait_for_circuit_state(cb, [STATE_CLOSED, STATE_HALF_OPEN]) + """ + start = time.time() + while time.time() - start < timeout: + if circuit_breaker.current_state in expected_states: + return True + time.sleep(0.1) # Poll every 100ms + return False + + @pytest.fixture(autouse=True) def aggressive_circuit_breaker_config(): """ @@ -65,12 +93,17 @@ def create_mock_response(self, status_code): }.get(status_code, b"Response") return response - @pytest.mark.parametrize("status_code,should_trigger", [ - (429, True), - (503, True), - (500, False), - ]) - def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): + @pytest.mark.parametrize( + "status_code,should_trigger", + [ + (429, True), + (503, True), + (500, False), + ], + ) + def test_circuit_breaker_triggers_for_rate_limit_codes( + self, status_code, should_trigger + ): """ Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). """ @@ -107,9 +140,14 @@ def mock_request(*args, **kwargs): time.sleep(0.5) if should_trigger: - # Circuit should be OPEN after 2 rate-limit failures + # Wait for circuit to open (async telemetry may take time) + assert wait_for_circuit_state( + circuit_breaker, [STATE_OPEN], timeout=5 + ), f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}" + + # Circuit should be OPEN after rate-limit failures assert circuit_breaker.current_state == STATE_OPEN - assert circuit_breaker.fail_counter == 2 + assert circuit_breaker.fail_counter >= 2 # At least 2 failures # Track requests before another query requests_before = request_count["count"] @@ -197,7 +235,10 @@ def mock_conditional_request(*args, **kwargs): cursor.fetchone() time.sleep(2) - assert circuit_breaker.current_state == STATE_OPEN + # Wait for circuit to open + assert wait_for_circuit_state( + circuit_breaker, [STATE_OPEN], timeout=5 + ), f"Circuit didn't open, state: {circuit_breaker.current_state}" # Wait for reset timeout (5 seconds in test) time.sleep(6) @@ -208,24 +249,20 @@ def mock_conditional_request(*args, **kwargs): # Execute query to trigger HALF_OPEN state cursor.execute("SELECT 3") cursor.fetchone() - time.sleep(1) - # Circuit should be recovering - assert circuit_breaker.current_state in [ - STATE_HALF_OPEN, - STATE_CLOSED, - ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + # Wait for circuit to start recovering + assert wait_for_circuit_state( + circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5 + ), f"Circuit didn't recover, state: {circuit_breaker.current_state}" # Execute more queries to fully recover cursor.execute("SELECT 4") cursor.fetchone() - time.sleep(1) - current_state = circuit_breaker.current_state - assert current_state in [ - STATE_CLOSED, - STATE_HALF_OPEN, - ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + # Wait for full recovery + assert wait_for_circuit_state( + circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5 + ), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}" if __name__ == "__main__": diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 8f8a97eae..5b6991931 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase): "access_token": "tok", } + def _setup_mock_session_with_http_client(self, mock_session): + """ + Helper to configure a mock session with HTTP client mocks. + This prevents feature flag network requests during Connection initialization. + """ + mock_session.host = "foo" + + # Mock HTTP client to prevent feature flag network requests + mock_http_client = Mock() + mock_session.http_client = mock_http_client + + # Mock feature flag response to prevent blocking HTTP calls + mock_ff_response = Mock() + mock_ff_response.status = 200 + mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}' + mock_http_client.request.return_value = mock_ff_response + def _create_mock_connection(self, mock_session_class): """Helper to create a mocked connection for transaction tests.""" - # Mock session mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" mock_session.get_autocommit.return_value = True + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=False to test actual transaction functionality @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): conn = self._create_mock_connection(mock_session_class) mock_cursor = Mock() - original_error = DatabaseError( - "Original error", host_url="test-host" - ) + original_error = DatabaseError("Original error", host_url="test-host") mock_cursor.execute.side_effect = original_error with patch.object(conn, "cursor", return_value=mock_cursor): @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session conn = client.Connection( @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class): mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true( mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true( mock_session = Mock() mock_session.is_open = True mock_session.guid_hex = "test-session-id" + + self._setup_mock_session_with_http_client(mock_session) mock_session_class.return_value = mock_session # Create connection with ignore_transactions=True (default) diff --git a/tests/unit/test_sea_http_client.py b/tests/unit/test_sea_http_client.py index 39ecb58a7..5100a5cb0 100644 --- a/tests/unit/test_sea_http_client.py +++ b/tests/unit/test_sea_http_client.py @@ -44,6 +44,40 @@ def sea_http_client(self, mock_auth_provider, ssl_options): client._pool = Mock() return client + @pytest.mark.parametrize( + "server_hostname,port,expected_base_url", + [ + # Basic hostname without protocol + ("myserver.com", 443, "https://myserver.com:443"), + # Hostname with trailing slash + ("myserver.com/", 443, "https://myserver.com:443"), + # Hostname with https:// protocol + ("https://myserver.com", 443, "https://myserver.com:443"), + # Hostname with http:// protocol (preserved as-is) + ("http://myserver.com", 443, "http://myserver.com:443"), + # Hostname with protocol and trailing slash + ("https://myserver.com/", 443, "https://myserver.com:443"), + # Custom port + ("myserver.com", 8080, "https://myserver.com:8080"), + # Protocol with custom port + ("https://myserver.com", 8080, "https://myserver.com:8080"), + ], + ) + def test_base_url_construction( + self, server_hostname, port, expected_base_url, mock_auth_provider, ssl_options + ): + """Test that base_url is constructed correctly from various hostname inputs.""" + with patch("databricks.sql.backend.sea.utils.http_client.HTTPSConnectionPool"): + client = SeaHttpClient( + server_hostname=server_hostname, + port=port, + http_path="/sql/1.0/warehouses/test", + http_headers=[], + auth_provider=mock_auth_provider, + ssl_options=ssl_options, + ) + assert client.base_url == expected_base_url + def test_get_command_type_from_path(self, sea_http_client): """Test the _get_command_type_from_path method with various paths and methods.""" # Test statement execution diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 2e671c33e..9c209e894 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -6,10 +6,10 @@ from databricks.sql.auth.token_federation import TokenFederationProvider, Token from databricks.sql.auth.auth_utils import ( - parse_hostname, decode_token, is_same_host, ) +from databricks.sql.common.url_utils import normalize_host_with_protocol from databricks.sql.common.http import HttpMethod @@ -78,10 +78,10 @@ def test_init_requires_http_client(self, mock_external_provider): @pytest.mark.parametrize( "input_hostname,expected", [ - ("test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com/", "https://test.databricks.com/"), - ("test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com/", "https://test.databricks.com"), + ("test.databricks.com/", "https://test.databricks.com"), ], ) def test_hostname_normalization( @@ -305,15 +305,15 @@ class TestUtilityFunctions: @pytest.mark.parametrize( "input_hostname,expected", [ - ("test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com", "https://test.databricks.com/"), - ("https://test.databricks.com/", "https://test.databricks.com/"), - ("test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com", "https://test.databricks.com"), + ("https://test.databricks.com/", "https://test.databricks.com"), + ("test.databricks.com/", "https://test.databricks.com"), ], ) - def test_parse_hostname(self, input_hostname, expected): - """Test hostname parsing.""" - assert parse_hostname(input_hostname) == expected + def test_normalize_hostname(self, input_hostname, expected): + """Test hostname normalization.""" + assert normalize_host_with_protocol(input_hostname) == expected @pytest.mark.parametrize( "url1,url2,expected", diff --git a/tests/unit/test_url_utils.py b/tests/unit/test_url_utils.py new file mode 100644 index 000000000..95f42408d --- /dev/null +++ b/tests/unit/test_url_utils.py @@ -0,0 +1,41 @@ +"""Tests for URL utility functions.""" +import pytest +from databricks.sql.common.url_utils import normalize_host_with_protocol + + +class TestNormalizeHostWithProtocol: + """Tests for normalize_host_with_protocol function.""" + + @pytest.mark.parametrize( + "input_url,expected_output", + [ + ("myserver.com", "https://myserver.com"), # Add https:// + ("https://myserver.com", "https://myserver.com"), # No duplicate + ("http://localhost:8080", "http://localhost:8080"), # Preserve http:// + ("myserver.com:443", "https://myserver.com:443"), # With port + ("myserver.com/", "https://myserver.com"), # Remove trailing slash + ("https://myserver.com///", "https://myserver.com"), # Multiple slashes + ("HTTPS://MyServer.COM", "https://MyServer.COM"), # Case handling + ], + ) + def test_normalize_host_with_protocol(self, input_url, expected_output): + """Test host normalization with various input formats.""" + result = normalize_host_with_protocol(input_url) + assert result == expected_output + + # Additional assertions + assert result.startswith("https://") or result.startswith("http://") + assert not result.endswith("/") + + @pytest.mark.parametrize( + "invalid_host", + [ + None, + "", + " ", # Whitespace only + ], + ) + def test_normalize_host_with_protocol_raises_on_invalid_input(self, invalid_host): + """Test that function raises ValueError for None or empty host.""" + with pytest.raises(ValueError, match="Host cannot be None or empty"): + normalize_host_with_protocol(invalid_host)