diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py index ec773201c7e2f..ab8b3aa2aec43 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py @@ -17,16 +17,29 @@ from __future__ import annotations +import math + +from pydantic import JsonValue, field_validator + from airflow.api_fastapi.core_api.base import StrictBaseModel class AssetStateResponse(StrictBaseModel): """Asset state value returned to a worker.""" - value: str + value: JsonValue class AssetStatePutBody(StrictBaseModel): """Request body for setting an asset state value.""" - value: str + value: JsonValue + + @field_validator("value") + @classmethod + def value_is_json_representable(cls, v: JsonValue) -> JsonValue: + if v is None: + raise ValueError("value cannot be null") + if isinstance(v, float) and not math.isfinite(v): + raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") + return v diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py index 20980b315c3d1..15fc44b726789 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py @@ -17,19 +17,31 @@ from __future__ import annotations +import math from datetime import datetime +from pydantic import JsonValue, field_validator + from airflow.api_fastapi.core_api.base import StrictBaseModel class TaskStateResponse(StrictBaseModel): """Task state value returned to a worker.""" - value: str + value: JsonValue class TaskStatePutBody(StrictBaseModel): """Request body for setting a task state value.""" - value: str + value: JsonValue expires_at: datetime | None = None + + @field_validator("value") + @classmethod + def value_is_json_representable(cls, v: JsonValue) -> JsonValue: + if v is None: + raise ValueError("value cannot be null") + if isinstance(v, float) and not math.isfinite(v): + raise ValueError("value must be a finite number; NaN and Inf are not JSON representable") + return v diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py index 2351caa6dfaf0..f7001c3158c88 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py @@ -28,6 +28,7 @@ from __future__ import annotations +import json from typing import Annotated from cadwyn import VersionedAPIRouter @@ -93,7 +94,7 @@ def get_asset_state_by_name( status_code=status.HTTP_404_NOT_FOUND, detail={"reason": "not_found", "message": f"Asset state key {key!r} not found"}, ) - return AssetStateResponse(value=value) + return AssetStateResponse(value=json.loads(value)) @router.put("/by-name/value", status_code=status.HTTP_204_NO_CONTENT) @@ -105,7 +106,7 @@ def set_asset_state_by_name( ) -> None: """Set an asset state value by asset name.""" asset_id = _resolve_asset_id_by_name(name, session) - get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session) + get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session) @router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT) @@ -143,7 +144,7 @@ def get_asset_state_by_uri( status_code=status.HTTP_404_NOT_FOUND, detail={"reason": "not_found", "message": f"Asset state key {key!r} not found"}, ) - return AssetStateResponse(value=value) + return AssetStateResponse(value=json.loads(value)) @router.put("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT) @@ -155,7 +156,7 @@ def set_asset_state_by_uri( ) -> None: """Set an asset state value by asset URI.""" asset_id = _resolve_asset_id_by_uri(uri, session) - get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session) + get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session) @router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py index 2f824e3ebb2f4..c59f2461e2aa2 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from typing import Annotated from uuid import UUID @@ -74,7 +75,7 @@ def get_task_state( "message": f"Task state key {key!r} not found", }, ) - return TaskStateResponse(value=value) + return TaskStateResponse(value=json.loads(value)) @router.put("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT) @@ -86,7 +87,7 @@ def set_task_state( ) -> None: """Set a task state key, creating or updating the row.""" scope = _get_task_scope_for_ti(task_instance_id, session) - get_state_backend().set(scope, key, body.value, expires_at=body.expires_at, session=session) + get_state_backend().set(scope, key, json.dumps(body.value), expires_at=body.expires_at, session=session) @router.delete("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py index 779612bbde134..cd2a6861d1100 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py @@ -60,7 +60,7 @@ class AddAssetsByAliasEndpoint(VersionChange): class AddStateEndpoints(VersionChange): - """Add task state and asset state CRUD endpoints.""" + """Add task state and asset state API endpoints.""" description = __doc__ diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py index 6041d01e7f1a6..c91171aa05e6c 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from typing import TYPE_CHECKING import pytest @@ -113,7 +114,37 @@ def test_put_creates_row(self, client: TestClient, asset: AssetModel, session: S ) ) assert row is not None - assert row.value == "2026-04-29" + # DB stores JSON-encoded string + assert row.value == '"2026-04-29"' + + def test_put_int_value_roundtrip(self, client: TestClient, asset: AssetModel): + response = client.put( + _BY_NAME_VALUE, params={"name": asset.name, "key": "total_runs"}, json={"value": 5} + ) + assert response.status_code == 204 + assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": "total_runs"}).json() == { + "value": 5 + } + + def test_put_dict_value_roundtrip(self, client: TestClient, asset: AssetModel): + response = client.put( + _BY_NAME_VALUE, + params={"name": asset.name, "key": "last_run"}, + json={"value": {"rows": 1234, "status": "ok"}}, + ) + assert response.status_code == 204 + assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": "last_run"}).json() == { + "value": {"rows": 1234, "status": "ok"} + } + + def test_put_list_value_roundtrip(self, client: TestClient, asset: AssetModel): + response = client.put( + _BY_NAME_VALUE, params={"name": asset.name, "key": "ids"}, json={"value": [1, 2, 3]} + ) + assert response.status_code == 204 + assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": "ids"}).json() == { + "value": [1, 2, 3] + } def test_put_overwrites_existing(self, client: TestClient, asset: AssetModel): client.put( @@ -134,6 +165,22 @@ def test_put_empty_body_returns_422(self, client: TestClient, asset: AssetModel) assert response.status_code == 422 + def test_put_null_value_returns_422(self, client: TestClient, asset: AssetModel): + response = client.put( + _BY_NAME_VALUE, params={"name": asset.name, "key": "watermark"}, json={"value": None} + ) + assert response.status_code == 422 + + @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) + def test_put_non_finite_float_returns_422(self, client: TestClient, asset: AssetModel, bad_float: float): + with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): + _ = client.put( + _BY_NAME_VALUE, + params={"name": asset.name, "key": "watermark"}, + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + def test_put_unknown_asset_returns_404(self, client: TestClient): response = client.put( _BY_NAME_VALUE, params={"name": "nonexistent", "key": "watermark"}, json={"value": "x"} @@ -208,7 +255,7 @@ def test_put_creates_row(self, client: TestClient, asset: AssetModel, session: S ) ) assert row is not None - assert row.value == "2026-04-29" + assert row.value == '"2026-04-29"' def test_put_unknown_uri_returns_404(self, client: TestClient): response = client.put( diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py index d83751050e7ec..97acc576a50a7 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from datetime import datetime from typing import TYPE_CHECKING from uuid import uuid4 @@ -95,7 +96,37 @@ def test_put_creates_row(self, client: TestClient, create_task_instance: CreateT ) ) assert row is not None - assert row.value == "spark_001" + # DB stores a json string + assert row.value == '"spark_001"' + + def test_put_int_value_roundtrip(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put(_api_url(ti.id, "retry_count"), json={"value": 3}) + + assert response.status_code == 204 + assert client.get(_api_url(ti.id, "retry_count")).json() == {"value": 3} + + def test_put_dict_value_roundtrip(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put( + _api_url(ti.id, "poll_result"), + json={"value": {"status": "succeeded", "rows": 1234}}, + ) + + assert response.status_code == 204 + assert client.get(_api_url(ti.id, "poll_result")).json() == { + "value": {"status": "succeeded", "rows": 1234} + } + + def test_put_list_value_roundtrip(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put(_api_url(ti.id, "checkpoints"), json={"value": [1, 2, 3]}) + + assert response.status_code == 204 + assert client.get(_api_url(ti.id, "checkpoints")).json() == {"value": [1, 2, 3]} def test_put_with_expires_at_creates_row( self, client: TestClient, create_task_instance: CreateTaskInstance, time_machine @@ -122,7 +153,7 @@ def test_put_with_expires_at_creates_row( ) ) assert row is not None - assert row.value == "spark_001" + assert row.value == '"spark_001"' assert row.expires_at == datetime(2026, 5, 15, 12, 0, 0, tzinfo=pendulum.UTC) def test_put_overwrites_existing(self, client: TestClient, create_task_instance: CreateTaskInstance): @@ -155,6 +186,18 @@ def test_put_null_value_returns_422(self, client: TestClient, create_task_instan assert response.status_code == 422 + @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")]) + def test_put_non_finite_float_returns_422( + self, client: TestClient, create_task_instance: CreateTaskInstance, bad_float: float + ): + ti = create_task_instance() + with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"): + _ = client.put( + _api_url(ti.id, "job_id"), + content=json.dumps({"value": bad_float}, allow_nan=True).encode(), + headers={"Content-Type": "application/json"}, + ) + def test_put_missing_ti_returns_404(self, client: TestClient): response = client.put(_api_url(uuid4(), "job_id"), json={"value": "x"}) diff --git a/shared/state/pyproject.toml b/shared/state/pyproject.toml index 17fa0eaaac470..fac791405cb32 100644 --- a/shared/state/pyproject.toml +++ b/shared/state/pyproject.toml @@ -23,7 +23,9 @@ classifiers = [ "Private :: Do Not Upload", ] -dependencies = [] +dependencies = [ + "pydantic>=2.11.0", +] [dependency-groups] dev = [ diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 7aa9fcba8372d..688cd6a630178 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import json from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING @@ -23,6 +24,7 @@ if TYPE_CHECKING: from datetime import datetime + from pydantic import JsonValue from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session @@ -96,9 +98,10 @@ class BaseStateBackend(ABC): @abstractmethod def get(self, scope: StateScope, key: str, *, session: Session | None = None) -> str | None: """ - Return the stored value, or None if the key does not exist. + Return the stored JSON encoded value string, or None if the key does not exist. - Must handle both ``TaskScope`` and ``AssetScope``. + Must handle both ``TaskScope`` and ``AssetScope``. The execution API calls + ``json.loads`` on the returned string from here, so it must be a valid JSON document. """ @abstractmethod @@ -112,9 +115,11 @@ def set( session: Session | None = None, ) -> None: """ - Write or overwrite the value for the given key. + Write or overwrite ``value`` for the given key. - Must handle both ``TaskScope`` and ``AssetScope``. + Must handle both ``TaskScope`` and ``AssetScope``. ``value`` is always a + JSON encoded string (the execution API calls ``json.dumps`` before passing it + here); store it verbatim so ``get`` can return it unchanged. ``expires_at`` is an absolute UTC datetime after which the row may be deleted. Pass ``None`` (default) for a key that should never expire — stored as ``NULL``, @@ -147,10 +152,10 @@ def clear( @abstractmethod async def aget(self, scope: StateScope, key: str, *, session: AsyncSession | None = None) -> str | None: """ - Async variant of get. Must handle both ``TaskScope`` and ``AssetScope``. + Async variant of ``get`` which returns a JSON encoded value string or None. - ``session`` is optional. If provided, implementations should use it directly. - If ``None``, implementations manage their own async session internally. + Must handle both ``TaskScope`` and ``AssetScope``. ``session`` is used directly + when provided; otherwise implementations manage their own session internally. """ @abstractmethod @@ -164,10 +169,10 @@ async def aset( session: AsyncSession | None = None, ) -> None: """ - Async variant of set. Must handle both ``TaskScope`` and ``AssetScope``. + Async variant of ``set``. ``value`` is always a JSON encoded string. - ``session`` is optional. If provided, implementations should use it directly. - If ``None``, implementations manage their own async session internally. + Must handle both ``TaskScope`` and ``AssetScope``. ``session`` is used directly + when provided; otherwise implementations manage their own session internally. """ @abstractmethod @@ -203,7 +208,7 @@ def cleanup(self) -> None: ``[state_store] default_retention_days``) and deciding what to delete. """ - def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str: + def serialize_task_state_to_ref(self, *, value: JsonValue, key: str, ti_id: str) -> str: """ Serialize a task state value before it is sent to the execution API for db persistence. @@ -214,20 +219,21 @@ def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> st The returned reference must be deterministic — given the same ``ti_id`` and ``key`` it must always return the same string. Do not use timestamps or random UUIDs as part of the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external - object will be orphaned. + object will be orphaned. By default, it JSON dumps the value and returns a JSON string. """ - return value + return json.dumps(value) - def deserialize_task_state_from_ref(self, stored: str) -> str: + def deserialize_task_state_from_ref(self, stored: str) -> JsonValue: """ - Resolve a stored task state string back to the actual value. + Resolve a stored task state reference back to the actual value. Called by ``TaskStateAccessor.get()`` after the stored string is retrieved from - the execution API. Default: return ``stored`` unchanged. + the execution API. By default, it JSON decodes ``stored`` to reverse the default + ``serialize_task_state_to_ref`` encoding. """ - return stored + return json.loads(stored) - def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str: + def serialize_asset_state_to_ref(self, *, value: JsonValue, key: str, asset_ref: str) -> str: """ Serialize an asset state value before it is sent to the Execution API for db persistence. @@ -241,15 +247,16 @@ def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) The returned reference must be deterministic — given the same ``asset_ref`` and ``key`` it must always return the same string. Do not use timestamps or random UUIDs as part of the reference, otherwise ``delete()``/``clear()`` cannot reconstruct it and the external - object will be orphaned. + object will be orphaned. By default, it JSON dumps the value and returns a JSON string. """ - return value + return json.dumps(value) - def deserialize_asset_state_from_ref(self, stored: str) -> str: + def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue: """ - Resolve a stored asset state string back to the actual value. + Resolve a stored asset state reference back to the actual value. Called by ``AssetStateAccessor.get()`` after the stored string is retrieved from - the Execution API. Default: return ``stored`` unchanged. + the Execution API. By default, it JSON decodes ``stored`` to reverse the default + ``serialize_asset_state_to_ref`` encoding. """ - return stored + return json.loads(stored) diff --git a/shared/state/tests/state/test_state.py b/shared/state/tests/state/test_state.py index 1ea31194e2788..eb658ff8c74ca 100644 --- a/shared/state/tests/state/test_state.py +++ b/shared/state/tests/state/test_state.py @@ -92,6 +92,18 @@ def test_task_state_serialize_deserialize_round_trip(self, backend): deserialized = backend.deserialize_task_state_from_ref(serialized) assert deserialized == original + def test_task_state_serialize_deserialize_typed_values(self, backend): + """Default backend passes typed values through unchanged (custom backends handle storage).""" + assert ( + backend.deserialize_task_state_from_ref( + backend.serialize_task_state_to_ref(value=42, key="count", ti_id="abc-123") + ) + == 42 + ) + assert backend.deserialize_task_state_from_ref( + backend.serialize_task_state_to_ref(value={"status": "ok"}, key="result", ti_id="abc-123") + ) == {"status": "ok"} + def test_custom_backend_overrides_task_state_ser_deser(self): class MyBackend(BaseStateBackend): def get(self, scope, key): ... @@ -126,6 +138,17 @@ def test_asset_state_serialize_deserialize_round_trip(self, backend): deserialized = backend.deserialize_asset_state_from_ref(serialized) assert deserialized == original + def test_asset_state_serialize_deserialize_typed_values(self, backend): + assert ( + backend.deserialize_asset_state_from_ref( + backend.serialize_asset_state_to_ref(value=5, key="total_runs", asset_ref="my_asset") + ) + == 5 + ) + assert backend.deserialize_asset_state_from_ref( + backend.serialize_asset_state_to_ref(value={"rows": 1234}, key="last_run", asset_ref="my_asset") + ) == {"rows": 1234} + def test_custom_backend_overrides_asset_state_ser_deser(self): class MyBackend(BaseStateBackend): def get(self, scope, key): ... diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 99b1aadb37f6d..1da539f29a3dc 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -32,7 +32,7 @@ import structlog from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from pydantic import BaseModel +from pydantic import BaseModel, JsonValue from tenacity import ( before_log, retry, @@ -721,7 +721,7 @@ def get(self, ti_id: uuid.UUID, key: str) -> TaskStateResponse | ErrorResponse: raise return TaskStateResponse.model_validate_json(resp.read()) - def set(self, ti_id: uuid.UUID, key: str, value: str, expires_at: datetime | None) -> OKResponse: + def set(self, ti_id: uuid.UUID, key: str, value: JsonValue, expires_at: datetime | None) -> OKResponse: """Set a task state value via the API server.""" body = TaskStatePutBody(value=value, expires_at=expires_at) self.client.put(f"state/ti/{ti_id}/{key}", content=body.model_dump_json()) @@ -774,7 +774,9 @@ def get( raise return AssetStateResponse.model_validate_json(resp.read()) - def set(self, key: str, value: str, *, name: str | None = None, uri: str | None = None) -> OKResponse: + def set( + self, key: str, value: JsonValue, *, name: str | None = None, uri: str | None = None + ) -> OKResponse: """Set an asset state value via the API server.""" endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri) self.client.put(endpoint, params=params, content=AssetStatePutBody(value=value).model_dump_json()) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index fc966a7696957..62c43ac17d10b 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -63,28 +63,6 @@ class AssetProfile(BaseModel): type: Annotated[str, Field(title="Type")] -class AssetStatePutBody(BaseModel): - """ - Request body for setting an asset state value. - """ - - model_config = ConfigDict( - extra="forbid", - ) - value: Annotated[str, Field(title="Value")] - - -class AssetStateResponse(BaseModel): - """ - Asset state value returned to a worker. - """ - - model_config = ConfigDict( - extra="forbid", - ) - value: Annotated[str, Field(title="Value")] - - class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. @@ -375,7 +353,7 @@ class TaskStatePutBody(BaseModel): model_config = ConfigDict( extra="forbid", ) - value: Annotated[str, Field(title="Value")] + value: JsonValue expires_at: Annotated[AwareDatetime | None, Field(title="Expires At")] = None @@ -387,7 +365,7 @@ class TaskStateResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - value: Annotated[str, Field(title="Value")] + value: JsonValue class TaskStatesResponse(BaseModel): @@ -596,6 +574,28 @@ class AssetResponse(BaseModel): extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None +class AssetStatePutBody(BaseModel): + """ + Request body for setting an asset state value. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: JsonValue + + +class AssetStateResponse(BaseModel): + """ + Asset state value returned to a worker. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: JsonValue + + class HITLDetailRequest(BaseModel): """ Schema for the request part of a Human-in-the-loop detail for a specific task instance. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 2364e942ed044..7b494cab83592 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -923,7 +923,7 @@ class GetTaskState(BaseModel): class SetTaskState(BaseModel): ti_id: UUID key: str - value: str + value: JsonValue expires_at: AwareDatetime | None type: Literal["SetTaskState"] = "SetTaskState" @@ -955,14 +955,14 @@ class GetAssetStateByUri(BaseModel): class SetAssetStateByName(BaseModel): name: str key: str - value: str + value: JsonValue type: Literal["SetAssetStateByName"] = "SetAssetStateByName" class SetAssetStateByUri(BaseModel): uri: str key: str - value: str + value: JsonValue type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 14922780da48a..cba613da85a4a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -495,12 +495,19 @@ def __repr__(self) -> str: # is not implemented yet cos it's unclear whether task state values will be # used in templates. - def get(self, key: str) -> str | None: - """Return the stored value, or ``None`` if the key does not exist.""" + def get(self, key: str) -> JsonValue: + """ + Return the stored value, or ``None`` if the key does not exist. + + Supported types: ``str``, ``int``, ``float``, ``bool``, ``list``, ``dict``. + ``datetime`` is not JSON-serializable; store it as ``value.isoformat()`` and + parse it back with ``datetime.fromisoformat(result)``. + """ from airflow.sdk.execution_time.comms import ErrorResponse, GetTaskState, TaskStateResult from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key)) + if isinstance(resp, ErrorResponse) and resp.error != ErrorType.TASK_STATE_NOT_FOUND: raise AirflowRuntimeError(resp) if isinstance(resp, TaskStateResult): @@ -508,10 +515,15 @@ def get(self, key: str) -> str | None: # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from # custom backend using the reference backend = _get_worker_state_backend() - return backend.deserialize_task_state_from_ref(stored) if backend else stored + if backend is not None: + # serialize_task_state_to_ref always returns str by contract; stored contains the ref. + if TYPE_CHECKING: + assert isinstance(stored, str) + return backend.deserialize_task_state_from_ref(stored) + return stored return None - def set(self, key: str, value: str, *, retention: timedelta | None = None) -> None: + def set(self, key: str, value: JsonValue, *, retention: timedelta | None = None) -> None: """ Write or overwrite the value for the given key. @@ -614,7 +626,7 @@ def __repr__(self) -> str: return f"" return f"" - def get(self, key: str) -> str | None: + def get(self, key: str) -> JsonValue: """Return the stored value, or ``None`` if the key does not exist.""" from airflow.sdk.execution_time.comms import ( AssetStateResult, @@ -635,13 +647,16 @@ def get(self, key: str) -> str | None: raise AirflowRuntimeError(resp) if isinstance(resp, AssetStateResult): stored = resp.value - # if custom backend is configured, the stored value in DB is a reference, fetch the actual value from - # custom backend using the reference backend = _get_worker_state_backend() - return backend.deserialize_asset_state_from_ref(stored) if backend else stored + if backend is not None: + # serialize_asset_state_to_ref always returns str by contract; stored contains the ref. + if TYPE_CHECKING: + assert isinstance(stored, str) + return backend.deserialize_asset_state_from_ref(stored) + return stored return None - def set(self, key: str, value: str) -> None: + def set(self, key: str, value: JsonValue) -> None: """Write or overwrite the value for the given key.""" from airflow.sdk.execution_time.comms import SetAssetStateByName, SetAssetStateByUri, ToSupervisor from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -756,11 +771,11 @@ def _single_accessor(self) -> AssetStateAccessor: return next(iter(self._by_name.values())) return next(iter(self._by_uri.values())) - def get(self, key: str) -> str | None: + def get(self, key: str) -> JsonValue: """Return the stored value for the single-inlet task, or ``None`` if not found.""" return self._single_accessor().get(key) - def set(self, key: str, value: str) -> None: + def set(self, key: str, value: JsonValue) -> None: """Write or overwrite the value for the single-inlet task.""" self._single_accessor().set(key, value) diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index d4eb3d9c5a8b7..fec9596e49391 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -317,6 +317,9 @@ "type": "object" }, "AssetStateResult": { + "$defs": { + "JsonValue": {} + }, "additionalProperties": false, "description": "Response to GetAssetState; wraps the generated API response for supervisor to worker comms.", "properties": { @@ -327,8 +330,7 @@ "type": "string" }, "value": { - "title": "Value", - "type": "string" + "$ref": "#/$defs/JsonValue" } }, "required": [ @@ -4549,6 +4551,9 @@ "type": "object" }, "SetAssetStateByName": { + "$defs": { + "JsonValue": {} + }, "properties": { "key": { "title": "Key", @@ -4565,8 +4570,7 @@ "type": "string" }, "value": { - "title": "Value", - "type": "string" + "$ref": "#/$defs/JsonValue" } }, "required": [ @@ -4578,6 +4582,9 @@ "type": "object" }, "SetAssetStateByUri": { + "$defs": { + "JsonValue": {} + }, "properties": { "key": { "title": "Key", @@ -4594,8 +4601,7 @@ "type": "string" }, "value": { - "title": "Value", - "type": "string" + "$ref": "#/$defs/JsonValue" } }, "required": [ @@ -4653,6 +4659,9 @@ "type": "object" }, "SetTaskState": { + "$defs": { + "JsonValue": {} + }, "properties": { "expires_at": { "anyOf": [ @@ -4682,8 +4691,7 @@ "type": "string" }, "value": { - "title": "Value", - "type": "string" + "$ref": "#/$defs/JsonValue" } }, "required": [ @@ -5773,6 +5781,9 @@ "type": "object" }, "TaskStateResult": { + "$defs": { + "JsonValue": {} + }, "additionalProperties": false, "description": "Response to GetTaskState; wraps the generated API response for supervisor to worker comms.", "properties": { @@ -5783,8 +5794,7 @@ "type": "string" }, "value": { - "title": "Value", - "type": "string" + "$ref": "#/$defs/JsonValue" } }, "required": [ diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 1763604e477e5..35c7427874849 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -18,11 +18,13 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone as dt_timezone +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, patch from uuid import UUID import pytest +from pydantic import ValidationError from airflow.sdk import BaseOperator, get_current_context, timezone from airflow.sdk._shared.state import TaskScope @@ -102,6 +104,9 @@ from tests_common.test_utils.config import conf_vars +if TYPE_CHECKING: + from pydantic import JsonValue + def test_convert_connection_result_conn(): """Test that the ConnectionResult is converted to a Connection object.""" @@ -1210,6 +1215,16 @@ def test_clear_all_map_indices_sends_flag_true(self, mock_supervisor_comms): ClearTaskState(ti_id=self.TI_ID, all_map_indices=True) ) + def test_set_datetime_raises_validation_error(self, mock_supervisor_comms): + """datetime is not JSON-serializable; callers must use .isoformat() first.""" + with pytest.raises(ValidationError): + TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set( + "watermark", + datetime(2026, 5, 15, tzinfo=dt_timezone.utc), + ) + + mock_supervisor_comms.send.assert_not_called() + class TestAssetStateAccessor: ASSET_NAME = "debug_watcher_asset" @@ -1417,23 +1432,23 @@ def __init__(self): self._actual_key_value_store: dict[str, str] = {} # key -> actual value self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI) - def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) -> str: + def serialize_task_state_to_ref(self, *, value, key: str, ti_id: str) -> str: ref = f"mem://{ti_id}/{key}" self._actual_key_value_store[key] = value self.reference[key] = ref return ref - def deserialize_task_state_from_ref(self, stored: str) -> str: + def deserialize_task_state_from_ref(self, stored: str) -> JsonValue: key = stored.rsplit("/", 1)[-1] return self._actual_key_value_store.get(key, stored) - def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: str) -> str: + def serialize_asset_state_to_ref(self, *, value, key: str, asset_ref: str) -> str: ref = f"mem://{asset_ref}/{key}" self._actual_key_value_store[key] = value self.reference[key] = ref return ref - def deserialize_asset_state_from_ref(self, stored: str) -> str: + def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue: key = stored.rsplit("/", 1)[-1] return self._actual_key_value_store.get(key, stored) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 56900fbadab26..4ba821e537e56 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5263,6 +5263,40 @@ def execute(self, context): ) mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id, key="job_id")) + def test_task_state_set_sends_typed_values(self, create_runtime_ti, mock_supervisor_comms, time_machine): + """set() accepts any JsonValue — dict, int, list — not just strings.""" + + class MyOperator(BaseOperator): + def execute(self, context): + ts = context["task_state"] + ts.set("retry_count", 3) + ts.set("poll_result", {"status": "succeeded", "rows": 1234}) + ts.set("checkpoints", [1, 2, 3]) + + frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc) + time_machine.move_to(frozen_dt, tick=False) + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + with conf_vars({("state_store", "default_retention_days"): "30"}): + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + expires_at = frozen_dt + timedelta(days=30) + mock_supervisor_comms.send.assert_any_call( + SetTaskState(ti_id=runtime_ti.id, key="retry_count", value=3, expires_at=expires_at) + ) + mock_supervisor_comms.send.assert_any_call( + SetTaskState( + ti_id=runtime_ti.id, + key="poll_result", + value={"status": "succeeded", "rows": 1234}, + expires_at=expires_at, + ) + ) + mock_supervisor_comms.send.assert_any_call( + SetTaskState(ti_id=runtime_ti.id, key="checkpoints", value=[1, 2, 3], expires_at=expires_at) + ) + def test_task_can_set_state_with_retention(self, create_runtime_ti, mock_supervisor_comms, time_machine): class MyOperator(BaseOperator): def execute(self, context): diff --git a/uv.lock b/uv.lock index 7766a2d3cf466..d601b0aa04095 100644 --- a/uv.lock +++ b/uv.lock @@ -8524,6 +8524,9 @@ mypy = [{ name = "apache-airflow-devel-common", extras = ["mypy"], editable = "d name = "apache-airflow-shared-state" version = "0.0" source = { editable = "shared/state" } +dependencies = [ + { name = "pydantic" }, +] [package.dev-dependencies] dev = [ @@ -8534,6 +8537,7 @@ mypy = [ ] [package.metadata] +requires-dist = [{ name = "pydantic", specifier = ">=2.11.0" }] [package.metadata.requires-dev] dev = [{ name = "apache-airflow-devel-common", editable = "devel-common" }] @@ -14454,6 +14458,11 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/48/a2/5d27e81d24eef64668bf702bfe0e091cc48388b4666f36e025243eb9d827/jpype1-1.7.1.tar.gz", hash = "sha256:3cd88838dc3d2d546f7eaeadaaff864e590010c15f2b6a44b6f37e60796a14b2", size = 783791, upload-time = "2026-05-06T23:55:10.664Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/c0/2c41dedfb65060fa05d152b3f57e7c3658c86257d92de365a3c1fcb80779/jpype1-1.7.1-1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6590cbdb6208e4522fd99ae5f5f4bed5de707122385bc48446a1e7d7b56357ef", size = 571509, upload-time = "2026-05-19T20:19:30.416Z" }, + { url = "https://files.pythonhosted.org/packages/2f/5e/5611d50222d146a060dbf22e69c4017545341ea6b289a591d5a9bdaad718/jpype1-1.7.1-1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:4c81ee11aee5ed938d7415877cd9c7a0cc9cbf1dac87f7eab928e641323a385b", size = 571633, upload-time = "2026-05-19T20:19:33.536Z" }, + { url = "https://files.pythonhosted.org/packages/79/32/8b2279b12364f260111c7843bf9ede7dc442d5521d6d2ca728b3d522d445/jpype1-1.7.1-1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b3ddd9f9099202212a34679dfb95dda590bcfbd23289559d104e24abec9120d1", size = 569654, upload-time = "2026-05-19T20:19:36.41Z" }, + { url = "https://files.pythonhosted.org/packages/b5/67/5caa0de30bcb1c8786cc988144a68908e0624de20cfed470a67b1dd1f60c/jpype1-1.7.1-1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:6d491a81281407f8a68552eb3c0e635e576e066c069268dc29a1ea27bb4778ae", size = 569800, upload-time = "2026-05-19T20:19:38.877Z" }, + { url = "https://files.pythonhosted.org/packages/5b/1d/9ee10b1aad9f01ea6ac6159981120eb5ace01962f9cfaa7de6b911de3eb8/jpype1-1.7.1-1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:ace0ba1a67561358fa5b57b8e93ed8bcf16f0a8d5cba79c875089c56827adf8e", size = 569753, upload-time = "2026-05-19T20:19:41.514Z" }, { url = "https://files.pythonhosted.org/packages/04/ff/44a6f285d4c07014cb64379b8863caaefad1cc976d36923073d097b1d461/jpype1-1.7.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:472b2f53002f5fdf118d2e6b8c6b5441d6e3ca3cf1b1bdb163442be76c8b2859", size = 375560, upload-time = "2026-05-06T23:53:48.669Z" }, { url = "https://files.pythonhosted.org/packages/42/c5/98c5ba221de29b341298341c07ad2221beae565886d18c2e6b821928db15/jpype1-1.7.1-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:80c4c8cbab99040b8b56f28ff834e0b089aefccaabe3b472b8b43bb1e4658b86", size = 408119, upload-time = "2026-05-06T23:53:51.382Z" }, { url = "https://files.pythonhosted.org/packages/37/3f/d3b7fd287d5bae63af0ae935b2f2c01291d18ea2e6cd706db8e4dda15354/jpype1-1.7.1-cp310-cp310-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:9c9a08d06016afbe5391daaf843b9e76c79022181685bbb23b64cd3f9aaec30d", size = 454716, upload-time = "2026-05-06T23:53:53.937Z" },