Skip to content

Commit b33d982

Browse files
committed
tests: address review feedback on test_pandas_compatibility
- Refactor _make_result_set to use ThriftResultSet via normal constructor with mocked args (matches test_client.py pattern), instead of bypassing __init__ via object.__new__ - Add coverage for additional Arrow types: decimal128, date32/date64, timestamp, binary, large_string, list_, struct, map_
1 parent 9d36bc3 commit b33d982

1 file changed

Lines changed: 207 additions & 41 deletions

File tree

tests/unit/test_pandas_compatibility.py

Lines changed: 207 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
33
These tests verify that _convert_arrow_table correctly converts Arrow tables
44
to Row objects using pandas as an intermediary, covering various data types
5-
including nullable integers, floats, booleans, and strings.
5+
including nullable integers, floats, booleans, strings, decimal, dates,
6+
timestamps, binary, and nested types.
67
"""
78

9+
import datetime
810
import unittest
11+
from decimal import Decimal
912
from unittest.mock import Mock
1013

1114
import pandas
@@ -16,43 +19,38 @@
1619
except ImportError:
1720
pa = None
1821

19-
from databricks.sql.result_set import ResultSet
20-
from databricks.sql.types import Row
22+
from databricks.sql.result_set import ThriftResultSet
2123

2224

23-
class _ConcreteResultSet(ResultSet):
24-
"""Minimal concrete subclass of ResultSet for testing _convert_arrow_table."""
25+
def _make_result_set(description, disable_pandas=False):
26+
"""Create a ThriftResultSet with mocked dependencies for testing _convert_arrow_table.
2527
26-
def fetchone(self):
27-
pass
28+
Mirrors the construction pattern used in test_client.py: pass mocked
29+
connection/execute_response/thrift_client to the normal constructor.
30+
"""
31+
mock_connection = Mock()
32+
mock_connection.disable_pandas = disable_pandas
2833

29-
def fetchmany(self, size):
30-
pass
34+
mock_execute_response = Mock()
35+
mock_execute_response.description = description
36+
# t_row_set defaults to None, so result_format being None keeps the queue-build branch off.
37+
mock_execute_response.result_format = None
3138

32-
def fetchall(self):
33-
pass
39+
mock_backend = Mock()
40+
# _fill_results_buffer() expects (results, has_more_rows, result_links_count).
41+
mock_backend.fetch_results.return_value = (Mock(), False, 0)
3442

35-
def fetchmany_arrow(self, size):
36-
pass
37-
38-
def fetchall_arrow(self):
39-
pass
43+
return ThriftResultSet(
44+
connection=mock_connection,
45+
execute_response=mock_execute_response,
46+
thrift_client=mock_backend,
47+
)
4048

4149

4250
@pytest.mark.skipif(pa is None, reason="PyArrow is not installed")
4351
class TestConvertArrowTablePandasCompat(unittest.TestCase):
4452
"""Test _convert_arrow_table with various Arrow types under current pandas version."""
4553

46-
def _make_result_set(self, description):
47-
"""Create a minimal ResultSet instance for testing _convert_arrow_table."""
48-
mock_connection = Mock()
49-
mock_connection.disable_pandas = False
50-
51-
rs = object.__new__(_ConcreteResultSet)
52-
rs.connection = mock_connection
53-
rs.description = description
54-
return rs
55-
5654
def test_integer_types(self):
5755
table = pa.table(
5856
{
@@ -69,7 +67,7 @@ def test_integer_types(self):
6967
("int64_col", "bigint", None, None, None, None, None),
7068
]
7169

72-
rs = self._make_result_set(description)
70+
rs = _make_result_set(description)
7371
rows = rs._convert_arrow_table(table)
7472

7573
self.assertEqual(len(rows), 3)
@@ -94,7 +92,7 @@ def test_unsigned_integer_types(self):
9492
("uint64_col", "bigint", None, None, None, None, None),
9593
]
9694

97-
rs = self._make_result_set(description)
95+
rs = _make_result_set(description)
9896
rows = rs._convert_arrow_table(table)
9997

10098
self.assertEqual(len(rows), 2)
@@ -115,7 +113,7 @@ def test_float_types(self):
115113
("float64_col", "double", None, None, None, None, None),
116114
]
117115

118-
rs = self._make_result_set(description)
116+
rs = _make_result_set(description)
119117
rows = rs._convert_arrow_table(table)
120118

121119
self.assertEqual(len(rows), 3)
@@ -134,7 +132,7 @@ def test_boolean_type(self):
134132
("bool_col", "boolean", None, None, None, None, None),
135133
]
136134

137-
rs = self._make_result_set(description)
135+
rs = _make_result_set(description)
138136
rows = rs._convert_arrow_table(table)
139137

140138
self.assertEqual(len(rows), 3)
@@ -158,14 +156,186 @@ def test_string_type(self):
158156
("str_col", "string", None, None, None, None, None),
159157
]
160158

161-
rs = self._make_result_set(description)
159+
rs = _make_result_set(description)
162160
rows = rs._convert_arrow_table(table)
163161

164162
self.assertEqual(len(rows), 3)
165163
self.assertEqual(rows[0].str_col, "hello")
166164
self.assertEqual(rows[1].str_col, "world")
167165
self.assertIsNone(rows[2].str_col)
168166

167+
def test_large_string_type(self):
168+
"""large_string is not in dtype_mapping → default Arrow→pandas conversion."""
169+
table = pa.table(
170+
{
171+
"lstr_col": pa.array(["foo", "bar", None], type=pa.large_string()),
172+
}
173+
)
174+
description = [
175+
("lstr_col", "string", None, None, None, None, None),
176+
]
177+
178+
rs = _make_result_set(description)
179+
rows = rs._convert_arrow_table(table)
180+
181+
self.assertEqual(len(rows), 3)
182+
self.assertEqual(rows[0].lstr_col, "foo")
183+
self.assertEqual(rows[1].lstr_col, "bar")
184+
self.assertIsNone(rows[2].lstr_col)
185+
186+
def test_binary_type(self):
187+
table = pa.table(
188+
{
189+
"bin_col": pa.array([b"hello", None, b"world"], type=pa.binary()),
190+
}
191+
)
192+
description = [
193+
("bin_col", "binary", None, None, None, None, None),
194+
]
195+
196+
rs = _make_result_set(description)
197+
rows = rs._convert_arrow_table(table)
198+
199+
self.assertEqual(len(rows), 3)
200+
self.assertEqual(rows[0].bin_col, b"hello")
201+
self.assertIsNone(rows[1].bin_col)
202+
self.assertEqual(rows[2].bin_col, b"world")
203+
204+
def test_decimal_type(self):
205+
table = pa.table(
206+
{
207+
"dec_col": pa.array(
208+
[Decimal("1.23"), None, Decimal("99.99")],
209+
type=pa.decimal128(5, 2),
210+
),
211+
}
212+
)
213+
description = [
214+
("dec_col", "decimal", None, None, None, None, None),
215+
]
216+
217+
rs = _make_result_set(description)
218+
rows = rs._convert_arrow_table(table)
219+
220+
self.assertEqual(len(rows), 3)
221+
self.assertEqual(rows[0].dec_col, Decimal("1.23"))
222+
self.assertIsNone(rows[1].dec_col)
223+
self.assertEqual(rows[2].dec_col, Decimal("99.99"))
224+
225+
def test_date_types(self):
226+
"""date32 and date64 → datetime.date objects via date_as_object=True."""
227+
table = pa.table(
228+
{
229+
"date32_col": pa.array(
230+
[datetime.date(2024, 1, 1), None, datetime.date(2026, 5, 19)],
231+
type=pa.date32(),
232+
),
233+
"date64_col": pa.array(
234+
[None, datetime.date(2024, 12, 31), datetime.date(2026, 1, 1)],
235+
type=pa.date64(),
236+
),
237+
}
238+
)
239+
description = [
240+
("date32_col", "date", None, None, None, None, None),
241+
("date64_col", "date", None, None, None, None, None),
242+
]
243+
244+
rs = _make_result_set(description)
245+
rows = rs._convert_arrow_table(table)
246+
247+
self.assertEqual(len(rows), 3)
248+
self.assertEqual(rows[0].date32_col, datetime.date(2024, 1, 1))
249+
self.assertIsNone(rows[0].date64_col)
250+
self.assertIsNone(rows[1].date32_col)
251+
self.assertEqual(rows[1].date64_col, datetime.date(2024, 12, 31))
252+
self.assertEqual(rows[2].date32_col, datetime.date(2026, 5, 19))
253+
self.assertEqual(rows[2].date64_col, datetime.date(2026, 1, 1))
254+
255+
def test_timestamp_type(self):
256+
"""timestamp → datetime.datetime objects via timestamp_as_object=True."""
257+
ts1 = datetime.datetime(2024, 1, 1, 12, 30, 45)
258+
ts2 = datetime.datetime(2026, 5, 19, 9, 15, 0)
259+
table = pa.table(
260+
{
261+
"ts_col": pa.array([ts1, None, ts2], type=pa.timestamp("us")),
262+
}
263+
)
264+
description = [
265+
("ts_col", "timestamp", None, None, None, None, None),
266+
]
267+
268+
rs = _make_result_set(description)
269+
rows = rs._convert_arrow_table(table)
270+
271+
self.assertEqual(len(rows), 3)
272+
self.assertEqual(rows[0].ts_col, ts1)
273+
self.assertIsNone(rows[1].ts_col)
274+
self.assertEqual(rows[2].ts_col, ts2)
275+
276+
def test_list_type(self):
277+
table = pa.table(
278+
{
279+
"list_col": pa.array(
280+
[[1, 2, 3], None, [4, 5]],
281+
type=pa.list_(pa.int64()),
282+
),
283+
}
284+
)
285+
description = [
286+
("list_col", "array", None, None, None, None, None),
287+
]
288+
289+
rs = _make_result_set(description)
290+
rows = rs._convert_arrow_table(table)
291+
292+
self.assertEqual(len(rows), 3)
293+
self.assertEqual(list(rows[0].list_col), [1, 2, 3])
294+
self.assertIsNone(rows[1].list_col)
295+
self.assertEqual(list(rows[2].list_col), [4, 5])
296+
297+
def test_struct_type(self):
298+
table = pa.table(
299+
{
300+
"struct_col": pa.array(
301+
[{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}],
302+
type=pa.struct([("x", pa.int64()), ("y", pa.string())]),
303+
),
304+
}
305+
)
306+
description = [
307+
("struct_col", "struct", None, None, None, None, None),
308+
]
309+
310+
rs = _make_result_set(description)
311+
rows = rs._convert_arrow_table(table)
312+
313+
self.assertEqual(len(rows), 3)
314+
self.assertEqual(rows[0].struct_col, {"x": 1, "y": "a"})
315+
self.assertIsNone(rows[1].struct_col)
316+
self.assertEqual(rows[2].struct_col, {"x": 3, "y": "c"})
317+
318+
def test_map_type(self):
319+
table = pa.table(
320+
{
321+
"map_col": pa.array(
322+
[[("k1", 1), ("k2", 2)], None, [("k3", 3)]],
323+
type=pa.map_(pa.string(), pa.int64()),
324+
),
325+
}
326+
)
327+
description = [
328+
("map_col", "map", None, None, None, None, None),
329+
]
330+
331+
rs = _make_result_set(description)
332+
rows = rs._convert_arrow_table(table)
333+
334+
self.assertEqual(len(rows), 3)
335+
self.assertEqual(list(rows[0].map_col), [("k1", 1), ("k2", 2)])
336+
self.assertIsNone(rows[1].map_col)
337+
self.assertEqual(list(rows[2].map_col), [("k3", 3)])
338+
169339
def test_mixed_types(self):
170340
"""Test a table with a mix of types, similar to real query results."""
171341
table = pa.table(
@@ -183,7 +353,7 @@ def test_mixed_types(self):
183353
("active", "boolean", None, None, None, None, None),
184354
]
185355

186-
rs = self._make_result_set(description)
356+
rs = _make_result_set(description)
187357
rows = rs._convert_arrow_table(table)
188358

189359
self.assertEqual(len(rows), 3)
@@ -220,7 +390,7 @@ def test_duplicate_column_names(self):
220390
("col", "int", None, None, None, None, None),
221391
]
222392

223-
rs = self._make_result_set(description)
393+
rs = _make_result_set(description)
224394
rows = rs._convert_arrow_table(table)
225395

226396
self.assertEqual(len(rows), 2)
@@ -238,7 +408,7 @@ def test_empty_table(self):
238408
("col", "int", None, None, None, None, None),
239409
]
240410

241-
rs = self._make_result_set(description)
411+
rs = _make_result_set(description)
242412
rows = rs._convert_arrow_table(table)
243413

244414
self.assertEqual(len(rows), 0)
@@ -255,7 +425,7 @@ def test_all_nulls(self):
255425
("str_col", "string", None, None, None, None, None),
256426
]
257427

258-
rs = self._make_result_set(description)
428+
rs = _make_result_set(description)
259429
rows = rs._convert_arrow_table(table)
260430

261431
self.assertEqual(len(rows), 2)
@@ -266,15 +436,11 @@ def test_all_nulls(self):
266436

267437
def test_disable_pandas_path(self):
268438
"""Verify the non-pandas code path still works."""
269-
mock_connection = Mock()
270-
mock_connection.disable_pandas = True
271-
272-
rs = object.__new__(_ConcreteResultSet)
273-
rs.connection = mock_connection
274-
rs.description = [
439+
description = [
275440
("id", "bigint", None, None, None, None, None),
276441
("name", "string", None, None, None, None, None),
277442
]
443+
rs = _make_result_set(description, disable_pandas=True)
278444

279445
table = pa.table(
280446
{

0 commit comments

Comments
 (0)