Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,29 @@

from __future__ import annotations

import math

from pydantic import JsonValue, field_validator

from airflow.api_fastapi.core_api.base import StrictBaseModel


class AssetStateResponse(StrictBaseModel):
"""Asset state value returned to a worker."""

value: str
value: JsonValue


class AssetStatePutBody(StrictBaseModel):
"""Request body for setting an asset state value."""

value: str
value: JsonValue

@field_validator("value")
@classmethod
def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
if v is None:
raise ValueError("value cannot be null")
if isinstance(v, float) and not math.isfinite(v):
raise ValueError("value must be a finite number; NaN and Inf are not JSON representable")
return v
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@

from __future__ import annotations

import math
from datetime import datetime

from pydantic import JsonValue, field_validator

from airflow.api_fastapi.core_api.base import StrictBaseModel


class TaskStateResponse(StrictBaseModel):
"""Task state value returned to a worker."""

value: str
value: JsonValue


Comment thread
amoghrajesh marked this conversation as resolved.
class TaskStatePutBody(StrictBaseModel):
"""Request body for setting a task state value."""

value: str
value: JsonValue
expires_at: datetime | None = None

@field_validator("value")
@classmethod
def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
if v is None:
raise ValueError("value cannot be null")
if isinstance(v, float) and not math.isfinite(v):
raise ValueError("value must be a finite number; NaN and Inf are not JSON representable")
return v
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from __future__ import annotations

import json
from typing import Annotated

from cadwyn import VersionedAPIRouter
Expand Down Expand Up @@ -93,7 +94,7 @@ def get_asset_state_by_name(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": f"Asset state key {key!r} not found"},
)
return AssetStateResponse(value=value)
return AssetStateResponse(value=json.loads(value))


@router.put("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
Expand All @@ -105,7 +106,7 @@ def set_asset_state_by_name(
) -> None:
"""Set an asset state value by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session)
Comment thread
amoghrajesh marked this conversation as resolved.


@router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
Expand Down Expand Up @@ -143,7 +144,7 @@ def get_asset_state_by_uri(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": f"Asset state key {key!r} not found"},
)
return AssetStateResponse(value=value)
return AssetStateResponse(value=json.loads(value))


@router.put("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
Expand All @@ -155,7 +156,7 @@ def set_asset_state_by_uri(
) -> None:
"""Set an asset state value by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session)
get_state_backend().set(AssetScope(asset_id=asset_id), key, json.dumps(body.value), session=session)


@router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from typing import Annotated
from uuid import UUID

Expand Down Expand Up @@ -74,7 +75,7 @@ def get_task_state(
"message": f"Task state key {key!r} not found",
},
)
return TaskStateResponse(value=value)
return TaskStateResponse(value=json.loads(value))


@router.put("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT)
Expand All @@ -86,7 +87,7 @@ def set_task_state(
) -> None:
"""Set a task state key, creating or updating the row."""
scope = _get_task_scope_for_ti(task_instance_id, session)
get_state_backend().set(scope, key, body.value, expires_at=body.expires_at, session=session)
get_state_backend().set(scope, key, json.dumps(body.value), expires_at=body.expires_at, session=session)


@router.delete("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class AddAssetsByAliasEndpoint(VersionChange):


class AddStateEndpoints(VersionChange):
"""Add task state and asset state CRUD endpoints."""
"""Add task state and asset state API endpoints."""

description = __doc__

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -113,7 +114,37 @@ def test_put_creates_row(self, client: TestClient, asset: AssetModel, session: S
)
)
assert row is not None
assert row.value == "2026-04-29"
# DB stores JSON-encoded string
assert row.value == '"2026-04-29"'

def test_put_int_value_roundtrip(self, client: TestClient, asset: AssetModel):
response = client.put(
_BY_NAME_VALUE, params={"name": asset.name, "key": "total_runs"}, json={"value": 5}
)
assert response.status_code == 204
assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": "total_runs"}).json() == {
"value": 5
}

def test_put_dict_value_roundtrip(self, client: TestClient, asset: AssetModel):
response = client.put(
_BY_NAME_VALUE,
params={"name": asset.name, "key": "last_run"},
json={"value": {"rows": 1234, "status": "ok"}},
)
assert response.status_code == 204
assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": "last_run"}).json() == {
"value": {"rows": 1234, "status": "ok"}
}

def test_put_list_value_roundtrip(self, client: TestClient, asset: AssetModel):
response = client.put(
_BY_NAME_VALUE, params={"name": asset.name, "key": "ids"}, json={"value": [1, 2, 3]}
)
assert response.status_code == 204
assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key": "ids"}).json() == {
"value": [1, 2, 3]
}

