diff --git a/airflow-core/src/airflow/example_dags/example_asset_state.py b/airflow-core/src/airflow/example_dags/example_asset_state.py new file mode 100644 index 0000000000000..3d18b9afaa609 --- /dev/null +++ b/airflow-core/src/airflow/example_dags/example_asset_state.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Dag that demonstrates using AIP-103 asset state to track a watermark across DAG runs. +The producer reads the last watermark, processes only new records, then +advances the watermark. The consumer is triggered by the asset event and +reads asset state to understand what the producer just loaded. + +Asset state persists on the asset across runs — unlike task state which is +scoped to a single task instance. This replaces the common pattern of +storing watermarks in Airflow Variables, which have no asset-level scoping. +""" + +from __future__ import annotations + +import json +import random +from datetime import datetime, timezone + +from airflow.sdk import DAG, Asset, task + +ORDERS = Asset(name="orders/daily", uri="s3://warehouse/orders/daily") + + +def _fetch_records(since: str) -> list[dict]: + """Simulate fetching records newer than `since`.""" + return [{"id": i} for i in range(random.randint(100, 5_000))] + + +with DAG( + dag_id="example_asset_state_producer", + schedule=None, + start_date=datetime(2026, 1, 1), + catchup=False, + tags=["example", "asset-state"], + doc_md=__doc__, +): + + @task(inlets=[ORDERS], outlets=[ORDERS]) + def load(asset_state=None): + state = asset_state[ORDERS] + + # First run: watermark is None — fall back to epoch start. + watermark = state.get("watermark") or "2026-01-01T00:00:00+00:00" + records = _fetch_records(since=watermark) + row_count = len(records) + + now = datetime.now(tz=timezone.utc).isoformat() + state.set("watermark", now) + state.set("total_runs", (state.get("total_runs") or 0) + 1) + state.set( + "last_run_summary", + { + "rows_loaded": row_count, + "prev_watermark": watermark, + "completed_at": now, + }, + ) + + print(f"Loaded {row_count} records. Watermark advanced to {now}.") + return row_count + + load() + + +with DAG( + dag_id="example_asset_state_consumer", + schedule=[ORDERS], + start_date=datetime(2026, 1, 1), + catchup=False, + tags=["example", "asset-state"], +): + + @task(inlets=[ORDERS]) + def consume(asset_state=None): + state = asset_state[ORDERS] + summary = json.loads(state.get("last_run_summary") or "{}") + print( + f"Processing {summary.get('rows_loaded', '?')} rows " + f"up to watermark {state.get('watermark')}. " + f"Total runs so far: {state.get('total_runs')}." + ) + + consume() diff --git a/airflow-core/src/airflow/example_dags/example_task_state.py b/airflow-core/src/airflow/example_dags/example_task_state.py new file mode 100644 index 0000000000000..2689bfeea240d --- /dev/null +++ b/airflow-core/src/airflow/example_dags/example_task_state.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Dag that demonstrates the canonical AIP-103 task state pattern: a task submits a +long-running external job, stores the job handle in task state, and polls +until completion. + +The first attempt always fails after submitting the job (simulating a +worker crash / connection to external system being lost). The retry reads +the job ID from task state and reattaches to the already-running job instead +of submitting a duplicate. +""" + +from __future__ import annotations + +import json +import random +import string +import time +from datetime import datetime, timedelta, timezone + +from airflow.sdk import DAG, task +from airflow.sdk.execution_time.context import NEVER_EXPIRE + + +def _submit_job() -> str: + """Simulate submitting an external job. Returns a job ID.""" + time.sleep(1) + return "job-" + "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) + + +def _poll_job(job_id: str) -> dict: + """Simulate polling an external job until complete.""" + time.sleep(1) + return {"job_id": job_id, "status": "succeeded", "rows_written": random.randint(100, 10_000)} + + +with DAG( + dag_id="example_task_state", + schedule=None, + start_date=datetime(2026, 1, 1), + catchup=False, + tags=["example", "task-state"], + doc_md=__doc__, +): + + @task(retries=2, retry_delay=timedelta(seconds=5)) + def run_job(**context): + task_state = context["task_state"] + try_number = context["ti"].try_number + + job_id = task_state.get("job_id") + if job_id: + print(f"Try {try_number}: reattaching to existing job: {job_id}") + else: + job_id = _submit_job() + # Store with NEVER_EXPIRE so the job ID survives across all retries. + task_state.set("job_id", job_id, retention=NEVER_EXPIRE) + task_state.set("submitted_at", datetime.now(tz=timezone.utc).isoformat()) + print(f"Try {try_number}: submitted job: {job_id}") + + # Simulate a crash after submission on the first attempt. + # The retry will reattach to the same job instead of submitting a duplicate. + raise RuntimeError( + f"Simulated failure after submitting {job_id}. The next retry will reattach to this job." + ) + + task_state.set("status", "running") + result = _poll_job(job_id) + task_state.set("status", "complete") + task_state.set("result", json.dumps(result)) + + print(f"Try {try_number}: job complete — {result['rows_written']} rows written") + return result["rows_written"] + + run_job()