Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions providers/edge3/src/airflow/providers/edge3/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,17 @@ def jwt_generator() -> JWTGenerator:
network_errors=ClientConnectionError,
timeouts=ServerTimeoutError,
)
async def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any:
authorization = jwt_generator().generate({"method": rest_path})
async def _make_generic_request(
method: str, rest_path: str, data: str | None = None, team_name: str | None = None
) -> Any:
# The JWT carries both the request method (binding the token to a specific
# endpoint) and the worker's team_name (so the server can enforce team
# isolation without trusting a body field). Pre-team-claim workers omit
# team_name; the server falls back to the body for backwards compatibility.
claims: dict = {"method": rest_path}
if team_name is not None:
claims["team_name"] = team_name
authorization = jwt_generator().generate(claims)
api_url = conf.get("edge", "api_url")
content_type = {"Content-Type": "application/json"} if data else {}
headers = {
Expand Down Expand Up @@ -126,6 +135,7 @@ async def worker_register(
WorkerStateBody(
state=state, jobs_active=0, queues=queues, sysinfo=sysinfo, team_name=team_name
).model_dump_json(exclude_unset=True),
team_name=team_name,
)
except ClientResponseError as e:
if e.status == HTTPStatus.BAD_REQUEST:
Expand Down Expand Up @@ -161,6 +171,7 @@ async def worker_set_state(
maintenance_comments=maintenance_comments,
team_name=team_name,
).model_dump_json(exclude_unset=True),
team_name=team_name,
)
except ClientResponseError as e:
if e.status == HTTPStatus.BAD_REQUEST:
Expand All @@ -182,6 +193,7 @@ async def jobs_fetch(
WorkerQueuesBody(
queues=queues, free_concurrency=free_concurrency, team_name=team_name
).model_dump_json(exclude_unset=True),
team_name=team_name,
)
if result:
return EdgeJobFetched(**result)
Expand Down
52 changes: 44 additions & 8 deletions providers/edge3/src/airflow/providers/edge3/worker_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def _forbidden_response(message: str):
)


def jwt_token_authorization(method: str, authorization: str):
"""Check if the JWT token is correct."""
def jwt_token_authorization(method: str, authorization: str) -> dict:
"""Check if the JWT token is correct and return the validated payload."""
try:
payload = jwt_validate(authorization)
signed_method = payload.get("method")
Expand All @@ -71,6 +71,7 @@ def jwt_token_authorization(method: str, authorization: str):
f"signed method='{signed_method}' "
f"called method='{method}'",
)
return payload
except BadSignature:
_forbidden_response("Bad Signature. Please use only the tokens provided by the API.")
except InvalidAudienceError:
Expand All @@ -91,20 +92,55 @@ def jwt_token_authorization(method: str, authorization: str):
)
except Exception:
_forbidden_response("Unable to authenticate API via token.")
# _forbidden_response always raises; this is unreachable but keeps type checkers happy.
return {}


def jwt_token_authorization_rpc(
body: JsonRpcRequestBase, authorization: str = Header(description="JWT Authorization Token")
):
"""Check if the JWT token is correct for JSON PRC requests."""
jwt_token_authorization(body.method, authorization)
) -> dict:
"""Check if the JWT token is correct for JSON PRC requests; return the validated payload."""
return jwt_token_authorization(body.method, authorization)


def jwt_token_authorization_rest(
request: Request, authorization: str = Header(description="JWT Authorization Token")
):
"""Check if the JWT token is correct for REST API requests."""
) -> dict:
"""
Check the JWT for a REST request; return the validated payload.

Routes can capture the return value via ``Depends(jwt_token_authorization_rest)``
to read claims (e.g. ``team_name``) that the server must trust over any
matching field in the request body. The payload is returned as a plain dict.
"""
PREFIX = "/edge_worker/v1/"
path = request.url.path
method_path = path[path.find(PREFIX) + len(PREFIX) :] if PREFIX in path else path
jwt_token_authorization(method_path, authorization)
return jwt_token_authorization(method_path, authorization)


