Skip to content

feature/issue-67200: Adding AssetState Task SDK mechanism#67248

Open
jroachgolf84 wants to merge 5 commits into
apache:mainfrom
jroachgolf84:feature/issue-67200
Open

feature/issue-67200: Adding AssetState Task SDK mechanism#67248
jroachgolf84 wants to merge 5 commits into
apache:mainfrom
jroachgolf84:feature/issue-67200

Conversation

@jroachgolf84
Copy link
Copy Markdown
Collaborator

@jroachgolf84 jroachgolf84 commented May 20, 2026

Description

This PR adds the AssetState mechanism to use the foundations put in place in AIP-103 within a Trigger or Task. AssetState can be used like this:

from airflow.sdk import AssetState

asset_state = AssetState(name="my_asset")  # Pass in the name of the Asset (or uri)

asset_state.set("watermark": "1")         # Store a value
watermark = asset_state.get("watermark")  # Retrieve the value

print(watermark)  # Should output "1"

Note: Documentation has not been created for this functionality. That will be done as part of #65782.

closes: #67200
related: #65782

Testing

These changes were unit-tested, as well as tested E2E. See below for more information.

Unit Tests

Existing unit tests were updated and new tests were added to validate the changes that were made in this branch. Note, there are additional unit-tests to be added.

# Updating existing unit testes
breeze testing core-tests airflow-core/tests/unit/jobs/test_triggerer_job.py

# New unit tests
breeze testing task-sdk-tests task-sdk/tests/task_sdk/definitions/test_asset_state.py

E2E Testing

To test this E2E, the DAG below was used. This implemented AssetState in both a Task and a Trigger (checkout the Trigger below). Note that I was able to retrieve and set state in the GenericEventTrigger, as well as the downstream_task.

from airflow.sdk import Asset, AssetState, AssetWatcher, DAG, task
from datetime import datetime
from triggers.event_triggers import GenericEventTrigger


generic_asset_watcher = AssetWatcher(
    name="generic_asset_watcher",
    trigger=GenericEventTrigger(
        random_number=1,
        waiter_delay=15,
        asset_name="generic_asset"
    )
)

generic_asset = Asset(
    name="generic_asset",
    watchers=[generic_asset_watcher],
)


with DAG(
    dag_id="issue-67200",
    start_date=datetime(2026, 1, 1),
    schedule=[generic_asset]
) as dag:

    @task
    def downstream_task():
        asset_state = AssetState(name="generic_asset")
        result = asset_state.get("result")
        print(f"***** result: {result}")

        asset_state.set("task_result", str(result))
        task_result = asset_state.get("task_result")
        print(f"***** task_result: {task_result}")

    downstream_task()

This GenericEventTrigger was used for E2E testing. Note that AssetState is used in the run method of the Trigger. This code properly stores and retrieves the generated number, and logs the output accordingly.

from airflow.sdk import AssetState
from airflow.triggers.base import BaseEventTrigger, TriggerEvent
from collections.abc import AsyncIterator
from typing import Any

import asyncio
import logging
import random


