Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions paimon-python/pypaimon/ray/data_evolution_merge_into.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,19 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
for c in when_not_matched
]

source_ds = _normalize_source(source, catalog_options)
source_snapshot_id = None
if isinstance(source, str):
source_snapshot = (
catalog.get_table(source)
.snapshot_manager()
.get_latest_snapshot()
)
if source_snapshot is not None:
source_snapshot_id = source_snapshot.id

source_ds = _normalize_source(
source, catalog_options, source_snapshot_id=source_snapshot_id,
)
_validate_source_on_cols(source_ds, source_on_cols)
_validate_source_has_target_cols(
source_ds, settable_field_names, on_map,
Expand Down Expand Up @@ -438,14 +450,21 @@ def _normalize_set_spec(
return {col: f"s.{on_map.get(col, col)}" for col in target_field_names}


def _normalize_source(source: Any, catalog_options: Dict[str, str]):
def _normalize_source(
source: Any,
catalog_options: Dict[str, str],
source_snapshot_id: Optional[int] = None,
):
import ray.data

if isinstance(source, ray.data.Dataset):
return source
if isinstance(source, str):
from pypaimon.ray.ray_paimon import read_paimon
return read_paimon(source, catalog_options)
read_kwargs = {}
if source_snapshot_id is not None:
read_kwargs["snapshot_id"] = source_snapshot_id
return read_paimon(source, catalog_options, **read_kwargs)
if isinstance(source, pa.Table):
return ray.data.from_arrow(source)
try:
Expand Down
31 changes: 31 additions & 0 deletions paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tempfile
import unittest
import uuid
from unittest.mock import Mock, patch

import pyarrow as pa
import ray
Expand Down Expand Up @@ -108,6 +109,36 @@ def _snapshot_id(self, target):
snap = table.snapshot_manager().get_latest_snapshot()
return snap.id if snap is not None else None

def test_paimon_source_table_pins_snapshot(self):
from pypaimon.ray import data_evolution_merge_into as m

target = self._create_table()
source = self._create_table()
self._write(source, self._source(ids=(1,)))
expected_snapshot_id = self._snapshot_id(source)

fake_ds = Mock()
fake_ds.schema.return_value = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])

with patch(
'pypaimon.ray.ray_paimon.read_paimon',
return_value=fake_ds,
) as mock_read_paimon:
m._prepare(
target, source, self.catalog_options,
[WhenMatched(update='*')], [], ['id'],
)

mock_read_paimon.assert_called_once_with(
source,
self.catalog_options,
snapshot_id=expected_snapshot_id,
)

def test_no_clause_raises(self):
target = self._create_table()
with self.assertRaises(ValueError):
Expand Down
Loading