From 6d1bd766cd0155761a6f647e86fabf342d087526 Mon Sep 17 00:00:00 2001 From: Jake Roach <116606359+jroachgolf84@users.noreply.github.com> Date: Wed, 20 May 2026 11:35:09 -0400 Subject: [PATCH 1/2] feature/issue-67200: Adding AssetState Task SDK mechanism --- .../src/airflow/jobs/triggerer_job_runner.py | 61 ++++++++++- .../tests/unit/jobs/test_triggerer_job.py | 100 ++++++++++++++++++ task-sdk/src/airflow/sdk/__init__.py | 3 + task-sdk/src/airflow/sdk/__init__.pyi | 2 + .../airflow/sdk/definitions/asset/state.py | 42 ++++++++ .../task_sdk/definitions/test_asset_state.py | 77 ++++++++++++++ 6 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 task-sdk/src/airflow/sdk/definitions/asset/state.py create mode 100644 task-sdk/tests/task_sdk/definitions/test_asset_state.py diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 2b4db481c266e..6d7212522d057 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -59,13 +59,20 @@ from airflow.observability.metrics import stats_utils from airflow.sdk.api.datamodels._generated import HITLDetailResponse from airflow.sdk.execution_time.comms import ( + AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, CommsDecoder, ConnectionResult, DagRunStateResult, + DeleteAssetStateByName, + DeleteAssetStateByUri, DeleteVariable, DeleteXCom, DRCount, ErrorResponse, + GetAssetStateByName, + GetAssetStateByUri, GetConnection, GetDagRunState, GetDRCount, @@ -79,6 +86,8 @@ MaskSecret, OKResponse, PutVariable, + SetAssetStateByName, + SetAssetStateByUri, SetXCom, TaskStatesResult, TICount, @@ -303,6 +312,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe ToTriggerRunner = Annotated[ messages.StartTriggerer | messages.TriggerStateSync + | AssetStateResult | ConnectionResult | VariableResult | VariableKeysResult @@ -339,7 +349,15 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | GetPreviousTI | GetHITLDetailResponse | UpdateHITLDetail - | MaskSecret, + | MaskSecret + | GetAssetStateByName + | GetAssetStateByUri + | SetAssetStateByName + | SetAssetStateByUri + | DeleteAssetStateByName + | DeleteAssetStateByUri + | ClearAssetStateByName + | ClearAssetStateByUri, Field(discriminator="type"), ] """ @@ -620,6 +638,47 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = HITLDetailResponseResult.from_api_response(response=api_resp) elif isinstance(msg, MaskSecret): handle_mask_secret(msg) + + elif isinstance(msg, GetAssetStateByName): + asset_state = self.client.asset_state.get(msg.key, name=msg.name) + resp = ( + asset_state + if isinstance(asset_state, ErrorResponse) + else AssetStateResult.from_asset_state_response(asset_state) + ) + + elif isinstance(msg, GetAssetStateByUri): + asset_state = self.client.asset_state.get(msg.key, uri=msg.uri) + resp = ( + asset_state + if isinstance(asset_state, ErrorResponse) + else AssetStateResult.from_asset_state_response(asset_state) + ) + + elif isinstance(msg, SetAssetStateByName): + self.client.asset_state.set(msg.key, msg.value, name=msg.name) + resp = OKResponse(ok=True) + + elif isinstance(msg, SetAssetStateByUri): + self.client.asset_state.set(msg.key, msg.value, uri=msg.uri) + resp = OKResponse(ok=True) + + elif isinstance(msg, DeleteAssetStateByName): + self.client.asset_state.delete(msg.key, name=msg.name) + resp = OKResponse(ok=True) + + elif isinstance(msg, DeleteAssetStateByUri): + self.client.asset_state.delete(msg.key, uri=msg.uri) + resp = OKResponse(ok=True) + + elif isinstance(msg, ClearAssetStateByName): + self.client.asset_state.clear(name=msg.name) + resp = OKResponse(ok=True) + + elif isinstance(msg, ClearAssetStateByUri): + self.client.asset_state.clear(uri=msg.uri) + resp = OKResponse(ok=True) + else: raise ValueError(f"Unknown message type {type(msg)}") diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 0501783b992d2..f51581d834ac6 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -635,6 +635,106 @@ def test_trigger_logger_fd_closed_when_upload_to_remote_raises(jobless_superviso assert 42 not in jobless_supervisor.running_triggers +class TestTriggerSupervisorAssetState: + """Verify the trigger supervisor handles the asset-state comms messages it now accepts.""" + + asset_name: str = "my_asset" + asset_uri: str = "s3://bucket/key" + + @pytest.fixture + def supervisor(self, jobless_supervisor, mocker): + mocker.patch.object( + type(jobless_supervisor), "client", new_callable=mocker.PropertyMock, return_value=mocker.Mock() + ) + mocker.patch.object(type(jobless_supervisor), "send_msg", mocker.Mock()) + return jobless_supervisor + + def test_get_by_name(self, supervisor): + from airflow.sdk.api.datamodels._generated import AssetStateResponse + from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByName + + supervisor.client.asset_state.get.return_value = AssetStateResponse(value="2026-04-30") + + msg = GetAssetStateByName(name=self.asset_name, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=1) + + supervisor.client.asset_state.get.assert_called_once_with("watermark", name=self.asset_name) + sent = supervisor.send_msg.call_args + assert isinstance(sent.args[0], AssetStateResult) + assert sent.args[0].value == "2026-04-30" + assert sent.kwargs["request_id"] == 1 + + def test_get_by_uri(self, supervisor): + from airflow.sdk.api.datamodels._generated import AssetStateResponse + from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByUri + + supervisor.client.asset_state.get.return_value = AssetStateResponse(value="2026-04-30") + + msg = GetAssetStateByUri(uri=self.asset_uri, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=2) + + supervisor.client.asset_state.get.assert_called_once_with("watermark", uri=self.asset_uri) + assert isinstance(supervisor.send_msg.call_args.args[0], AssetStateResult) + + def test_set_by_name(self, supervisor): + from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByName + + msg = SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=3) + + supervisor.client.asset_state.set.assert_called_once_with( + "watermark", "2026-04-30", name=self.asset_name + ) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_set_by_uri(self, supervisor): + from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByUri + + msg = SetAssetStateByUri(uri=self.asset_uri, key="watermark", value="2026-04-30") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=4) + + supervisor.client.asset_state.set.assert_called_once_with( + "watermark", "2026-04-30", uri=self.asset_uri + ) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_delete_by_name(self, supervisor): + from airflow.sdk.execution_time.comms import DeleteAssetStateByName, OKResponse + + msg = DeleteAssetStateByName(name=self.asset_name, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=5) + + supervisor.client.asset_state.delete.assert_called_once_with("watermark", name=self.asset_name) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_delete_by_uri(self, supervisor): + from airflow.sdk.execution_time.comms import DeleteAssetStateByUri, OKResponse + + msg = DeleteAssetStateByUri(uri=self.asset_uri, key="watermark") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=6) + + supervisor.client.asset_state.delete.assert_called_once_with("watermark", uri=self.asset_uri) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_clear_by_name(self, supervisor): + from airflow.sdk.execution_time.comms import ClearAssetStateByName, OKResponse + + msg = ClearAssetStateByName(name=self.asset_name) + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=7) + + supervisor.client.asset_state.clear.assert_called_once_with(name=self.asset_name) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + def test_clear_by_uri(self, supervisor): + from airflow.sdk.execution_time.comms import ClearAssetStateByUri, OKResponse + + msg = ClearAssetStateByUri(uri=self.asset_uri) + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=8) + + supervisor.client.asset_state.clear.assert_called_once_with(uri=self.asset_uri) + assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) + + class TestTriggerRunner: def test_blocked_main_thread_warning_threshold_decode(self) -> None: with conf_vars({("triggerer", "blocked_main_thread_warning_threshold"): "0.5"}): diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 05ececc29562d..8bce2fd95ee22 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -26,6 +26,7 @@ "AssetAll", "AssetAny", "AssetOrTimeSchedule", + "AssetState", "AssetWatcher", "AsyncCallback", "BaseAsyncOperator", @@ -122,6 +123,7 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.asset.state import AssetState from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context, get_current_context, get_parsing_context @@ -184,6 +186,7 @@ "AssetAll": ".definitions.asset", "AssetAny": ".definitions.asset", "AssetOrTimeSchedule": ".definitions.timetables.assets", + "AssetState": ".definitions.asset.state", "AssetWatcher": ".definitions.asset", "AsyncCallback": ".definitions.callback", "BaseAsyncOperator": ".bases.operator", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 7e6d211674eba..b6299c858300a 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -51,6 +51,7 @@ from airflow.sdk.definitions.asset import ( ) from airflow.sdk.definitions.asset.decorators import asset as asset from airflow.sdk.definitions.asset.metadata import Metadata as Metadata +from airflow.sdk.definitions.asset.state import AssetState as AssetState from airflow.sdk.definitions.connection import Connection as Connection from airflow.sdk.definitions.context import ( Context as Context, @@ -115,6 +116,7 @@ __all__ = [ "AssetAll", "AssetAny", "AssetOrTimeSchedule", + "AssetState", "AssetWatcher", "BaseAsyncOperator", "BaseBranchOperator", diff --git a/task-sdk/src/airflow/sdk/definitions/asset/state.py b/task-sdk/src/airflow/sdk/definitions/asset/state.py new file mode 100644 index 0000000000000..9647c4f575e8e --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/asset/state.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk.execution_time.context import AssetStateAccessor + + +class AssetState(AssetStateAccessor): + """ + Access the state store for a single asset from anywhere a SUPERVISOR_COMMS + channel is available (task, callback, or trigger). + + This is the equivalent of subscripting ``context['asset_state'][asset]`` + inside a task, but usable from contexts where ``context`` is not bound - + most notably from inside a :class:`BaseEventTrigger`. + + Identify the asset by either ``name`` or ``uri`` (exactly one is required):: + + from airflow.sdk import AssetState + + asset_state = AssetState(name="my_asset") + watermark = asset_state.get("watermark") + asset_state.set("watermark", "2026-01-01") + """ + + def __init__(self, *, name: str | None = None, uri: str | None = None) -> None: + super().__init__(name=name, uri=uri) diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_state.py b/task-sdk/tests/task_sdk/definitions/test_asset_state.py new file mode 100644 index 0000000000000..53e2e838a3841 --- /dev/null +++ b/task-sdk/tests/task_sdk/definitions/test_asset_state.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.sdk import AssetState +from airflow.sdk.definitions.asset.state import AssetState as DirectAssetState +from airflow.sdk.execution_time.comms import ( + AssetStateResult, + GetAssetStateByName, + GetAssetStateByUri, + OKResponse, + SetAssetStateByName, +) +from airflow.sdk.execution_time.context import AssetStateAccessor + + +class TestAssetState: + """Validate the AssetState SDK interface.""" + asset_name: str = "my_asset" + asset_uri: str = "s3://bucket/key" + + def test_lazy_import_from_airflow_sdk(self): + assert AssetState is DirectAssetState + + def test_is_asset_state_accessor_subclass(self): + assert issubclass(AssetState, AssetStateAccessor) + assert isinstance(AssetState(name=self.asset_name), AssetStateAccessor) + + def test_requires_name_or_uri(self): + with pytest.raises(ValueError, match="Either `name` or `uri` must be provided"): + AssetState() + + def test_get_by_name_sends_supervisor_message(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") + + result = AssetState(name=self.asset_name).get("watermark") + + assert result == "2026-04-30T00:00:00Z" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByName(name=self.asset_name, key="watermark") + ) + + def test_get_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") + + result = AssetState(uri=self.asset_uri).get("watermark") + + assert result == "2026-04-30T00:00:00Z" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByUri(uri=self.asset_uri, key="watermark") + ) + + def test_set_by_name_sends_supervisor_message(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(name=self.asset_name).set("watermark", "2026-04-30T00:00:00Z") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30T00:00:00Z") + ) From 5414d32bc0250da226effd9caca9b36efac512d3 Mon Sep 17 00:00:00 2001 From: Jake Roach <116606359+jroachgolf84@users.noreply.github.com> Date: Wed, 20 May 2026 12:55:19 -0400 Subject: [PATCH 2/2] feature/issue-67200: Enhancing unit test --- .../tests/unit/jobs/test_triggerer_job.py | 57 ++++++++++- .../task_sdk/definitions/test_asset_state.py | 99 ++++++++++++++++++- 2 files changed, 154 insertions(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index f51581d834ac6..2503308867491 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -636,7 +636,7 @@ def test_trigger_logger_fd_closed_when_upload_to_remote_raises(jobless_superviso class TestTriggerSupervisorAssetState: - """Verify the trigger supervisor handles the asset-state comms messages it now accepts.""" + """TriggerRunnerSupervisor._handle_request dispatches all asset-state comms messages added in this PR.""" asset_name: str = "my_asset" asset_uri: str = "s3://bucket/key" @@ -650,6 +650,10 @@ def supervisor(self, jobless_supervisor, mocker): return jobless_supervisor def test_get_by_name(self, supervisor): + """ + Validate that GetAssetStateByName calls client.asset_state.get(key, name=...) and respond with an + AssetStateResult carrying the returned value + """ from airflow.sdk.api.datamodels._generated import AssetStateResponse from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByName @@ -665,6 +669,7 @@ def test_get_by_name(self, supervisor): assert sent.kwargs["request_id"] == 1 def test_get_by_uri(self, supervisor): + """Validate call chain when retrieving using uri""" from airflow.sdk.api.datamodels._generated import AssetStateResponse from airflow.sdk.execution_time.comms import AssetStateResult, GetAssetStateByUri @@ -676,7 +681,37 @@ def test_get_by_uri(self, supervisor): supervisor.client.asset_state.get.assert_called_once_with("watermark", uri=self.asset_uri) assert isinstance(supervisor.send_msg.call_args.args[0], AssetStateResult) + def test_get_by_name_propagates_error_response(self, supervisor): + """Validate that retrieving using a missing name should propagate error through call chain""" + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetStateByName + + error = ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"}) + supervisor.client.asset_state.get.return_value = error + + msg = GetAssetStateByName(name=self.asset_name, key="missing_key") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=9) + + assert supervisor.send_msg.call_args.args[0] is error + + def test_get_by_uri_propagates_error_response(self, supervisor): + """Same error-forwarding contract, but for the URI lookup variant""" + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetStateByUri + + error = ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"}) + supervisor.client.asset_state.get.return_value = error + + msg = GetAssetStateByUri(uri=self.asset_uri, key="missing_key") + supervisor._handle_request(msg, log=MagicMock(spec=FilteringBoundLogger), req_id=10) + + assert supervisor.send_msg.call_args.args[0] is error + def test_set_by_name(self, supervisor): + """ + Validate that SetAssetStateByName calls client.asset_state.set(key, value, name=...), responds with + OKResponse + """ from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByName msg = SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30") @@ -688,6 +723,10 @@ def test_set_by_name(self, supervisor): assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) def test_set_by_uri(self, supervisor): + """ + Validate that SetAssetStateByUri calls client.asset_state.set(key, value, uri=...), responds with + OKResponse + """ from airflow.sdk.execution_time.comms import OKResponse, SetAssetStateByUri msg = SetAssetStateByUri(uri=self.asset_uri, key="watermark", value="2026-04-30") @@ -699,6 +738,10 @@ def test_set_by_uri(self, supervisor): assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) def test_delete_by_name(self, supervisor): + """ + Validate that DeleteAssetStateByName calls client.asset_state.delete(key, name=...) responds with + OKResponse + """ from airflow.sdk.execution_time.comms import DeleteAssetStateByName, OKResponse msg = DeleteAssetStateByName(name=self.asset_name, key="watermark") @@ -708,6 +751,10 @@ def test_delete_by_name(self, supervisor): assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) def test_delete_by_uri(self, supervisor): + """ + Validate that DeleteAssetStateByUri calls client.asset_state.delete(key, uri=...) responds with + OKResponse + """ from airflow.sdk.execution_time.comms import DeleteAssetStateByUri, OKResponse msg = DeleteAssetStateByUri(uri=self.asset_uri, key="watermark") @@ -717,6 +764,10 @@ def test_delete_by_uri(self, supervisor): assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) def test_clear_by_name(self, supervisor): + """ + Validate that ClearAssetStateByName calls client.asset_state.clear(name=...) with no key argument, + responds with OKResponse + """ from airflow.sdk.execution_time.comms import ClearAssetStateByName, OKResponse msg = ClearAssetStateByName(name=self.asset_name) @@ -726,6 +777,10 @@ def test_clear_by_name(self, supervisor): assert isinstance(supervisor.send_msg.call_args.args[0], OKResponse) def test_clear_by_uri(self, supervisor): + """ + Validate that ClearAssetStateByUri calls client.asset_state.clear(uri=...) with no key argument, + responds with OKResponse + """ from airflow.sdk.execution_time.comms import ClearAssetStateByUri, OKResponse msg = ClearAssetStateByUri(uri=self.asset_uri) diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_state.py b/task-sdk/tests/task_sdk/definitions/test_asset_state.py index 53e2e838a3841..f5a8822461823 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset_state.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset_state.py @@ -18,36 +18,63 @@ from __future__ import annotations import pytest +from pydantic import ValidationError from airflow.sdk import AssetState from airflow.sdk.definitions.asset.state import AssetState as DirectAssetState +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ( AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, + DeleteAssetStateByName, + DeleteAssetStateByUri, + ErrorResponse, GetAssetStateByName, GetAssetStateByUri, OKResponse, SetAssetStateByName, + SetAssetStateByUri, ) from airflow.sdk.execution_time.context import AssetStateAccessor class TestAssetState: - """Validate the AssetState SDK interface.""" + """Validate the public AssetState SDK interface.""" + asset_name: str = "my_asset" asset_uri: str = "s3://bucket/key" def test_lazy_import_from_airflow_sdk(self): + """Validate the lazy __init__.py alias resolves to the real class""" assert AssetState is DirectAssetState def test_is_asset_state_accessor_subclass(self): + """Validate that AssetState inherits all AssetStateAccess logic""" assert issubclass(AssetState, AssetStateAccessor) assert isinstance(AssetState(name=self.asset_name), AssetStateAccessor) def test_requires_name_or_uri(self): + """Validate that constructing without either identifier must fail fast at init time""" with pytest.raises(ValueError, match="Either `name` or `uri` must be provided"): AssetState() + def test_set_fails_on_non_string_key(self, mock_supervisor_comms): + """Validate that set(key, value) where isinstance(key, str) false raises a ValidationError""" + with pytest.raises(ValidationError): + AssetState(name=self.asset_name).set(123, "some_value") # type: ignore[arg-type] + + mock_supervisor_comms.send.assert_not_called() + + def test_set_fails_on_non_string_value(self, mock_supervisor_comms): + """Validate that set(key, value) where isinstance(value, str) false raises a ValidationError""" + with pytest.raises(ValidationError): + AssetState(name=self.asset_name).set("watermark", 12345) # type: ignore[arg-type] + + mock_supervisor_comms.send.assert_not_called() + def test_get_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that get() with name=... dispatches GetAssetStateByName and unwraps the value""" mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") result = AssetState(name=self.asset_name).get("watermark") @@ -58,6 +85,7 @@ def test_get_by_name_sends_supervisor_message(self, mock_supervisor_comms): ) def test_get_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that get() with uri=... dispatches GetAssetStateByUri and unwraps the value""" mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") result = AssetState(uri=self.asset_uri).get("watermark") @@ -67,7 +95,30 @@ def test_get_by_uri_sends_supervisor_message(self, mock_supervisor_comms): GetAssetStateByUri(uri=self.asset_uri, key="watermark") ) + def test_get_returns_none_when_key_not_found(self, mock_supervisor_comms): + """ + Validate that a 404-style ASSET_STATE_NOT_FOUND response must silently return None rather than + raising, matching the contract documented in AssetStateAccessor + """ + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"} + ) + + result = AssetState(name=self.asset_name).get("missing_key") + + assert result is None + + def test_get_raises_on_generic_error(self, mock_supervisor_comms): + """Validate that any error other than ASSET_STATE_NOT_FOUND must propagate as AirflowRuntimeError""" + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.GENERIC_ERROR, detail={"message": "server error"} + ) + + with pytest.raises(AirflowRuntimeError): + AssetState(name=self.asset_name).get("some_key") + def test_set_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that set() with name= dispatches SetAssetStateByName""" mock_supervisor_comms.send.return_value = OKResponse(ok=True) AssetState(name=self.asset_name).set("watermark", "2026-04-30T00:00:00Z") @@ -75,3 +126,49 @@ def test_set_by_name_sends_supervisor_message(self, mock_supervisor_comms): mock_supervisor_comms.send.assert_called_once_with( SetAssetStateByName(name=self.asset_name, key="watermark", value="2026-04-30T00:00:00Z") ) + + def test_set_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that set() with uri= dispatches SetAssetStateByUri""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(uri=self.asset_uri).set("watermark", "2026-04-30T00:00:00Z") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByUri(uri=self.asset_uri, key="watermark", value="2026-04-30T00:00:00Z") + ) + + def test_delete_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that delete() with name= dispatches DeleteAssetStateByName""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(name=self.asset_name).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByName(name=self.asset_name, key="watermark") + ) + + def test_delete_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that delete() with uri= dispatches DeleteAssetStateByUri""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(uri=self.asset_uri).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByUri(uri=self.asset_uri, key="watermark") + ) + + def test_clear_by_name_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that clear() with name= dispatches ClearAssetStateByName (no key argument)""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(name=self.asset_name).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByName(name=self.asset_name)) + + def test_clear_by_uri_sends_supervisor_message(self, mock_supervisor_comms): + """Validate that clear() with uri= dispatches ClearAssetStateByUri (no key argument)""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetState(uri=self.asset_uri).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByUri(uri=self.asset_uri))