class GenericEventTrigger(BaseEventTrigger):
    def __init__(
        self,
        random_number,
        waiter_delay,
        asset_name,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.random_number = random_number
        self.waiter_delay = waiter_delay
        self.asset_name = asset_name

    def serialize(self) -> tuple[str, dict[str, Any]]:
        """Serialize the Trigger, including the func, params, and waiter_delay."""
        return (
            self.__class__.__module__ + "." + self.__class__.__qualname__,
            {
                "random_number": self.random_number,
                "waiter_delay": self.waiter_delay,
                "asset_name": self.asset_name,
            },
        )

    async def run(self) -> AsyncIterator[TriggerEvent]:
        """Logic that fires a TriggerEvent."""
        # Here's where the AssetState is actually being used
        asset_state = AssetState(name=self.asset_name)
        logging.info(f"***** asset_state: {asset_state}")

        while True:
            result = random.randint(0, 5)
            logging.info(f"result: {result}")

            asset_state.set("result", str(result))
            get_result = asset_state.get("result")
            logging.info(f"get_result: {get_result}")

            if result == self.random_number:
                logging.info("yield'ing TriggerEvent")
                yield TriggerEvent({"status": "success", "result": result})
                break

            logging.info(f"Sleeping for {self.waiter_delay} seconds")
            await asyncio.sleep(self.waiter_delay)

Outstanding Items

The following items still need to be completed:

  • Validate that AssetState works when called within a Task.
  • Ensure complete test coverage of the changes that were made.
Was generative AI tooling used to co-author this PR?

No, generative AI was not used to generate this PR.

@jroachgolf84 jroachgolf84 marked this pull request as ready for review May 20, 2026 20:21
@vincbeck
Copy link
Copy Markdown
Contributor

From AIP-103 I understood this new task state management could be used across multiple different use cases such as intra-task progress checkpointing . If that's the case, I am not sure about the name AssetState, it feels like something specific to asset although it should be agnostic. What do you think?

@jroachgolf84
Copy link
Copy Markdown
Collaborator Author

From AIP-103 I understood this new task state management could be used across multiple different use cases such as intra-task progress checkpointing . If that's the case, I am not sure about the name AssetState, it feels like something specific to asset although it should be agnostic. What do you think?

AIP-103 addresses both Task and Asset state. Here are some of the PR's that have added Asset State.

cc: @amoghrajesh

@jroachgolf84
Copy link
Copy Markdown
Collaborator Author

@amoghrajesh - when you get a chance, can you look at this? I'm going to work on getting these checks green.

@vincbeck
Copy link
Copy Markdown
Contributor

From AIP-103 I understood this new task state management could be used across multiple different use cases such as intra-task progress checkpointing . If that's the case, I am not sure about the name AssetState, it feels like something specific to asset although it should be agnostic. What do you think?

AIP-103 addresses both Task and Asset state. Here are some of the PR's that have added Asset State.

cc: @amoghrajesh

The APIs look good to me, I am only questioning the way to access the state. In your example you do asset_state = AssetState(name="generic_asset"), but in reality the task instance/asset state is scoped to its task instance/asset. It feels weird to me to be able to specify the asset name. If I do asset_state = AssetState(name="another_asset"), would I get the state from this other asset even though my code s not scoped to another_asset? Would it be protected by a JWT? In my mental model I was expecting a getter and setter like set_state, get_state that would use the runtime context (task instance ID/asset name) to retrieve the state

@vincbeck
Copy link
Copy Markdown
Contributor

I just looked at #67376 and I meant exactly this same mechanism but for asset. Cannot we plumb through automatically the state of an asset to its related triggers? That would avoid forcing Dag author to manually get the asset state with AssetState(name=self.asset_name)

@jroachgolf84
Copy link
Copy Markdown
Collaborator Author

I just looked at #67376 and I meant exactly this same mechanism but for asset. Cannot we plumb through automatically the state of an asset to its related triggers? That would avoid forcing Dag author to manually get the asset state with AssetState(name=self.asset_name)

@vincbeck - I think that makes sense. The only caveat is that a BaseEventTrigger can have multiple Assets "plumbed" through to it (as we discussed here: #66595). That would make the pattern for access the state for a specific Asset a bit more difficult. I guess the syntax would be something more like this:

# Contains the asset states
self.asset_states = ...

asset_a_state = self.asset_states.get("asset_a")
asset_b_state = self.asset_states.get("asset_b")

Thoughts? I'm kinda caught in the middle on this one.

@jroachgolf84
Copy link
Copy Markdown
Collaborator Author

@cmarteepants, @vikramkoka

Copy link
Copy Markdown
Contributor

@amoghrajesh amoghrajesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jroachgolf84, I think the plumbing fix here will be correct.

AssetState(name=self.asset_name) works but it is inconsistent with how tasks access asset state as Vincent also mentioned, and nothing prevents a trigger from doing AssetState(name="some_other_asset") even if it's not associated with that asset.

Looking at BaseTrigger, the triggerer already injects self.task_instance before calling run(). We could use the same pattern here and the triggerer could populate self.asset_states on BaseEventTrigger before run(), keyed by asset name, using the assets associated with the TI. The trigger author can then do:

async def run(self):
    watermark = self.asset_states["orders"].get("watermark")
    self.asset_states["orders"].set("watermark", new_watermark)
    yield TriggerEvent(...)

This will also provide some benefits like:

  • Make the triggers scoped automatically (only the trigger's associated assets are present)
  • Should also work cleanly for multi-asset triggers(just use the asset name while accessing)
  • The framework handles scoping, not the author

The triggerer already knows which assets are associated with the TI at the point it sets trigger.task_instance = ti, so populating asset_states there is probably straightforward.

@vincbeck
Copy link
Copy Markdown
Contributor

Hey @jroachgolf84, I think the plumbing fix here will be correct.

AssetState(name=self.asset_name) works but it is inconsistent with how tasks access asset state as Vincent also mentioned, and nothing prevents a trigger from doing AssetState(name="some_other_asset") even if it's not associated with that asset.

Looking at BaseTrigger, the triggerer already injects self.task_instance before calling run(). We could use the same pattern here and the triggerer could populate self.asset_states on BaseEventTrigger before run(), keyed by asset name, using the assets associated with the TI. The trigger author can then do:

async def run(self):
    watermark = self.asset_states["orders"].get("watermark")
    self.asset_states["orders"].set("watermark", new_watermark)
    yield TriggerEvent(...)

This will also provide some benefits like:

  • Make the triggers scoped automatically (only the trigger's associated assets are present)
  • Should also work cleanly for multi-asset triggers(just use the asset name while accessing)
  • The framework handles scoping, not the author

The triggerer already knows which assets are associated with the TI at the point it sets trigger.task_instance = ti, so populating asset_states there is probably straightforward.

I really like that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AIP-103: Add Task SDK support for retrieving Asset State by name/uri

3 participants