Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,20 @@
from airflow.observability.metrics import stats_utils
from airflow.sdk.api.datamodels._generated import HITLDetailResponse
from airflow.sdk.execution_time.comms import (
AssetStateResult,
ClearAssetStateByName,
ClearAssetStateByUri,
CommsDecoder,
ConnectionResult,
DagRunStateResult,
DeleteAssetStateByName,
DeleteAssetStateByUri,
DeleteVariable,
DeleteXCom,
DRCount,
ErrorResponse,
GetAssetStateByName,
GetAssetStateByUri,
GetConnection,
GetDagRunState,
GetDRCount,
Expand All @@ -79,6 +86,8 @@
MaskSecret,
OKResponse,
PutVariable,
SetAssetStateByName,
SetAssetStateByUri,
SetXCom,
TaskStatesResult,
TICount,
Expand Down Expand Up @@ -313,6 +322,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
ToTriggerRunner = Annotated[
messages.StartTriggerer
| messages.TriggerStateSync
| AssetStateResult
| ConnectionResult
| VariableResult
| VariableKeysResult
Expand Down Expand Up @@ -349,7 +359,15 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| GetPreviousTI
| GetHITLDetailResponse
| UpdateHITLDetail
| MaskSecret,
| MaskSecret
| GetAssetStateByName
| GetAssetStateByUri
| SetAssetStateByName
| SetAssetStateByUri
| DeleteAssetStateByName
| DeleteAssetStateByUri
| ClearAssetStateByName
| ClearAssetStateByUri,
Field(discriminator="type"),
]
"""
Expand Down Expand Up @@ -579,6 +597,47 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp = HITLDetailResponseResult.from_api_response(response=api_resp)
elif isinstance(msg, MaskSecret):
handle_mask_secret(msg)

elif isinstance(msg, GetAssetStateByName):
asset_state = self.client.asset_state.get(msg.key, name=msg.name)
resp = (
asset_state
if isinstance(asset_state, ErrorResponse)
else AssetStateResult.from_asset_state_response(asset_state)
)

elif isinstance(msg, GetAssetStateByUri):
asset_state = self.client.asset_state.get(msg.key, uri=msg.uri)
resp = (
asset_state
if isinstance(asset_state, ErrorResponse)
else AssetStateResult.from_asset_state_response(asset_state)
)

elif isinstance(msg, SetAssetStateByName):
self.client.asset_state.set(msg.key, msg.value, name=msg.name)
resp = OKResponse(ok=True)

elif isinstance(msg, SetAssetStateByUri):
self.client.asset_state.set(msg.key, msg.value, uri=msg.uri)
resp = OKResponse(ok=True)

elif isinstance(msg, DeleteAssetStateByName):
self.client.asset_state.delete(msg.key, name=msg.name)
resp = OKResponse(ok=True)

elif isinstance(msg, DeleteAssetStateByUri):
self.client.asset_state.delete(msg.key, uri=msg.uri)
resp = OKResponse(ok=True)

elif isinstance(msg, ClearAssetStateByName):
self.client.asset_state.clear(name=msg.name)
resp = OKResponse(ok=True)

elif isinstance(msg, ClearAssetStateByUri):
self.client.asset_state.clear(uri=msg.uri)
resp = OKResponse(ok=True)

else:
raise ValueError(f"Unknown message type {type(msg)}")

Expand Down
155 changes: 155 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,161 @@ def test_trigger_logger_fd_closed_when_upload_to_remote_raises(jobless_superviso
assert 42 not in jobless_supervisor.running_triggers


class TestTriggerSupervisorAssetState:
"""TriggerRunnerSupervisor._handle_request dispatches all asset-state comms messages added in this PR."""

asset_name: str = "my_asset"
asset_uri: str = "s3://bucket/key"

@pytest.fixture
def supervisor(self, jobless_supervisor, mocker):
mocker.patch.object(
type(jobless_supervisor), "client", new_callable=mocker.PropertyMock, return_value=mocker.Mock()
)
mocker.patch.object(type(jobless_supervisor), "send_msg", mocker.Mock())
return jobless_supervisor

def test_get_by_name(self, supervisor):
"""
Validate that GetAssetStateByName calls client.asset_state.get(key, name=...) and respond with an
AssetStateResult carrying the returned value
"""
from airflow.sdk.api.datamodels._generated import AssetStateResponse
from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByName

supervisor.client.asset_state.get.return_value = AssetStateResponse(value="2026-04-30")

msg = GetAssetStateByName(name=self.asset_name, key="watermark")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=1)

supervisor.client.asset_state.get.assert_called_once_with("watermark", name=self.asset_name)
sent = supervisor.send_msg.call_args
assert isinstance(sent.args[0], AssetStateResult)
assert sent.args[0].value == "2026-04-30"
assert sent.kwargs["request_id"] == 1

def test_get_by_uri(self, supervisor):
"""Validate call chain when retrieving using uri"""
from airflow.sdk.api.datamodels._generated import AssetStateResponse
from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByUri

supervisor.client.asset_state.get.return_value = AssetStateResponse(value="2026-04-30")

msg = GetAssetStateByUri(uri=self.asset_uri, key="watermark")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=2)

supervisor.client.asset_state.get.assert_called_once_with("watermark", uri=self.asset_uri)
assert isinstance(supervisor.send_msg.call_args.args[0], AssetStateResult)

def test_get_by_name_propagates_error_response(self, supervisor):
"""Validate that retrieving using a missing name should propagate error through call chain"""
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetStateByName

error = ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"})
supervisor.client.asset_state.get.return_value = error

msg = GetAssetStateByName(name=self.asset_name, key="missing_key")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=9)

assert supervisor.send_msg.call_args.args[0] is error

def test_get_by_uri_propagates_error_response(self, supervisor):
"""Same error-forwarding contract, but for the URI lookup variant"""
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetStateByUri

error = ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"})
supervisor.client.asset_state.get.return_value = error

msg = GetAssetStateByUri(uri=self.asset_uri, key="missing_key")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=10)

assert supervisor.send_msg.call_args.args[0] is error

def test_set_by_name(self, supervisor):
"""
Validate that SetAssetStateByName calls client.asset_state.set(key, value, name=...), responds with
OKResponse
"""
from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByName

msg = SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=3)

supervisor.client.asset_state.set.assert_called_once_with(
"watermark", "2026-04-30", name=self.asset_name
)
assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse)

def test_set_by_uri(self, supervisor):
"""
Validate that SetAssetStateByUri calls client.asset_state.set(key, value, uri=...), responds with
OKResponse
"""
from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByUri

msg = SetAssetStateByUri(uri=self.asset_uri, key="watermark", value="2026-04-30")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=4)

supervisor.client.asset_state.set.assert_called_once_with(
"watermark", "2026-04-30", uri=self.asset_uri
)
assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse)

def test_delete_by_name(self, supervisor):
"""
Validate that DeleteAssetStateByName calls client.asset_state.delete(key, name=...) responds with
OKResponse
"""
from airflow.sdk.execution_time.comms import DeleteAssetStateByName, OKResponse

msg = DeleteAssetStateByName(name=self.asset_name, key="watermark")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=5)

supervisor.client.asset_state.delete.assert_called_once_with("watermark", name=self.asset_name)
assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse)

def test_delete_by_uri(self, supervisor):
"""
Validate that DeleteAssetStateByUri calls client.asset_state.delete(key, uri=...) responds with
OKResponse
"""
from airflow.sdk.execution_time.comms import DeleteAssetStateByUri, OKResponse

msg = DeleteAssetStateByUri(uri=self.asset_uri, key="watermark")
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=6)

supervisor.client.asset_state.delete.assert_called_once_with("watermark", uri=self.asset_uri)
assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse)

def test_clear_by_name(self, supervisor):
"""
Validate that ClearAssetStateByName calls client.asset_state.clear(name=...) with no key argument,
responds with OKResponse
"""
from airflow.sdk.execution_time.comms import ClearAssetStateByName, OKResponse

msg = ClearAssetStateByName(name=self.asset_name)
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=7)

supervisor.client.asset_state.clear.assert_called_once_with(name=self.asset_name)
assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse)

def test_clear_by_uri(self, supervisor):
"""
Validate that ClearAssetStateByUri calls client.asset_state.clear(uri=...) with no key argument,
responds with OKResponse
"""
from airflow.sdk.execution_time.comms import ClearAssetStateByUri, OKResponse

msg = ClearAssetStateByUri(uri=self.asset_uri)
supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=8)

supervisor.client.asset_state.clear.assert_called_once_with(uri=self.asset_uri)
assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse)


class TestTriggerRunner:
def test_blocked_main_thread_warning_threshold_decode(self) -> None:
with conf_vars({("triggerer", "blocked_main_thread_warning_threshold"): "0.5"}):
Expand Down
3 changes: 3 additions & 0 deletions task-sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"AssetAll",
"AssetAny",
"AssetOrTimeSchedule",
"AssetState",
"AssetWatcher",
"AsyncCallback",
"BaseAsyncOperator",
Expand Down Expand Up @@ -131,6 +132,7 @@
)
from airflow.sdk.definitions.asset.decorators import asset
from airflow.sdk.definitions.asset.metadata import Metadata
from airflow.sdk.definitions.asset.state import AssetState
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import Context, get_current_context, get_parsing_context
Expand Down Expand Up @@ -195,6 +197,7 @@
"AssetAll": ".definitions.asset",
"AssetAny": ".definitions.asset",
"AssetOrTimeSchedule": ".definitions.timetables.assets",
"AssetState": ".definitions.asset.state",
"AssetWatcher": ".definitions.asset",
"AsyncCallback": ".definitions.callback",
"BaseAsyncOperator": ".bases.operator",
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ from airflow.sdk.definitions.asset import (
from airflow.sdk.definitions.asset.access_control import AssetAccessControl as AssetAccessControl
from airflow.sdk.definitions.asset.decorators import asset as asset
from airflow.sdk.definitions.asset.metadata import Metadata as Metadata
from airflow.sdk.definitions.asset.state import AssetState as AssetState
from airflow.sdk.definitions.connection import Connection as Connection
from airflow.sdk.definitions.context import (
Context as Context,
Expand Down Expand Up @@ -118,6 +119,7 @@ __all__ = [
"AssetAll",
"AssetAny",
"AssetOrTimeSchedule",
"AssetState",
"AssetWatcher",
"BaseAsyncOperator",
"BaseBranchOperator",
Expand Down
42 changes: 42 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/asset/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.sdk.execution_time.context import AssetStateAccessor


class AssetState(AssetStateAccessor):
Comment thread
jroachgolf84 marked this conversation as resolved.
"""
Access the state store for a single asset from anywhere a SUPERVISOR_COMMS
channel is available (task, callback, or trigger).

This is the equivalent of subscripting ``context['asset_state'][asset]``
inside a task, but usable from contexts where ``context`` is not bound -
most notably from inside a :class:`BaseEventTrigger`.

Identify the asset by either ``name`` or ``uri`` (exactly one is required)::

from airflow.sdk import AssetState

asset_state = AssetState(name="my_asset")
watermark = asset_state.get("watermark")
asset_state.set("watermark", "2026-01-01")
"""

def __init__(self, *, name: str | None = None, uri: str | None = None) -> None:
super().__init__(name=name, uri=uri)
Loading
Loading