Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 205 additions & 21 deletions src/datajoint/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,17 @@ def cascade(cls, table_expr, part_integrity="enforce"):
# Propagate downstream
result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity)

# part_integrity="cascade" may pull in nodes that aren't descendants of
# the seed (e.g. the master of a seed Part, plus the master's other
# Parts). Expand nodes_to_show to include any restricted node and the
# descendants of any newly-restricted ancestor. See #1429.
restricted_nodes = set(result._cascade_restrictions)
expanded = set(result.nodes_to_show) | restricted_nodes
for n in restricted_nodes - result.nodes_to_show:
expanded.update(nx.descendants(result, n))
result.nodes_to_show = expanded & set(result.nodes())
result._expanded_nodes = set(result.nodes_to_show)

# Trim graph to cascade subgraph: only restricted tables
# (seed + descendants) plus alias nodes connecting them.
keep = set(result._cascade_restrictions)
Expand Down Expand Up @@ -443,7 +454,6 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
propagation rules at each edge. Only processes descendants of
start_node to avoid duplicate propagation when chaining.
"""
from .table import FreeTable

sorted_nodes = topo_sort(self)
# Only propagate through descendants of start_node
Expand All @@ -453,6 +463,18 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):

restrictions = self._cascade_restrictions if mode == "cascade" else self._restrict_conditions

# Seed-is-Part case: when the seed itself is a Part and part_integrity="cascade",
# the main loop's part_integrity block (which fires inside `out_edges`)
# cannot trigger from the seed because a leaf Part has no out-edges.
# Trigger the upward propagation explicitly for the seed. See #1429.
if part_integrity == "cascade" and mode == "cascade":
seed_master = extract_master(start_node)
if seed_master and seed_master in self.nodes() and seed_master not in visited_masters:
visited_masters.add(seed_master)
if self._propagate_part_to_master(start_node, seed_master, mode, restrictions):
allowed_nodes.add(seed_master)
allowed_nodes.update(nx.descendants(self, seed_master))

# Multiple passes to handle part_integrity="cascade" upward propagation.
# When a part table triggers its master to join the cascade, the master's
# other descendants need processing in a subsequent pass. The loop
Expand Down Expand Up @@ -512,29 +534,19 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
any_new = True

# part_integrity="cascade": propagate up from part to master
# via the actual FK graph path, applying upward propagation
# rules at each edge. Handles Part-of-Part chains and
# renamed FKs (via .proj()), unlike the prior implementation
# which assumed shared PK attribute names. See #1429.
if part_integrity == "cascade" and mode == "cascade":
master_name = extract_master(target)
if (
master_name
and master_name in self.nodes()
and master_name not in restrictions
and master_name not in visited_masters
):
if master_name and master_name in self.nodes() and master_name not in visited_masters:
visited_masters.add(master_name)
child_ft = self._restricted_table(target)
master_ft = FreeTable(self._connection, master_name)
from .condition import make_condition

master_restr = make_condition(
master_ft,
(master_ft.proj() & child_ft.proj()).to_arrays(),
master_ft.restriction_attributes,
)
restrictions[master_name] = [master_restr]
self._restriction_attrs[master_name] = set()
allowed_nodes.add(master_name)
allowed_nodes.update(nx.descendants(self, master_name))
any_new = True
propagated = self._propagate_part_to_master(target, master_name, mode, restrictions)
if propagated:
allowed_nodes.add(master_name)
allowed_nodes.update(nx.descendants(self, master_name))
any_new = True

def _apply_propagation_rule(
self,
Expand Down Expand Up @@ -590,6 +602,178 @@ def _apply_propagation_rule(

self._restriction_attrs.setdefault(child_node, set()).update(child_attrs)

def _apply_propagation_rule_upward(self, child_ft, child_attrs, parent_node, attr_map, aliased, mode, restrictions):
"""
Apply the symmetric (upward) propagation rule to a parent←child edge.

Inverts `_apply_propagation_rule`: derives a restriction on the parent
from a restriction on the child, following the FK chain in reverse.
Used by part_integrity="cascade" to propagate a Part's restriction up
to its Master, transparently handling renamed FKs (via .proj()) and
Part-of-Part chains. See #1429.

Edge metadata convention (matches `_apply_propagation_rule`):
- `attr_map`: dict mapping child column → parent (referenced) column.
- `aliased`: True iff any column was renamed across the FK.

Rules (symmetric to the forward rules in `_apply_propagation_rule`):

1. Non-aliased AND child restriction attrs ⊆ parent PK:
Copy child restriction directly (attrs are shared by name).
2. Aliased FK (attr_map renames columns):
``child.proj(**{parent: child for child, parent in attr_map.items()})``
— reverses the renaming so the result has parent's column names.
3. Non-aliased AND child restriction attrs ⊄ parent PK:
``child.proj()`` — project child to parent's PK columns.
"""
parent_pk = self.nodes[parent_node].get("primary_key", set())

if not aliased and child_attrs and child_attrs <= parent_pk:
# Backward Rule 1: copy child restriction directly
child_restr = restrictions.get(
child_ft.full_table_name,
[] if mode == "cascade" else AndList(),
)
if mode == "cascade":
restrictions.setdefault(parent_node, []).extend(child_restr)
else:
restrictions.setdefault(parent_node, AndList()).extend(child_restr)
parent_attrs = set(child_attrs)
elif aliased:
# Backward Rule 2: reverse rename
parent_item = child_ft.proj(**{pk: fk for fk, pk in attr_map.items()})
if mode == "cascade":
restrictions.setdefault(parent_node, []).append(parent_item)
else:
restrictions.setdefault(parent_node, AndList()).append(parent_item)
parent_attrs = set(attr_map.values()) # parent's PK column names
else:
# Backward Rule 3: project child to parent PK
parent_item = child_ft.proj()
if mode == "cascade":
restrictions.setdefault(parent_node, []).append(parent_item)
else:
restrictions.setdefault(parent_node, AndList()).append(parent_item)
parent_attrs = set(attr_map.values())

self._restriction_attrs.setdefault(parent_node, set()).update(parent_attrs)

def _propagate_part_to_master(self, part_node, master_name, mode, restrictions):
"""
Walk the FK graph from `part_node` up to `master_name`, applying
`_apply_propagation_rule_upward` at each real edge along the path.

Returns True if any propagation occurred. Handles Part-of-Part chains
by walking the full path (intermediate Parts get restricted too) and
renamed FKs via the upward rules.

Alias nodes (integer-named graph nodes inserted for aliased edges)
are transparent — both half-edges carry the same `attr_map` props,
so we read props from one and skip the alias node when walking.

After the walk, the master's restriction is **materialized** to a
literal value tuple via ``to_arrays()``. Without materialization, a
subsequent forward cascade from the master back down to its parts
would produce a self-referential subquery (MySQL error 1093, since
the master's restriction depends on the same Part being deleted).
Materializing converts the restriction into a static value set, so
the forward cascade generates ``WHERE ... IN (literal-list)`` rather
than ``WHERE ... IN (SELECT ... FROM <part>)``.

Limitations
-----------
- **Single FK path**: ``nx.shortest_path`` returns *one* path from
``master_name`` to ``part_node``. If a Part is reachable from its
Master through multiple distinct FK chains (e.g. references two
different intermediate Parts), restrictions through the
non-shortest paths are not applied. This pattern is unusual; if a
schema hits it, the user is responsible for restricting the
additional paths explicitly via ``part_integrity="ignore"`` plus
manual ``delete()`` calls.
- **Memory cost of materialization**: ``master_ft.proj().to_arrays()``
pulls the matching master primary keys into Python memory. Cost is
bounded by the count of *distinct* master rows referenced by the
matching parts — typically small for surgical cascades, but can
grow with bulk cascades on tables with many master rows. Cascade
*preview* (``Diagram.cascade(...).counts()``) pays the same cost.
"""
try:
path = nx.shortest_path(self, master_name, part_node)
except (nx.NetworkXNoPath, nx.NodeNotFound):
return False

# Strip alias nodes; what remains is the sequence of real tables.
real_path = [n for n in path if not (isinstance(n, str) and n.isdigit())]
if len(real_path) < 2 or real_path[-1] != part_node or real_path[0] != master_name:
return False

# Walk real_path in reverse (child → parent direction). For each
# adjacent (parent, child) pair, look up the edge props — direct
# edge if non-aliased, via alias node if aliased.
any_propagated = False
for i in range(len(real_path) - 1, 0, -1):
child = real_path[i]
parent = real_path[i - 1]
edge_props = self._find_real_edge_props(parent, child)
if edge_props is None:
return any_propagated # Path broken (shouldn't happen if shortest_path succeeded)

attr_map = edge_props.get("attr_map", {})
aliased = edge_props.get("aliased", False)
child_ft = self._restricted_table(child)
child_attrs = self._restriction_attrs.get(child, set())

self._apply_propagation_rule_upward(
child_ft,
child_attrs,
parent,
attr_map,
aliased,
mode,
restrictions,
)
any_propagated = True

# Materialize the master's restriction so subsequent forward cascade
# doesn't produce self-referential subqueries. Replace the master's
# accumulated query restrictions with a literal value tuple.
if any_propagated and master_name in restrictions:
from .condition import make_condition
from .table import FreeTable

master_ft = self._restricted_table(master_name)
master_pk_values = master_ft.proj().to_arrays()
if mode == "cascade":
bare_master = FreeTable(self._connection, master_name)
if len(master_pk_values) > 0:
materialized = make_condition(
bare_master,
master_pk_values,
bare_master.restriction_attributes,
)
restrictions[master_name] = [materialized]
else:
# No matching master rows — false restriction so master is
# included with zero matches in counts/iter.
restrictions[master_name] = [False]
self._restriction_attrs.setdefault(master_name, set())

return any_propagated

def _find_real_edge_props(self, parent, child):
"""
Return edge props for parent → child, transparently traversing the
integer-named alias node that the graph inserts for aliased FKs.
Returns None if no such edge or alias-mediated edge exists.
"""
if self.has_edge(parent, child):
return self.edges[parent, child]
for _, mid, _ in self.out_edges(parent, data=True):
if isinstance(mid, str) and mid.isdigit() and self.has_edge(mid, child):
# Both half-edges carry the same attr_map / aliased props
return self.edges[parent, mid]
return None

def counts(self):
"""
Return affected row counts per table without modifying data.
Expand Down
Loading
Loading