From b186c8f602a3d269c84792e8fdf8452cd3599133 Mon Sep 17 00:00:00 2001 From: geruh Date: Mon, 12 Jan 2026 17:24:23 -0800 Subject: [PATCH] feat: Add rollback_to_snapshot to ManageSnapshots API --- pyiceberg/table/update/snapshot.py | 35 ++++++ tests/integration/test_snapshot_operations.py | 108 ++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index bc05aab966..987200bf67 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -64,6 +64,7 @@ Snapshot, SnapshotSummaryCollector, Summary, + ancestors_of, update_snapshot_summaries, ) from pyiceberg.table.update import ( @@ -985,6 +986,40 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N self._transaction._stage(update, requirement) return self + def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: + """Rollback the table to the given snapshot id. + + The snapshot needs to be an ancestor of the current table state. + + Args: + snapshot_id (int): rollback to this snapshot_id that used to be current. + + Returns: + This for method chaining + + Raises: + ValueError: If the snapshot does not exist or is not an ancestor of the current table state. + """ + if not self._transaction.table_metadata.snapshot_by_id(snapshot_id): + raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") + + if not self._is_current_ancestor(snapshot_id): + raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") + + return self.set_current_snapshot(snapshot_id=snapshot_id) + + def _is_current_ancestor(self, snapshot_id: int) -> bool: + return snapshot_id in self._current_ancestors() + + def _current_ancestors(self) -> set[int]: + return { + a.snapshot_id + for a in ancestors_of( + self._transaction.table_metadata.current_snapshot(), + self._transaction.table_metadata, + ) + } + class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): """Expire snapshots by ID. diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 2f0447ec52..8755e95fbb 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -14,12 +14,44 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import uuid +from collections.abc import Generator + +import pyarrow as pa import pytest from pyiceberg.catalog import Catalog +from pyiceberg.table import Table from pyiceberg.table.refs import SnapshotRef +@pytest.fixture +def table_with_snapshots(session_catalog: Catalog) -> Generator[Table, None, None]: + session_catalog.create_namespace_if_not_exists("default") + identifier = f"default.test_table_snapshot_ops_{uuid.uuid4().hex[:8]}" + + arrow_schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("data", pa.string(), nullable=True), + ] + ) + + tbl = session_catalog.create_table(identifier=identifier, schema=arrow_schema) + + data1 = pa.Table.from_pylist([{"id": 1, "data": "a"}, {"id": 2, "data": "b"}], schema=arrow_schema) + tbl.append(data1) + + data2 = pa.Table.from_pylist([{"id": 3, "data": "c"}, {"id": 4, "data": "d"}], schema=arrow_schema) + tbl.append(data2) + + tbl = session_catalog.load_table(identifier) + + yield tbl + + session_catalog.drop_table(identifier) + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_create_tag(catalog: Catalog) -> None: @@ -160,3 +192,79 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None: tbl = catalog.load_table(identifier) tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +def test_rollback_to_snapshot(table_with_snapshots: Table) -> None: + history = table_with_snapshots.history() + assert len(history) >= 2 + + ancestor_snapshot_id = history[-2].snapshot_id + + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() + + updated = table_with_snapshots.current_snapshot() + assert updated is not None + assert updated.snapshot_id == ancestor_snapshot_id + + +@pytest.mark.integration +def test_rollback_to_current_snapshot(table_with_snapshots: Table) -> None: + current = table_with_snapshots.current_snapshot() + assert current is not None + + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=current.snapshot_id).commit() + + updated = table_with_snapshots.current_snapshot() + assert updated is not None + assert updated.snapshot_id == current.snapshot_id + + +@pytest.mark.integration +def test_rollback_to_snapshot_chained_with_tag(table_with_snapshots: Table) -> None: + history = table_with_snapshots.history() + assert len(history) >= 2 + + ancestor_snapshot_id = history[-2].snapshot_id + tag_name = "my-tag" + + ( + table_with_snapshots.manage_snapshots() + .create_tag(snapshot_id=ancestor_snapshot_id, tag_name=tag_name) + .rollback_to_snapshot(snapshot_id=ancestor_snapshot_id) + .commit() + ) + + updated = table_with_snapshots.current_snapshot() + assert updated is not None + assert updated.snapshot_id == ancestor_snapshot_id + assert table_with_snapshots.metadata.refs[tag_name] == SnapshotRef(snapshot_id=ancestor_snapshot_id, snapshot_ref_type="tag") + + +@pytest.mark.integration +def test_rollback_to_snapshot_not_ancestor(table_with_snapshots: Table) -> None: + history = table_with_snapshots.history() + assert len(history) >= 2 + + snapshot_a = history[-2].snapshot_id + + branch_name = "my-branch" + table_with_snapshots.manage_snapshots().create_branch(snapshot_id=snapshot_a, branch_name=branch_name).commit() + + data = pa.Table.from_pylist([{"id": 5, "data": "e"}], schema=table_with_snapshots.schema().as_arrow()) + table_with_snapshots.append(data, branch=branch_name) + + snapshot_c = table_with_snapshots.metadata.snapshot_by_name(branch_name) + assert snapshot_c is not None + assert snapshot_c.snapshot_id != snapshot_a + + with pytest.raises(ValueError, match="not an ancestor"): + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c.snapshot_id).commit() + + +@pytest.mark.integration +def test_rollback_to_snapshot_unknown_id(table_with_snapshots: Table) -> None: + invalid_snapshot_id = 1234567890000 + + with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"): + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit()