Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/code-quality-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down Expand Up @@ -82,6 +83,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/publish-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: "2.2.1"
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
Expand Down
53 changes: 53 additions & 0 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,59 @@ def do_rollback(self, dbapi_connection):
# Databricks SQL Does not support transactions
pass

def is_disconnect(self, e, connection, cursor):
"""Determine if an exception indicates the connection was lost.

This method is called by SQLAlchemy after exceptions occur during query
execution to determine if the error was due to a lost connection. If this
returns True, SQLAlchemy will invalidate the connection and create a new
one for the next operation.

This method is also used by SQLAlchemy's default do_ping() implementation
when pool_pre_ping=True. If do_ping() encounters an exception, it calls
is_disconnect() to classify the error and determine whether to invalidate
the connection.

Args:
e: The exception that was raised
connection: The connection that raised the exception (may be None)
cursor: The cursor that raised the exception (may be None)

Returns:
True if the error indicates a disconnect, False otherwise
"""
from databricks.sql.exc import (
Error,
InterfaceError,
DatabaseError,
RequestError,
)

error_msg = str(e).lower()

# InterfaceError: closed connection/cursor errors from client.py
# All raised when self.open is False:
if isinstance(e, InterfaceError):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InterfaceError can also be raised for programming errors like invalid params. Will this be an expected behavior then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the case with the current connector. All the interface errors are raised when connection is closed.

return "closed" in error_msg

# RequestError (subclass of DatabaseError via OperationalError):
# transport/network-level errors indicating connection is unusable.
# Check before DatabaseError since RequestError is a subclass.
if isinstance(e, RequestError):
return True

# DatabaseError: server-side errors indicating session/operation gone
if isinstance(e, DatabaseError):
return ("invalid" in error_msg and "handle" in error_msg) or (
"unexpectedly closed server side" in error_msg
)

# Base Error class: older connector versions raise Error (not InterfaceError)
if isinstance(e, Error):
return "closed connection" in error_msg or "closed cursor" in error_msg

return False

