Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/67672.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The scheduler no longer issues the trigger-rule upstream task-instance count query once per downstream task. For tasks that share the same upstreams within a scheduling pass, the count is computed once and reused, cutting database round-trips for DAGs where a mapped upstream feeds many downstream tasks.
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,10 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
if new_tis is not None:
additional_tis.extend(new_tis)
expansion_happened = True
# Expansion changes a mapped task's instance count, which invalidates the
# trigger-rule upstream-count memo on this DepContext (a downstream evaluated
# later in this same pass must see the post-expansion count).
dep_context.upstream_task_id_counts.clear()
if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
# It's enough to revise map index once per task id,
# checking the map index for each mapped task significantly slows down scheduling
Expand Down
13 changes: 13 additions & 0 deletions airflow-core/src/airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ class DepContext:
have_changed_ti_states: bool = False
"""Have any of the TIs state's been changed as a result of evaluating dependencies"""

upstream_task_id_counts: dict[tuple[str, str, frozenset[str]], list[tuple[str, int]]] = attr.ib(
factory=dict, init=False
)
"""
Per-pass memo of the trigger-rule upstream task-instance counts, keyed by
``(dag_id, run_id, frozenset of direct-upstream task_ids)``.

Shares the lifetime and snapshot semantics of ``finished_tis`` (one scheduling pass). Only
populated for the "simple" case where the count-query predicate is exactly
``task_id IN (upstream_ids)`` and is therefore identical for every downstream sharing the same
direct upstreams; the mapped-task-group case uses per-ti map-index predicates and is not cached.
"""

def ensure_finished_tis(self, dag_run: DagRun, session: Session) -> list[TaskInstance]:
"""
Ensure finished_tis is populated if it's currently None, which allows running tasks without dag_run.
Expand Down
39 changes: 29 additions & 10 deletions airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import collections.abc
import functools
from collections import Counter
from collections.abc import Iterator, KeysView, Mapping, Sequence
from collections.abc import Iterator, KeysView, Mapping
from typing import TYPE_CHECKING, NamedTuple

from sqlalchemy import and_, func, or_, select
Expand All @@ -31,7 +31,6 @@
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
from sqlalchemy.sql import ColumnElement

Expand All @@ -40,7 +39,6 @@
from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
from airflow.typing_compat import Unpack


class _UpstreamTIStates(NamedTuple):
Expand Down Expand Up @@ -372,13 +370,34 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
upstream = len(upstream_tasks)
upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup)
else:
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
task_id_counts: Sequence[Row[Unpack[tuple[str, int]]]] = session.execute( # type: ignore[type-arg]
select(TaskInstance.task_id, func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
.where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
.group_by(TaskInstance.task_id)
).all()
# In the simple case, `_iter_upstream_conditions` emits exactly
# `task_id IN (upstream_task_ids)` (the matching `get_closest_mapped_task_group()
# is None` branch). That predicate, and therefore the resulting counts, are
# identical for every downstream that shares the same set of direct upstreams, so
# we memoize them on the DepContext and run the query once per pass instead of
# once per downstream. The mapped-task-group case uses per-ti map-index predicates
# and is left un-memoized. The cache shares finished_tis' per-pass snapshot
# semantics; it is cleared in DagRun._get_ready_tis when a mapped task expands and
# changes its instance count.
cache_key: tuple[str, str, frozenset[str]] | None = None
task_id_counts: list[tuple[str, int]] | None = None
if task.get_closest_mapped_task_group() is None:
cache_key = (ti.dag_id, ti.run_id, frozenset(upstream_tasks))
task_id_counts = dep_context.upstream_task_id_counts.get(cache_key)
if task_id_counts is None:
task_id_counts = [
(task_id, count)
for task_id, count in session.execute(
select(TaskInstance.task_id, func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id)
.where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
.group_by(TaskInstance.task_id)
)
]
if cache_key is not None:
dep_context.upstream_task_id_counts[cache_key] = task_id_counts
# `task_id_counts` only contains task_ids matched by `task_id IN (upstream_tasks)`,
# so every key is present in `upstream_tasks`; is_setup is re-derived locally.
upstream = sum(count for _, count in task_id_counts)
upstream_setup = sum(c for t, c in task_id_counts if upstream_tasks[t].is_setup)

Expand Down
159 changes: 159 additions & 0 deletions airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import Mock

import pytest
from sqlalchemy import event

import airflow.settings
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -1997,3 +2000,159 @@ def _test_trigger_rule(
else:
assert not dep_statuses
assert ti.state == expected_ti_state


@contextmanager
def _count_upstream_count_queries():
"""
Count only the trigger-rule upstream-count query.

That query is ``SELECT task_instance.task_id, count(task_instance.task_id) ...
GROUP BY task_instance.task_id``; the filter below matches it and nothing else emitted while
evaluating the trigger rule for a plain (non-mapped-task-group) downstream.
"""
counter = {"n": 0}

def _on_execute(conn, cursor, statement, parameters, context, executemany):
sql = statement.lower()
if "count(" in sql and "group by" in sql and "task_id" in sql and "task_instance" in sql:
counter["n"] += 1

event.listen(airflow.settings.engine, "after_cursor_execute", _on_execute)
try:
yield counter
finally:
event.remove(airflow.settings.engine, "after_cursor_execute", _on_execute)


def _expand_mapped_task(dr, dag, task_id, states, session):
"""
Materialise ``len(states)`` instances of a mapped ``task_id`` with the given states.

Handles both shapes: a single unexpanded ``map_index=-1`` placeholder (expand it), or a task
already pre-expanded at dagrun creation (just set states on the existing instances).
"""
tis = [ti for ti in dr.get_task_instances(session=session) if ti.task_id == task_id]
assert tis, f"no task instances found for {task_id!r}"
if len(tis) == 1 and tis[0].map_index == -1:
base = tis[0]
mapped_task = base.task
dag_version = DagVersion.get_latest_version(dag.dag_id)
if TYPE_CHECKING:
assert dag_version
base.map_index = 0
base.state = states[0]
session.merge(base)
for map_index in range(1, len(states)):
ti = TaskInstance(
mapped_task, run_id=dr.run_id, map_index=map_index, dag_version_id=dag_version.id
)
ti.state = states[map_index]
session.add(ti)
ti.dag_run = dr
else:
tis.sort(key=lambda ti: ti.map_index)
assert len(tis) == len(states), f"{task_id!r}: {len(tis)} instances but {len(states)} states given"
for ti, state in zip(tis, states):
ti.state = state
session.merge(ti)
session.flush()


class TestTriggerRuleUpstreamCountMemo:
"""The upstream-count query is memoized per scheduling pass (one DepContext) in the simple case."""

def _make_dag(
self, dag_maker, session, *, n_downstreams, src_states, trigger_rule=TriggerRule.ALL_SUCCESS
):
@task
def src(i):
return i

@task(trigger_rule=trigger_rule)
def plain():
return 1

with dag_maker(dag_id="trmemo_simple", session=session) as dag:
nums = src.expand(i=list(range(len(src_states))))
for k in range(n_downstreams):
nums >> plain.override(task_id=f"p{k}")()

dr = dag_maker.create_dagrun()
_expand_mapped_task(dr, dag, "src", src_states, session)
session.commit()
return dr

def test_memoized_across_downstreams_sharing_upstream(self, dag_maker, session):
"""N plain downstreams of the same mapped upstream issue the count query once per pass."""
dr = self._make_dag(dag_maker, session, n_downstreams=4, src_states=[SUCCESS, SUCCESS, SUCCESS])
dep_context = DepContext()
with _count_upstream_count_queries() as counter:
for k in range(4):
ti = dr.get_task_instance(f"p{k}", session=session)
statuses = list(
TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
)
# All three upstreams succeeded -> ALL_SUCCESS is met -> no failing status.
assert statuses == []
assert counter["n"] == 1

def test_memoized_count_value_is_correct(self, dag_maker, session):
"""
Guards that the cached value is the real count, not just "present".

Three upstream instances exist but only two are finished-success; ALL_SUCCESS must NOT be met
because ``upstream`` (3) > ``success`` (2). A wrongly-cached count of 2 would let it pass.
"""
dr = self._make_dag(
dag_maker,
session,
n_downstreams=2,
src_states=[SUCCESS, SUCCESS, TaskInstanceState.RUNNING],
)
dep_context = DepContext()
with _count_upstream_count_queries() as counter:
for k in range(2):
ti = dr.get_task_instance(f"p{k}", session=session)
statuses = list(
TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
)
assert len(statuses) == 1
assert not statuses[0].passed
assert counter["n"] == 1

def test_distinct_upstream_sets_are_not_collapsed(self, dag_maker, session):
"""Downstreams with different upstream sets get different cache keys -> one query each."""

@task
def src_a(i):
return i

@task
def src_b(i):
return i

@task
def plain():
return 1

with dag_maker(dag_id="trmemo_keys", session=session) as dag:
a = src_a.expand(i=[0, 1])
b = src_b.expand(i=[0, 1, 2])
a >> plain.override(task_id="pa")()
b >> plain.override(task_id="pb")()

dr = dag_maker.create_dagrun()
_expand_mapped_task(dr, dag, "src_a", [SUCCESS, SUCCESS], session)
_expand_mapped_task(dr, dag, "src_b", [SUCCESS, SUCCESS, SUCCESS], session)
session.commit()

dep_context = DepContext()
with _count_upstream_count_queries() as counter:
for task_id in ("pa", "pb"):
ti = dr.get_task_instance(task_id, session=session)
statuses = list(
TriggerRuleDep()._evaluate_trigger_rule(ti=ti, dep_context=dep_context, session=session)
)
assert statuses == []
assert counter["n"] == 2
Loading