Skip to content

Commit 24e9a5c

Browse files
feat(backend/kernel): introduce dedicated use_kernel flag + substantive review fixes
Major change: route the kernel backend through a new ``use_kernel=True`` connection kwarg instead of repurposing ``use_sea=True``. ``use_sea=True`` once again routes to the native pure-Python SEA backend (no behaviour change); ``use_kernel=True`` routes to the Rust kernel via PyO3. The two flags are mutually exclusive. This addresses the largest reviewer concern from the multi-agent review: silently hijacking a documented public flag broke OAuth / federation / parameter-binding callers on ``use_sea=True`` who had no opt-out. With the new flag, the kernel backend is fully opt-in and existing ``use_sea=True`` users continue to get the native SEA backend they signed up for. Other substantive fixes: - session.py: restore ``SeaDatabricksClient`` import + routing. Reject ``use_kernel=True`` + ``use_sea=True`` together with a clear ``ValueError``. - client.py (kernel ``Cursor.columns``): update docstring to flag the ``catalog_name=None`` divergence — kernel requires a catalog, Thrift / native SEA do not (F13). - conftest.py: drop the collection-time ``pytest_collection_modifyitems`` hook that was skipping ``extra_params={"use_sea": True}`` cases. With ``use_sea=True`` back on the native SEA backend, those cases run as they did before this PR (F8). - kernel/client.py: ``get_tables`` now applies the ``table_types`` filter client-side using ``ResultSetFilter._filter_arrow_table`` (the same helper the native SEA backend uses), wrapped in a tiny ``_StaticArrowHandle`` that flows the filtered table back through the normal ``KernelResultSet`` path. Replaces the previous "log a warning and return unfiltered" behaviour (F4). - kernel/client.py: guard ``_async_handles`` with ``threading.RLock`` so concurrent cursors on the same connection don't race on submit / close / close-session (F15). - kernel/result_set.py: ``KernelResultSet.close()`` now drops the entry from ``backend._async_handles`` so async-submitted statements don't leave stale references behind (F5). - kernel/{__init__,client,auth_bridge}.py, tests/e2e/test_kernel_backend.py: update docstrings, error messages, and the e2e fixture to refer to ``use_kernel=True`` instead of ``use_sea=True``. - client.py (``Connection`` docstring): document the new ``use_kernel`` kwarg + its Phase-1 limitations. New tests: - tests/unit/test_kernel_client.py (38 cases): cover the 14-entry ``_CODE_TO_EXCEPTION`` table, ``_reraise_kernel_error`` attribute forwarding, the 6-entry ``_STATE_TO_COMMAND_STATE`` table, the no-open-session guards on every method, ``open_session`` double-open, ``parameters`` / ``query_tags`` rejection, ``get_columns``' catalog-required check, ``cancel_command`` / ``close_command`` no-handle tolerance, ``get_query_state`` sync-path SUCCEEDED, the Failed-state re-raise, the synthetic-command-id UUID shape, and ``close_session`` cleanup even when per-handle close errors fire. Uses a fake ``databricks_sql_kernel`` module installed into ``sys.modules`` so the test runs with no Rust extension dependency (F9). 77/77 kernel unit tests pass. Co-authored-by: Isaac
1 parent 37fa544 commit 24e9a5c

9 files changed

Lines changed: 550 additions & 80 deletions

File tree

conftest.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,7 @@
1-
import importlib.util
21
import os
32
import pytest
43

54

6-
def _kernel_wheel_available() -> bool:
7-
"""The ``use_sea=True`` code path now routes through the Rust
8-
kernel via PyO3. The ``databricks_sql_kernel`` wheel is not
9-
yet on PyPI (built from a separate repo); CI environments
10-
without it should skip ``use_sea=True`` parametrized cases
11-
rather than fail with a hard ImportError."""
12-
return importlib.util.find_spec("databricks_sql_kernel") is not None
13-
14-
15-
def pytest_collection_modifyitems(config, items):
16-
"""Skip parametrized test cases that pass ``use_sea=True`` when
17-
the kernel wheel isn't installed.
18-
19-
The existing e2e suite uses ``@pytest.mark.parametrize(
20-
"extra_params", [{}, {"use_sea": True}])`` to exercise both
21-
backends. When the kernel wheel is missing those cases die at
22-
``connect()`` time with our pointed ImportError; mark them
23-
skipped at collection time so CI signal stays accurate.
24-
"""
25-
if _kernel_wheel_available():
26-
return
27-
skip_marker = pytest.mark.skip(
28-
reason="use_sea=True requires databricks-sql-kernel (not installed)"
29-
)
30-
for item in items:
31-
params = getattr(item, "callspec", None)
32-
if params is None:
33-
continue
34-
extra_params = params.params.get("extra_params")
35-
if isinstance(extra_params, dict) and extra_params.get("use_sea") is True:
36-
item.add_marker(skip_marker)
37-
38-
395
@pytest.fixture(scope="session")
406
def host():
417
return os.getenv("DATABRICKS_SERVER_HOSTNAME")

