diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 6f8f7baae84ab..6b32210f3cfa1 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -59,13 +59,20 @@ from airflow.observability.metrics import stats_utils from airflow.sdk.api.datamodels._generated import HITLDetailResponse from airflow.sdk.execution_time.comms import ( + AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, CommsDecoder, ConnectionResult, DagRunStateResult, + DeleteAssetStateByName, + DeleteAssetStateByUri, DeleteVariable, DeleteXCom, DRCount, ErrorResponse, + GetAssetStateByName, + GetAssetStateByUri, GetConnection, GetDagRunState, GetDRCount, @@ -79,6 +86,8 @@ MaskSecret, OKResponse, PutVariable, + SetAssetStateByName, + SetAssetStateByUri, SetXCom, TaskStatesResult, TICount, @@ -313,6 +322,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe ToTriggerRunner = Annotated[ messages.StartTriggerer | messages.TriggerStateSync + | AssetStateResult | ConnectionResult | VariableResult | VariableKeysResult @@ -349,7 +359,15 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | GetPreviousTI | GetHITLDetailResponse | UpdateHITLDetail - | MaskSecret, + | MaskSecret + | GetAssetStateByName + | GetAssetStateByUri + | SetAssetStateByName + | SetAssetStateByUri + | DeleteAssetStateByName + | DeleteAssetStateByUri + | ClearAssetStateByName + | ClearAssetStateByUri, Field(discriminator="type"), ] """ @@ -579,6 +597,47 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = HITLDetailResponseResult.from_api_response(response=api_resp) elif isinstance(msg, MaskSecret): handle_mask_secret(msg) + + elif isinstance(msg, GetAssetStateByName): + asset_state = self.client.asset_state.get(msg.key, name=msg.name) + resp = ( + asset_state + if isinstance(asset_state, ErrorResponse) + else AssetStateResult.from_asset_state_response(asset_state) + ) + + elif isinstance(msg, GetAssetStateByUri): + asset_state = self.client.asset_state.get(msg.key, uri=msg.uri) + resp = ( + asset_state + if isinstance(asset_state, ErrorResponse) + else AssetStateResult.from_asset_state_response(asset_state) + ) + + elif isinstance(msg, SetAssetStateByName): + self.client.asset_state.set(msg.key, msg.value, name=msg.name) + resp = OKResponse(ok=True) + + elif isinstance(msg, SetAssetStateByUri): + self.client.asset_state.set(msg.key, msg.value, uri=msg.uri) + resp = OKResponse(ok=True) + + elif isinstance(msg, DeleteAssetStateByName): + self.client.asset_state.delete(msg.key, name=msg.name) + resp = OKResponse(ok=True) + + elif isinstance(msg, DeleteAssetStateByUri): + self.client.asset_state.delete(msg.key, uri=msg.uri) + resp = OKResponse(ok=True) + + elif isinstance(msg, ClearAssetStateByName): + self.client.asset_state.clear(name=msg.name) + resp = OKResponse(ok=True) + + elif isinstance(msg, ClearAssetStateByUri): + self.client.asset_state.clear(uri=msg.uri) + resp = OKResponse(ok=True) + else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 0501783b992d2..2503308867491 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -635,6 +635,161 @@ def test_trigger_logger_fd_closed_when_upload_to_remote_raises(jobless_superviso assert 42 not in jobless_supervisor.running_triggers +class TestTriggerSupervisorAssetState: + """TriggerRunnerSupervisor._handle_request dispatches all asset-state comms messages added in this PR.""" + + asset_name: str = "my_asset" + asset_uri: str = "s3://bucket/key" + + @pytest.fixture + def supervisor(self, jobless_supervisor, mocker): + mocker.patch.object( + type(jobless_supervisor), "client", new_callable=mocker.PropertyMock, return_value=mocker.Mock() + ) + mocker.patch.object(type(jobless_supervisor), "send_msg", mocker.Mock()) + return jobless_supervisor + + def test_get_by_name(self, supervisor): + """ + Validate that GetAssetStateByName calls client.asset_state.get(key, name=...) and respond with an + AssetStateResult carrying the returned value + """ + from airflow.sdk.api.datamodels._generated import AssetStateResponse + from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByName + + supervisor.client.asset_state.get.return_value = AssetStateResponse(value="2026-04-30") + + msg = GetAssetStateByName(name=self.asset_name, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=1) + + supervisor.client.asset_state.get.assert_called_once_with("watermark", name=self.asset_name) + sent = supervisor.send_msg.call_args + assert isinstance(sent.args[0], AssetStateResult) + assert sent.args[0].value == "2026-04-30" + assert sent.kwargs["request_id"] == 1 + + def test_get_by_uri(self, supervisor): + """Validate call chain when retrieving using uri""" + from airflow.sdk.api.datamodels._generated import AssetStateResponse + from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByUri + + supervisor.client.asset_state.get.return_value = AssetStateResponse(value="2026-04-30") + + msg = GetAssetStateByUri(uri=self.asset_uri, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=2) + + supervisor.client.asset_state.get.assert_called_once_with("watermark", uri=self.asset_uri) + assert isinstance(supervisor.send_msg.call_args.args[0], AssetStateResult) + + def test_get_by_name_propagates_error_response(self, supervisor): + """Validate that retrieving using a missing name should propagate error through call chain""" + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetStateByName + + error = ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"}) + supervisor.client.asset_state.get.return_value = error + + msg = GetAssetStateByName(name=self.asset_name, key="missing_key") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=9) + + assert supervisor.send_msg.call_args.args[0] is error + + def test_get_by_uri_propagates_error_response(self, supervisor): + """Same error-forwarding contract, but for the URI lookup variant""" + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetStateByUri + + error = ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"}) + supervisor.client.asset_state.get.return_value = error + + msg = GetAssetStateByUri(uri=self.asset_uri, key="missing_key") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=10) + + assert supervisor.send_msg.call_args.args[0] is error + + def test_set_by_name(self, supervisor): + """ + Validate that SetAssetStateByName calls client.asset_state.set(key, value, name=...), responds with + OKResponse + """ + from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByName + + msg = SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=3) + + supervisor.client.asset_state.set.assert_called_once_with( + "watermark", "2026-04-30", name=self.asset_name + ) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_set_by_uri(self, supervisor): + """ + Validate that SetAssetStateByUri calls client.asset_state.set(key, value, uri=...), responds with + OKResponse + """ + from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByUri + + msg = SetAssetStateByUri(uri=self.asset_uri, key="watermark", value="2026-04-30") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=4) + + supervisor.client.asset_state.set.assert_called_once_with( + "watermark", "2026-04-30", uri=self.asset_uri + ) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_delete_by_name(self, supervisor): + """ + Validate that DeleteAssetStateByName calls client.asset_state.delete(key, name=...) responds with + OKResponse + """ + from airflow.sdk.execution_time.comms import DeleteAssetStateByName, OKResponse + + msg = DeleteAssetStateByName(name=self.asset_name, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=5) + + supervisor.client.asset_state.delete.assert_called_once_with("watermark", name=self.asset_name) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_delete_by_uri(self, supervisor): + """ + Validate that DeleteAssetStateByUri calls client.asset_state.delete(key, uri=...) responds with + OKResponse + """ + from airflow.sdk.execution_time.comms import DeleteAssetStateByUri, OKResponse + + msg = DeleteAssetStateByUri(uri=self.asset_uri, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=6) + + supervisor.client.asset_state.delete.assert_called_once_with("watermark", uri=self.asset_uri) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_clear_by_name(self, supervisor): + """ + Validate that ClearAssetStateByName calls client.asset_state.clear(name=...) with no key argument, + responds with OKResponse + """ + from airflow.sdk.execution_time.comms import ClearAssetStateByName, OKResponse + + msg = ClearAssetStateByName(name=self.asset_name) + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=7) + + supervisor.client.asset_state.clear.assert_called_once_with(name=self.asset_name) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_clear_by_uri(self, supervisor): + """ + Validate that ClearAssetStateByUri calls client.asset_state.clear(uri=...) with no key argument, + responds with OKResponse + """ + from airflow.sdk.execution_time.comms import ClearAssetStateByUri, OKResponse + + msg = ClearAssetStateByUri(uri=self.asset_uri) + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=8) + + supervisor.client.asset_state.clear.assert_called_once_with(uri=self.asset_uri) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + class TestTriggerRunner: def test_blocked_main_thread_warning_threshold_decode(self) -> None: with conf_vars({("triggerer", "blocked_main_thread_warning_threshold"): "0.5"}): diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index eeae86f1eb3d0..a1ed343ce6a2b 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -27,6 +27,7 @@ "AssetAll", "AssetAny", "AssetOrTimeSchedule", + "AssetState", "AssetWatcher", "AsyncCallback", "BaseAsyncOperator", @@ -131,6 +132,7 @@ ) from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.asset.state import AssetState from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context, get_current_context, get_parsing_context @@ -195,6 +197,7 @@ "AssetAll": ".definitions.asset", "AssetAny": ".definitions.asset", "AssetOrTimeSchedule": ".definitions.timetables.assets", + "AssetState": ".definitions.asset.state", "AssetWatcher": ".definitions.asset", "AsyncCallback": ".definitions.callback", "BaseAsyncOperator": ".bases.operator", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index a947e3676df06..533a4af9fa2d4 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -52,6 +52,7 @@ from airflow.sdk.definitions.asset import ( from airflow.sdk.definitions.asset.access_control import AssetAccessControl as AssetAccessControl from airflow.sdk.definitions.asset.decorators import asset as asset from airflow.sdk.definitions.asset.metadata import Metadata as Metadata +from airflow.sdk.definitions.asset.state import AssetState as AssetState from airflow.sdk.definitions.connection import Connection as Connection from airflow.sdk.definitions.context import ( Context as Context, @@ -118,6 +119,7 @@ __all__ = [ "AssetAll", "AssetAny", "AssetOrTimeSchedule", + "AssetState", "AssetWatcher", "BaseAsyncOperator", "BaseBranchOperator", diff --git a/task-sdk/src/airflow/sdk/definitions/asset/state.py b/task-sdk/src/airflow/sdk/definitions/asset/state.py new file mode 100644 index 0000000000000..9647c4f575e8e --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/asset/state.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk.execution_time.context import AssetStateAccessor + + +class AssetState(AssetStateAccessor): + """ + Access the state store for a single asset from anywhere a SUPERVISOR_COMMS + channel is available (task, callback, or trigger). + + This is the equivalent of subscripting ``context['asset_state'][asset]`` + inside a task, but usable from contexts where ``context`` is not bound - + most notably from inside a :class:`BaseEventTrigger`. + + Identify the asset by either ``name`` or ``uri`` (exactly one is required):: + + from airflow.sdk import AssetState + + asset_state = AssetState(name="my_asset") + watermark = asset_state.get("watermark") + asset_state.set("watermark", "2026-01-01") + """ + + def __init__(self, *, name: str | None = None, uri: str | None = None) -> None: + super().__init__(name=name, uri=uri) diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_state.py b/task-sdk/tests/task_sdk/definitions/test_asset_state.py new file mode 100644 index 0000000000000..f5a8822461823 --- /dev/null +++ b/task-sdk/tests/task_sdk/definitions/test_asset_state.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from airflow.sdk import AssetState +from airflow.sdk.definitions.asset.state import AssetState as DirectAssetState +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time.comms import ( + AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, + DeleteAssetStateByName, + DeleteAssetStateByUri, + ErrorResponse, + GetAssetStateByName, + GetAssetStateByUri, + OKResponse, + SetAssetStateByName, + SetAssetStateByUri, +) +from airflow.sdk.execution_time.context import AssetStateAccessor + + +class TestAssetState: + """Validate the public AssetState SDK interface.""" + + asset_name: str = "my_asset" + asset_uri: str = "s3://bucket/key" + + def test_lazy_import_from_airflow_sdk(self): + """Validate the lazy __init__.py alias resolves to the real class""" + assert AssetState is DirectAssetState + + def test_is_asset_state_accessor_subclass(self): + """Validate that AssetState inherits all AssetStateAccess logic""" + assert issubclass(AssetState, AssetStateAccessor) + assert isinstance(AssetState(name=self.asset_name), AssetStateAccessor) + + def test_requires_name_or_uri(self): + """Validate that constructing without either identifier must fail fast at init time""" + with pytest.raises(ValueError, match="Either `name` or `uri` must be provided"): + AssetState() + + def test_set_fails_on_non_string_key(self, mock_supervisor_comms): + """Validate that set(key, value) where isinstance(key, str) false raises a ValidationError""" + with pytest.raises(ValidationError): + AssetState(name=self.asset_name).set(123, "some_value") # type: ignore[arg-type] + + mock_supervisor_comms.send.assert_not_called() + + def test_set_fails_on_non_string_value(self, mock_supervisor_comms): + """Validate that set(key, value) where isinstance(value, str) false raises a ValidationError""" + with pytest.raises(ValidationError): + AssetState(name=self.asset_name).set("watermark", 12345) # type: ignore[arg-type] + + mock_supervisor_comms.send.assert_not_called() + + def test_get_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that get() with name=... dispatches GetAssetStateByName and unwraps the value""" + mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") + + result = AssetState(name=self.asset_name).get("watermark") + + assert result == "2026-04-30T00:00:00Z" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByName(name=self.asset_name, key="watermark") + ) + + def test_get_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that get() with uri=... dispatches GetAssetStateByUri and unwraps the value""" + mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") + + result = AssetState(uri=self.asset_uri).get("watermark") + + assert result == "2026-04-30T00:00:00Z" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByUri(uri=self.asset_uri, key="watermark") + ) + + def test_get_returns_none_when_key_not_found(self, mock_supervisor_comms): + """ + Validate that a 404-style ASSET_STATE_NOT_FOUND response must silently return None rather than + raising, matching the contract documented in AssetStateAccessor + """ + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"} + ) + + result = AssetState(name=self.asset_name).get("missing_key") + + assert result is None + + def test_get_raises_on_generic_error(self, mock_supervisor_comms): + """Validate that any error other than ASSET_STATE_NOT_FOUND must propagate as AirflowRuntimeError""" + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.GENERIC_ERROR, detail={"message": "server error"} + ) + + with pytest.raises(AirflowRuntimeError): + AssetState(name=self.asset_name).get("some_key") + + def test_set_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that set() with name= dispatches SetAssetStateByName""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(name=self.asset_name).set("watermark", "2026-04-30T00:00:00Z") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30T00:00:00Z") + ) + + def test_set_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that set() with uri= dispatches SetAssetStateByUri""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(uri=self.asset_uri).set("watermark", "2026-04-30T00:00:00Z") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByUri(uri=self.asset_uri, key="watermark", value="2026-04-30T00:00:00Z") + ) + + def test_delete_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that delete() with name= dispatches DeleteAssetStateByName""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(name=self.asset_name).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByName(name=self.asset_name, key="watermark") + ) + + def test_delete_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that delete() with uri= dispatches DeleteAssetStateByUri""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(uri=self.asset_uri).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByUri(uri=self.asset_uri, key="watermark") + ) + + def test_clear_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that clear() with name= dispatches ClearAssetStateByName (no key argument)""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(name=self.asset_name).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByName(name=self.asset_name)) + + def test_clear_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that clear() with uri= dispatches ClearAssetStateByUri (no key argument)""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(uri=self.asset_uri).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByUri(uri=self.asset_uri))