@reflection.cache
def has_table(
self, connection, table_name, schema=None, catalog=None, **kwargs
Expand Down
124 changes: 124 additions & 0 deletions tests/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,127 @@ def test_table_comment_reflection(self, inspector: Inspector, table: Table):
def test_column_comment(self, inspector: Inspector, table: Table):
result = inspector.get_columns(table.name)[0].get("comment")
assert result == "column comment"


def test_pool_pre_ping_with_closed_connection(connection_details):
"""Test that pool_pre_ping detects closed connections and creates new ones.

This test verifies that when a connection is manually closed (simulating
session expiration), pool_pre_ping detects it and automatically creates
a new connection without raising an error to the user.
"""
conn_string, connect_args = version_agnostic_connect_arguments(connection_details)

# Create engine with pool_pre_ping enabled
engine = create_engine(
conn_string,
connect_args=connect_args,
pool_pre_ping=True,
pool_size=1,
max_overflow=0
)

# Step 1: Create connection and get session ID
with engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()")).scalar()
assert result is not None

# Get session ID of first connection
raw_conn = conn.connection.dbapi_connection
session_id_1 = raw_conn.get_session_id_hex()
assert session_id_1 is not None

# Step 2: Manually close the connection to simulate expiration
pooled_conn = engine.pool._pool.queue[0]
pooled_conn.driver_connection.close()

# Verify connection is closed
assert not pooled_conn.driver_connection.open

# Step 3: Try to use the closed connection - pool_pre_ping should detect and recycle
with engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()")).scalar()
assert result is not None

# Get session ID of new connection
raw_conn = conn.connection.dbapi_connection
session_id_2 = raw_conn.get_session_id_hex()
assert session_id_2 is not None

# Verify a NEW connection was created (different session ID)
assert session_id_1 != session_id_2, (
"pool_pre_ping should have detected the closed connection "
"and created a new one with a different session ID"
)

# Cleanup
engine.dispose()


def test_is_disconnect_handles_runtime_errors(db_engine):
"""Test that is_disconnect() properly classifies disconnect errors during query execution.

When a connection fails DURING a query, is_disconnect() should recognize the error
and tell SQLAlchemy to invalidate the connection so the next query gets a fresh one.
"""
from sqlalchemy import create_engine, text
from sqlalchemy.exc import DBAPIError

engine = create_engine(
db_engine.url,
pool_pre_ping=False, # Disabled - we want to test is_disconnect, not do_ping
pool_size=1,
max_overflow=0,
)

# Step 1: Execute a successful query
with engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()")).scalar()
assert result is not None

# Get session ID of working connection
raw_conn = conn.connection.dbapi_connection
session_id_1 = raw_conn.get_session_id_hex()
assert session_id_1 is not None

# Step 2: Manually close the connection to simulate server-side session expiration
pooled_conn = engine.pool._pool.queue[0]
pooled_conn.driver_connection.close()

# Step 3: Try to execute query on closed connection
# This should:
# 1. Fail with an exception
# 2. is_disconnect() gets called by SQLAlchemy
# 3. Returns True (recognizes it as disconnect error)
# 4. SQLAlchemy invalidates the connection
# 5. Next operation gets a fresh connection

# First query will fail because connection is closed
try:
with engine.connect() as conn:
conn.execute(text("SELECT VERSION()")).scalar()
# If we get here without exception, the connection wasn't actually closed
pytest.skip("Connection wasn't properly closed - cannot test is_disconnect")
except DBAPIError as e:
# Expected - connection was closed
# is_disconnect() should have been called and returned True
# This causes SQLAlchemy to invalidate the connection
assert "closed" in str(e).lower() or "invalid" in str(e).lower()

# Step 4: Next query should work because is_disconnect() invalidated the bad connection
with engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()")).scalar()
assert result is not None

# Verify we got a NEW connection
raw_conn = conn.connection.dbapi_connection
session_id_2 = raw_conn.get_session_id_hex()
assert session_id_2 is not None

# Different session ID proves connection was invalidated and recreated
assert session_id_1 != session_id_2, (
"is_disconnect() should have invalidated the bad connection, "
"causing SQLAlchemy to create a new one with different session ID"
)

engine.dispose()
151 changes: 151 additions & 0 deletions tests/test_local/test_is_disconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Tests for DatabricksDialect.is_disconnect() method."""
import pytest
from databricks.sqlalchemy import DatabricksDialect
from databricks.sql.exc import (
Error,
InterfaceError,
DatabaseError,
OperationalError,
RequestError,
SessionAlreadyClosedError,
CursorAlreadyClosedError,
MaxRetryDurationError,
NonRecoverableNetworkError,
UnsafeToRetryError,
)


class TestIsDisconnect:
@pytest.fixture
def dialect(self):
return DatabricksDialect()

# --- InterfaceError: closed connection/cursor (client.py) ---

def test_interface_error_closed_connection(self, dialect):
"""All InterfaceError messages with 'closed' are disconnects."""
test_cases = [
InterfaceError("Cannot create cursor from closed connection"),
InterfaceError("Cannot get autocommit on closed connection"),
InterfaceError("Cannot set autocommit on closed connection"),
InterfaceError("Cannot commit on closed connection"),
InterfaceError("Cannot rollback on closed connection"),
InterfaceError("Cannot get transaction isolation on closed connection"),
InterfaceError("Cannot set transaction isolation on closed connection"),
InterfaceError("Attempting operation on closed cursor"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is True

def test_interface_error_without_closed_not_disconnect(self, dialect):
"""InterfaceError without 'closed' is not a disconnect."""
error = InterfaceError("Some other interface error")
assert dialect.is_disconnect(error, None, None) is False

# --- RequestError: transport/network-level errors ---

def test_request_error_is_disconnect(self, dialect):
"""All RequestError instances are disconnects."""
test_cases = [
RequestError("HTTP client is closing or has been closed"),
RequestError("Connection pool not initialized"),
RequestError("HTTP request failed: max retries exceeded"),
RequestError("HTTP request error: connection reset"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is True

def test_request_error_subclasses_are_disconnect(self, dialect):
"""RequestError subclasses are all disconnects."""
test_cases = [
SessionAlreadyClosedError("Session already closed"),
CursorAlreadyClosedError("Cursor already closed"),
MaxRetryDurationError("Retry duration exceeded"),
NonRecoverableNetworkError("HTTP 501"),
UnsafeToRetryError("Unexpected HTTP error"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is True

# --- DatabaseError: server-side session/operation errors ---

def test_database_error_with_invalid_handle(self, dialect):
"""DatabaseError with 'invalid handle' is a disconnect."""
test_cases = [
DatabaseError("Invalid SessionHandle"),
DatabaseError("[Errno INVALID_HANDLE] Session does not exist"),
DatabaseError("INVALID HANDLE"),
DatabaseError("invalid handle"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is True

def test_database_error_unexpectedly_closed_server_side(self, dialect):
"""DatabaseError for operations closed server-side is a disconnect."""
test_cases = [
DatabaseError("Command abc123 unexpectedly closed server side"),
DatabaseError("Command None unexpectedly closed server side"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is True

def test_database_error_without_disconnect_indicators(self, dialect):
"""DatabaseError without disconnect indicators is not a disconnect."""
test_cases = [
DatabaseError("Syntax error in SQL"),
DatabaseError("Table not found"),
DatabaseError("Permission denied"),
DatabaseError("Catalog name is required for get_schemas"),
DatabaseError("Catalog name is required for get_columns"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is False

# --- OperationalError (non-RequestError) ---

def test_operational_error_not_disconnect(self, dialect):
"""OperationalError without disconnect indicators is not a disconnect."""
test_cases = [
OperationalError("Timeout waiting for query"),
OperationalError("Empty TColumn instance"),
OperationalError("Unsupported TRowSet instance"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is False

# --- Base Error class: older connector versions (client.py:385) ---

def test_base_error_closed_connection_is_disconnect(self, dialect):
"""Base Error with 'closed connection/cursor' is a disconnect.

Older released versions of databricks-sql-connector raise Error
(not InterfaceError) for closed connection messages.
"""
test_cases = [
Error("Cannot create cursor from closed connection"),
Error("Cannot get autocommit on closed connection"),
Error("Attempting operation on closed cursor"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is True

def test_base_error_without_closed_not_disconnect(self, dialect):
"""Base Error without 'closed connection/cursor' is not a disconnect."""
test_cases = [
Error("Some other error"),
Error("Connection timeout"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is False

# --- Other exceptions ---

def test_other_errors_not_disconnect(self, dialect):
"""Non-connector exception types are not disconnects."""
test_cases = [
Exception("Some random error"),
ValueError("Bad value"),
RuntimeError("Runtime failure"),
]
for error in test_cases:
assert dialect.is_disconnect(error, None, None) is False
Loading