Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@
from airflow.models import Connection, DagModel, Pool, Variable
from airflow.models.dagbundle import DagBundleModel
from airflow.models.team import Team, dag_bundle_team_association_table
from airflow.typing_compat import Unpack
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from collections.abc import Sequence

from fastapi import FastAPI
from sqlalchemy import Row
from sqlalchemy.orm import Session

from airflow.api_fastapi.auth.managers.models.batch_apis import (
Expand Down Expand Up @@ -569,8 +571,9 @@ def get_authorized_dag_ids(
isouter=True,
)
)
rows = session.execute(stmt).all()
dags_by_team: dict[str | None, set[str]] = defaultdict(set)
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
rows: Sequence[Row[Unpack[tuple[str, str]]]] = session.execute(stmt).all() # type: ignore[type-arg]
dags_by_team: dict[str, set[str]] = defaultdict(set)
for dag_id, team_name in rows:
dags_by_team[team_name].add(dag_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@
AssetWatcherModel,
TaskOutletAssetReference,
)
from airflow.typing_compat import Unpack
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

if TYPE_CHECKING:
from sqlalchemy.engine import Result
from sqlalchemy.sql import Select

assets_router = AirflowRouter(tags=["Asset"])
Expand Down Expand Up @@ -179,7 +181,8 @@ def get_assets(
session=session,
)

