From f5421b5a3f3e0a6b6e14e2e6e7a1bd3ed5f5fdf7 Mon Sep 17 00:00:00 2001 From: Dev-iL <6509619+Dev-iL@users.noreply.github.com> Date: Tue, 3 Feb 2026 18:22:46 +0200 Subject: [PATCH] Fix type hints --- .../auth/managers/base_auth_manager.py | 7 ++- .../core_api/routes/public/assets.py | 5 ++- .../core_api/routes/public/dag_stats.py | 9 +++- .../core_api/routes/public/dag_tags.py | 3 +- .../core_api/routes/public/xcom.py | 5 ++- .../core_api/services/public/dag_run.py | 10 +++-- .../execution_api/routes/dag_runs.py | 2 +- .../api_fastapi/execution_api/routes/xcoms.py | 9 ++-- .../src/airflow/dag_processing/collection.py | 27 +++++++----- .../src/airflow/dag_processing/manager.py | 20 +++++---- .../src/airflow/jobs/scheduler_job_runner.py | 5 +-- airflow-core/src/airflow/models/backfill.py | 3 +- airflow-core/src/airflow/models/deadline.py | 5 ++- airflow-core/src/airflow/models/pool.py | 3 +- .../src/airflow/models/taskreschedule.py | 4 +- .../serialization/definitions/deadline.py | 4 +- .../airflow/serialization/definitions/node.py | 7 ++- .../airflow/ti_deps/deps/trigger_rule_dep.py | 4 +- airflow-core/src/airflow/typing_compat.py | 12 ++--- airflow-core/src/airflow/utils/sqlalchemy.py | 6 +-- airflow-core/tests/unit/models/test_dag.py | 3 +- .../tests/unit/models/test_taskinstance.py | 4 +- .../_internal_client/secret_manager_client.py | 4 +- .../google/common/hooks/base_google.py | 1 + .../src/airflow_shared/dagnode/node.py | 44 +++++++++++++++++-- .../observability/metrics/datadog_logger.py | 4 +- .../observability/metrics/stats.py | 1 + .../plugins_manager/plugins_manager.py | 7 ++- .../providers_discovery.py | 10 ++++- .../airflow/sdk/definitions/_internal/node.py | 2 +- 30 files changed, 152 insertions(+), 78 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py index 5e98dc86fb087..8f9f1dd8e0039 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py @@ -47,6 +47,7 @@ from airflow.models import Connection, DagModel, Pool, Variable from airflow.models.dagbundle import DagBundleModel from airflow.models.team import Team, dag_bundle_team_association_table +from airflow.typing_compat import Unpack from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -54,6 +55,7 @@ from collections.abc import Sequence from fastapi import FastAPI + from sqlalchemy import Row from sqlalchemy.orm import Session from airflow.api_fastapi.auth.managers.models.batch_apis import ( @@ -569,8 +571,9 @@ def get_authorized_dag_ids( isouter=True, ) ) - rows = session.execute(stmt).all() - dags_by_team: dict[str | None, set[str]] = defaultdict(set) + # The below type annotation is acceptable on SQLA2.1, but not on 2.0 + rows: Sequence[Row[Unpack[tuple[str, str]]]] = session.execute(stmt).all() # type: ignore[type-arg] + dags_by_team: dict[str, set[str]] = defaultdict(set) for dag_id, team_name in rows: dags_by_team[team_name].add(dag_id) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py index 0c23cc93a78d7..3e5108f14472c 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py @@ -74,10 +74,12 @@ AssetWatcherModel, TaskOutletAssetReference, ) +from airflow.typing_compat import Unpack from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: + from sqlalchemy.engine import Result from sqlalchemy.sql import Select assets_router = AirflowRouter(tags=["Asset"]) @@ -179,7 +181,8 @@ def get_assets( session=session, ) - assets_rows = session.execute( + # The below type annotation is acceptable on SQLA2.1, but not on 2.0 + assets_rows: Result[Unpack[tuple[AssetModel, int, datetime]]] = session.execute( # type: ignore[type-arg] assets_select.options( subqueryload(AssetModel.scheduled_dags), subqueryload(AssetModel.producing_tasks), diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py index d2b4cc17bf1d8..b42607369d35d 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import Depends, status @@ -41,8 +41,12 @@ from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep, requires_access_dag from airflow.models.dagrun import DagRun +from airflow.typing_compat import Unpack from airflow.utils.state import DagRunState +if TYPE_CHECKING: + from sqlalchemy import Result + dag_stats_router = AirflowRouter(tags=["DagStats"], prefix="/dagStats") @@ -71,7 +75,8 @@ def get_dag_stats( session=session, return_total_entries=False, ) - query_result = session.execute(dagruns_select) + # The below type annotation is acceptable on SQLA2.1, but not on 2.0 + query_result: Result[Unpack[tuple[str, str, str, int]]] = session.execute(dagruns_select) # type: ignore[type-arg] result_dag_ids = [] dag_display_names: dict[str, str] = {} diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py index b02b9be31ecbb..86ba73e69dae4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py @@ -17,6 +17,7 @@ from __future__ import annotations +from collections.abc import Sequence from typing import Annotated from fastapi import Depends @@ -67,5 +68,5 @@ def get_dag_tags( limit=limit, session=session, ) - dag_tags = session.execute(dag_tags_select).scalars().all() + dag_tags: Sequence = session.execute(dag_tags_select).scalars().all() return DAGTagCollectionResponse(tags=[x for x in dag_tags], total_entries=total_entries) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 4ca9e3420377c..7bf64592aa640 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -93,10 +93,11 @@ def get_xcom_entry( # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead # retrieves the raw serialized value from the database. - result = session.scalars(xcom_query).first() + raw_result: tuple[XComModel] | None = session.scalars(xcom_query).first() - if result is None: + if raw_result is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found") + result = raw_result[0] if isinstance(raw_result, tuple) else raw_result item = copy.copy(result) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py index 5a08ed1c3b065..110f34c780ead 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py @@ -35,6 +35,8 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Iterator + from sqlalchemy import ScalarResult + @attrs.define class DagRunWaiter: @@ -57,10 +59,12 @@ def _serialize_xcoms(self) -> dict[str, Any]: task_ids=self.result_task_ids, dag_ids=self.dag_id, ) - xcom_results = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)) + xcom_results: ScalarResult[tuple[XComModel]] = self.session.scalars( + xcom_query.order_by(XComModel.task_id, XComModel.map_index) + ) - def _group_xcoms(g: Iterator[XComModel]) -> Any: - entries = list(g) + def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any: + entries = [row[0] if isinstance(row, tuple) else row for row in g] if len(entries) == 1 and entries[0].map_index < 0: # Unpack non-mapped task xcom. return entries[0].value return [entry.value for entry in entries] # Task is mapped; return all xcoms in a list. diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index 7763850b5ee4e..b3fd1cff7eec4 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -190,7 +190,7 @@ def get_dagrun_state( ) -> DagRunStateResponse: """Get a Dag run State.""" try: - state = session.scalars( + state: DagRunState = session.scalars( select(DagRunModel.state).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id) ).one() except NoResultFound: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 3408513a8c8c3..ec77b64dc4496 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -113,6 +113,7 @@ def get_mapped_xcom_by_index( else: xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset) + result: tuple[XComModel] | None if (result := session.scalars(xcom_query).first()) is None: message = ( f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}" @@ -121,7 +122,7 @@ def get_mapped_xcom_by_index( status_code=status.HTTP_404_NOT_FOUND, detail={"reason": "not_found", "message": message}, ) - return XComSequenceIndexResponse(result.value) + return XComSequenceIndexResponse((result[0] if isinstance(result, tuple) else result).value) class GetXComSliceFilterParams(BaseModel): @@ -291,8 +292,8 @@ def get_xcom( # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` # (which automatically deserializes using the backend), we avoid potential # performance hits from retrieving large data files into the API server. - result = session.scalars(xcom_query).first() - if result is None: + result: tuple[XComModel] | None + if (result := session.scalars(xcom_query).first()) is None: if params.offset is None: message = ( f"XCom with {key=} map_index={params.map_index} not found for " @@ -308,7 +309,7 @@ def get_xcom( detail={"reason": "not_found", "message": message}, ) - return XComResponse(key=key, value=result.value) + return XComResponse(key=key, value=(result[0] if isinstance(result, tuple) else result).value) # TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the URL and just use the info in diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index f5ad10196b48c..5abfbd8ae7525 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -28,7 +28,7 @@ from __future__ import annotations import traceback -from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast +from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar import structlog from sqlalchemy import delete, func, insert, select, tuple_, update @@ -76,7 +76,7 @@ from sqlalchemy.sql import Select from airflow.models.dagwarning import DagWarning - from airflow.typing_compat import Self + from airflow.typing_compat import Self, Unpack AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias) @@ -512,15 +512,18 @@ class DagModelOperation(NamedTuple): def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]: """Find existing DagModel objects from DAG objects.""" - stmt = ( - select(DagModel) - .options(joinedload(DagModel.tags, innerjoin=False)) - .where(DagModel.dag_id.in_(self.dags)) - .options(joinedload(DagModel.schedule_asset_references)) - .options(joinedload(DagModel.schedule_asset_alias_references)) - .options(joinedload(DagModel.task_outlet_asset_references)) + stmt: Select[Unpack[tuple[DagModel]]] = with_row_locks( + ( + select(DagModel) + .options(joinedload(DagModel.tags, innerjoin=False)) + .where(DagModel.dag_id.in_(self.dags)) + .options(joinedload(DagModel.schedule_asset_references)) + .options(joinedload(DagModel.schedule_asset_alias_references)) + .options(joinedload(DagModel.task_outlet_asset_references)) + ), + of=DagModel, + session=session, ) - stmt = cast("Select[tuple[DagModel]]", with_row_locks(stmt, of=DagModel, session=session)) return {dm.dag_id: dm for dm in session.scalars(stmt).unique()} def add_dags(self, *, session: Session) -> dict[str, DagModel]: @@ -711,7 +714,7 @@ def _find_all_asset_aliases(dags: Iterable[LazyDeserializedDAG]) -> Iterator[Ser def _find_active_assets(name_uri_assets: Iterable[tuple[str, str]], session: Session) -> set[tuple[str, str]]: return { - tuple(row) + (str(row[0]), str(row[1])) for row in session.execute( select(AssetModel.name, AssetModel.uri).where( tuple_(AssetModel.name, AssetModel.uri).in_(name_uri_assets), @@ -906,7 +909,7 @@ def _add_dag_asset_references( if not references: return orm_refs = { - tuple(row) + (str(row[0]), str(row[1])) for row in session.execute( select(model.dag_id, getattr(model, attr)).where( model.dag_id.in_(dag_id for dag_id, _ in references) diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index d77503a588519..81ae9afe577dc 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -75,7 +75,7 @@ from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Iterator + from collections.abc import Callable, Iterable, Iterator, Sequence from socket import socket from sqlalchemy.orm import Session @@ -497,15 +497,17 @@ def _fetch_callbacks( callback_queue: list[CallbackRequest] = [] with prohibit_commit(session) as guard: bundle_names = [bundle.name for bundle in self._dag_bundles] - query: Select[tuple[DbCallbackRequest]] = select(DbCallbackRequest) - query = query.order_by(DbCallbackRequest.priority_weight.desc()).limit( - self.max_callbacks_per_loop - ) - query = cast( - "Select[tuple[DbCallbackRequest]]", - with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True), + query: Select[tuple[DbCallbackRequest]] = with_row_locks( + select(DbCallbackRequest) + .order_by(DbCallbackRequest.priority_weight.desc()) + .limit(self.max_callbacks_per_loop), + of=DbCallbackRequest, + session=session, + skip_locked=True, ) - callbacks = session.scalars(query) + callbacks: Sequence[DbCallbackRequest] = [ + cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query) + ] for callback in callbacks: req = callback.get_callback_request() if req.bundle_name not in bundle_names: diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 0dc18ca696ac2..96b865f75bbaa 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -36,7 +36,6 @@ from sqlalchemy import ( and_, delete, - desc, exists, func, inspect, @@ -2578,7 +2577,7 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE Log.try_number == ti.try_number, Log.event == "running", ) - .order_by(desc(Log.dttm)) + .order_by(Log.dttm.desc()) .limit(1) ) @@ -2652,7 +2651,7 @@ def _emit_ti_metrics(self, session: Session = NEW_SESSION) -> None: @provide_session def _emit_running_dags_metric(self, session: Session = NEW_SESSION) -> None: stmt = select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) - running_dags = float(session.scalar(stmt)) + running_dags = float(session.scalar(stmt) or 0) Stats.gauge("scheduler.dagruns.running", running_dags) @provide_session diff --git a/airflow-core/src/airflow/models/backfill.py b/airflow-core/src/airflow/models/backfill.py index 365e6c9b225b4..64828bdc10f3f 100644 --- a/airflow-core/src/airflow/models/backfill.py +++ b/airflow-core/src/airflow/models/backfill.py @@ -35,7 +35,6 @@ Integer, String, UniqueConstraint, - desc, func, select, ) @@ -229,7 +228,7 @@ def _get_latest_dag_run_row_query(*, dag_id: str, info: DagRunInfo, session: Ses DagRun.logical_date == info.logical_date, DagRun.dag_id == dag_id, ) - .order_by(nulls_first(desc(DagRun.start_date), session=session)) + .order_by(nulls_first(DagRun.start_date.desc(), session=session)) .limit(1) ) diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 11985a42c5aca..070304f30a728 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -18,6 +18,7 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, cast @@ -185,7 +186,7 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> dagruns_to_refresh = set() for deadline, dagrun in deadline_dagrun_pairs: - if dagrun.end_date <= deadline.deadline_time: + if dagrun.end_date is not None and dagrun.end_date <= deadline.deadline_time: # If the DagRun finished before the Deadline: session.delete(deadline) Stats.incr( @@ -403,7 +404,7 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: query = query.limit(self.max_runs) # Get all durations and calculate average - durations = session.execute(query).scalars().all() + durations: Sequence = session.execute(query).scalars().all() if len(durations) < cast("int", self.min_runs): logger.info( diff --git a/airflow-core/src/airflow/models/pool.py b/airflow-core/src/airflow/models/pool.py index a00e8b90b5849..dab8862a12c47 100644 --- a/airflow-core/src/airflow/models/pool.py +++ b/airflow-core/src/airflow/models/pool.py @@ -191,7 +191,8 @@ def slots_stats( pools: dict[str, PoolStats] = {} pool_includes_deferred: dict[str, bool] = {} - query: Select[Any] = select(Pool.pool, Pool.slots, Pool.include_deferred) + # The below type annotation is acceptable on SQLA2.1, but not on 2.0 + query: Select[str, int, bool] = select(Pool.pool, Pool.slots, Pool.include_deferred) # type: ignore[type-arg] if lock_rows: query = with_row_locks(query, session=session, nowait=True) diff --git a/airflow-core/src/airflow/models/taskreschedule.py b/airflow-core/src/airflow/models/taskreschedule.py index 005b68468456e..88f87121fd717 100644 --- a/airflow-core/src/airflow/models/taskreschedule.py +++ b/airflow-core/src/airflow/models/taskreschedule.py @@ -28,8 +28,6 @@ Index, Integer, String, - asc, - desc, select, ) from sqlalchemy.dialects import postgresql @@ -94,4 +92,4 @@ def stmt_for_task_instance( :param descending: If True then records are returned in descending order :meta private: """ - return select(cls).where(cls.ti_id == ti.id).order_by(desc(cls.id) if descending else asc(cls.id)) + return select(cls).where(cls.ti_id == ti.id).order_by(cls.id.desc() if descending else cls.id.asc()) diff --git a/airflow-core/src/airflow/serialization/definitions/deadline.py b/airflow-core/src/airflow/serialization/definitions/deadline.py index 2fefbcb86afe9..78adc6b9a7666 100644 --- a/airflow-core/src/airflow/serialization/definitions/deadline.py +++ b/airflow-core/src/airflow/serialization/definitions/deadline.py @@ -32,6 +32,8 @@ from airflow.utils.sqlalchemy import get_dialect_name if TYPE_CHECKING: + from collections.abc import Sequence + from sqlalchemy import ColumnElement from sqlalchemy.orm import Session @@ -210,7 +212,7 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None: .limit(self.max_runs) ) - durations = list(session.execute(query).scalars()) + durations: Sequence = session.execute(query).scalars().all() min_runs = self.min_runs or 0 if len(durations) < min_runs: diff --git a/airflow-core/src/airflow/serialization/definitions/node.py b/airflow-core/src/airflow/serialization/definitions/node.py index 06c61a54de528..2cbdc9db7714b 100644 --- a/airflow-core/src/airflow/serialization/definitions/node.py +++ b/airflow-core/src/airflow/serialization/definitions/node.py @@ -32,11 +32,16 @@ __all__ = ["DAGNode"] -class DAGNode(GenericDAGNode["SerializedDAG", "Operator", "SerializedTaskGroup"], metaclass=abc.ABCMeta): +class DAGNode(GenericDAGNode["SerializedDAG", "Operator", "SerializedTaskGroup"], metaclass=abc.ABCMeta): # type: ignore[type-var] """ Base class for a node in the graph of a workflow. A node may be an operator or task group, either mapped or unmapped. + + Note: type: ignore is used because SerializedBaseOperator and SerializedTaskGroup + don't have explicit type annotations for all attributes required by TaskProtocol + and TaskGroupProtocol (they inherit them from GenericDAGNode). This is acceptable + because they are implemented correctly at runtime. """ @property diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 893807fa59930..5d2b6955d75ba 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -40,6 +40,7 @@ from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus + from airflow.typing_compat import Unpack class _UpstreamTIStates(NamedTuple): @@ -371,7 +372,8 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: upstream = len(upstream_tasks) upstream_setup = sum(1 for x in upstream_tasks.values() if x.is_setup) else: - task_id_counts: Sequence[Row[tuple[str, int]]] = session.execute( + # The below type annotation is acceptable on SQLA2.1, but not on 2.0 + task_id_counts: Sequence[Row[Unpack[tuple[str, int]]]] = session.execute( # type: ignore[type-arg] select(TaskInstance.task_id, func.count(TaskInstance.task_id)) .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id == ti.run_id) .where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks))) diff --git a/airflow-core/src/airflow/typing_compat.py b/airflow-core/src/airflow/typing_compat.py index 8a00ac06bd7f1..e1efb87067b34 100644 --- a/airflow-core/src/airflow/typing_compat.py +++ b/airflow-core/src/airflow/typing_compat.py @@ -19,13 +19,7 @@ from __future__ import annotations -__all__ = [ - "Literal", - "ParamSpec", - "Self", - "TypeAlias", - "TypeGuard", -] +__all__ = ["Literal", "ParamSpec", "Self", "TypeAlias", "TypeGuard", "Unpack"] import sys @@ -33,6 +27,6 @@ from typing import Literal, ParamSpec, TypeAlias, TypeGuard if sys.version_info >= (3, 11): - from typing import Self + from typing import Self, Unpack else: - from typing_extensions import Self + from typing_extensions import Self, Unpack diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py b/airflow-core/src/airflow/utils/sqlalchemy.py index 8d9e826bef794..266be08c3bb14 100644 --- a/airflow-core/src/airflow/utils/sqlalchemy.py +++ b/airflow-core/src/airflow/utils/sqlalchemy.py @@ -22,7 +22,7 @@ import datetime import logging from collections.abc import Generator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst from sqlalchemy.dialects import mysql @@ -319,14 +319,14 @@ def nulls_first(col: ColumnElement, session: Session) -> ColumnElement: def with_row_locks( - query: Select[Any], + query: Select, session: Session, *, nowait: bool = False, skip_locked: bool = False, key_share: bool = True, **kwargs, -) -> Select[Any]: +) -> Select: """ Apply with_for_update to the SQLAlchemy query if row level locking is in use. diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index 2a45735847ecf..9683a500d9509 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -118,6 +118,7 @@ ) if TYPE_CHECKING: + from sqlalchemy.engine import ScalarResult from sqlalchemy.orm import Session pytestmark = pytest.mark.db_test @@ -2361,7 +2362,7 @@ def test_asset_expression(self, session: Session, testing_dag_bundle) -> None: ) SerializedDAG.bulk_write_to_db("testing", None, [dag], session=session) - expression = session.scalars( + expression: ScalarResult = session.scalars( select(DagModel.asset_expression).where(DagModel.dag_id == dag.dag_id) ).one() assert expression == { diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 08dffc78dc566..1b6b3a01d1b42 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -1733,7 +1733,9 @@ def _write2_post_execute(context, _): for ti in dr.get_task_instances(session=session): run_task_instance(ti, dag_maker.dag.get_task(ti.task_id), session=session) - events = dict((tuple(row)) for row in session.execute(select(AssetEvent.source_task_id, AssetEvent))) + events: dict[str, AssetEvent] = dict( + (str(row[0]), row[1]) for row in session.execute(select(AssetEvent.source_task_id, AssetEvent)) + ) assert set(events) == {"write1", "write2"} assert events["write1"].source_dag_id == dr.dag_id diff --git a/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py b/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py index a78d0c7bdb237..0ddf01d99fb17 100644 --- a/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py +++ b/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py @@ -27,7 +27,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: - import google + from google.auth.credentials import Credentials SECRET_ID_PATTERN = r"^[a-zA-Z0-9-_]*$" @@ -45,7 +45,7 @@ class _SecretManagerClient(LoggingMixin): def __init__( self, - credentials: google.auth.credentials.Credentials, + credentials: Credentials, ) -> None: super().__init__() self.credentials = credentials diff --git a/providers/google/src/airflow/providers/google/common/hooks/base_google.py b/providers/google/src/airflow/providers/google/common/hooks/base_google.py index b55dce384eb44..0a7a0805a4f0a 100644 --- a/providers/google/src/airflow/providers/google/common/hooks/base_google.py +++ b/providers/google/src/airflow/providers/google/common/hooks/base_google.py @@ -718,6 +718,7 @@ def __init__( super().__init__(session=cast("Session", session), scopes=_scopes) self.credentials = credentials self.project = project + self.acquiring: asyncio.Task[None] | None = None @classmethod async def from_hook( diff --git a/shared/dagnode/src/airflow_shared/dagnode/node.py b/shared/dagnode/src/airflow_shared/dagnode/node.py index 0fed0c97f7524..7d52ff1ea1f4d 100644 --- a/shared/dagnode/src/airflow_shared/dagnode/node.py +++ b/shared/dagnode/src/airflow_shared/dagnode/node.py @@ -17,18 +17,54 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar import structlog if TYPE_CHECKING: + import sys from collections.abc import Collection, Iterable + # Replicate `airflow.typing_compat.Self` to avoid illegal imports + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + from ..logging.types import Logger -Dag = TypeVar("Dag") -Task = TypeVar("Task") -TaskGroup = TypeVar("TaskGroup") + +class DagProtocol(Protocol): + """Protocol defining the minimum interface required for Dag generic type.""" + + dag_id: str + task_dict: dict[str, Any] + + def get_task(self, tid: str) -> Any: + """Retrieve a task by its task ID.""" + ... + + +class TaskProtocol(Protocol): + """Protocol defining the minimum interface required for Task generic type.""" + + task_id: str + is_setup: bool + is_teardown: bool + downstream_list: Iterable[Self] + downstream_task_ids: set[str] + + +class TaskGroupProtocol(Protocol): + """Protocol defining the minimum interface required for TaskGroup generic type.""" + + node_id: str + prefix_group_id: bool + + +Dag = TypeVar("Dag", bound=DagProtocol) +Task = TypeVar("Task", bound=TaskProtocol) +TaskGroup = TypeVar("TaskGroup", bound=TaskGroupProtocol) class GenericDAGNode(Generic[Dag, Task, TaskGroup]): diff --git a/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py b/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py index 595e6c8a33f25..09ac6b15a7104 100644 --- a/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py +++ b/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py @@ -20,7 +20,7 @@ import datetime import logging from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from .protocols import Timer from .validators import ( @@ -176,7 +176,7 @@ def get_dogstatsd_logger( """Get DataDog StatsD logger.""" from datadog import DogStatsd - dogstatsd_kwargs: dict[str, str | int | list[str]] = { + dogstatsd_kwargs: dict[str, Any] = { "constant_tags": cls.get_constant_tags(tags_in_string=tags_in_string), } if host is not None: diff --git a/shared/observability/src/airflow_shared/observability/metrics/stats.py b/shared/observability/src/airflow_shared/observability/metrics/stats.py index e2c3e63077dac..e477a751ec090 100644 --- a/shared/observability/src/airflow_shared/observability/metrics/stats.py +++ b/shared/observability/src/airflow_shared/observability/metrics/stats.py @@ -56,6 +56,7 @@ def __getattr__(cls, name: str) -> str: def initialize(cls, *, is_statsd_datadog_enabled: bool, is_statsd_on: bool, is_otel_on: bool) -> None: type.__setattr__(cls, "factory", None) type.__setattr__(cls, "instance", None) + factory: Callable if is_statsd_datadog_enabled: from airflow.observability.metrics import datadog_logger diff --git a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py index 9ea497e5a10c5..8fcc5c9c808d5 100644 --- a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py +++ b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py @@ -19,6 +19,9 @@ from __future__ import annotations +import importlib +import importlib.machinery +import importlib.util import inspect import logging import os @@ -208,8 +211,6 @@ def _load_plugins_from_plugin_directory( ignore_file_syntax: str = "glob", ) -> tuple[list[AirflowPlugin], dict[str, str]]: """Load and register Airflow Plugins from plugins directory.""" - import importlib - from ..module_loading import find_path_from_directory if not plugins_folder: @@ -219,6 +220,8 @@ def _load_plugins_from_plugin_directory( plugin_search_locations: list[tuple[str, Generator[str, None, None]]] = [("", files)] if load_examples: + if not example_plugins_module: + raise ValueError("example_plugins_module is required when load_examples is True") log.debug("Note: Loading plugins from examples as well: %s", plugins_folder) example_plugins = importlib.import_module(example_plugins_module) example_plugins_folder = next(iter(example_plugins.__path__)) diff --git a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py index 4fc882d1b5d23..dcab0fe3034aa 100644 --- a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py +++ b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py @@ -27,7 +27,7 @@ from functools import wraps from importlib.resources import files as resource_files from time import perf_counter -from typing import Any, NamedTuple, ParamSpec +from typing import Any, NamedTuple, ParamSpec, Protocol, cast import structlog from packaging.utils import canonicalize_name @@ -43,6 +43,12 @@ KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")] +class ProvidersManagerProtocol(Protocol): + """Protocol for ProvidersManager for type checking purposes.""" + + _initialized_cache: dict[str, bool] + + @dataclass class ProviderInfo: """ @@ -271,7 +277,7 @@ def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, None]], Calla def provider_info_cache_decorator(func: Callable[PS, None]) -> Callable[PS, None]: @wraps(func) def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None: - instance = args[0] + instance = cast("ProvidersManagerProtocol", args[0]) if cache_name in instance._initialized_cache: return diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index b2cb651efe1a8..7a652e8c27943 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -64,7 +64,7 @@ def validate_group_key(k: str, max_length: int = 200): ) -class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin, metaclass=ABCMeta): +class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin, metaclass=ABCMeta): # type: ignore[type-var] """ A base class for a node in the graph of a workflow.