diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index 5be68f301e05..379655ec7714 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -169,7 +169,19 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on for c in when_not_matched ] - source_ds = _normalize_source(source, catalog_options) + source_snapshot_id = None + if isinstance(source, str): + source_snapshot = ( + catalog.get_table(source) + .snapshot_manager() + .get_latest_snapshot() + ) + if source_snapshot is not None: + source_snapshot_id = source_snapshot.id + + source_ds = _normalize_source( + source, catalog_options, source_snapshot_id=source_snapshot_id, + ) _validate_source_on_cols(source_ds, source_on_cols) _validate_source_has_target_cols( source_ds, settable_field_names, on_map, @@ -438,14 +450,21 @@ def _normalize_set_spec( return {col: f"s.{on_map.get(col, col)}" for col in target_field_names} -def _normalize_source(source: Any, catalog_options: Dict[str, str]): +def _normalize_source( + source: Any, + catalog_options: Dict[str, str], + source_snapshot_id: Optional[int] = None, +): import ray.data if isinstance(source, ray.data.Dataset): return source if isinstance(source, str): from pypaimon.ray.ray_paimon import read_paimon - return read_paimon(source, catalog_options) + read_kwargs = {} + if source_snapshot_id is not None: + read_kwargs["snapshot_id"] = source_snapshot_id + return read_paimon(source, catalog_options, **read_kwargs) if isinstance(source, pa.Table): return ray.data.from_arrow(source) try: diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 47981088f2d3..7be86683205f 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -21,6 +21,7 @@ import tempfile import unittest import uuid +from unittest.mock import Mock, patch import pyarrow as pa import ray @@ -108,6 +109,36 @@ def _snapshot_id(self, target): snap = table.snapshot_manager().get_latest_snapshot() return snap.id if snap is not None else None + def test_paimon_source_table_pins_snapshot(self): + from pypaimon.ray import data_evolution_merge_into as m + + target = self._create_table() + source = self._create_table() + self._write(source, self._source(ids=(1,))) + expected_snapshot_id = self._snapshot_id(source) + + fake_ds = Mock() + fake_ds.schema.return_value = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + + with patch( + 'pypaimon.ray.ray_paimon.read_paimon', + return_value=fake_ds, + ) as mock_read_paimon: + m._prepare( + target, source, self.catalog_options, + [WhenMatched(update='*')], [], ['id'], + ) + + mock_read_paimon.assert_called_once_with( + source, + self.catalog_options, + snapshot_id=expected_snapshot_id, + ) + def test_no_clause_raises(self): target = self._create_table() with self.assertRaises(ValueError):