assets_rows = session.execute(
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
assets_rows: Result[Unpack[tuple[AssetModel, int, datetime]]] = session.execute( # type: ignore[type-arg]
assets_select.options(
subqueryload(AssetModel.scheduled_dags),
subqueryload(AssetModel.producing_tasks),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, status

Expand All @@ -41,8 +41,12 @@
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep, requires_access_dag
from airflow.models.dagrun import DagRun
from airflow.typing_compat import Unpack
from airflow.utils.state import DagRunState

if TYPE_CHECKING:
from sqlalchemy import Result

dag_stats_router = AirflowRouter(tags=["DagStats"], prefix="/dagStats")


Expand Down Expand Up @@ -71,7 +75,8 @@ def get_dag_stats(
session=session,
return_total_entries=False,
)
query_result = session.execute(dagruns_select)
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
query_result: Result[Unpack[tuple[str, str, str, int]]] = session.execute(dagruns_select) # type: ignore[type-arg]

result_dag_ids = []
dag_display_names: dict[str, str] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

from collections.abc import Sequence
from typing import Annotated

from fastapi import Depends
Expand Down Expand Up @@ -67,5 +68,5 @@ def get_dag_tags(
limit=limit,
session=session,
)
dag_tags = session.execute(dag_tags_select).scalars().all()
dag_tags: Sequence = session.execute(dag_tags_select).scalars().all()
return DAGTagCollectionResponse(tags=[x for x in dag_tags], total_entries=total_entries)
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ def get_xcom_entry(
# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database.
result = session.scalars(xcom_query).first()
raw_result: tuple[XComModel] | None = session.scalars(xcom_query).first()

if result is None:
if raw_result is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found")
result = raw_result[0] if isinstance(raw_result, tuple) else raw_result

item = copy.copy(result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterator

from sqlalchemy import ScalarResult


@attrs.define
class DagRunWaiter:
Expand All @@ -57,10 +59,12 @@ def _serialize_xcoms(self) -> dict[str, Any]:
task_ids=self.result_task_ids,
dag_ids=self.dag_id,
)
xcom_results = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index))
xcom_results: ScalarResult[tuple[XComModel]] = self.session.scalars(
xcom_query.order_by(XComModel.task_id, XComModel.map_index)
)

def _group_xcoms(g: Iterator[XComModel]) -> Any:
entries = list(g)
def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any:
entries = [row[0] if isinstance(row, tuple) else row for row in g]
if len(entries) == 1 and entries[0].map_index < 0: # Unpack non-mapped task xcom.
return entries[0].value
return [entry.value for entry in entries] # Task is mapped; return all xcoms in a list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_dagrun_state(
) -> DagRunStateResponse:
"""Get a Dag run State."""
try:
state = session.scalars(
state: DagRunState = session.scalars(
select(DagRunModel.state).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id)
).one()
except NoResultFound:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_mapped_xcom_by_index(
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset)

result: tuple[XComModel] | None
if (result := session.scalars(xcom_query).first()) is None:
message = (
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
Expand All @@ -121,7 +122,7 @@ def get_mapped_xcom_by_index(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": message},
)
return XComSequenceIndexResponse(result.value)
return XComSequenceIndexResponse((result[0] if isinstance(result, tuple) else result).value)


class GetXComSliceFilterParams(BaseModel):
Expand Down Expand Up @@ -291,8 +292,8 @@ def get_xcom(
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
result = session.scalars(xcom_query).first()
if result is None:
result: tuple[XComModel] | None
if (result := session.scalars(xcom_query).first()) is None:
if params.offset is None:
message = (
f"XCom with {key=} map_index={params.map_index} not found for "
Expand All @@ -308,7 +309,7 @@ def get_xcom(
detail={"reason": "not_found", "message": message},
)

return XComResponse(key=key, value=result.value)
return XComResponse(key=key, value=(result[0] if isinstance(result, tuple) else result).value)


# TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the URL and just use the info in
Expand Down
27 changes: 15 additions & 12 deletions airflow-core/src/airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from __future__ import annotations

import traceback
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar

import structlog
from sqlalchemy import delete, func, insert, select, tuple_, update
Expand Down Expand Up @@ -76,7 +76,7 @@
from sqlalchemy.sql import Select

from airflow.models.dagwarning import DagWarning
from airflow.typing_compat import Self
from airflow.typing_compat import Self, Unpack

AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias)

Expand Down Expand Up @@ -512,15 +512,18 @@ class DagModelOperation(NamedTuple):

def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
stmt = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
.where(DagModel.dag_id.in_(self.dags))
.options(joinedload(DagModel.schedule_asset_references))
.options(joinedload(DagModel.schedule_asset_alias_references))
.options(joinedload(DagModel.task_outlet_asset_references))
stmt: Select[Unpack[tuple[DagModel]]] = with_row_locks(
(
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
.where(DagModel.dag_id.in_(self.dags))
.options(joinedload(DagModel.schedule_asset_references))
.options(joinedload(DagModel.schedule_asset_alias_references))
.options(joinedload(DagModel.task_outlet_asset_references))
),
of=DagModel,
session=session,
)
stmt = cast("Select[tuple[DagModel]]", with_row_locks(stmt, of=DagModel, session=session))
return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}

def add_dags(self, *, session: Session) -> dict[str, DagModel]:
Expand Down Expand Up @@ -711,7 +714,7 @@ def _find_all_asset_aliases(dags: Iterable[LazyDeserializedDAG]) -> Iterator[Ser

def _find_active_assets(name_uri_assets: Iterable[tuple[str, str]], session: Session) -> set[tuple[str, str]]:
return {
tuple(row)
(str(row[0]), str(row[1]))
for row in session.execute(
select(AssetModel.name, AssetModel.uri).where(
tuple_(AssetModel.name, AssetModel.uri).in_(name_uri_assets),
Expand Down Expand Up @@ -906,7 +909,7 @@ def _add_dag_asset_references(
if not references:
return
orm_refs = {
tuple(row)
(str(row[0]), str(row[1]))
for row in session.execute(
select(model.dag_id, getattr(model, attr)).where(
model.dag_id.in_(dag_id for dag_id, _ in references)
Expand Down
20 changes: 11 additions & 9 deletions airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator, Sequence
from socket import socket

from sqlalchemy.orm import Session
Expand Down Expand Up @@ -497,15 +497,17 @@ def _fetch_callbacks(
callback_queue: list[CallbackRequest] = []
with prohibit_commit(session) as guard:
bundle_names = [bundle.name for bundle in self._dag_bundles]
query: Select[tuple[DbCallbackRequest]] = select(DbCallbackRequest)
query = query.order_by(DbCallbackRequest.priority_weight.desc()).limit(
self.max_callbacks_per_loop
)
query = cast(
"Select[tuple[DbCallbackRequest]]",
with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True),
query: Select[tuple[DbCallbackRequest]] = with_row_locks(
select(DbCallbackRequest)
.order_by(DbCallbackRequest.priority_weight.desc())
.limit(self.max_callbacks_per_loop),
of=DbCallbackRequest,
session=session,
skip_locked=True,
)
callbacks = session.scalars(query)
callbacks: Sequence[DbCallbackRequest] = [
cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query)
]
for callback in callbacks:
req = callback.get_callback_request()
if req.bundle_name not in bundle_names:
Expand Down
5 changes: 2 additions & 3 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from sqlalchemy import (
and_,
delete,
desc,
exists,
func,
inspect,
Expand Down Expand Up @@ -2578,7 +2577,7 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE
Log.try_number == ti.try_number,
Log.event == "running",
)
.order_by(desc(Log.dttm))
.order_by(Log.dttm.desc())
.limit(1)
)

Expand Down Expand Up @@ -2652,7 +2651,7 @@ def _emit_ti_metrics(self, session: Session = NEW_SESSION) -> None:
@provide_session
def _emit_running_dags_metric(self, session: Session = NEW_SESSION) -> None:
stmt = select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING)
running_dags = float(session.scalar(stmt))
running_dags = float(session.scalar(stmt) or 0)
Stats.gauge("scheduler.dagruns.running", running_dags)

@provide_session
Expand Down
3 changes: 1 addition & 2 deletions airflow-core/src/airflow/models/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
Integer,
String,
UniqueConstraint,
desc,
func,
select,
)
Expand Down Expand Up @@ -229,7 +228,7 @@ def _get_latest_dag_run_row_query(*, dag_id: str, info: DagRunInfo, session: Ses
DagRun.logical_date == info.logical_date,
DagRun.dag_id == dag_id,
)
.order_by(nulls_first(desc(DagRun.start_date), session=session))
.order_by(nulls_first(DagRun.start_date.desc(), session=session))
.limit(1)
)

Expand Down
5 changes: 3 additions & 2 deletions airflow-core/src/airflow/models/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -185,7 +186,7 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) ->
dagruns_to_refresh = set()

for deadline, dagrun in deadline_dagrun_pairs:
if dagrun.end_date <= deadline.deadline_time:
if dagrun.end_date is not None and dagrun.end_date <= deadline.deadline_time:
# If the DagRun finished before the Deadline:
session.delete(deadline)
Stats.incr(
Expand Down Expand Up @@ -403,7 +404,7 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
query = query.limit(self.max_runs)

# Get all durations and calculate average
durations = session.execute(query).scalars().all()
durations: Sequence = session.execute(query).scalars().all()

if len(durations) < cast("int", self.min_runs):
logger.info(
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def slots_stats(
pools: dict[str, PoolStats] = {}
pool_includes_deferred: dict[str, bool] = {}

query: Select[Any] = select(Pool.pool, Pool.slots, Pool.include_deferred)
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
query: Select[str, int, bool] = select(Pool.pool, Pool.slots, Pool.include_deferred) # type: ignore[type-arg]

if lock_rows:
query = with_row_locks(query, session=session, nowait=True)
Expand Down
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/models/taskreschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
Index,
Integer,
String,
asc,
desc,
select,
)
from sqlalchemy.dialects import postgresql
Expand Down Expand Up @@ -94,4 +92,4 @@ def stmt_for_task_instance(
:param descending: If True then records are returned in descending order
:meta private:
"""
return select(cls).where(cls.ti_id == ti.id).order_by(desc(cls.id) if descending else asc(cls.id))
return select(cls).where(cls.ti_id == ti.id).order_by(cls.id.desc() if descending else cls.id.asc())
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from airflow.utils.sqlalchemy import get_dialect_name

if TYPE_CHECKING:
from collections.abc import Sequence

from sqlalchemy import ColumnElement
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -210,7 +212,7 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
.limit(self.max_runs)
)

durations = list(session.execute(query).scalars())
durations: Sequence = session.execute(query).scalars().all()

min_runs = self.min_runs or 0
if len(durations) < min_runs:
Expand Down
Loading
Loading