src/databricks/sql/backend/kernel/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Backend that delegates to the Databricks SQL Kernel (Rust) via PyO3.
22
3-
Routed when ``use_sea=True`` is passed to ``databricks.sql.connect``.
3+
Routed when ``use_kernel=True`` is passed to ``databricks.sql.connect``.
44
The module's identity is "delegates to the kernel" — not the wire
55
protocol the kernel happens to use today (SEA REST). The kernel may
66
switch its default transport (SEA REST → SEA gRPC → …) without
@@ -18,7 +18,7 @@
1818
from databricks.sql.backend.kernel.client import KernelDatabricksClient
1919
2020
``session.py::_create_backend`` already does this lazy import under
21-
the ``use_sea=True`` branch.
21+
the ``use_kernel=True`` branch.
2222
2323
See ``docs/designs/pysql-kernel-integration.md`` in
2424
``databricks-sql-kernel`` for the full integration design.

src/databricks/sql/backend/kernel/auth_bridge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]:
105105
return {"auth_type": "pat", "access_token": token}
106106

107107
raise NotSupportedError(
108-
f"The kernel backend (use_sea=True) currently only supports PAT auth, "
109-
f"but got {type(auth_provider).__name__}. Use use_sea=False (Thrift) "
110-
"for OAuth / federation / custom credential providers."
108+
f"The kernel backend (use_kernel=True) currently only supports PAT auth, "
109+
f"but got {type(auth_provider).__name__}. Use the Thrift backend "
110+
"(default) for OAuth / federation / custom credential providers."
111111
)

src/databricks/sql/backend/kernel/client.py

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""``DatabricksClient`` backed by the Rust kernel via PyO3.
22
3-
Routed when ``use_sea=True``. Constructor takes the connector's
3+
Routed when ``use_kernel=True``. Constructor takes the connector's
44
already-built ``auth_provider`` and forwards everything else to the
55
kernel's ``Session``. Every kernel call goes through this thin
66
wrapper; this module is the single seam between the connector's
@@ -34,6 +34,7 @@
3434
from __future__ import annotations
3535

3636
import logging
37+
import threading
3738
import uuid
3839
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
3940

@@ -71,7 +72,7 @@
7172
# (doing so breaks `poetry lock`). Once published the install
7273
# hint will move to `pip install 'databricks-sql-connector[kernel]'`.
7374
raise ImportError(
74-
"use_sea=True requires the databricks-sql-kernel package. Install it with:\n"
75+
"use_kernel=True requires the databricks-sql-kernel package. Install it with:\n"
7576
" pip install databricks-sql-kernel\n"
7677
"or for local development from the kernel repo:\n"
7778
" cd databricks-sql-kernel/pyo3 && maturin develop --release"
@@ -176,7 +177,10 @@ def __init__(
176177
self._session_id: Optional[SessionId] = None
177178
# Async-exec handles keyed by CommandId.guid. Populated by
178179
# ``execute_command(async_op=True)``; drained by ``close_command``.
180+
# Guarded by ``_async_handles_lock`` so concurrent cursors on the
181+
# same connection don't race on submit / close / close-session.
179182
self._async_handles: Dict[str, Any] = {}
183+
self._async_handles_lock = threading.RLock()
180184

181185
# ── Session lifecycle ──────────────────────────────────────────
182186

@@ -226,14 +230,16 @@ def close_session(self, session_id: SessionId) -> None:
226230
return
227231
# Close any tracked async handles first so they fire their
228232
# server-side CloseStatement before the session goes away.
229-
for handle in list(self._async_handles.values()):
233+
with self._async_handles_lock:
234+
handles_to_close = list(self._async_handles.values())
235+
self._async_handles.clear()
236+
for handle in handles_to_close:
230237
try:
231238
handle.close()
232239
except _kernel.KernelError as exc:
233240
logger.warning(
234241
"Error closing async handle during session close: %s", exc
235242
)
236-
self._async_handles.clear()
237243
try:
238244
self._kernel_session.close()
239245
except _kernel.KernelError as exc:
@@ -280,7 +286,8 @@ def execute_command(
280286
async_exec = stmt.submit()
281287
command_id = CommandId.from_sea_statement_id(async_exec.statement_id)
282288
cursor.active_command_id = command_id
283-
self._async_handles[command_id.guid] = async_exec
289+
with self._async_handles_lock:
290+
self._async_handles[command_id.guid] = async_exec
284291
return None
285292
executed = stmt.execute()
286293
except _kernel.KernelError as exc:
@@ -300,7 +307,8 @@ def execute_command(
300307
return self._make_result_set(executed, cursor, command_id)
301308

302309
def cancel_command(self, command_id: CommandId) -> None:
303-
handle = self._async_handles.get(command_id.guid)
310+
with self._async_handles_lock:
311+
handle = self._async_handles.get(command_id.guid)
304312
if handle is None:
305313
# Sync-execute paths fully materialise the result before
306314
# ``execute_command`` returns, so by the time
@@ -314,7 +322,8 @@ def cancel_command(self, command_id: CommandId) -> None:
314322
raise _reraise_kernel_error(exc)
315323

316324
def close_command(self, command_id: CommandId) -> None:
317-
handle = self._async_handles.pop(command_id.guid, None)
325+
with self._async_handles_lock:
326+
handle = self._async_handles.pop(command_id.guid, None)
318327
if handle is None:
319328
logger.debug("close_command: no tracked handle for %s", command_id)
320329
return
@@ -324,7 +333,8 @@ def close_command(self, command_id: CommandId) -> None:
324333
raise _reraise_kernel_error(exc)
325334

326335
def get_query_state(self, command_id: CommandId) -> CommandState:
327-
handle = self._async_handles.get(command_id.guid)
336+
with self._async_handles_lock:
337+
handle = self._async_handles.get(command_id.guid)
328338
if handle is None:
329339
# No tracked async handle means execute_command ran
330340
# sync and the result was materialised before returning;
@@ -347,7 +357,8 @@ def get_execution_result(
347357
command_id: CommandId,
348358
cursor: "Cursor",
349359
) -> "ResultSet":
350-
handle = self._async_handles.get(command_id.guid)
360+
with self._async_handles_lock:
361+
handle = self._async_handles.get(command_id.guid)
351362
if handle is None:
352363
raise ProgrammingError(
353364
"get_execution_result called for an unknown command_id; "
@@ -438,16 +449,6 @@ def get_tables(
438449
) -> "ResultSet":
439450
if self._kernel_session is None:
440451
raise InterfaceError("get_tables requires an open session.")
441-
if table_types:
442-
# Documented gap: native SEA backend filters here, but
443-
# its filter is keyed on SeaResultSet. Day-1 we surface
444-
# the unfiltered result; a small follow-up ports the
445-
# filter to operate on KernelResultSet.
446-
logger.warning(
447-
"get_tables: client-side table_types filter not yet implemented "
448-
"on the kernel backend; returning unfiltered rows for %r",
449-
table_types,
450-
)
451452
try:
452453
stream = self._kernel_session.metadata().list_tables(
453454
catalog=catalog_name,
@@ -457,7 +458,27 @@ def get_tables(
457458
)
458459
except _kernel.KernelError as exc:
459460
raise _reraise_kernel_error(exc)
460-
return self._make_result_set(stream, cursor, self._synthetic_command_id())
461+
if not table_types:
462+
return self._make_result_set(stream, cursor, self._synthetic_command_id())
463+
# The kernel today returns the unfiltered ``SHOW TABLES`` shape
464+
# regardless of ``table_types``. Drain to a single Arrow table
465+
# and apply the same client-side filter the native SEA backend
466+
# uses (column index 5 is TABLE_TYPE, case-sensitive). Cheap
467+
# because metadata result sets are small.
468+
from databricks.sql.backend.sea.utils.filters import ResultSetFilter
469+
470+
full_table = _drain_kernel_handle(stream)
471+
filtered_table = ResultSetFilter._filter_arrow_table(
472+
full_table,
473+
column_name=full_table.schema.field(5).name,
474+
allowed_values=table_types,
475+
case_sensitive=True,
476+
)
477+
return self._make_result_set(
478+
_StaticArrowHandle(filtered_table),
479+
cursor,
480+
self._synthetic_command_id(),
481+
)
461482

462483
def get_columns(
463484
self,
@@ -496,7 +517,7 @@ def get_columns(
496517
def max_download_threads(self) -> int:
497518
# CloudFetch parallelism lives kernel-side. This property is
498519
# consulted by Thrift code paths that don't run for
499-
# use_sea=True; return a non-zero default so anything that
520+
# use_kernel=True; return a non-zero default so anything that
500521
# peeks at it does not divide by zero.
501522
return 10
502523

@@ -509,3 +530,52 @@ def max_download_threads(self) -> int:
509530
"Cancelled": CommandState.CANCELLED,
510531
"Closed": CommandState.CLOSED,
511532
}
533+
534+
535+
def _drain_kernel_handle(handle: Any) -> Any:
536+
"""Drain a kernel ResultStream / ExecutedStatement into a single
537+
``pyarrow.Table``. Used by ``get_tables`` to apply a client-side
538+
``table_types`` filter on a metadata result; cheap because
539+
metadata streams are small."""
540+
import pyarrow
541+
542+
schema = handle.arrow_schema()
543+
batches = []
544+
while True:
545+
batch = handle.fetch_next_batch()
546+
if batch is None:
547+
break
548+
if batch.num_rows > 0:
549+
batches.append(batch)
550+
try:
551+
handle.close()
552+
except _kernel.KernelError:
553+
pass
554+
return pyarrow.Table.from_batches(batches, schema=schema)
555+
556+
557+
class _StaticArrowHandle:
558+
"""Duck-typed kernel handle that replays a pre-built
559+
``pyarrow.Table`` through ``arrow_schema()`` /
560+
``fetch_next_batch()`` / ``close()``. Used to wrap a
561+
post-processed table (e.g., the ``table_types``-filtered output
562+
of ``get_tables``) so it flows back through the normal
563+
``KernelResultSet`` path."""
564+
565+
def __init__(self, table: Any) -> None:
566+
self._schema = table.schema
567+
self._batches = list(table.to_batches())
568+
self._idx = 0
569+
570+
def arrow_schema(self) -> Any:
571+
return self._schema
572+
573+
def fetch_next_batch(self) -> Optional[Any]:
574+
if self._idx >= len(self._batches):
575+
return None
576+
batch = self._batches[self._idx]
577+
self._idx += 1
578+
return batch
579+
580+
def close(self) -> None:
581+
self._batches = []

src/databricks/sql/backend/kernel/result_set.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,21 @@ def close(self) -> None:
226226
# level; log and swallow so the cursor's __del__ /
227227
# connection close path stays clean.
228228
logger.warning("Error closing kernel handle: %s", exc)
229+
# Drop the entry from the backend's async-handle map (if
230+
# present) — for async-submitted statements the handle is
231+
# tracked there and the base ``ResultSet.close`` path would
232+
# otherwise leave a stale entry pointing at a closed handle.
233+
# No-op for the sync-execute and metadata paths, which never
234+
# register in ``_async_handles``.
235+
guid = getattr(self.command_id, "guid", None)
236+
if guid is not None:
237+
self.backend._async_handles_lock.acquire()
238+
try:
239+
self.backend._async_handles.pop(guid, None)
240+
finally:
241+
self.backend._async_handles_lock.release()
229242
self._buffer.clear()
243+
self._buffered_count = 0
230244
self._kernel_handle = None
231245
self._exhausted = True
232246
self.has_been_closed_server_side = True

src/databricks/sql/client.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,17 @@ def __init__(
115115
116116
Parameters:
117117
:param use_sea: `bool`, optional (default is False)
118-
Use the SEA backend instead of the Thrift backend.
118+
Use the native pure-Python SEA backend instead of
119+
the Thrift backend.
120+
:param use_kernel: `bool`, optional (default is False)
121+
Route the connection through the Rust kernel
122+
(``databricks-sql-kernel`` via PyO3). Requires the
123+
kernel wheel to be installed separately
124+
(``pip install databricks-sql-kernel``); raises
125+
ImportError otherwise. In active development —
126+
PAT auth only today; OAuth / federation / external
127+
credentials and native parameter binding land in
128+
follow-ups. Mutually exclusive with ``use_sea``.
119129
:param use_hybrid_disposition: `bool`, optional (default is False)
120130
Use the hybrid disposition instead of the inline disposition.
121131
:param server_hostname: Databricks instance host name.
@@ -1575,6 +1585,12 @@ def columns(
15751585
Get columns corresponding to the catalog_name, schema_name, table_name and column_name.
15761586
15771587
Names can contain % wildcards.
1588+
1589+
Note: on ``use_kernel=True``, ``catalog_name`` is required —
1590+
the kernel's underlying ``SHOW COLUMNS`` cannot span catalogs.
1591+
Passing ``catalog_name=None`` raises ``ProgrammingError``. The
1592+
Thrift and native SEA backends accept ``catalog_name=None``.
1593+
15781594
:returns self
15791595
"""
15801596
self._check_not_closed()

0 commit comments

Comments
 (0)