def test_put_overwrites_existing(self, client: TestClient, asset: AssetModel):
client.put(
Expand All @@ -134,6 +165,22 @@ def test_put_empty_body_returns_422(self, client: TestClient, asset: AssetModel)

assert response.status_code == 422

def test_put_null_value_returns_422(self, client: TestClient, asset: AssetModel):
response = client.put(
_BY_NAME_VALUE, params={"name": asset.name, "key": "watermark"}, json={"value": None}
)
assert response.status_code == 422

@pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")])
def test_put_non_finite_float_returns_422(self, client: TestClient, asset: AssetModel, bad_float: float):
with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"):
_ = client.put(
_BY_NAME_VALUE,
params={"name": asset.name, "key": "watermark"},
content=json.dumps({"value": bad_float}, allow_nan=True).encode(),
headers={"Content-Type": "application/json"},
)

def test_put_unknown_asset_returns_404(self, client: TestClient):
response = client.put(
_BY_NAME_VALUE, params={"name": "nonexistent", "key": "watermark"}, json={"value": "x"}
Expand Down Expand Up @@ -208,7 +255,7 @@ def test_put_creates_row(self, client: TestClient, asset: AssetModel, session: S
)
)
assert row is not None
assert row.value == "2026-04-29"
assert row.value == '"2026-04-29"'

def test_put_unknown_uri_returns_404(self, client: TestClient):
response = client.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import json
from datetime import datetime
from typing import TYPE_CHECKING
from uuid import uuid4
Expand Down Expand Up @@ -95,7 +96,37 @@ def test_put_creates_row(self, client: TestClient, create_task_instance: CreateT
)
)
assert row is not None
assert row.value == "spark_001"
# DB stores a json string
assert row.value == '"spark_001"'

def test_put_int_value_roundtrip(self, client: TestClient, create_task_instance: CreateTaskInstance):
ti = create_task_instance()

response = client.put(_api_url(ti.id, "retry_count"), json={"value": 3})

assert response.status_code == 204
assert client.get(_api_url(ti.id, "retry_count")).json() == {"value": 3}

def test_put_dict_value_roundtrip(self, client: TestClient, create_task_instance: CreateTaskInstance):
ti = create_task_instance()

response = client.put(
_api_url(ti.id, "poll_result"),
json={"value": {"status": "succeeded", "rows": 1234}},
)

assert response.status_code == 204
assert client.get(_api_url(ti.id, "poll_result")).json() == {
"value": {"status": "succeeded", "rows": 1234}
}

def test_put_list_value_roundtrip(self, client: TestClient, create_task_instance: CreateTaskInstance):
ti = create_task_instance()

response = client.put(_api_url(ti.id, "checkpoints"), json={"value": [1, 2, 3]})

assert response.status_code == 204
assert client.get(_api_url(ti.id, "checkpoints")).json() == {"value": [1, 2, 3]}

def test_put_with_expires_at_creates_row(
self, client: TestClient, create_task_instance: CreateTaskInstance, time_machine
Expand All @@ -122,7 +153,7 @@ def test_put_with_expires_at_creates_row(
)
)
assert row is not None
assert row.value == "spark_001"
assert row.value == '"spark_001"'
assert row.expires_at == datetime(2026, 5, 15, 12, 0, 0, tzinfo=pendulum.UTC)

def test_put_overwrites_existing(self, client: TestClient, create_task_instance: CreateTaskInstance):
Expand Down Expand Up @@ -155,6 +186,18 @@ def test_put_null_value_returns_422(self, client: TestClient, create_task_instan

assert response.status_code == 422

@pytest.mark.parametrize("bad_float", [float("nan"), float("inf"), float("-inf")])
def test_put_non_finite_float_returns_422(
self, client: TestClient, create_task_instance: CreateTaskInstance, bad_float: float
):
ti = create_task_instance()
with pytest.raises(ValueError, match="Out of range float values are not JSON compliant"):
_ = client.put(
_api_url(ti.id, "job_id"),
content=json.dumps({"value": bad_float}, allow_nan=True).encode(),
headers={"Content-Type": "application/json"},
)

def test_put_missing_ti_returns_404(self, client: TestClient):
response = client.put(_api_url(uuid4(), "job_id"), json={"value": "x"})

Expand Down
4 changes: 3 additions & 1 deletion shared/state/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ classifiers = [
"Private :: Do Not Upload",
]

dependencies = []
dependencies = [
"pydantic>=2.11.0",
]

[dependency-groups]
dev = [
Expand Down
Loading
Loading