diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 8f22261f5d..49d287f8a6 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1614,8 +1614,13 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + dictionary_columns: tuple[str, ...] | None = None, ) -> Iterator[pa.RecordBatch]: - arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + # Only pass dictionary_columns for Parquet — ORC does not support this kwarg. + format_kwargs: dict[str, Any] = {"pre_buffer": True, "buffer_size": ONE_MEGABYTE * 8} + if dictionary_columns and task.file.file_format == FileFormat.PARQUET: + format_kwargs["dictionary_columns"] = dictionary_columns + arrow_format = _get_file_format(task.file.file_format, **format_kwargs) with io.new_input(task.file.file_path).open() as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema @@ -1718,6 +1723,7 @@ class ArrowScan: _case_sensitive: bool _limit: int | None _downcast_ns_timestamp_to_us: bool | None + _dictionary_columns: tuple[str, ...] | None """Scan the Iceberg Table and create an Arrow construct. Attributes: @@ -1737,6 +1743,8 @@ def __init__( row_filter: BooleanExpression, case_sensitive: bool = True, limit: int | None = None, + *, + dictionary_columns: tuple[str, ...] | None = None, ) -> None: self._table_metadata = table_metadata self._io = io @@ -1745,6 +1753,7 @@ def __init__( self._case_sensitive = case_sensitive self._limit = limit self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) + self._dictionary_columns = dictionary_columns @property def _projected_field_ids(self) -> set[int]: @@ -1773,6 +1782,15 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: ValueError: When a field type in the file cannot be projected to the schema type """ arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False) + if self._dictionary_columns: + dict_cols_set = set(self._dictionary_columns) + arrow_schema = pa.schema( + [ + field.with_type(pa.dictionary(pa.int32(), field.type)) if field.name in dict_cols_set else field + for field in arrow_schema + ], + metadata=arrow_schema.metadata, + ) batches = self.to_record_batches(tasks) try: @@ -1855,6 +1873,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + dictionary_columns=self._dictionary_columns, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index bb8765b651..2381ea8492 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1121,6 +1121,7 @@ def scan( snapshot_id: int | None = None, options: Properties = EMPTY_DICT, limit: int | None = None, + dictionary_columns: tuple[str, ...] | None = None, ) -> DataScan: """Fetch a DataScan based on the table's current metadata. @@ -1147,6 +1148,13 @@ def scan( An integer representing the number of rows to return in the scan result. If None, fetches all matching rows. + dictionary_columns: + A tuple of column names that PyArrow should read as + dictionary-encoded (DictionaryArray). Reduces memory + usage for columns with large or repeated string values + (e.g. large JSON blobs). Only applies to Parquet files; + silently ignored for ORC. Columns absent from the file + are silently skipped. Default is None (no dictionary encoding). Returns: A DataScan based on the table's current metadata. @@ -1162,6 +1170,7 @@ def scan( limit=limit, catalog=self.catalog, table_identifier=self._identifier, + dictionary_columns=dictionary_columns, ) @property @@ -1664,6 +1673,7 @@ def scan( snapshot_id: int | None = None, options: Properties = EMPTY_DICT, limit: int | None = None, + dictionary_columns: tuple[str, ...] | None = None, ) -> DataScan: raise ValueError("Cannot scan a staged table") @@ -1916,6 +1926,36 @@ def _min_sequence_number(manifests: list[ManifestFile]) -> int: class DataScan(TableScan): + dictionary_columns: tuple[str, ...] | None + + def __init__( + self, + table_metadata: TableMetadata, + io: FileIO, + row_filter: str | BooleanExpression = ALWAYS_TRUE, + selected_fields: tuple[str, ...] = ("*",), + case_sensitive: bool = True, + snapshot_id: int | None = None, + options: Properties = EMPTY_DICT, + limit: int | None = None, + catalog: Catalog | None = None, + table_identifier: Identifier | None = None, + dictionary_columns: tuple[str, ...] | None = None, + ) -> None: + super().__init__( + table_metadata=table_metadata, + io=io, + row_filter=row_filter, + selected_fields=selected_fields, + case_sensitive=case_sensitive, + snapshot_id=snapshot_id, + options=options, + limit=limit, + catalog=catalog, + table_identifier=table_identifier, + ) + self.dictionary_columns = dictionary_columns + def _build_partition_projection(self, spec_id: int) -> BooleanExpression: project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive) return project(self.row_filter) @@ -2113,7 +2153,13 @@ def to_arrow(self) -> pa.Table: from pyiceberg.io.pyarrow import ArrowScan return ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + dictionary_columns=self.dictionary_columns, ).to_table(self.plan_files()) def to_arrow_batch_reader(self) -> pa.RecordBatchReader: @@ -2132,8 +2178,29 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow target_schema = schema_to_pyarrow(self.projection()) + + # When dictionary_columns is set, PyArrow returns DictionaryArray for those columns. + # target_schema uses plain string types, so .cast(target_schema) would silently decode + # them back to plain strings. Rebuild target_schema with dictionary types for the listed + # columns so from_batches and cast both preserve the encoding. + if self.dictionary_columns: + dict_cols_set = set(self.dictionary_columns) + target_schema = pa.schema( + [ + field.with_type(pa.dictionary(pa.int32(), field.type)) if field.name in dict_cols_set else field + for field in target_schema + ], + metadata=target_schema.metadata, + ) + batches = ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + dictionary_columns=self.dictionary_columns, ).to_record_batches(self.plan_files()) return pa.RecordBatchReader.from_batches( diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2170741bdd..0361460513 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -3152,6 +3152,168 @@ def _expected_batch(unit: str) -> pa.RecordBatch: assert _expected_batch("ns" if format_version > 2 else "us").equals(actual_result) +def test_task_to_record_batches_dictionary_columns(tmpdir: str) -> None: + """dictionary_columns causes the column to be read as DictionaryArray, saving memory.""" + arrow_table = pa.table( + {"json_col": pa.array(["large-json-1", "large-json-2", "large-json-1"], type=pa.string())}, + schema=pa.schema([pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_dictionary_columns.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "json_col", StringType(), required=False)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + dictionary_columns=("json_col",), + ) + ) + + assert len(batches) == 1, "Expected exactly one record batch" + col = batches[0].column("json_col") + assert pa.types.is_dictionary(col.type), ( + f"Expected DictionaryArray for 'json_col' when dictionary_columns is set, got {col.type}" + ) + + +def test_task_to_record_batches_no_dictionary_columns_by_default(tmpdir: str) -> None: + """Without dictionary_columns, string columns are returned as plain StringArray — default unchanged.""" + arrow_table = pa.table( + {"json_col": pa.array(["a", "b", "c"], type=pa.string())}, + schema=pa.schema([pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"})]), + ) + data_file = _write_table_to_data_file(f"{tmpdir}/test_no_dictionary_default.parquet", arrow_table.schema, arrow_table) + table_schema = Schema(NestedField(1, "json_col", StringType(), required=False)) + + batches = list( + _task_to_record_batches( + PyArrowFileIO(), + FileScanTask(data_file), + bound_row_filter=AlwaysTrue(), + projected_schema=table_schema, + table_schema=table_schema, + projected_field_ids={1}, + positional_deletes=None, + case_sensitive=True, + # dictionary_columns intentionally omitted — must not change behavior + ) + ) + + assert len(batches) == 1, "Expected exactly one record batch" + col = batches[0].column("json_col") + assert not pa.types.is_dictionary(col.type), f"Expected plain StringArray by default, got {col.type}" + + +def test_arrow_scan_to_table_with_dictionary_columns(tmpdir: str) -> None: + """ArrowScan.to_table() with dictionary_columns: named column is DictionaryArray, others are not.""" + import pyarrow.parquet as pq + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}), + pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "2"}), + ] + ) + arrow_table = pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int32()), + "json_col": pa.array(['{"x": 1}', '{"x": 2}', '{"x": 1}'], type=pa.string()), + }, + schema=arrow_schema, + ) + filepath = f"{tmpdir}/test_e2e_dictionary.parquet" + with pq.ParquetWriter(filepath, arrow_schema) as writer: + writer.write_table(arrow_table) + + iceberg_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "json_col", StringType(), required=False), + ) + data_file = DataFile.from_args( + content=DataFileContent.DATA, + file_path=filepath, + file_format=FileFormat.PARQUET, + partition={}, + record_count=3, + file_size_in_bytes=100, + ) + data_file.spec_id = 0 + + result = ArrowScan( + TableMetadataV2( + location="file://a/b/", + last_column_id=2, + format_version=2, + schemas=[iceberg_schema], + partition_specs=[PartitionSpec()], + ), + PyArrowFileIO(), + iceberg_schema, + AlwaysTrue(), + dictionary_columns=("json_col",), + ).to_table(tasks=[FileScanTask(data_file)]) + + assert pa.types.is_dictionary(result.schema.field("json_col").type), ( + f"Expected DictionaryArray for 'json_col', got {result.schema.field('json_col').type}" + ) + assert not pa.types.is_dictionary(result.schema.field("id").type), "Non-listed column 'id' should NOT be dictionary-encoded" + + +def test_arrow_scan_to_record_batches_preserves_dictionary_encoding(tmpdir: str) -> None: + """ArrowScan.to_record_batches() must preserve DictionaryArray — not decode back to plain string.""" + import pyarrow.parquet as pq + + arrow_schema = pa.schema( + [ + pa.field("json_col", pa.string(), nullable=True, metadata={PYARROW_PARQUET_FIELD_ID_KEY: "1"}), + ] + ) + arrow_table = pa.table( + {"json_col": pa.array(['{"a": 1}', '{"b": 2}'], type=pa.string())}, + schema=arrow_schema, + ) + filepath = f"{tmpdir}/test_batch_reader_dict.parquet" + with pq.ParquetWriter(filepath, arrow_schema) as writer: + writer.write_table(arrow_table) + + iceberg_schema = Schema(NestedField(1, "json_col", StringType(), required=False)) + data_file = DataFile.from_args( + content=DataFileContent.DATA, + file_path=filepath, + file_format=FileFormat.PARQUET, + partition={}, + record_count=2, + file_size_in_bytes=100, + ) + data_file.spec_id = 0 + + batches = list( + ArrowScan( + TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[iceberg_schema], + partition_specs=[PartitionSpec()], + ), + PyArrowFileIO(), + iceberg_schema, + AlwaysTrue(), + dictionary_columns=("json_col",), + ).to_record_batches(tasks=[FileScanTask(data_file)]) + ) + + assert len(batches) >= 1, "Expected at least one record batch" + col = batches[0].column("json_col") + assert pa.types.is_dictionary(col.type), f"DictionaryArray must be preserved through to_record_batches, got {col.type}" + + def test_parse_location_defaults() -> None: """Test that parse_location uses defaults.""" diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 30c4a3a45a..f3da6ad16e 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -274,6 +274,25 @@ def test_table_scan_select(table_fixture: Table) -> None: assert scan.select("a", "c").select("a").selected_fields == ("a",) +def test_table_scan_dictionary_columns_default(table_v2: Table) -> None: + scan = table_v2.scan() + assert scan.dictionary_columns is None, "dictionary_columns should default to None" + + +def test_table_scan_dictionary_columns_set(table_v2: Table) -> None: + scan = table_v2.scan(dictionary_columns=("json_col", "other_col")) + assert scan.dictionary_columns == ("json_col", "other_col"), "dictionary_columns should be stored on the scan" + + +def test_table_scan_dictionary_columns_preserved_on_update(table_v2: Table) -> None: + scan = table_v2.scan(dictionary_columns=("json_col",)) + updated = scan.update(limit=10) + assert updated.dictionary_columns == ("json_col",), ( + "dictionary_columns must survive .update() — TableScan.update() uses inspect.signature " + "so DataScan.__init__ must declare and store it" + ) + + def test_table_scan_row_filter(table_v2: Table) -> None: scan = table_v2.scan() assert scan.row_filter == AlwaysTrue()