Skip to content

Commit 946d70a

Browse files
committed
feat: add dictionary_columns to scan API for memory-efficient string reads
Exposes `dictionary_columns: tuple[str, ...] | None = None` on `Table.scan()` and `DataScan`, threading it through to PyArrow's `ParquetFileFormat` so that named columns are read as `DictionaryArray` instead of plain `large_utf8`. This dramatically reduces memory usage for high-cardinality repeated JSON/string columns (issue #3168) and addresses the general scan parameter extensibility request (issue #3170). Key implementation details: - ORC files are guarded — `dictionary_columns` is only passed for Parquet - `ArrowScan.to_table()` rebuilds the Arrow schema with dict types before the empty-table fast-path so schema is consistent regardless of row count - `DataScan.to_arrow_batch_reader()` rebuilds `target_schema` with dict types to prevent `.cast()` from silently decoding DictionaryArray back to plain string - `DataScan.__init__` declares and stores the param so `TableScan.update()` (which uses `inspect.signature`) preserves it across scan copies Fixes #3168, closes #3170
1 parent 1a54e9c commit 946d70a

4 files changed

Lines changed: 286 additions & 3 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1614,8 +1614,13 @@ def _task_to_record_batches(
16141614
partition_spec: PartitionSpec | None = None,
16151615
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
16161616
downcast_ns_timestamp_to_us: bool | None = None,
1617+
dictionary_columns: tuple[str, ...] | None = None,
16171618
) -> Iterator[pa.RecordBatch]:
1618-
arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
1619+
# Only pass dictionary_columns for Parquet — ORC does not support this kwarg.
1620+
format_kwargs: dict[str, Any] = {"pre_buffer": True, "buffer_size": ONE_MEGABYTE * 8}
1621+
if dictionary_columns and task.file.file_format == FileFormat.PARQUET:
1622+
format_kwargs["dictionary_columns"] = dictionary_columns
1623+
arrow_format = _get_file_format(task.file.file_format, **format_kwargs)
16191624
with io.new_input(task.file.file_path).open() as fin:
16201625
fragment = arrow_format.make_fragment(fin)
16211626
physical_schema = fragment.physical_schema
@@ -1718,6 +1723,7 @@ class ArrowScan:
17181723
_case_sensitive: bool
17191724
_limit: int | None
17201725
_downcast_ns_timestamp_to_us: bool | None
1726+
_dictionary_columns: tuple[str, ...] | None
17211727
"""Scan the Iceberg Table and create an Arrow construct.
17221728
17231729
Attributes:
@@ -1737,6 +1743,8 @@ def __init__(
17371743
row_filter: BooleanExpression,
17381744
case_sensitive: bool = True,
17391745
limit: int | None = None,
1746+
*,
1747+
dictionary_columns: tuple[str, ...] | None = None,
17401748
) -> None:
17411749
self._table_metadata = table_metadata
17421750
self._io = io
@@ -1745,6 +1753,7 @@ def __init__(
17451753
self._case_sensitive = case_sensitive
17461754
self._limit = limit
17471755
self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE)
1756+
self._dictionary_columns = dictionary_columns
17481757

17491758
@property
17501759
def _projected_field_ids(self) -> set[int]:
@@ -1773,6 +1782,17 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
17731782
ValueError: When a field type in the file cannot be projected to the schema type
17741783
"""
17751784
arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False)
1785+
if self._dictionary_columns:
1786+
dict_cols_set = set(self._dictionary_columns)
1787+
arrow_schema = pa.schema(
1788+
[
1789+
field.with_type(pa.dictionary(pa.int32(), field.type))
1790+
if field.name in dict_cols_set
1791+
else field
1792+
for field in arrow_schema
1793+
],
1794+
metadata=arrow_schema.metadata,
1795+
)
17761796

17771797
batches = self.to_record_batches(tasks)
17781798
try:
@@ -1855,6 +1875,7 @@ def _record_batches_from_scan_tasks_and_deletes(
18551875
self._table_metadata.specs().get(task.file.spec_id),
18561876
self._table_metadata.format_version,
18571877
self._downcast_ns_timestamp_to_us,
1878+
dictionary_columns=self._dictionary_columns,
18581879
)
18591880
for batch in batches:
18601881
if self._limit is not None:

pyiceberg/table/__init__.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,7 @@ def scan(
11211121
snapshot_id: int | None = None,
11221122
options: Properties = EMPTY_DICT,
11231123
limit: int | None = None,
1124+
dictionary_columns: tuple[str, ...] | None = None,
11241125
) -> DataScan:
11251126
"""Fetch a DataScan based on the table's current metadata.
11261127
@@ -1147,6 +1148,13 @@ def scan(
11471148
An integer representing the number of rows to
11481149
return in the scan result. If None, fetches all
11491150
matching rows.
1151+
dictionary_columns:
1152+
A tuple of column names that PyArrow should read as
1153+
dictionary-encoded (DictionaryArray). Reduces memory
1154+
usage for columns with large or repeated string values
1155+
(e.g. large JSON blobs). Only applies to Parquet files;
1156+
silently ignored for ORC. Columns absent from the file
1157+
are silently skipped. Default is None (no dictionary encoding).
11501158
11511159
Returns:
11521160
A DataScan based on the table's current metadata.
@@ -1162,6 +1170,7 @@ def scan(
11621170
limit=limit,
11631171
catalog=self.catalog,
11641172
table_identifier=self._identifier,
1173+
dictionary_columns=dictionary_columns,
11651174
)
11661175

11671176
@property
@@ -1664,6 +1673,7 @@ def scan(
16641673
snapshot_id: int | None = None,
16651674
options: Properties = EMPTY_DICT,
16661675
limit: int | None = None,
1676+
dictionary_columns: tuple[str, ...] | None = None,
16671677
) -> DataScan:
16681678
raise ValueError("Cannot scan a staged table")
16691679

@@ -1916,6 +1926,36 @@ def _min_sequence_number(manifests: list[ManifestFile]) -> int:
19161926

19171927

19181928
class DataScan(TableScan):
1929+
dictionary_columns: tuple[str, ...] | None
1930+
1931+
def __init__(
1932+
self,
1933+
table_metadata: TableMetadata,
1934+
io: FileIO,
1935+
row_filter: str | BooleanExpression = ALWAYS_TRUE,
1936+
selected_fields: tuple[str, ...] = ("*",),
1937+
case_sensitive: bool = True,
1938+
snapshot_id: int | None = None,
1939+
options: Properties = EMPTY_DICT,
1940+
limit: int | None = None,
1941+
catalog: Catalog | None = None,
1942+
table_identifier: Identifier | None = None,
1943+
dictionary_columns: tuple[str, ...] | None = None,
1944+
) -> None:
1945+
super().__init__(
1946+
table_metadata=table_metadata,
1947+
io=io,
1948+
row_filter=row_filter,
1949+
selected_fields=selected_fields,
1950+
case_sensitive=case_sensitive,
1951+
snapshot_id=snapshot_id,
1952+
options=options,
1953+
limit=limit,
1954+
catalog=catalog,
1955+
table_identifier=table_identifier,
1956+
)
1957+
self.dictionary_columns = dictionary_columns
1958+
19191959
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
19201960
project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive)
19211961
return project(self.row_filter)
@@ -2113,7 +2153,13 @@ def to_arrow(self) -> pa.Table:
21132153
from pyiceberg.io.pyarrow import ArrowScan
21142154

21152155
return ArrowScan(
2116-
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2156+
self.table_metadata,
2157+
self.io,
2158+
self.projection(),
2159+
self.row_filter,
2160+
self.case_sensitive,
2161+
self.limit,
2162+
dictionary_columns=self.dictionary_columns,
21172163
).to_table(self.plan_files())
21182164

21192165
def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
@@ -2132,8 +2178,31 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
21322178
from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow
21332179

21342180
target_schema = schema_to_pyarrow(self.projection())
2181+
2182+
# When dictionary_columns is set, PyArrow returns DictionaryArray for those columns.
2183+
# target_schema uses plain string types, so .cast(target_schema) would silently decode
2184+
# them back to plain strings. Rebuild target_schema with dictionary types for the listed
2185+
# columns so from_batches and cast both preserve the encoding.
2186+
if self.dictionary_columns:
2187+
dict_cols_set = set(self.dictionary_columns)
2188+
target_schema = pa.schema(
2189+
[
2190+
field.with_type(pa.dictionary(pa.int32(), field.type))
2191+
if field.name in dict_cols_set
2192+
else field
2193+
for field in target_schema
2194+
],
2195+
metadata=target_schema.metadata,
2196+
)
2197+
21352198
batches = ArrowScan(
2136-
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
2199+
self.table_metadata,
2200+
self.io,
2201+
self.projection(),
2202+
self.row_filter,
2203+
self.case_sensitive,
2204+
self.limit,
2205+
dictionary_columns=self.dictionary_columns,
21372206
).to_record_batches(self.plan_files())
21382207

21392208
return pa.RecordBatchReader.from_batches(

tests/io/test_pyarrow.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,6 +3152,178 @@ def _expected_batch(unit: str) -> pa.RecordBatch:
31523152
assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result)
31533153

31543154

3155+
def test_task_to_record_batches_dictionary_columns(tmpdir: str) -> None:
3156+
"""dictionary_columns causes the column to be read as DictionaryArray, saving memory."""
3157+
arrow_table = pa.table(
3158+
{"json_col": pa.array(["large-json-1", "large-json-2", "large-json-1"], type=pa.string())},
3159+
schema=pa.schema(
3160+
[pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]
3161+
),
3162+
)
3163+
data_file = _write_table_to_data_file(
3164+
f"{tmpdir}/test_dictionary_columns.parquet", arrow_table.schema, arrow_table
3165+
)
3166+
table_schema = Schema(NestedField(1, "json_col", StringType(), required=False))
3167+
3168+
batches = list(
3169+
_task_to_record_batches(
3170+
PyArrowFileIO(),
3171+
FileScanTask(data_file),
3172+
bound_row_filter=AlwaysTrue(),
3173+
projected_schema=table_schema,
3174+
table_schema=table_schema,
3175+
projected_field_ids={1},
3176+
positional_deletes=None,
3177+
case_sensitive=True,
3178+
dictionary_columns=("json_col",),
3179+
)
3180+
)
3181+
3182+
assert len(batches) == 1, "Expected exactly one record batch"
3183+
col = batches[0].column("json_col")
3184+
assert pa.types.is_dictionary(col.type), (
3185+
f"Expected DictionaryArray for 'json_col' when dictionary_columns is set, got {col.type}"
3186+
)
3187+
3188+
3189+
def test_task_to_record_batches_no_dictionary_columns_by_default(tmpdir: str) -> None:
3190+
"""Without dictionary_columns, string columns are returned as plain StringArray — default unchanged."""
3191+
arrow_table = pa.table(
3192+
{"json_col": pa.array(["a", "b", "c"], type=pa.string())},
3193+
schema=pa.schema(
3194+
[pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]
3195+
),
3196+
)
3197+
data_file = _write_table_to_data_file(
3198+
f"{tmpdir}/test_no_dictionary_default.parquet", arrow_table.schema, arrow_table
3199+
)
3200+
table_schema = Schema(NestedField(1, "json_col", StringType(), required=False))
3201+
3202+
batches = list(
3203+
_task_to_record_batches(
3204+
PyArrowFileIO(),
3205+
FileScanTask(data_file),
3206+
bound_row_filter=AlwaysTrue(),
3207+
projected_schema=table_schema,
3208+
table_schema=table_schema,
3209+
projected_field_ids={1},
3210+
positional_deletes=None,
3211+
case_sensitive=True,
3212+
# dictionary_columns intentionally omitted — must not change behavior
3213+
)
3214+
)
3215+
3216+
assert len(batches) == 1, "Expected exactly one record batch"
3217+
col = batches[0].column("json_col")
3218+
assert not pa.types.is_dictionary(col.type), (
3219+
f"Expected plain StringArray by default, got {col.type}"
3220+
)
3221+
3222+
3223+
def test_arrow_scan_to_table_with_dictionary_columns(tmpdir: str) -> None:
3224+
"""ArrowScan.to_table() with dictionary_columns: named column is DictionaryArray, others are not."""
3225+
import pyarrow.parquet as pq
3226+
3227+
arrow_schema = pa.schema([
3228+
pa.field("id", pa.int32(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),
3229+
pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "2"}),
3230+
])
3231+
arrow_table = pa.table(
3232+
{
3233+
"id": pa.array([1, 2, 3], type=pa.int32()),
3234+
"json_col": pa.array(['{"x": 1}', '{"x": 2}', '{"x": 1}'], type=pa.string()),
3235+
},
3236+
schema=arrow_schema,
3237+
)
3238+
filepath = f"{tmpdir}/test_e2e_dictionary.parquet"
3239+
with pq.ParquetWriter(filepath, arrow_schema) as writer:
3240+
writer.write_table(arrow_table)
3241+
3242+
iceberg_schema = Schema(
3243+
NestedField(1, "id", IntegerType(), required=False),
3244+
NestedField(2, "json_col", StringType(), required=False),
3245+
)
3246+
data_file = DataFile.from_args(
3247+
content=DataFileContent.DATA,
3248+
file_path=filepath,
3249+
file_format=FileFormat.PARQUET,
3250+
partition={},
3251+
record_count=3,
3252+
file_size_in_bytes=100,
3253+
)
3254+
data_file.spec_id = 0
3255+
3256+
result = ArrowScan(
3257+
TableMetadataV2(
3258+
location="file://a/b/",
3259+
last_column_id=2,
3260+
format_version=2,
3261+
schemas=[iceberg_schema],
3262+
partition_specs=[PartitionSpec()],
3263+
),
3264+
PyArrowFileIO(),
3265+
iceberg_schema,
3266+
AlwaysTrue(),
3267+
dictionary_columns=("json_col",),
3268+
).to_table(tasks=[FileScanTask(data_file)])
3269+
3270+
assert pa.types.is_dictionary(result.schema.field("json_col").type), (
3271+
f"Expected DictionaryArray for 'json_col', got {result.schema.field('json_col').type}"
3272+
)
3273+
assert not pa.types.is_dictionary(result.schema.field("id").type), (
3274+
"Non-listed column 'id' should NOT be dictionary-encoded"
3275+
)
3276+
3277+
3278+
def test_arrow_scan_to_record_batches_preserves_dictionary_encoding(tmpdir: str) -> None:
3279+
"""ArrowScan.to_record_batches() must preserve DictionaryArray — not decode back to plain string."""
3280+
import pyarrow.parquet as pq
3281+
3282+
arrow_schema = pa.schema([
3283+
pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}),
3284+
])
3285+
arrow_table = pa.table(
3286+
{"json_col": pa.array(['{"a": 1}', '{"b": 2}'], type=pa.string())},
3287+
schema=arrow_schema,
3288+
)
3289+
filepath = f"{tmpdir}/test_batch_reader_dict.parquet"
3290+
with pq.ParquetWriter(filepath, arrow_schema) as writer:
3291+
writer.write_table(arrow_table)
3292+
3293+
iceberg_schema = Schema(NestedField(1, "json_col", StringType(), required=False))
3294+
data_file = DataFile.from_args(
3295+
content=DataFileContent.DATA,
3296+
file_path=filepath,
3297+
file_format=FileFormat.PARQUET,
3298+
partition={},
3299+
record_count=2,
3300+
file_size_in_bytes=100,
3301+
)
3302+
data_file.spec_id = 0
3303+
3304+
batches = list(
3305+
ArrowScan(
3306+
TableMetadataV2(
3307+
location="file://a/b/",
3308+
last_column_id=1,
3309+
format_version=2,
3310+
schemas=[iceberg_schema],
3311+
partition_specs=[PartitionSpec()],
3312+
),
3313+
PyArrowFileIO(),
3314+
iceberg_schema,
3315+
AlwaysTrue(),
3316+
dictionary_columns=("json_col",),
3317+
).to_record_batches(tasks=[FileScanTask(data_file)])
3318+
)
3319+
3320+
assert len(batches) >= 1, "Expected at least one record batch"
3321+
col = batches[0].column("json_col")
3322+
assert pa.types.is_dictionary(col.type), (
3323+
f"DictionaryArray must be preserved through to_record_batches, got {col.type}"
3324+
)
3325+
3326+
31553327
def test_parse_location_defaults() -> None:
31563328
"""Test that parse_location uses defaults."""
31573329

tests/table/test_init.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,27 @@ def test_table_scan_select(table_fixture: Table) -> None:
274274
assert scan.select("a", "c").select("a").selected_fields == ("a",)
275275

276276

277+
def test_table_scan_dictionary_columns_default(table_v2: Table) -> None:
278+
scan = table_v2.scan()
279+
assert scan.dictionary_columns is None, "dictionary_columns should default to None"
280+
281+
282+
def test_table_scan_dictionary_columns_set(table_v2: Table) -> None:
283+
scan = table_v2.scan(dictionary_columns=("json_col", "other_col"))
284+
assert scan.dictionary_columns == ("json_col", "other_col"), (
285+
"dictionary_columns should be stored on the scan"
286+
)
287+
288+
289+
def test_table_scan_dictionary_columns_preserved_on_update(table_v2: Table) -> None:
290+
scan = table_v2.scan(dictionary_columns=("json_col",))
291+
updated = scan.update(limit=10)
292+
assert updated.dictionary_columns == ("json_col",), (
293+
"dictionary_columns must survive .update() — TableScan.update() uses inspect.signature "
294+
"so DataScan.__init__ must declare and store it"
295+
)
296+
297+
277298
def test_table_scan_row_filter(table_v2: Table) -> None:
278299
scan = table_v2.scan()
279300
assert scan.row_filter == AlwaysTrue()

0 commit comments

Comments
 (0)