diff --git a/providers/edge3/src/airflow/providers/edge3/cli/api_client.py b/providers/edge3/src/airflow/providers/edge3/cli/api_client.py index b7ce2316bf61d..2ca54be34ef6e 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/api_client.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/api_client.py @@ -94,8 +94,17 @@ def jwt_generator() -> JWTGenerator: network_errors=ClientConnectionError, timeouts=ServerTimeoutError, ) -async def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any: - authorization = jwt_generator().generate({"method": rest_path}) +async def _make_generic_request( + method: str, rest_path: str, data: str | None = None, team_name: str | None = None +) -> Any: + # The JWT carries both the request method (binding the token to a specific + # endpoint) and the worker's team_name (so the server can enforce team + # isolation without trusting a body field). Pre-team-claim workers omit + # team_name; the server falls back to the body for backwards compatibility. + claims: dict = {"method": rest_path} + if team_name is not None: + claims["team_name"] = team_name + authorization = jwt_generator().generate(claims) api_url = conf.get("edge", "api_url") content_type = {"Content-Type": "application/json"} if data else {} headers = { @@ -126,6 +135,7 @@ async def worker_register( WorkerStateBody( state=state, jobs_active=0, queues=queues, sysinfo=sysinfo, team_name=team_name ).model_dump_json(exclude_unset=True), + team_name=team_name, ) except ClientResponseError as e: if e.status == HTTPStatus.BAD_REQUEST: @@ -161,6 +171,7 @@ async def worker_set_state( maintenance_comments=maintenance_comments, team_name=team_name, ).model_dump_json(exclude_unset=True), + team_name=team_name, ) except ClientResponseError as e: if e.status == HTTPStatus.BAD_REQUEST: @@ -182,6 +193,7 @@ async def jobs_fetch( WorkerQueuesBody( queues=queues, free_concurrency=free_concurrency, team_name=team_name ).model_dump_json(exclude_unset=True), + team_name=team_name, ) if result: return EdgeJobFetched(**result) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/auth.py b/providers/edge3/src/airflow/providers/edge3/worker_api/auth.py index 6e9ba46ba87cd..ef0e314005df4 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/auth.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/auth.py @@ -60,8 +60,8 @@ def _forbidden_response(message: str): ) -def jwt_token_authorization(method: str, authorization: str): - """Check if the JWT token is correct.""" +def jwt_token_authorization(method: str, authorization: str) -> dict: + """Check if the JWT token is correct and return the validated payload.""" try: payload = jwt_validate(authorization) signed_method = payload.get("method") @@ -71,6 +71,7 @@ def jwt_token_authorization(method: str, authorization: str): f"signed method='{signed_method}' " f"called method='{method}'", ) + return payload except BadSignature: _forbidden_response("Bad Signature. Please use only the tokens provided by the API.") except InvalidAudienceError: @@ -91,20 +92,55 @@ def jwt_token_authorization(method: str, authorization: str): ) except Exception: _forbidden_response("Unable to authenticate API via token.") + # _forbidden_response always raises; this is unreachable but keeps type checkers happy. + return {} def jwt_token_authorization_rpc( body: JsonRpcRequestBase, authorization: str = Header(description="JWT Authorization Token") -): - """Check if the JWT token is correct for JSON PRC requests.""" - jwt_token_authorization(body.method, authorization) +) -> dict: + """Check if the JWT token is correct for JSON PRC requests; return the validated payload.""" + return jwt_token_authorization(body.method, authorization) def jwt_token_authorization_rest( request: Request, authorization: str = Header(description="JWT Authorization Token") -): - """Check if the JWT token is correct for REST API requests.""" +) -> dict: + """ + Check the JWT for a REST request; return the validated payload. + + Routes can capture the return value via ``Depends(jwt_token_authorization_rest)`` + to read claims (e.g. ``team_name``) that the server must trust over any + matching field in the request body. The payload is returned as a plain dict. + """ PREFIX = "/edge_worker/v1/" path = request.url.path method_path = path[path.find(PREFIX) + len(PREFIX) :] if PREFIX in path else path - jwt_token_authorization(method_path, authorization) + return jwt_token_authorization(method_path, authorization) + + +def assert_jwt_team_matches_body(jwt_payload: dict, body_team_name: str | None) -> str | None: + """ + Return the JWT-bound team_name; reject if the body claims a different team. + + The JWT is the source of truth for team membership: it was issued by the + central site to a specific worker for a specific team. Any team_name in the + request body is at most an echo of the local config; if it disagrees with + the JWT we treat it as an attempt to cross team boundaries and reject the + request with HTTP 403. + + Returns the JWT-bound team_name (which may be ``None`` for backwards-compat + workers whose JWTs predate the team_name claim — in that case the body's + team_name is used so older workers keep working). + """ + jwt_team = jwt_payload.get("team_name") + if jwt_team is None: + # Backwards-compat: pre-team-claim JWTs. Fall back to body's value so an + # in-flight upgrade does not lock workers out. + return body_team_name + if body_team_name is not None and body_team_name != jwt_team: + _forbidden_response( + "team_name in request body does not match JWT claim. " + f"jwt_team='{jwt_team}' body_team='{body_team_name}'", + ) + return jwt_team diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py index 5b296c92f1d33..62bd326f1f7fd 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/jobs.py @@ -28,7 +28,10 @@ from airflow.executors.workloads import ExecuteTask from airflow.providers.common.compat.sdk import Stats, timezone from airflow.providers.edge3.models.edge_job import EdgeJobModel -from airflow.providers.edge3.worker_api.auth import jwt_token_authorization_rest +from airflow.providers.edge3.worker_api.auth import ( + assert_jwt_team_matches_body, + jwt_token_authorization_rest, +) from airflow.providers.edge3.worker_api.datamodels import ( EdgeJobFetched, WorkerApiDocs, @@ -45,7 +48,6 @@ def parse_command(command: str) -> ExecuteTask: @jobs_router.post( "/fetch/{worker_name}", - dependencies=[Depends(jwt_token_authorization_rest)], responses=create_openapi_http_exception_doc( [ status.HTTP_400_BAD_REQUEST, @@ -63,8 +65,13 @@ def fetch( ), ], session: SessionDep, + jwt_payload: Annotated[dict, Depends(jwt_token_authorization_rest)], ) -> EdgeJobFetched | None: """Fetch a job to execute on the edge worker.""" + # Trust the JWT-bound team_name over any value in the body. Reject if the + # body's team_name disagrees with the JWT — that would be a cross-team + # request from a worker authenticated only for a different team. + team_name = assert_jwt_team_matches_body(jwt_payload, body.team_name) query = ( select(EdgeJobModel) .where( @@ -75,7 +82,7 @@ def fetch( ) if body.queues: query = query.where(EdgeJobModel.queue.in_(body.queues)) - query = query.where(EdgeJobModel.team_name == body.team_name) + query = query.where(EdgeJobModel.team_name == team_name) query = query.limit(1) query = query.with_for_update(skip_locked=True) job: EdgeJobModel | None = session.scalar(query) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py index 6ca8e794e9953..4f1acffd58a69 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py @@ -31,7 +31,10 @@ from airflow.providers.common.compat.sdk import Stats, conf, timezone from airflow.providers.edge3 import __version__ as edge_provider_version from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, set_metrics -from airflow.providers.edge3.worker_api.auth import jwt_token_authorization_rest +from airflow.providers.edge3.worker_api.auth import ( + assert_jwt_team_matches_body, + jwt_token_authorization_rest, +) from airflow.providers.edge3.worker_api.datamodels import ( WorkerQueueUpdateBody, WorkerRegistrationReturn, @@ -186,19 +189,22 @@ def redefine_maintenance_comments( return worker_maintenance_comment -@worker_router.post("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)]) +@worker_router.post("/{worker_name}") def register( worker_name: Annotated[str, _worker_name_doc], body: Annotated[WorkerStateBody, _worker_state_doc], session: SessionDep, + jwt_payload: Annotated[dict, Depends(jwt_token_authorization_rest)], ) -> WorkerRegistrationReturn: """Register a new worker to the backend.""" versions_match = _assert_version(body.sysinfo) + # Trust the JWT-bound team_name; reject if the body claims a different team. + team_name = assert_jwt_team_matches_body(jwt_payload, body.team_name) query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel | None = session.scalar(query) if not worker: worker = EdgeWorkerModel( - worker_name=worker_name, state=body.state, queues=body.queues, team_name=body.team_name + worker_name=worker_name, state=body.state, queues=body.queues, team_name=team_name ) else: # Prevent duplicate workers unless the existing worker is in offline or unknown state @@ -220,18 +226,23 @@ def register( worker.queues = body.queues worker.sysinfo = body.sysinfo worker.last_update = timezone.utcnow() - worker.team_name = body.team_name + worker.team_name = team_name session.add(worker) return WorkerRegistrationReturn(last_update=worker.last_update, versions_match=versions_match) -@worker_router.patch("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)]) +@worker_router.patch("/{worker_name}") def set_state( worker_name: Annotated[str, _worker_name_doc], body: Annotated[WorkerStateBody, _worker_state_doc], session: SessionDep, + jwt_payload: Annotated[dict, Depends(jwt_token_authorization_rest)], ) -> WorkerSetStateReturn: """Set state of worker and returns the current assigned queues.""" + # Trust the JWT-bound team_name; reject if the body claims a different team. + # The worker may be heart-beating with a stale local team_name config; the + # JWT is the authoritative declaration. + assert_jwt_team_matches_body(jwt_payload, body.team_name) query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel | None = session.scalar(query) if not worker: diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py index 22c63d12cead9..b49f447f0df06 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py @@ -153,7 +153,12 @@ def test_fetch_filters_by_team_name(self, session: Session): session.commit() body = WorkerQueuesBody(free_concurrency=1, queues=[QUEUE], team_name="team_a") - result = fetch("worker1", body, session) + result = fetch( + "worker1", + body, + session, + jwt_payload={"method": "jobs/fetch/worker1", "team_name": "team_a"}, + ) assert result is not None assert result.dag_id == "dag_a" assert result.task_id == "task_a" @@ -188,13 +193,45 @@ def test_fetch_without_team_name_returns_any_team(self, session: Session): session.add_all([job_team_a, job_no_team]) session.commit() + # New JWT-bound team check: a worker authenticated for team_a cannot + # request a team_b job by setting body.team_name. The server rejects + # the cross-team body with 403. Backwards-compat path (legacy worker + # whose JWT predates the team_name claim) still trusts the body. + from fastapi import HTTPException + + body_cross = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE], team_name="team_a") + with pytest.raises(HTTPException) as exc: + fetch( + "worker1", + body_cross, + session, + jwt_payload={"method": "jobs/fetch/worker1", "team_name": "team_b"}, + ) + assert exc.value.status_code == 403 + body1 = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE], team_name="team_a") - result1 = fetch("worker1", body1, session) + result1 = fetch( + "worker1", + body1, + session, + jwt_payload={"method": "jobs/fetch/worker1", "team_name": "team_a"}, + ) assert result1 is not None body2 = WorkerQueuesBody(free_concurrency=2, queues=[QUEUE], team_name=None) - result2 = fetch("worker1", body2, session) + # No team in body or JWT — legacy worker, body wins via backcompat path. + result2 = fetch( + "worker1", + body2, + session, + jwt_payload={"method": "jobs/fetch/worker1"}, + ) assert result2 is not None - result3 = fetch("worker1", body2, session) + result3 = fetch( + "worker1", + body2, + session, + jwt_payload={"method": "jobs/fetch/worker1"}, + ) assert result3 is None fetched_dag_ids = {result1.dag_id, result2.dag_id} assert fetched_dag_ids == {"dag_a", "dag_b"} diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py index 099ba7e882739..c49c258dff960 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py @@ -127,7 +127,12 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo queues=input_queues, sysinfo=self.MOCK_SYSINFO, ) - register("test_worker", body, session) + register( + "test_worker", + body, + session, + jwt_payload={"method": "worker/test_worker", "team_name": getattr(body, "team_name", None)}, + ) session.commit() worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all() @@ -147,7 +152,12 @@ def test_register_with_team_name(self, session: Session, cli_worker: EdgeWorker) sysinfo=self.MOCK_SYSINFO, team_name="team_a", ) - register("test_worker", body, session) + register( + "test_worker", + body, + session, + jwt_payload={"method": "worker/test_worker", "team_name": getattr(body, "team_name", None)}, + ) session.commit() worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all() @@ -178,7 +188,12 @@ def test_register_same_name_different_team_rejects_when_active( team_name="team_b", ) with pytest.raises(HTTPException) as exc_info: - register("test_worker", body, session) + register( + "test_worker", + body, + session, + jwt_payload={"method": "worker/test_worker", "team_name": getattr(body, "team_name", None)}, + ) assert exc_info.value.status_code == 409 def test_register_same_name_different_team_reuses_when_offline( @@ -203,7 +218,12 @@ def test_register_same_name_different_team_reuses_when_offline( sysinfo=self.MOCK_SYSINFO, team_name="team_b", ) - register("test_worker", body, session) + register( + "test_worker", + body, + session, + jwt_payload={"method": "worker/test_worker", "team_name": getattr(body, "team_name", None)}, + ) session.commit() worker = session.execute( @@ -248,12 +268,25 @@ def test_register_duplicate_worker( if should_raise: with pytest.raises(HTTPException) as exc_info: - register("test_worker", body, session) + register( + "test_worker", + body, + session, + jwt_payload={ + "method": "worker/test_worker", + "team_name": getattr(body, "team_name", None), + }, + ) assert exc_info.value.status_code == 409 assert "already active" in str(exc_info.value.detail).lower() else: # Should succeed for offline/unknown states - register("test_worker", body, session) + register( + "test_worker", + body, + session, + jwt_payload={"method": "worker/test_worker", "team_name": getattr(body, "team_name", None)}, + ) session.commit() worker = session.execute( select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == "test_worker") @@ -354,7 +387,12 @@ def test_set_state(self, session: Session, cli_worker: EdgeWorker): queues=["default2"], sysinfo=self.MOCK_SYSINFO, ) - return_queues = set_state("test2_worker", body, session).queues + return_queues = set_state( + "test2_worker", + body, + session, + jwt_payload={"method": "worker/test2_worker", "team_name": getattr(body, "team_name", None)}, + ).queues worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all() assert len(worker) == 1 @@ -381,7 +419,12 @@ def test_set_state_returns_concurrency(self, session: Session, cli_worker: EdgeW queues=["default"], sysinfo=self.MOCK_SYSINFO, ) - result = set_state("test2_worker", body, session) + result = set_state( + "test2_worker", + body, + session, + jwt_payload={"method": "worker/test2_worker", "team_name": getattr(body, "team_name", None)}, + ) assert result.concurrency == 16 def test_set_state_returns_none_concurrency_when_not_overridden( @@ -403,7 +446,12 @@ def test_set_state_returns_none_concurrency_when_not_overridden( queues=["default"], sysinfo=self.MOCK_SYSINFO, ) - result = set_state("test2_worker", body, session) + result = set_state( + "test2_worker", + body, + session, + jwt_payload={"method": "worker/test2_worker", "team_name": getattr(body, "team_name", None)}, + ) assert result.concurrency is None def test_set_worker_concurrency(self, session: Session):