diff --git a/airflow-core/newsfragments/67672.improvement.rst b/airflow-core/newsfragments/67672.improvement.rst new file mode 100644 index 0000000000000..16ba1fe0d88f5 --- /dev/null +++ b/airflow-core/newsfragments/67672.improvement.rst @@ -0,0 +1 @@ +The scheduler no longer issues the trigger-rule upstream task-instance count query once per downstream task. For tasks that share the same upstreams within a scheduling pass, the count is computed once and reused, cutting database round-trips for DAGs where a mapped upstream feeds many downstream tasks. diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 36ed309feb0b7..03bc71b868495 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1520,6 +1520,10 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None: if new_tis is not None: additional_tis.extend(new_tis) expansion_happened = True + # Expansion changes a mapped task's instance count, which invalidates the + # trigger-rule upstream-count memo on this DepContext (a downstream evaluated + # later in this same pass must see the post-expansion count). + dep_context.upstream_task_id_counts.clear() if new_tis is None and schedulable.state in SCHEDULEABLE_STATES: # It's enough to revise map index once per task id, # checking the map index for each mapped task significantly slows down scheduling diff --git a/airflow-core/src/airflow/ti_deps/dep_context.py b/airflow-core/src/airflow/ti_deps/dep_context.py index 1feafdd041ae1..fccb78332746f 100644 --- a/airflow-core/src/airflow/ti_deps/dep_context.py +++ b/airflow-core/src/airflow/ti_deps/dep_context.py @@ -85,6 +85,19 @@ class DepContext: have_changed_ti_states: bool = False """Have any of the TIs state's been changed as a result of evaluating dependencies""" + upstream_task_id_counts: dict[tuple[str, str, frozenset[str]], list[tuple[str, int]]] = attr.ib( + factory=dict, init=False + ) + """ + Per-pass memo of the trigger-rule upstream task-instance counts, keyed by + ``(dag_id, run_id, frozenset of direct-upstream task_ids)``. + + Shares the lifetime and snapshot semantics of ``finished_tis`` (one scheduling pass). Only + populated for the "simple" case where the count-query predicate is exactly + ``task_id IN (upstream_ids)`` and is therefore identical for every downstream sharing the same + direct upstreams; the mapped-task-group case uses per-ti map-index predicates and is not cached. + """ + def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]: """ Ensure finished_tis is populated if it's currently None, which allows running tasks without dag_run. diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 4943913d3283a..044574d0f8902 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -20,7 +20,7 @@ import collections.abc import functools from collections import Counter -from collections.abc import Iterator, KeysView, Mapping, Sequence +from collections.abc import Iterator, KeysView, Mapping from typing import TYPE_CHECKING, NamedTuple from sqlalchemy import and_, func, or_, select @@ -31,7 +31,6 @@ from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from sqlalchemy.engine import Row from sqlalchemy.orm import Session from sqlalchemy.sql import ColumnElement @@ -40,7 +39,6 @@ from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus - from airflow.typing_compat import Unpack class _UpstreamTIStates(NamedTuple): @@ -372,13 +370,34 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: upstream = len(upstream_tasks) upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) else: - # The below type annotation is acceptable on SQLA2.1, but not on 2.0 - task_id_counts: Sequence[Row[Unpack[tuple[str, int]]]] = session.execute( # type: ignore[type-arg] - select(TaskInstance.task_id, func.count(TaskInstance.task_id)) - .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) - .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) - .group_by(TaskInstance.task_id) - ).all() + # In the simple case, `_iter_upstream_conditions` emits exactly + # `task_id IN (upstream_task_ids)` (the matching `get_closest_mapped_task_group() + # is None` branch). That predicate, and therefore the resulting counts, are + # identical for every downstream that shares the same set of direct upstreams, so + # we memoize them on the DepContext and run the query once per pass instead of + # once per downstream. The mapped-task-group case uses per-ti map-index predicates + # and is left un-memoized. The cache shares finished_tis' per-pass snapshot + # semantics; it is cleared in DagRun._get_ready_tis when a mapped task expands and + # changes its instance count. + cache_key: tuple[str, str, frozenset[str]] | None = None + task_id_counts: list[tuple[str, int]] | None = None + if task.get_closest_mapped_task_group() is None: + cache_key = (ti.dag_id, ti.run_id, frozenset(upstream_tasks)) + task_id_counts = dep_context.upstream_task_id_counts.get(cache_key) + if task_id_counts is None: + task_id_counts = [ + (task_id, count) + for task_id, count in session.execute( + select(TaskInstance.task_id, func.count(TaskInstance.task_id)) + .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) + .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) + .group_by(TaskInstance.task_id) + ) + ] + if cache_key is not None: + dep_context.upstream_task_id_counts[cache_key] = task_id_counts + # `task_id_counts` only contains task_ids matched by `task_id IN (upstream_tasks)`, + # so every key is present in `upstream_tasks`; is_setup is re-derived locally. upstream = sum(count for _, count in task_id_counts) upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup) diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py index f0820eb5b4673..07b10dfe4692e 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py @@ -18,13 +18,16 @@ from __future__ import annotations from collections.abc import Iterator +from contextlib import contextmanager from datetime import datetime from typing import TYPE_CHECKING from unittest import mock from unittest.mock import Mock import pytest +from sqlalchemy import event +import airflow.settings from airflow.models.dag_version import DagVersion from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator @@ -1997,3 +2000,159 @@ def _test_trigger_rule( else: assert not dep_statuses assert ti.state == expected_ti_state + + +@contextmanager +def _count_upstream_count_queries(): + """ + Count only the trigger-rule upstream-count query. + + That query is ``SELECT task_instance.task_id, count(task_instance.task_id) ... + GROUP BY task_instance.task_id``; the filter below matches it and nothing else emitted while + evaluating the trigger rule for a plain (non-mapped-task-group) downstream. + """ + counter = {"n": 0} + + def _on_execute(conn, cursor, statement, parameters, context, executemany): + sql = statement.lower() + if "count(" in sql and "group by" in sql and "task_id" in sql and "task_instance" in sql: + counter["n"] += 1 + + event.listen(airflow.settings.engine, "after_cursor_execute", _on_execute) + try: + yield counter + finally: + event.remove(airflow.settings.engine, "after_cursor_execute", _on_execute) + + +def _expand_mapped_task(dr, dag, task_id, states, session): + """ + Materialise ``len(states)`` instances of a mapped ``task_id`` with the given states. + + Handles both shapes: a single unexpanded ``map_index=-1`` placeholder (expand it), or a task + already pre-expanded at dagrun creation (just set states on the existing instances). + """ + tis = [ti for ti in dr.get_task_instances(session=session) if ti.task_id == task_id] + assert tis, f"no task instances found for {task_id!r}" + if len(tis) == 1 and tis[0].map_index == -1: + base = tis[0] + mapped_task = base.task + dag_version = DagVersion.get_latest_version(dag.dag_id) + if TYPE_CHECKING: + assert dag_version + base.map_index = 0 + base.state = states[0] + session.merge(base) + for map_index in range(1, len(states)): + ti = TaskInstance( + mapped_task, run_id=dr.run_id, map_index=map_index, dag_version_id=dag_version.id + ) + ti.state = states[map_index] + session.add(ti) + ti.dag_run = dr + else: + tis.sort(key=lambda ti: ti.map_index) + assert len(tis) == len(states), f"{task_id!r}: {len(tis)} instances but {len(states)} states given" + for ti, state in zip(tis, states): + ti.state = state + session.merge(ti) + session.flush() + + +class TestTriggerRuleUpstreamCountMemo: + """The upstream-count query is memoized per scheduling pass (one DepContext) in the simple case.""" + + def _make_dag( + self, dag_maker, session, *, n_downstreams, src_states, trigger_rule=TriggerRule.ALL_SUCCESS + ): + @task + def src(i): + return i + + @task(trigger_rule=trigger_rule) + def plain(): + return 1 + + with dag_maker(dag_id="trmemo_simple", session=session) as dag: + nums = src.expand(i=list(range(len(src_states)))) + for k in range(n_downstreams): + nums >> plain.override(task_id=f"p{k}")() + + dr = dag_maker.create_dagrun() + _expand_mapped_task(dr, dag, "src", src_states, session) + session.commit() + return dr + + def test_memoized_across_downstreams_sharing_upstream(self, dag_maker, session): + """N plain downstreams of the same mapped upstream issue the count query once per pass.""" + dr = self._make_dag(dag_maker, session, n_downstreams=4, src_states=[SUCCESS, SUCCESS, SUCCESS]) + dep_context = DepContext() + with _count_upstream_count_queries() as counter: + for k in range(4): + ti = dr.get_task_instance(f"p{k}", session=session) + statuses = list( + TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) + ) + # All three upstreams succeeded -> ALL_SUCCESS is met -> no failing status. + assert statuses == [] + assert counter["n"] == 1 + + def test_memoized_count_value_is_correct(self, dag_maker, session): + """ + Guards that the cached value is the real count, not just "present". + + Three upstream instances exist but only two are finished-success; ALL_SUCCESS must NOT be met + because ``upstream`` (3) > ``success`` (2). A wrongly-cached count of 2 would let it pass. + """ + dr = self._make_dag( + dag_maker, + session, + n_downstreams=2, + src_states=[SUCCESS, SUCCESS, TaskInstanceState.RUNNING], + ) + dep_context = DepContext() + with _count_upstream_count_queries() as counter: + for k in range(2): + ti = dr.get_task_instance(f"p{k}", session=session) + statuses = list( + TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) + ) + assert len(statuses) == 1 + assert not statuses[0].passed + assert counter["n"] == 1 + + def test_distinct_upstream_sets_are_not_collapsed(self, dag_maker, session): + """Downstreams with different upstream sets get different cache keys -> one query each.""" + + @task + def src_a(i): + return i + + @task + def src_b(i): + return i + + @task + def plain(): + return 1 + + with dag_maker(dag_id="trmemo_keys", session=session) as dag: + a = src_a.expand(i=[0, 1]) + b = src_b.expand(i=[0, 1, 2]) + a >> plain.override(task_id="pa")() + b >> plain.override(task_id="pb")() + + dr = dag_maker.create_dagrun() + _expand_mapped_task(dr, dag, "src_a", [SUCCESS, SUCCESS], session) + _expand_mapped_task(dr, dag, "src_b", [SUCCESS, SUCCESS, SUCCESS], session) + session.commit() + + dep_context = DepContext() + with _count_upstream_count_queries() as counter: + for task_id in ("pa", "pb"): + ti = dr.get_task_instance(task_id, session=session) + statuses = list( + TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session) + ) + assert statuses == [] + assert counter["n"] == 2