From 67eb1f8003e83d117cd780465312670b17af749b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 29 May 2026 01:46:25 +0100 Subject: [PATCH 1/2] Avoid duplicate trigger-rule upstream-count queries per scheduling pass TriggerRuleDep runs a `SELECT task_id, count(*) ... GROUP BY task_id` per downstream task to size its upstream set, but only when an upstream is mapped. When many downstreams share the same mapped upstream, each issues an identical query within the same scheduling pass. Memoize the result on DepContext (one scheduling pass, same lifetime as finished_tis), keyed by (dag_id, run_id, frozenset of direct-upstream task_ids). Only the simple case is cached, where the predicate is exactly `task_id IN (upstream_ids)`; downstreams inside a mapped task group keep their own per-instance map-index query. The cache is cleared in _get_ready_tis when a mapped task expands and changes its instance count. --- airflow-core/src/airflow/models/dagrun.py | 4 + .../src/airflow/ti_deps/dep_context.py | 13 ++ .../airflow/ti_deps/deps/trigger_rule_dep.py | 39 +++-- .../ti_deps/deps/test_trigger_rule_dep.py | 159 ++++++++++++++++++ 4 files changed, 205 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 36ed309feb0b7..03bc71b868495 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1520,6 +1520,10 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: if new_tis is not None: additional_tis.extend(new_tis) expansion_happened = True + # Expansion changes a mapped task's instance count, which invalidates the + # trigger-rule upstream-count memo on this DepContext (a downstream evaluated + # later in this same pass must see the post-expansion count). + dep_context.upstream_task_id_counts.clear() if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: # It's enough to revise map index once per task id, # checking the map index for each mapped task significantly slows down scheduling diff --git a/airflow-core/src/airflow/ti_deps/dep_context.py b/airflow-core/src/airflow/ti_deps/dep_context.py index 1feafdd041ae1..fccb78332746f 100644 --- a/airflow-core/src/airflow/ti_deps/dep_context.py +++ b/airflow-core/src/airflow/ti_deps/dep_context.py @@ -85,6 +85,19 @@ class DepContext: have_changed_ti_states: bool = False """Have any of the TIs state's been changed as a result of evaluating dependencies""" + upstream_task_id_counts: dict[tuple[str, str, frozenset[str]], list[tuple[str, int]]] = attr.ib( + factory=dict, init=False + ) + """ + Per-pass memo of the trigger-rule upstream task-instance counts, keyed by + ``(dag_id, run_id, frozenset of direct-upstream task_ids)``. + + Shares the lifetime and snapshot semantics of ``finished_tis`` (one scheduling pass). Only + populated for the "simple" case where the count-query predicate is exactly + ``task_id IN (upstream_ids)`` and is therefore identical for every downstream sharing the same + direct upstreams; the mapped-task-group case uses per-ti map-index predicates and is not cached. + """ + def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]: """ Ensure finished_tis is populated if it's currently None, which allows running tasks without dag_run. diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 4943913d3283a..044574d0f8902 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -20,7 +20,7 @@ import collections.abc import functools from collections import Counter -from collections.abc import Iterator, KeysView, Mapping, Sequence +from collections.abc import Iterator, KeysView, Mapping from typing import TYPE_CHECKING, NamedTuple from sqlalchemy import and_, func, or_, select @@ -31,7 +31,6 @@ from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from sqlalchemy.engine import Row from sqlalchemy.orm import Session from sqlalchemy.sql import ColumnElement @@ -40,7 +39,6 @@ from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus - from airflow.typing_compat import Unpack class _UpstreamTIStates(NamedTuple): @@ -372,13 +370,34 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: upstream = len(upstream_tasks) upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) else: - # The below type annotation is acceptable on SQLA2.1, but not on 2.0 - task_id_counts: Sequence[Row[Unpack[tuple[str, int]]]] = session.execute( # type: ignore[type-arg] - select(TaskInstance.task_id, func.count(TaskInstance.task_id)) - .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) - .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) - .group_by(TaskInstance.task_id) - ).all() + # In the simple case, `_iter_upstream_conditions` emits exactly + # `task_id IN (upstream_task_ids)` (the matching `get_closest_mapped_task_group() + # is None` branch). That predicate, and therefore the resulting counts, are + # identical for every downstream that shares the same set of direct upstreams, so + # we memoize them on the DepContext and run the query once per pass instead of + # once per downstream. The mapped-task-group case uses per-ti map-index predicates + # and is left un-memoized. The cache shares finished_tis' per-pass snapshot + # semantics; it is cleared in DagRun._get_ready_tis when a mapped task expands and + # changes its instance count. + cache_key: tuple[str, str, frozenset[str]] | None = None + task_id_counts: list[tuple[str, int]] | None = None + if task.get_closest_mapped_task_group() is None: + cache_key = (ti.dag_id, ti.run_id, frozenset(upstream_tasks)) + task_id_counts = dep_context.upstream_task_id_counts.get(cache_key) + if task_id_counts is None: + task_id_counts = [ + (task_id, count) + for task_id, count in session.execute( + select(TaskInstance.task_id, func.count(TaskInstance.task_id)) + .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) + .group_by(TaskInstance.task_id) + ) + ] + if cache_key is not None: + dep_context.upstream_task_id_counts[cache_key] = task_id_counts + # `task_id_counts` only contains task_ids matched by `task_id IN (upstream_tasks)`, + # so every key is present in `upstream_tasks`; is_setup is re-derived locally. upstream = sum(count for _, count in task_id_counts) upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py index f0820eb5b4673..07b10dfe4692e 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py @@ -18,13 +18,16 @@ from __future__ import annotations from collections.abc import Iterator +from contextlib import contextmanager from datetime import datetime from typing import TYPE_CHECKING from unittest import mock from unittest.mock import Mock import pytest +from sqlalchemy import event +import airflow.settings from airflow.models.dag_version import DagVersion from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator @@ -1997,3 +2000,159 @@ def _test_trigger_rule( else: assert not dep_statuses assert ti.state == expected_ti_state + + +@contextmanager +def _count_upstream_count_queries(): + """ + Count only the trigger-rule upstream-count query. + + That query is ``SELECT task_instance.task_id, count(task_instance.task_id) ... + GROUP BY task_instance.task_id``; the filter below matches it and nothing else emitted while + evaluating the trigger rule for a plain (non-mapped-task-group) downstream. + """ + counter = {"n": 0} + + def _on_execute(conn, cursor, statement, parameters, context, executemany): + sql = statement.lower() + if "count(" in sql and "group by" in sql and "task_id" in sql and "task_instance" in sql: + counter["n"] += 1 + + event.listen(airflow.settings.engine, "after_cursor_execute", _on_execute) + try: + yield counter + finally: + event.remove(airflow.settings.engine, "after_cursor_execute", _on_execute) + + +def _expand_mapped_task(dr, dag, task_id, states, session): + """ + Materialise ``len(states)`` instances of a mapped ``task_id`` with the given states. + + Handles both shapes: a single unexpanded ``map_index=-1`` placeholder (expand it), or a task + already pre-expanded at dagrun creation (just set states on the existing instances). + """ + tis = [ti for ti in dr.get_task_instances(session=session) if ti.task_id == task_id] + assert tis, f"no task instances found for {task_id!r}" + if len(tis) == 1 and tis[0].map_index == -1: + base = tis[0] + mapped_task = base.task + dag_version = DagVersion.get_latest_version(dag.dag_id) + if TYPE_CHECKING: + assert dag_version + base.map_index = 0 + base.state = states[0] + session.merge(base) + for map_index in range(1, len(states)): + ti = TaskInstance( + mapped_task, run_id=dr.run_id, map_index=map_index, dag_version_id=dag_version.id + ) + ti.state = states[map_index] + session.add(ti) + ti.dag_run = dr + else: + tis.sort(key=lambda ti: ti.map_index) + assert len(tis) == len(states), f"{task_id!r}: {len(tis)} instances but {len(states)} states given" + for ti, state in zip(tis, states): + ti.state = state + session.merge(ti) + session.flush() + + +class TestTriggerRuleUpstreamCountMemo: + """The upstream-count query is memoized per scheduling pass (one DepContext) in the simple case.""" + + def _make_dag( + self, dag_maker, session, *, n_downstreams, src_states, trigger_rule=TriggerRule.ALL_SUCCESS + ): + @task + def src(i): + return i + + @task(trigger_rule=trigger_rule) + def plain(): + return 1 + + with dag_maker(dag_id="trmemo_simple", session=session) as dag: + nums = src.expand(i=list(range(len(src_states)))) + for k in range(n_downstreams): + nums >> plain.override(task_id=f"p{k}")() + + dr = dag_maker.create_dagrun() + _expand_mapped_task(dr, dag, "src", src_states, session) + session.commit() + return dr + + def test_memoized_across_downstreams_sharing_upstream(self, dag_maker, session): + """N plain downstreams of the same mapped upstream issue the count query once per pass.""" + dr = self._make_dag(dag_maker, session, n_downstreams=4, src_states=[SUCCESS, SUCCESS, SUCCESS]) + dep_context = DepContext() + with _count_upstream_count_queries() as counter: + for k in range(4): + ti = dr.get_task_instance(f"p{k}", session=session) + statuses = list( + TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) + ) + # All three upstreams succeeded -> ALL_SUCCESS is met -> no failing status. + assert statuses == [] + assert counter["n"] == 1 + + def test_memoized_count_value_is_correct(self, dag_maker, session): + """ + Guards that the cached value is the real count, not just "present". + + Three upstream instances exist but only two are finished-success; ALL_SUCCESS must NOT be met + because ``upstream`` (3) > ``success`` (2). A wrongly-cached count of 2 would let it pass. + """ + dr = self._make_dag( + dag_maker, + session, + n_downstreams=2, + src_states=[SUCCESS, SUCCESS, TaskInstanceState.RUNNING], + ) + dep_context = DepContext() + with _count_upstream_count_queries() as counter: + for k in range(2): + ti = dr.get_task_instance(f"p{k}", session=session) + statuses = list( + TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) + ) + assert len(statuses) == 1 + assert not statuses[0].passed + assert counter["n"] == 1 + + def test_distinct_upstream_sets_are_not_collapsed(self, dag_maker, session): + """Downstreams with different upstream sets get different cache keys -> one query each.""" + + @task + def src_a(i): + return i + + @task + def src_b(i): + return i + + @task + def plain(): + return 1 + + with dag_maker(dag_id="trmemo_keys", session=session) as dag: + a = src_a.expand(i=[0, 1]) + b = src_b.expand(i=[0, 1, 2]) + a >> plain.override(task_id="pa")() + b >> plain.override(task_id="pb")() + + dr = dag_maker.create_dagrun() + _expand_mapped_task(dr, dag, "src_a", [SUCCESS, SUCCESS], session) + _expand_mapped_task(dr, dag, "src_b", [SUCCESS, SUCCESS, SUCCESS], session) + session.commit() + + dep_context = DepContext() + with _count_upstream_count_queries() as counter: + for task_id in ("pa", "pb"): + ti = dr.get_task_instance(task_id, session=session) + statuses = list( + TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) + ) + assert statuses == [] + assert counter["n"] == 2 From 0690b69144a5535831c4987136a1fd45e5968528 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 29 May 2026 01:59:11 +0100 Subject: [PATCH 2/2] Add newsfragment for trigger-rule upstream-count memoization --- airflow-core/newsfragments/67672.improvement.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 airflow-core/newsfragments/67672.improvement.rst diff --git a/airflow-core/newsfragments/67672.improvement.rst b/airflow-core/newsfragments/67672.improvement.rst new file mode 100644 index 0000000000000..16ba1fe0d88f5 --- /dev/null +++ b/airflow-core/newsfragments/67672.improvement.rst @@ -0,0 +1 @@ +The scheduler no longer issues the trigger-rule upstream task-instance count query once per downstream task. For tasks that share the same upstreams within a scheduling pass, the count is computed once and reused, cutting database round-trips for DAGs where a mapped upstream feeds many downstream tasks.