def assert_jwt_team_matches_body(jwt_payload: dict, body_team_name: str | None) -> str | None:
"""
Return the JWT-bound team_name; reject if the body claims a different team.

The JWT is the source of truth for team membership: it was issued by the
central site to a specific worker for a specific team. Any team_name in the
request body is at most an echo of the local config; if it disagrees with
the JWT we treat it as an attempt to cross team boundaries and reject the
request with HTTP 403.

Returns the JWT-bound team_name (which may be ``None`` for backwards-compat
workers whose JWTs predate the team_name claim — in that case the body's
team_name is used so older workers keep working).
"""
jwt_team = jwt_payload.get("team_name")
if jwt_team is None:
# Backwards-compat: pre-team-claim JWTs. Fall back to body's value so an
# in-flight upgrade does not lock workers out.
return body_team_name
if body_team_name is not None and body_team_name != jwt_team:
_forbidden_response(
"team_name in request body does not match JWT claim. "
f"jwt_team='{jwt_team}' body_team='{body_team_name}'",
)
return jwt_team
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from airflow.executors.workloads import ExecuteTask
from airflow.providers.common.compat.sdk import Stats, timezone
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.worker_api.auth import jwt_token_authorization_rest
from airflow.providers.edge3.worker_api.auth import (
assert_jwt_team_matches_body,
jwt_token_authorization_rest,
)
from airflow.providers.edge3.worker_api.datamodels import (
EdgeJobFetched,
WorkerApiDocs,
Expand All @@ -45,7 +48,6 @@ def parse_command(command: str) -> ExecuteTask:

@jobs_router.post(
"/fetch/{worker_name}",
dependencies=[Depends(jwt_token_authorization_rest)],
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
Expand All @@ -63,8 +65,13 @@ def fetch(
),
],
session: SessionDep,
jwt_payload: Annotated[dict, Depends(jwt_token_authorization_rest)],
) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
# Trust the JWT-bound team_name over any value in the body. Reject if the
# body's team_name disagrees with the JWT — that would be a cross-team
# request from a worker authenticated only for a different team.
team_name = assert_jwt_team_matches_body(jwt_payload, body.team_name)
query = (
select(EdgeJobModel)
.where(
Expand All @@ -75,7 +82,7 @@ def fetch(
)
if body.queues:
query = query.where(EdgeJobModel.queue.in_(body.queues))
query = query.where(EdgeJobModel.team_name == body.team_name)
query = query.where(EdgeJobModel.team_name == team_name)
query = query.limit(1)
query = query.with_for_update(skip_locked=True)
job: EdgeJobModel | None = session.scalar(query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from airflow.providers.common.compat.sdk import Stats, conf, timezone
from airflow.providers.edge3 import __version__ as edge_provider_version
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, set_metrics
from airflow.providers.edge3.worker_api.auth import jwt_token_authorization_rest
from airflow.providers.edge3.worker_api.auth import (
assert_jwt_team_matches_body,
jwt_token_authorization_rest,
)
from airflow.providers.edge3.worker_api.datamodels import (
WorkerQueueUpdateBody,
WorkerRegistrationReturn,
Expand Down Expand Up @@ -186,19 +189,22 @@ def redefine_maintenance_comments(
return worker_maintenance_comment


@worker_router.post("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)])
@worker_router.post("/{worker_name}")
def register(
worker_name: Annotated[str, _worker_name_doc],
body: Annotated[WorkerStateBody, _worker_state_doc],
session: SessionDep,
jwt_payload: Annotated[dict, Depends(jwt_token_authorization_rest)],
) -> WorkerRegistrationReturn:
"""Register a new worker to the backend."""
versions_match = _assert_version(body.sysinfo)
# Trust the JWT-bound team_name; reject if the body claims a different team.
team_name = assert_jwt_team_matches_body(jwt_payload, body.team_name)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
worker = EdgeWorkerModel(
worker_name=worker_name, state=body.state, queues=body.queues, team_name=body.team_name
worker_name=worker_name, state=body.state, queues=body.queues, team_name=team_name
)
else:
# Prevent duplicate workers unless the existing worker is in offline or unknown state
Expand All @@ -220,18 +226,23 @@ def register(
worker.queues = body.queues
worker.sysinfo = body.sysinfo
worker.last_update = timezone.utcnow()
worker.team_name = body.team_name
worker.team_name = team_name
session.add(worker)
return WorkerRegistrationReturn(last_update=worker.last_update, versions_match=versions_match)


@worker_router.patch("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)])
@worker_router.patch("/{worker_name}")
def set_state(
worker_name: Annotated[str, _worker_name_doc],
body: Annotated[WorkerStateBody, _worker_state_doc],
session: SessionDep,
jwt_payload: Annotated[dict, Depends(jwt_token_authorization_rest)],
) -> WorkerSetStateReturn:
"""Set state of worker and returns the current assigned queues."""
# Trust the JWT-bound team_name; reject if the body claims a different team.
# The worker may be heart-beating with a stale local team_name config; the
# JWT is the authoritative declaration.
assert_jwt_team_matches_body(jwt_payload, body.team_name)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
Expand Down
45 changes: 41 additions & 4 deletions providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ def test_fetch_filters_by_team_name(self, session: Session):
session.commit()

body = WorkerQueuesBody(free_concurrency=1, queues=[QUEUE], team_name="team_a")
result = fetch("worker1", body, session)
result = fetch(
"worker1",
body,
session,
jwt_payload={"method": "jobs/fetch/worker1", "team_name": "team_a"},
)
assert result is not None
assert result.dag_id == "dag_a"
assert result.task_id == "task_a"
Expand Down Expand Up @@ -188,13 +193,45 @@ def test_fetch_without_team_name_returns_any_team(self, session: Session):
session.add_all([job_team_a, job_no_team])
session.commit()

# New JWT-bound team check: a worker authenticated for team_a cannot
# request a team_b job by setting body.team_name. The server rejects
# the cross-team body with 403. Backwards-compat path (legacy worker
# whose JWT predates the team_name claim) still trusts the body.
from fastapi import HTTPException

body_cross = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE], team_name="team_a")
with pytest.raises(HTTPException) as exc:
fetch(
"worker1",
body_cross,
session,
jwt_payload={"method": "jobs/fetch/worker1", "team_name": "team_b"},
)
assert exc.value.status_code == 403

body1 = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE], team_name="team_a")
result1 = fetch("worker1", body1, session)
result1 = fetch(
"worker1",
body1,
session,
jwt_payload={"method": "jobs/fetch/worker1", "team_name": "team_a"},
)
assert result1 is not None
body2 = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE], team_name=None)
result2 = fetch("worker1", body2, session)
# No team in body or JWT — legacy worker, body wins via backcompat path.
result2 = fetch(
"worker1",
body2,
session,
jwt_payload={"method": "jobs/fetch/worker1"},
)
assert result2 is not None
result3 = fetch("worker1", body2, session)
result3 = fetch(
"worker1",
body2,
session,
jwt_payload={"method": "jobs/fetch/worker1"},
)
assert result3 is None
fetched_dag_ids = {result1.dag_id, result2.dag_id}
assert fetched_dag_ids == {"dag_a", "dag_b"}
Loading
Loading