diff --git a/.github/workflows/phase1-ci-and-release.yml b/.github/workflows/phase1-ci-and-release.yml index 5c78aa2..886e2d5 100644 --- a/.github/workflows/phase1-ci-and-release.yml +++ b/.github/workflows/phase1-ci-and-release.yml @@ -26,7 +26,9 @@ jobs: python-version: "3.11" - name: Install dependencies - run: python -m pip install --upgrade pip pre-commit pytest + run: | + python -m pip install --upgrade pip pre-commit pytest + python -m pip install -e predicate_contracts -e predicate_authority - name: Verify package release order run: python scripts/verify_release_order.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bc917e0..25d30f4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,9 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install test dependencies - run: python -m pip install --upgrade pip pytest + run: | + python -m pip install --upgrade pip pytest + python -m pip install -e predicate_contracts -e predicate_authority - name: Run tests run: python -m pytest -q diff --git a/docs/authorityd-operations.md b/docs/authorityd-operations.md index f991ce9..462533f 100644 --- a/docs/authorityd-operations.md +++ b/docs/authorityd-operations.md @@ -63,6 +63,24 @@ PYTHONPATH=. predicate-authorityd \ --control-plane-fail-open ``` +### Signing key safety note (required until mandate `v2` claims) + +Until mandate `v2` introduces explicit `iss`/`aud` claims and asymmetric signing defaults, +each deployment instance must use a unique signing key to reduce cross-instance replay risk. + +Recommended startup pattern: + +```bash +export PREDICATE_AUTHORITY_SIGNING_KEY="" + +PYTHONPATH=. predicate-authorityd \ + --host 127.0.0.1 \ + --port 8787 \ + --mode local_only \ + --policy-file examples/authorityd/policy.json \ + --mandate-signing-key-env PREDICATE_AUTHORITY_SIGNING_KEY +``` + When enabled, daemon bootstrap auto-attaches `ControlPlaneTraceEmitter` so each authority decision pushes: diff --git a/examples/authority_client_local_policy.yaml b/examples/authority_client_local_policy.yaml new file mode 100644 index 0000000..38d2435 --- /dev/null +++ b/examples/authority_client_local_policy.yaml @@ -0,0 +1,9 @@ +rules: + - name: allow-orders-create + effect: allow + principals: + - agent:checkout + actions: + - http.post + resources: + - https://api.vendor.com/orders diff --git a/examples/authority_client_local_yaml.py b/examples/authority_client_local_yaml.py new file mode 100644 index 0000000..44d945a --- /dev/null +++ b/examples/authority_client_local_yaml.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + + +def _ensure_repo_root_on_syspath() -> None: + repo_root = Path(__file__).resolve().parents[1] + root = str(repo_root) + if root not in sys.path: + sys.path.insert(0, root) + + +def _build_request() -> object: + _ensure_repo_root_on_syspath() + from predicate_contracts import ( # pylint: disable=import-error + ActionRequest, + ActionSpec, + PrincipalRef, + StateEvidence, + VerificationEvidence, + ) + + return ActionRequest( + principal=PrincipalRef(principal_id="agent:checkout"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="submit customer order", + ), + state_evidence=StateEvidence(source="sdk-python", state_hash="sha256:example"), + verification_evidence=VerificationEvidence(), + ) + + +def run(policy_file: str, secret_key: str) -> dict[str, object]: + _ensure_repo_root_on_syspath() + from predicate_authority import AuthorityClient # pylint: disable=import-error + + context = AuthorityClient.from_policy_file( + policy_file=policy_file, + secret_key=secret_key, + ttl_seconds=120, + ) + client = context.client + decision = client.authorize(_build_request()) + token_verified = False + if decision.mandate is not None: + token_verified = client.verify_token(decision.mandate.token) is not None + return { + "policy_file": policy_file, + "allowed": decision.allowed, + "reason": decision.reason.value, + "token_issued": decision.mandate is not None, + "token_verified": token_verified, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Local AuthorityClient example using YAML policy.") + parser.add_argument( + "--policy-file", + default="examples/authority_client_local_policy.yaml", + help="Path to local YAML policy file.", + ) + parser.add_argument( + "--secret-key", + default="dev-secret", + help="Signing key used for local mandates.", + ) + args = parser.parse_args() + payload = run(policy_file=args.policy_file, secret_key=args.secret_key) + print(json.dumps(payload, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/examples/delegation/delegate.py b/examples/delegation/delegate.py new file mode 100644 index 0000000..5dc4841 --- /dev/null +++ b/examples/delegation/delegate.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import argparse +import importlib.util +import json +import sys +from collections.abc import Callable +from pathlib import Path +from typing import Any, cast + + +def _ensure_repo_root_on_syspath() -> None: + repo_root = Path(__file__).resolve().parents[2] + root = str(repo_root) + if root not in sys.path: + sys.path.insert(0, root) + + +def _build_request() -> object: + _ensure_repo_root_on_syspath() + from predicate_contracts import ( # pylint: disable=import-error + ActionRequest, + ActionSpec, + PrincipalRef, + StateEvidence, + VerificationEvidence, + ) + + return ActionRequest( + principal=PrincipalRef(principal_id="agent:root"), + action_spec=ActionSpec( + action="task.delegate", + resource="worker:queue/main", + intent="delegate processing to worker agent", + ), + state_evidence=StateEvidence(source="delegate.py", state_hash="sha256:delegate"), + verification_evidence=VerificationEvidence(), + ) + + +def _run_worker( + worker_script: str, + token: str, + secret_key: str, + revocation_file: str, + policy_file: str, +) -> dict[str, object]: + worker_run = _load_worker_runner(worker_script) + payload = worker_run( + token=token, + secret_key=secret_key, + revocation_file=revocation_file, + policy_file=policy_file, + ) + if not isinstance(payload, dict): + raise RuntimeError("worker payload must be an object") + return cast(dict[str, object], payload) + + +def _load_worker_runner(worker_script: str) -> Callable[..., Any]: + worker_path = Path(worker_script).resolve() + spec = importlib.util.spec_from_file_location("delegation_worker_runtime", worker_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load worker module from path: {worker_script}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + run_callable = getattr(module, "run", None) + if not callable(run_callable): + raise RuntimeError("Worker module must expose callable run(...) function.") + return cast(Callable[..., Any], run_callable) + + +def run( + policy_file: str, + worker_script: str, + revocation_file: str, + secret_key: str, +) -> dict[str, object]: + _ensure_repo_root_on_syspath() + from predicate_authority import AuthorityClient # pylint: disable=import-error + + context = AuthorityClient.from_policy_file( + policy_file=policy_file, + secret_key=secret_key, + ttl_seconds=120, + ) + client = context.client + + decision = client.authorize(_build_request()) + if not decision.allowed or decision.mandate is None: + return { + "root_allowed": False, + "worker_allowed_before_revoke": False, + "worker_allowed_after_revoke": False, + } + + token = decision.mandate.token + Path(revocation_file).write_text( + json.dumps({"revoked_principal_ids": []}, indent=2), + encoding="utf-8", + ) + before = _run_worker(worker_script, token, secret_key, revocation_file, policy_file) + + client.revoke_principal("agent:root") + Path(revocation_file).write_text( + json.dumps({"revoked_principal_ids": ["agent:root"]}, indent=2), + encoding="utf-8", + ) + after = _run_worker(worker_script, token, secret_key, revocation_file, policy_file) + + return { + "root_allowed": True, + "root_delegation_depth": decision.mandate.claims.delegation_depth, + "root_chain_hash": decision.mandate.claims.delegation_chain_hash, + "worker_allowed_before_revoke": bool(before.get("allowed", False)), + "worker_allowed_after_revoke": bool(after.get("allowed", False)), + "worker_delegation_depth_before_revoke": before.get("delegation_depth"), + "worker_chain_verified_before_revoke": bool(before.get("chain_verified", False)), + "before_reason": before.get("reason"), + "after_reason": after.get("reason"), + } + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Delegation simulation for local authority runtime." + ) + parser.add_argument( + "--policy-file", + default="examples/delegation/policy.yaml", + help="Path to policy file for the root delegating agent.", + ) + parser.add_argument( + "--worker-script", + default="examples/delegation/worker.py", + help="Path to worker.py.", + ) + parser.add_argument( + "--revocation-file", + default="examples/delegation/revocations.json", + help="Path to revocation state shared with worker.", + ) + parser.add_argument("--secret-key", default="dev-secret") + args = parser.parse_args() + payload = run( + policy_file=args.policy_file, + worker_script=args.worker_script, + revocation_file=args.revocation_file, + secret_key=args.secret_key, + ) + print(json.dumps(payload, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/examples/delegation/policy.yaml b/examples/delegation/policy.yaml new file mode 100644 index 0000000..91d3bfa --- /dev/null +++ b/examples/delegation/policy.yaml @@ -0,0 +1,19 @@ +rules: + - name: allow-delegate-task + effect: allow + principals: + - agent:root + actions: + - task.delegate + resources: + - worker:queue/* + max_delegation_depth: 1 + - name: allow-worker-execute + effect: allow + principals: + - agent:worker + actions: + - job.execute + resources: + - queue://jobs/* + max_delegation_depth: 1 diff --git a/examples/delegation/worker.py b/examples/delegation/worker.py new file mode 100644 index 0000000..269d649 --- /dev/null +++ b/examples/delegation/worker.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + + +def _ensure_repo_root_on_syspath() -> None: + repo_root = Path(__file__).resolve().parents[2] + root = str(repo_root) + if root not in sys.path: + sys.path.insert(0, root) + + +def _load_revocations(path: str) -> list[str]: + file_path = Path(path) + if not file_path.exists(): + return [] + payload = json.loads(file_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + return [] + revoked = payload.get("revoked_principal_ids", []) + if not isinstance(revoked, list): + return [] + return [str(item) for item in revoked] + + +def _build_worker_request() -> object: + _ensure_repo_root_on_syspath() + from predicate_contracts import ( # pylint: disable=import-error + ActionRequest, + ActionSpec, + PrincipalRef, + StateEvidence, + VerificationEvidence, + ) + + return ActionRequest( + principal=PrincipalRef(principal_id="agent:worker"), + action_spec=ActionSpec( + action="job.execute", + resource="queue://jobs/high-priority", + intent="execute delegated job", + ), + state_evidence=StateEvidence(source="worker.py", state_hash="sha256:worker"), + verification_evidence=VerificationEvidence(), + ) + + +def run( + token: str, + secret_key: str, + revocation_file: str, + policy_file: str, +) -> dict[str, object]: + _ensure_repo_root_on_syspath() + from predicate_authority import AuthorityClient # pylint: disable=import-error + + context = AuthorityClient.from_policy_file( + policy_file=policy_file, + secret_key=secret_key, + ttl_seconds=120, + ) + client = context.client + revoked_principal_ids = _load_revocations(revocation_file) + for principal_id in revoked_principal_ids: + client.revoke_principal(principal_id) + + parent_mandate = client.verify_token(token) + if parent_mandate is None: + if "agent:root" in revoked_principal_ids: + return {"allowed": False, "reason": "revoked_root_token"} + return {"allowed": False, "reason": "invalid_or_expired_token"} + + decision = client.authorize( + _build_worker_request(), + parent_mandate=parent_mandate, + ) + if not decision.allowed or decision.mandate is None: + denied_reason = ( + "revoked_root_token" + if parent_mandate.claims.principal_id in revoked_principal_ids + else "denied" + ) + return {"allowed": False, "reason": denied_reason} + + chain_ok = client.verify_delegation_chain( + token=decision.mandate.token, + parent_token=token, + ) + return { + "allowed": True, + "reason": "ok", + "principal_id": decision.mandate.claims.principal_id, + "delegated_by": decision.mandate.claims.delegated_by, + "delegation_depth": decision.mandate.claims.delegation_depth, + "chain_hash": decision.mandate.claims.delegation_chain_hash, + "chain_verified": chain_ok, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Worker process for delegation simulation.") + parser.add_argument("--token", required=True) + parser.add_argument("--secret-key", default="dev-secret") + parser.add_argument("--revocation-file", required=True) + parser.add_argument("--policy-file", required=True) + args = parser.parse_args() + payload = run( + token=args.token, + secret_key=args.secret_key, + revocation_file=args.revocation_file, + policy_file=args.policy_file, + ) + print(json.dumps(payload, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/predicate_authority/__init__.py b/predicate_authority/__init__.py index 64800e9..452c683 100644 --- a/predicate_authority/__init__.py +++ b/predicate_authority/__init__.py @@ -9,6 +9,7 @@ OIDCIdentityBridge, TokenExchangeResult, ) +from predicate_authority.client import AuthorityClient, LocalAuthorizationContext from predicate_authority.control_plane import ( AuditEventEnvelope, ControlPlaneClient, @@ -46,6 +47,7 @@ "ActionExecutionResult", "ActionGuard", "AuthorityMode", + "AuthorityClient", "AuthorizationDeniedError", "AuditEventEnvelope", "ControlPlaneClient", @@ -61,6 +63,7 @@ "LocalIdPBridge", "LocalIdPBridgeConfig", "LocalCredentialStore", + "LocalAuthorizationContext", "LocalIdentityRegistry", "LocalIdentityRegistryStats", "LocalLedgerQueueEmitter", diff --git a/predicate_authority/bridge.py b/predicate_authority/bridge.py index 0786666..b934561 100644 --- a/predicate_authority/bridge.py +++ b/predicate_authority/bridge.py @@ -112,7 +112,10 @@ def refresh_token( class EntraIdentityBridge(OIDCIdentityBridge): - """Microsoft Entra adapter built on generic OIDC behavior.""" + """Microsoft Entra adapter built on generic OIDC behavior. + + Phase 2 keeps this as a deterministic local stand-in for real IdP token exchange. + """ def __init__(self, config: EntraBridgeConfig) -> None: oidc_config = OIDCBridgeConfig( diff --git a/predicate_authority/client.py b/predicate_authority/client.py new file mode 100644 index 0000000..7909b16 --- /dev/null +++ b/predicate_authority/client.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass + +from predicate_authority.guard import ActionGuard +from predicate_authority.mandate import LocalMandateSigner +from predicate_authority.policy import PolicyEngine +from predicate_authority.policy_source import PolicyFileSource +from predicate_authority.proof import InMemoryProofLedger +from predicate_authority.revocation import LocalRevocationCache +from predicate_contracts import ( + ActionRequest, + AuthorizationDecision, + AuthorizationReason, + SignedMandate, +) + + +@dataclass(frozen=True) +class LocalAuthorizationContext: + client: AuthorityClient + policy_file: str + + +class AuthorityClient: + """Lightweight local authority client for pre-action authorization flows.""" + + def __init__( + self, + action_guard: ActionGuard, + mandate_signer: LocalMandateSigner, + revocation_cache: LocalRevocationCache | None = None, + ) -> None: + self._action_guard = action_guard + self._mandate_signer = mandate_signer + self._revocation_cache = revocation_cache or LocalRevocationCache() + + @classmethod + def from_policy_file( + cls, + policy_file: str, + secret_key: str, + ttl_seconds: int = 300, + ) -> LocalAuthorizationContext: + rules, global_max_delegation_depth = PolicyFileSource(policy_file).load_policy() + policy_engine = PolicyEngine( + rules=rules, + global_max_delegation_depth=global_max_delegation_depth, + ) + proof_ledger = InMemoryProofLedger() + mandate_signer = LocalMandateSigner(secret_key=secret_key, ttl_seconds=ttl_seconds) + action_guard = ActionGuard( + policy_engine=policy_engine, + mandate_signer=mandate_signer, + proof_ledger=proof_ledger, + ) + return LocalAuthorizationContext( + client=cls( + action_guard=action_guard, + mandate_signer=mandate_signer, + revocation_cache=LocalRevocationCache(), + ), + policy_file=policy_file, + ) + + @classmethod + def from_env(cls) -> LocalAuthorizationContext: + policy_file = os.getenv("PREDICATE_AUTHORITY_POLICY_FILE") + secret_key = os.getenv("PREDICATE_AUTHORITY_SIGNING_KEY") + ttl_seconds_raw = os.getenv("PREDICATE_AUTHORITY_MANDATE_TTL_SECONDS", "300") + if policy_file is None or policy_file.strip() == "": + raise RuntimeError("PREDICATE_AUTHORITY_POLICY_FILE is required.") + if secret_key is None or secret_key.strip() == "": + raise RuntimeError("PREDICATE_AUTHORITY_SIGNING_KEY is required.") + try: + ttl_seconds = int(ttl_seconds_raw) + except ValueError as exc: + raise RuntimeError( + "PREDICATE_AUTHORITY_MANDATE_TTL_SECONDS must be an integer." + ) from exc + return cls.from_policy_file( + policy_file=policy_file, + secret_key=secret_key, + ttl_seconds=ttl_seconds, + ) + + def authorize( + self, + request: ActionRequest, + parent_mandate: SignedMandate | None = None, + ) -> AuthorizationDecision: + if self._revocation_cache.is_request_revoked(request): + return AuthorizationDecision( + allowed=False, + reason=AuthorizationReason.INVALID_MANDATE, + violated_rule="revocation_cache", + ) + if parent_mandate is not None and self._revocation_cache.is_mandate_revoked(parent_mandate): + return AuthorizationDecision( + allowed=False, + reason=AuthorizationReason.INVALID_MANDATE, + violated_rule="revocation_cache", + ) + decision = self._action_guard.authorize(request, parent_mandate=parent_mandate) + if ( + decision.allowed + and decision.mandate is not None + and self._revocation_cache.is_mandate_revoked(decision.mandate) + ): + return AuthorizationDecision( + allowed=False, + reason=AuthorizationReason.INVALID_MANDATE, + violated_rule="revocation_cache", + ) + return decision + + def verify_token(self, token: str) -> SignedMandate | None: + mandate = self._mandate_signer.verify(token) + if mandate is None: + return None + if self._revocation_cache.is_mandate_revoked(mandate): + return None + return mandate + + def verify_delegation_chain( + self, + token: str, + parent_token: str | None = None, + ) -> bool: + mandate = self.verify_token(token) + if mandate is None: + return False + parent_mandate = self.verify_token(parent_token) if parent_token is not None else None + if parent_token is not None and parent_mandate is None: + return False + return self._mandate_signer.verify_delegation( + mandate=mandate, + parent_mandate=parent_mandate, + ) + + def revoke_principal(self, principal_id: str) -> None: + self._revocation_cache.revoke_principal(principal_id) + + def revoke_mandate(self, mandate_id: str) -> None: + self._revocation_cache.revoke_mandate_id(mandate_id) diff --git a/predicate_authority/control_plane.py b/predicate_authority/control_plane.py index 9b04909..36d41ea 100644 --- a/predicate_authority/control_plane.py +++ b/predicate_authority/control_plane.py @@ -9,7 +9,7 @@ from datetime import datetime, timezone from urllib.parse import urlsplit -from predicate_contracts import ProofEvent, TraceEmitter +from predicate_contracts import ProofEvent @dataclass(frozen=True) @@ -138,7 +138,7 @@ def _new_connection(self) -> http.client.HTTPConnection: @dataclass -class ControlPlaneTraceEmitter(TraceEmitter): +class ControlPlaneTraceEmitter: client: ControlPlaneClient trace_id: str | None = None emit_usage_credits: bool = True @@ -183,7 +183,8 @@ def _send_audit_event(self, audit_event: AuditEventEnvelope) -> None: except Exception as exc: self.audit_push_failure_count += 1 self.last_push_error = str(exc) - raise + if not self.client.config.fail_open: + raise def _send_usage_record(self, usage: UsageCreditRecord) -> None: try: @@ -197,4 +198,5 @@ def _send_usage_record(self, usage: UsageCreditRecord) -> None: except Exception as exc: self.usage_push_failure_count += 1 self.last_push_error = str(exc) - raise + if not self.client.config.fail_open: + raise diff --git a/predicate_authority/daemon.py b/predicate_authority/daemon.py index 405146f..648a083 100644 --- a/predicate_authority/daemon.py +++ b/predicate_authority/daemon.py @@ -53,6 +53,7 @@ class DaemonConfig: host: str = "127.0.0.1" port: int = 8787 policy_poll_interval_s: float = 2.0 + max_request_body_bytes: int = 1_048_576 @dataclass(frozen=True) @@ -167,6 +168,7 @@ def do_POST(self) -> None: # noqa: N802 "/policy/reload": self._handle_policy_reload, "/revoke/principal": self._handle_revoke_principal, "/revoke/intent": self._handle_revoke_intent, + "/revoke/mandate": self._handle_revoke_mandate, "/identity/task": self._handle_identity_task, "/identity/revoke": self._handle_identity_revoke, "/ledger/flush-ack": self._handle_ledger_flush_ack, @@ -201,6 +203,15 @@ def _handle_revoke_intent(self) -> None: self.server.daemon_ref.revoke_intent(intent_hash.strip()) # type: ignore[attr-defined] self._send_json(200, {"ok": True, "intent_hash": intent_hash.strip()}) + def _handle_revoke_mandate(self) -> None: + payload = self._read_json_body() + mandate_id = payload.get("mandate_id") + if not isinstance(mandate_id, str) or mandate_id.strip() == "": + self._send_json(400, {"error": "mandate_id is required"}) + return + self.server.daemon_ref.revoke_mandate(mandate_id.strip()) # type: ignore[attr-defined] + self._send_json(200, {"ok": True, "mandate_id": mandate_id.strip()}) + def _handle_identity_task(self) -> None: payload = self._read_json_body() principal_id = payload.get("principal_id") @@ -273,6 +284,9 @@ def _read_json_body(self) -> dict[str, Any]: return {} if content_length <= 0: return {} + max_body = self.server.daemon_ref.max_request_body_bytes() # type: ignore[attr-defined] + if content_length > max_body: + return {} payload = self.rfile.read(content_length).decode("utf-8") try: loaded = json.loads(payload) @@ -388,6 +402,12 @@ def revoke_principal(self, principal_id: str) -> None: def revoke_intent(self, intent_hash: str) -> None: self._sidecar.revoke_intent_hash(intent_hash) + def revoke_mandate(self, mandate_id: str) -> None: + self._sidecar.revoke_mandate_id(mandate_id) + + def max_request_body_bytes(self) -> int: + return max(0, int(self._config.max_request_body_bytes)) + def issue_task_identity( self, principal_id: str, @@ -592,11 +612,16 @@ def _build_default_sidecar( control_plane_config: ControlPlaneBootstrapConfig | None = None, local_identity_config: LocalIdentityBootstrapConfig | None = None, identity_bridge: ExchangeTokenBridge | None = None, + mandate_signing_key: str | None = None, ) -> PredicateAuthoritySidecar: policy_rules: tuple[PolicyRule, ...] = () + global_max_delegation_depth: int | None = None if policy_file is not None and Path(policy_file).exists(): - policy_rules = PolicyFileSource(policy_file).load_rules() - policy_engine = PolicyEngine(rules=policy_rules) + policy_rules, global_max_delegation_depth = PolicyFileSource(policy_file).load_policy() + policy_engine = PolicyEngine( + rules=policy_rules, + global_max_delegation_depth=global_max_delegation_depth, + ) trace_emitters: list[TraceEmitter] = [] if ( @@ -643,7 +668,7 @@ def _build_default_sidecar( guard = ActionGuard( policy_engine=policy_engine, - mandate_signer=LocalMandateSigner(secret_key=secrets.token_hex(32)), + mandate_signer=LocalMandateSigner(secret_key=mandate_signing_key or secrets.token_hex(32)), proof_ledger=proof_ledger, ) return PredicateAuthoritySidecar( @@ -705,6 +730,22 @@ def _build_identity_bridge_from_args(args: argparse.Namespace) -> ExchangeTokenB raise SystemExit(f"Unsupported identity mode: {mode}") +def _resolve_mandate_signing_key( + signing_key_file: str | None, + signing_key_env: str, +) -> str: + if signing_key_file is not None and str(signing_key_file).strip() != "": + key_path = Path(signing_key_file) + if key_path.exists(): + loaded = key_path.read_text(encoding="utf-8").strip() + if loaded != "": + return loaded + env_value = os.getenv(signing_key_env) + if env_value is not None and env_value.strip() != "": + return env_value.strip() + return secrets.token_hex(32) + + def main() -> None: parser = argparse.ArgumentParser(description="predicate-authorityd sidecar daemon") parser.add_argument("--host", default="127.0.0.1") @@ -812,6 +853,16 @@ def main() -> None: ) parser.set_defaults(control_plane_fail_open=True) parser.add_argument("--control-plane-usage-credits-per-decision", type=int, default=1) + parser.add_argument( + "--mandate-signing-key-env", + default="PREDICATE_AUTHORITY_SIGNING_KEY", + help="Env var name for mandate signing key.", + ) + parser.add_argument( + "--mandate-signing-key-file", + default=None, + help="Optional file path containing mandate signing key.", + ) args = parser.parse_args() mode = AuthorityMode(args.mode) @@ -847,6 +898,10 @@ def main() -> None: default_ttl_seconds=max(1, int(args.local_identity_default_ttl_s)), ) identity_bridge = _build_identity_bridge_from_args(args) + mandate_signing_key = _resolve_mandate_signing_key( + signing_key_file=args.mandate_signing_key_file, + signing_key_env=args.mandate_signing_key_env, + ) sidecar = _build_default_sidecar( mode=mode, policy_file=args.policy_file, @@ -854,6 +909,7 @@ def main() -> None: control_plane_config=control_plane_bootstrap, local_identity_config=local_identity_bootstrap, identity_bridge=identity_bridge, + mandate_signing_key=mandate_signing_key, ) daemon = PredicateAuthorityDaemon( sidecar=sidecar, @@ -861,6 +917,7 @@ def main() -> None: host=args.host, port=args.port, policy_poll_interval_s=args.policy_poll_interval_s, + max_request_body_bytes=1_048_576, ), flush_worker=FlushWorkerConfig( enabled=bool(args.flush_worker_enabled), diff --git a/predicate_authority/guard.py b/predicate_authority/guard.py index 2331b0c..dea5171 100644 --- a/predicate_authority/guard.py +++ b/predicate_authority/guard.py @@ -36,8 +36,18 @@ def __init__( self._mandate_signer = mandate_signer self._proof_ledger = proof_ledger - def authorize(self, request: ActionRequest) -> AuthorizationDecision: - evaluation = self._policy_engine.evaluate(request) + def authorize( + self, + request: ActionRequest, + parent_mandate: SignedMandate | None = None, + ) -> AuthorizationDecision: + requested_delegation_depth = ( + parent_mandate.claims.delegation_depth + 1 if parent_mandate is not None else 0 + ) + evaluation = self._policy_engine.evaluate( + request, + delegation_depth=requested_delegation_depth, + ) if not evaluation.allowed: decision = AuthorizationDecision( allowed=False, @@ -48,7 +58,7 @@ def authorize(self, request: ActionRequest) -> AuthorizationDecision: self._proof_ledger.record(decision, request) return decision - mandate = self._mandate_signer.issue(request) + mandate = self._mandate_signer.issue(request, parent_mandate=parent_mandate) decision = AuthorizationDecision( allowed=True, reason=AuthorizationReason.ALLOWED, @@ -59,9 +69,12 @@ def authorize(self, request: ActionRequest) -> AuthorizationDecision: return decision def enforce( - self, action_callable: Callable[[], T], request: ActionRequest + self, + action_callable: Callable[[], T], + request: ActionRequest, + parent_mandate: SignedMandate | None = None, ) -> ActionExecutionResult[T]: - decision = self.authorize(request) + decision = self.authorize(request, parent_mandate=parent_mandate) if not decision.allowed or decision.mandate is None: raise AuthorizationDeniedError(decision) value = action_callable() diff --git a/predicate_authority/local_identity.py b/predicate_authority/local_identity.py index 86380c9..1c49310 100644 --- a/predicate_authority/local_identity.py +++ b/predicate_authority/local_identity.py @@ -259,11 +259,12 @@ def quarantine_queue_item(self, queue_item_id: str, reason: str) -> bool: return True def list_dead_letter_queue(self, limit: int | None = None) -> list[LedgerQueueItem]: - return self.list_flush_queue( + items = self.list_flush_queue( include_flushed=True, include_quarantined=True, limit=limit, ) + return [item for item in items if item.quarantined] def requeue_item(self, queue_item_id: str, reset_attempts: bool = True) -> bool: with self._lock: @@ -317,7 +318,10 @@ def _read_all_unlocked(self) -> dict[str, Any]: content = self._file_path.read_text(encoding="utf-8").strip() if content == "": return {"identities": {}, "flush_queue": {}} - loaded = json.loads(content) + try: + loaded = json.loads(content) + except json.JSONDecodeError: + return {"identities": {}, "flush_queue": {}} if isinstance(loaded, dict): if "identities" not in loaded: loaded["identities"] = {} @@ -356,7 +360,9 @@ def _parse_queue_item(self, raw: dict[str, Any]) -> LedgerQueueItem | None: return None def _write_all_unlocked(self, payload: dict[str, Any]) -> None: - self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + tmp_path = self._file_path.with_name(f"{self._file_path.name}.{uuid.uuid4().hex}.tmp") + tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + os.replace(tmp_path, self._file_path) self._chmod_file_safe() def _ensure_store_path(self) -> None: @@ -366,10 +372,12 @@ def _ensure_store_path(self) -> None: except OSError: pass if not self._file_path.exists(): - self._file_path.write_text( + tmp_path = self._file_path.with_name(f"{self._file_path.name}.{uuid.uuid4().hex}.tmp") + tmp_path.write_text( json.dumps({"identities": {}, "flush_queue": {}}, indent=2), encoding="utf-8", ) + os.replace(tmp_path, self._file_path) self._chmod_file_safe() def _chmod_file_safe(self) -> None: @@ -380,7 +388,7 @@ def _chmod_file_safe(self) -> None: @dataclass -class LocalLedgerQueueEmitter(TraceEmitter): +class LocalLedgerQueueEmitter: registry: LocalIdentityRegistry source: str = "predicate-authorityd" @@ -389,7 +397,7 @@ def emit(self, event: ProofEvent) -> None: @dataclass -class CompositeTraceEmitter(TraceEmitter): +class CompositeTraceEmitter: emitters: tuple[TraceEmitter, ...] def emit(self, event: ProofEvent) -> None: diff --git a/predicate_authority/mandate.py b/predicate_authority/mandate.py index 68346ba..ca854b4 100644 --- a/predicate_authority/mandate.py +++ b/predicate_authority/mandate.py @@ -17,7 +17,11 @@ def __init__(self, secret_key: str, ttl_seconds: int = 300) -> None: self._secret_key = secret_key.encode("utf-8") self._ttl_seconds = ttl_seconds - def issue(self, request: ActionRequest) -> SignedMandate: + def issue( + self, + request: ActionRequest, + parent_mandate: SignedMandate | None = None, + ) -> SignedMandate: issued_at = int(time.time()) expires_at = issued_at + self._ttl_seconds intent_hash = hashlib.sha256(request.action_spec.intent.encode("utf-8")).hexdigest() @@ -30,6 +34,18 @@ def issue(self, request: ActionRequest) -> SignedMandate: f"{issued_at}" ) mandate_id = hashlib.sha256(mandate_id_seed.encode("utf-8")).hexdigest()[:24] + delegated_by = parent_mandate.claims.principal_id if parent_mandate is not None else None + delegation_depth = ( + parent_mandate.claims.delegation_depth + 1 if parent_mandate is not None else 0 + ) + delegation_chain_hash = self._compute_delegation_chain_hash( + request=request, + mandate_id=mandate_id, + intent_hash=intent_hash, + delegated_by=delegated_by, + delegation_depth=delegation_depth, + parent_mandate=parent_mandate, + ) claims = MandateClaims( mandate_id=mandate_id, @@ -40,6 +56,9 @@ def issue(self, request: ActionRequest) -> SignedMandate: state_hash=request.state_evidence.state_hash, issued_at_epoch_s=issued_at, expires_at_epoch_s=expires_at, + delegated_by=delegated_by, + delegation_depth=delegation_depth, + delegation_chain_hash=delegation_chain_hash, ) token, signature = self._sign_claims(claims) return SignedMandate(token=token, claims=claims, signature=signature) @@ -66,8 +85,44 @@ def verify(self, token: str) -> SignedMandate | None: now_epoch = int(time.time()) if claims.expires_at_epoch_s < now_epoch: return None + if claims.delegation_depth < 0: + return None + if claims.delegation_depth == 0 and claims.delegated_by is not None: + return None + if claims.delegation_depth > 0 and claims.delegated_by is None: + return None + if claims.delegation_chain_hash is None: + return None return SignedMandate(token=token, claims=claims, signature=encoded_signature) + def verify_delegation( + self, + mandate: SignedMandate, + parent_mandate: SignedMandate | None = None, + ) -> bool: + claims = mandate.claims + if claims.delegation_chain_hash is None: + return False + if parent_mandate is None: + if claims.delegation_depth != 0 or claims.delegated_by is not None: + return False + expected_hash = self._compute_delegation_chain_hash_for_claims( + claims=claims, + parent_mandate=None, + ) + return hmac.compare_digest(expected_hash, claims.delegation_chain_hash) + + parent_claims = parent_mandate.claims + if claims.delegated_by != parent_claims.principal_id: + return False + if claims.delegation_depth != parent_claims.delegation_depth + 1: + return False + expected_hash = self._compute_delegation_chain_hash_for_claims( + claims=claims, + parent_mandate=parent_mandate, + ) + return hmac.compare_digest(expected_hash, claims.delegation_chain_hash) + def _sign_claims(self, claims: MandateClaims) -> tuple[str, str]: header_json = json.dumps( {"alg": "HS256", "typ": "JWT"}, separators=(",", ":"), sort_keys=True @@ -84,6 +139,61 @@ def _sign_claims(self, claims: MandateClaims) -> tuple[str, str]: def _hmac(self, payload: bytes) -> bytes: return hmac.new(self._secret_key, payload, hashlib.sha256).digest() + @staticmethod + def _compute_delegation_chain_hash( + request: ActionRequest, + mandate_id: str, + intent_hash: str, + delegated_by: str | None, + delegation_depth: int, + parent_mandate: SignedMandate | None, + ) -> str: + parent_chain = ( + parent_mandate.claims.delegation_chain_hash if parent_mandate is not None else "root" + ) + parent_mandate_id = ( + parent_mandate.claims.mandate_id if parent_mandate is not None else "none" + ) + chain_seed = ( + f"{parent_chain}|" + f"{parent_mandate_id}|" + f"{delegated_by or 'none'}|" + f"{delegation_depth}|" + f"{mandate_id}|" + f"{request.principal.principal_id}|" + f"{request.action_spec.action}|" + f"{request.action_spec.resource}|" + f"{intent_hash}|" + f"{request.state_evidence.state_hash}" + ) + return hashlib.sha256(chain_seed.encode("utf-8")).hexdigest() + + @classmethod + def _compute_delegation_chain_hash_for_claims( + cls, + claims: MandateClaims, + parent_mandate: SignedMandate | None, + ) -> str: + parent_chain = ( + parent_mandate.claims.delegation_chain_hash if parent_mandate is not None else "root" + ) + parent_mandate_id = ( + parent_mandate.claims.mandate_id if parent_mandate is not None else "none" + ) + chain_seed = ( + f"{parent_chain}|" + f"{parent_mandate_id}|" + f"{claims.delegated_by or 'none'}|" + f"{claims.delegation_depth}|" + f"{claims.mandate_id}|" + f"{claims.principal_id}|" + f"{claims.action}|" + f"{claims.resource}|" + f"{claims.intent_hash}|" + f"{claims.state_hash}" + ) + return hashlib.sha256(chain_seed.encode("utf-8")).hexdigest() + @staticmethod def _base64url_encode(value: bytes) -> str: return base64.urlsafe_b64encode(value).rstrip(b"=").decode("ascii") diff --git a/predicate_authority/policy.py b/predicate_authority/policy.py index 41ce33e..79685b4 100644 --- a/predicate_authority/policy.py +++ b/predicate_authority/policy.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from fnmatch import fnmatch +from threading import Lock from predicate_contracts import ActionRequest, AuthorizationReason, PolicyEffect, PolicyRule @@ -15,14 +16,38 @@ class PolicyMatchResult: class PolicyEngine: - def __init__(self, rules: tuple[PolicyRule, ...]) -> None: + def __init__( + self, + rules: tuple[PolicyRule, ...], + global_max_delegation_depth: int | None = None, + ) -> None: self._rules = rules + self._global_max_delegation_depth = global_max_delegation_depth + self._lock = Lock() def replace_rules(self, rules: tuple[PolicyRule, ...]) -> None: - self._rules = rules + with self._lock: + self._rules = rules + + def set_global_max_delegation_depth(self, max_depth: int | None) -> None: + with self._lock: + self._global_max_delegation_depth = max_depth + + def replace_policy( + self, + rules: tuple[PolicyRule, ...], + global_max_delegation_depth: int | None, + ) -> None: + with self._lock: + self._rules = rules + self._global_max_delegation_depth = global_max_delegation_depth + + def evaluate(self, request: ActionRequest, delegation_depth: int = 0) -> PolicyMatchResult: + with self._lock: + rules = self._rules + global_max_delegation_depth = self._global_max_delegation_depth - def evaluate(self, request: ActionRequest) -> PolicyMatchResult: - matching_rules = [rule for rule in self._rules if self._matches_rule(rule, request)] + matching_rules = [rule for rule in rules if self._matches_rule(rule, request)] if not matching_rules: return PolicyMatchResult( allowed=False, @@ -37,22 +62,40 @@ def evaluate(self, request: ActionRequest) -> PolicyMatchResult: matched_rule=rule.name, ) + first_allow_failure: PolicyMatchResult | None = None for rule in matching_rules: if rule.effect != PolicyEffect.ALLOW: continue + effective_max_depth = self._effective_max_delegation_depth( + global_max_delegation_depth, + rule.max_delegation_depth, + ) + if effective_max_depth is not None and delegation_depth > effective_max_depth: + failure = PolicyMatchResult( + allowed=False, + reason=AuthorizationReason.MAX_DELEGATION_DEPTH_EXCEEDED, + matched_rule=rule.name, + ) + if first_allow_failure is None: + first_allow_failure = failure + continue + missing_labels = tuple( label for label in rule.required_labels if not request.verification_evidence.is_label_passed(label) ) if missing_labels: - return PolicyMatchResult( + failure = PolicyMatchResult( allowed=False, reason=AuthorizationReason.MISSING_REQUIRED_VERIFICATION, matched_rule=rule.name, missing_labels=missing_labels, ) + if first_allow_failure is None: + first_allow_failure = failure + continue return PolicyMatchResult( allowed=True, @@ -60,6 +103,9 @@ def evaluate(self, request: ActionRequest) -> PolicyMatchResult: matched_rule=rule.name, ) + if first_allow_failure is not None: + return first_allow_failure + return PolicyMatchResult( allowed=False, reason=AuthorizationReason.NO_MATCHING_POLICY, @@ -75,3 +121,14 @@ def _matches_rule(rule: PolicyRule, request: ActionRequest) -> bool: fnmatch(request.action_spec.resource, pattern) for pattern in rule.resources ) return principal_ok and action_ok and resource_ok + + @staticmethod + def _effective_max_delegation_depth( + global_max: int | None, + rule_max: int | None, + ) -> int | None: + if global_max is None: + return rule_max + if rule_max is None: + return global_max + return min(global_max, rule_max) diff --git a/predicate_authority/policy_source.py b/predicate_authority/policy_source.py index 69d1447..9fa2a94 100644 --- a/predicate_authority/policy_source.py +++ b/predicate_authority/policy_source.py @@ -3,6 +3,7 @@ import json from dataclasses import dataclass from pathlib import Path +from typing import Any from predicate_contracts import PolicyEffect, PolicyRule @@ -11,6 +12,7 @@ class PolicyReloadResult: changed: bool rules: tuple[PolicyRule, ...] + global_max_delegation_depth: int | None = None class PolicyFileSource: @@ -19,7 +21,11 @@ def __init__(self, policy_path: str) -> None: self._last_mtime_ns: int | None = None def load_rules(self) -> tuple[PolicyRule, ...]: - payload = json.loads(self._policy_path.read_text(encoding="utf-8")) + rules, _ = self.load_policy() + return rules + + def load_policy(self) -> tuple[tuple[PolicyRule, ...], int | None]: + payload = self._load_payload(self._policy_path.read_text(encoding="utf-8")) rules_payload = payload.get("rules", []) rules: list[PolicyRule] = [] for item in rules_payload: @@ -31,15 +37,46 @@ def load_rules(self) -> tuple[PolicyRule, ...]: actions=tuple(item["actions"]), resources=tuple(item["resources"]), required_labels=tuple(item.get("required_labels", [])), + max_delegation_depth=( + int(item["max_delegation_depth"]) + if item.get("max_delegation_depth") is not None + else None + ), ) ) stat = self._policy_path.stat() self._last_mtime_ns = stat.st_mtime_ns - return tuple(rules) + global_max_delegation_depth = ( + int(payload["global_max_delegation_depth"]) + if payload.get("global_max_delegation_depth") is not None + else None + ) + return tuple(rules), global_max_delegation_depth + + def _load_payload(self, raw: str) -> dict[str, Any]: + suffix = self._policy_path.suffix.lower() + if suffix in {".yaml", ".yml"}: + try: + import yaml # type: ignore[import-untyped] + except ImportError as exc: # pragma: no cover - env-dependent + raise RuntimeError( + "YAML policy files require PyYAML. Install with: pip install pyyaml" + ) from exc + loaded = yaml.safe_load(raw) + else: + loaded = json.loads(raw) + + if not isinstance(loaded, dict): + raise RuntimeError("Policy file must deserialize to an object.") + return loaded def reload_if_changed(self) -> PolicyReloadResult: stat = self._policy_path.stat() if self._last_mtime_ns is None or stat.st_mtime_ns != self._last_mtime_ns: - rules = self.load_rules() - return PolicyReloadResult(changed=True, rules=rules) - return PolicyReloadResult(changed=False, rules=()) + rules, global_max_delegation_depth = self.load_policy() + return PolicyReloadResult( + changed=True, + rules=rules, + global_max_delegation_depth=global_max_delegation_depth, + ) + return PolicyReloadResult(changed=False, rules=(), global_max_delegation_depth=None) diff --git a/predicate_authority/proof.py b/predicate_authority/proof.py index e78471e..e79c388 100644 --- a/predicate_authority/proof.py +++ b/predicate_authority/proof.py @@ -2,6 +2,7 @@ import time from dataclasses import dataclass, field +from threading import Lock from predicate_contracts import ActionRequest, AuthorizationDecision, ProofEvent, TraceEmitter @@ -10,6 +11,7 @@ class InMemoryProofLedger: trace_emitter: TraceEmitter | None = None events: list[ProofEvent] = field(default_factory=list) + _lock: Lock = field(default_factory=Lock) def record(self, decision: AuthorizationDecision, request: ActionRequest) -> ProofEvent: event = ProofEvent( @@ -22,7 +24,13 @@ def record(self, decision: AuthorizationDecision, request: ActionRequest) -> Pro mandate_id=decision.mandate.claims.mandate_id if decision.mandate else None, emitted_at_epoch_s=int(time.time()), ) - self.events.append(event) - if self.trace_emitter is not None: - self.trace_emitter.emit(event) + with self._lock: + self.events.append(event) + trace_emitter = self.trace_emitter + if trace_emitter is not None: + trace_emitter.emit(event) return event + + def event_count(self) -> int: + with self._lock: + return len(self.events) diff --git a/predicate_authority/pyproject.toml b/predicate_authority/pyproject.toml index bb0c123..17c9088 100644 --- a/predicate_authority/pyproject.toml +++ b/predicate_authority/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ ] dependencies = [ "predicate-contracts>=0.1.0,<0.2.0", + "pyyaml>=6.0", ] [project.scripts] diff --git a/predicate_authority/revocation.py b/predicate_authority/revocation.py index de25f16..b4bda6b 100644 --- a/predicate_authority/revocation.py +++ b/predicate_authority/revocation.py @@ -1,35 +1,68 @@ from __future__ import annotations import hashlib -from dataclasses import dataclass, field +from threading import Lock from predicate_contracts import ActionRequest, SignedMandate -@dataclass class LocalRevocationCache: - revoked_principal_ids: set[str] = field(default_factory=set) - revoked_intent_hashes: set[str] = field(default_factory=set) - revoked_mandate_ids: set[str] = field(default_factory=set) + def __init__(self) -> None: + self._revoked_principal_ids: set[str] = set() + self._revoked_intent_hashes: set[str] = set() + self._revoked_mandate_ids: set[str] = set() + self._lock = Lock() + + @property + def revoked_principal_ids(self) -> set[str]: + with self._lock: + return set(self._revoked_principal_ids) + + @property + def revoked_intent_hashes(self) -> set[str]: + with self._lock: + return set(self._revoked_intent_hashes) + + @property + def revoked_mandate_ids(self) -> set[str]: + with self._lock: + return set(self._revoked_mandate_ids) + + def revoked_principal_count(self) -> int: + with self._lock: + return len(self._revoked_principal_ids) + + def revoked_intent_count(self) -> int: + with self._lock: + return len(self._revoked_intent_hashes) + + def revoked_mandate_count(self) -> int: + with self._lock: + return len(self._revoked_mandate_ids) def revoke_principal(self, principal_id: str) -> None: - self.revoked_principal_ids.add(principal_id) + with self._lock: + self._revoked_principal_ids.add(principal_id) def revoke_intent_hash(self, intent_hash: str) -> None: - self.revoked_intent_hashes.add(intent_hash) + with self._lock: + self._revoked_intent_hashes.add(intent_hash) def revoke_mandate_id(self, mandate_id: str) -> None: - self.revoked_mandate_ids.add(mandate_id) + with self._lock: + self._revoked_mandate_ids.add(mandate_id) def is_request_revoked(self, request: ActionRequest) -> bool: - if request.principal.principal_id in self.revoked_principal_ids: - return True - intent_hash = hashlib.sha256(request.action_spec.intent.encode("utf-8")).hexdigest() - return intent_hash in self.revoked_intent_hashes + with self._lock: + if request.principal.principal_id in self._revoked_principal_ids: + return True + intent_hash = hashlib.sha256(request.action_spec.intent.encode("utf-8")).hexdigest() + return intent_hash in self._revoked_intent_hashes def is_mandate_revoked(self, mandate: SignedMandate) -> bool: - if mandate.claims.principal_id in self.revoked_principal_ids: - return True - if mandate.claims.intent_hash in self.revoked_intent_hashes: - return True - return mandate.claims.mandate_id in self.revoked_mandate_ids + with self._lock: + if mandate.claims.principal_id in self._revoked_principal_ids: + return True + if mandate.claims.intent_hash in self._revoked_intent_hashes: + return True + return mandate.claims.mandate_id in self._revoked_mandate_ids diff --git a/predicate_authority/sidecar.py b/predicate_authority/sidecar.py index b96c4fd..338faab 100644 --- a/predicate_authority/sidecar.py +++ b/predicate_authority/sidecar.py @@ -152,12 +152,18 @@ def revoke_by_invariant(self, principal_id: str) -> None: def revoke_intent_hash(self, intent_hash: str) -> None: self._revocation_cache.revoke_intent_hash(intent_hash) + def revoke_mandate_id(self, mandate_id: str) -> None: + self._revocation_cache.revoke_mandate_id(mandate_id) + def hot_reload_policy(self) -> bool: if self._policy_source is None: return False result = self._policy_source.reload_if_changed() if result.changed: - self._policy_engine.replace_rules(result.rules) + self._policy_engine.replace_policy( + rules=result.rules, + global_max_delegation_depth=result.global_max_delegation_depth, + ) return True return False @@ -174,10 +180,10 @@ def status(self) -> SidecarStatus: return SidecarStatus( mode=self._config.mode, policy_hot_reload_enabled=self._policy_source is not None, - revoked_principal_count=len(self._revocation_cache.revoked_principal_ids), - revoked_intent_count=len(self._revocation_cache.revoked_intent_hashes), - revoked_mandate_count=len(self._revocation_cache.revoked_mandate_ids), - proof_event_count=len(self._proof_ledger.events), + revoked_principal_count=self._revocation_cache.revoked_principal_count(), + revoked_intent_count=self._revocation_cache.revoked_intent_count(), + revoked_mandate_count=self._revocation_cache.revoked_mandate_count(), + proof_event_count=self._proof_ledger.event_count(), control_plane_emitter_attached=control_plane_attached, control_plane_audit_push_success_count=int( control_plane_payload.get("control_plane_audit_push_success_count", 0) diff --git a/predicate_authority/sidecar_store.py b/predicate_authority/sidecar_store.py index 745b18f..dc02265 100644 --- a/predicate_authority/sidecar_store.py +++ b/predicate_authority/sidecar_store.py @@ -2,8 +2,11 @@ import json import os +import time +import uuid from dataclasses import asdict, dataclass from pathlib import Path +from threading import Lock from typing import Any @@ -22,43 +25,55 @@ class LocalCredentialStore: def __init__(self, file_path: str) -> None: self._file_path = Path(file_path) + self._lock = Lock() self._ensure_store_path() def save(self, record: CredentialRecord) -> None: - payload = self._read_all() - payload[record.principal_id] = asdict(record) - self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") - self._chmod_file_safe() + with self._lock: + payload = self._read_all_unlocked() + payload[record.principal_id] = asdict(record) + self._write_all_unlocked(payload) def get(self, principal_id: str) -> CredentialRecord | None: - payload = self._read_all() - item = payload.get(principal_id) - if not isinstance(item, dict): - return None - item_principal = item.get("principal_id") - item_refresh = item.get("refresh_token") - item_expires = item.get("expires_at_epoch_s") - if not isinstance(item_principal, str) or not isinstance(item_refresh, str): - return None - if not isinstance(item_expires, (int, str)): - return None - return CredentialRecord( - principal_id=item_principal, - refresh_token=item_refresh, - expires_at_epoch_s=int(item_expires), - ) + with self._lock: + payload = self._read_all_unlocked() + item = payload.get(principal_id) + if not isinstance(item, dict): + return None + item_principal = item.get("principal_id") + item_refresh = item.get("refresh_token") + item_expires = item.get("expires_at_epoch_s") + if not isinstance(item_principal, str) or not isinstance(item_refresh, str): + return None + if not isinstance(item_expires, (int, str)): + return None + expires_at = int(item_expires) + if expires_at <= int(time.time()): + return None + return CredentialRecord( + principal_id=item_principal, + refresh_token=item_refresh, + expires_at_epoch_s=expires_at, + ) - def _read_all(self) -> dict[str, Any]: + def _read_all_unlocked(self) -> dict[str, Any]: if not self._file_path.exists(): return {} content = self._file_path.read_text(encoding="utf-8").strip() if content == "": return {} - loaded = json.loads(content) + try: + loaded = json.loads(content) + except json.JSONDecodeError: + return {} if isinstance(loaded, dict): return loaded return {} + def _write_all_unlocked(self, payload: dict[str, Any]) -> None: + self._atomic_write_json(payload) + self._chmod_file_safe() + def _ensure_store_path(self) -> None: self._file_path.parent.mkdir(parents=True, exist_ok=True) try: @@ -66,7 +81,7 @@ def _ensure_store_path(self) -> None: except OSError: pass if not self._file_path.exists(): - self._file_path.write_text("{}", encoding="utf-8") + self._atomic_write_json({}) self._chmod_file_safe() def _chmod_file_safe(self) -> None: @@ -74,3 +89,8 @@ def _chmod_file_safe(self) -> None: os.chmod(self._file_path, 0o600) except OSError: pass + + def _atomic_write_json(self, payload: dict[str, Any]) -> None: + tmp_path = self._file_path.with_name(f"{self._file_path.name}.{uuid.uuid4().hex}.tmp") + tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + os.replace(tmp_path, self._file_path) diff --git a/predicate_authority/telemetry.py b/predicate_authority/telemetry.py index 741eb28..cb9d05e 100644 --- a/predicate_authority/telemetry.py +++ b/predicate_authority/telemetry.py @@ -3,7 +3,7 @@ from contextlib import AbstractContextManager from typing import Protocol, cast -from predicate_contracts import ProofEvent, TraceEmitter +from predicate_contracts import ProofEvent class SpanLike(Protocol): @@ -14,7 +14,7 @@ class TracerLike(Protocol): def start_as_current_span(self, name: str) -> AbstractContextManager[SpanLike]: ... -class OpenTelemetryTraceEmitter(TraceEmitter): +class OpenTelemetryTraceEmitter: """TraceEmitter backed by OpenTelemetry spans/events.""" def __init__(self, tracer: TracerLike | None = None) -> None: diff --git a/predicate_contracts/models.py b/predicate_contracts/models.py index 5c25505..d0a3b85 100644 --- a/predicate_contracts/models.py +++ b/predicate_contracts/models.py @@ -20,6 +20,7 @@ class AuthorizationReason(str, Enum): NO_MATCHING_POLICY = "no_matching_policy" EXPLICIT_DENY = "explicit_deny" MISSING_REQUIRED_VERIFICATION = "missing_required_verification" + MAX_DELEGATION_DEPTH_EXCEEDED = "max_delegation_depth_exceeded" INVALID_MANDATE = "invalid_mandate" @@ -80,6 +81,7 @@ class PolicyRule: actions: tuple[str, ...] resources: tuple[str, ...] required_labels: tuple[str, ...] = field(default_factory=tuple) + max_delegation_depth: int | None = None @dataclass(frozen=True) @@ -92,6 +94,9 @@ class MandateClaims: state_hash: str issued_at_epoch_s: int expires_at_epoch_s: int + delegated_by: str | None = None + delegation_depth: int = 0 + delegation_chain_hash: str | None = None @dataclass(frozen=True) diff --git a/tests/test_authority_client_local_yaml.py b/tests/test_authority_client_local_yaml.py new file mode 100644 index 0000000..78bcc59 --- /dev/null +++ b/tests/test_authority_client_local_yaml.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from pathlib import Path + +from pytest import MonkeyPatch + +# pylint: disable=import-error +from predicate_authority import AuthorityClient +from predicate_contracts import ( + ActionRequest, + ActionSpec, + AuthorizationReason, + PrincipalRef, + StateEvidence, + VerificationEvidence, +) + + +def test_authority_client_mint_and_verify_with_local_yaml_policy(tmp_path: Path) -> None: + policy = tmp_path / "policy.yaml" + policy.write_text( + "\n".join( + [ + "rules:", + " - name: allow-orders-create", + " effect: allow", + " principals:", + " - agent:checkout", + " actions:", + " - http.post", + " resources:", + " - https://api.vendor.com/orders", + ] + ), + encoding="utf-8", + ) + context = AuthorityClient.from_policy_file(str(policy), secret_key="local-test-secret") + client = context.client + + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:checkout"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="submit order", + ), + state_evidence=StateEvidence(source="unit-test", state_hash="sha256:test"), + verification_evidence=VerificationEvidence(), + ) + decision = client.authorize(request) + + assert decision.allowed + assert decision.mandate is not None + verified = client.verify_token(decision.mandate.token) + assert verified is not None + assert verified.claims.principal_id == "agent:checkout" + + +def test_authority_client_global_max_depth_from_yaml_is_enforced(tmp_path: Path) -> None: + policy = tmp_path / "policy.yaml" + policy.write_text( + "\n".join( + [ + "global_max_delegation_depth: 0", + "rules:", + " - name: allow-orders-create", + " effect: allow", + " principals:", + " - agent:checkout", + " actions:", + " - http.post", + " resources:", + " - https://api.vendor.com/orders", + ] + ), + encoding="utf-8", + ) + context = AuthorityClient.from_policy_file(str(policy), secret_key="local-test-secret") + client = context.client + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:checkout"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="submit order", + ), + state_evidence=StateEvidence(source="unit-test", state_hash="sha256:test"), + verification_evidence=VerificationEvidence(), + ) + root = client.authorize(request) + assert root.allowed is True + assert root.mandate is not None + + child = client.authorize(request, parent_mandate=root.mandate) + assert child.allowed is False + assert child.reason == AuthorizationReason.MAX_DELEGATION_DEPTH_EXCEEDED + + +def test_authority_client_from_env(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: + policy = tmp_path / "policy.yaml" + policy.write_text( + "\n".join( + [ + "rules:", + " - name: allow-orders-create", + " effect: allow", + " principals:", + " - agent:checkout", + " actions:", + " - http.post", + " resources:", + " - https://api.vendor.com/orders", + ] + ), + encoding="utf-8", + ) + monkeypatch.setenv("PREDICATE_AUTHORITY_POLICY_FILE", str(policy)) + monkeypatch.setenv("PREDICATE_AUTHORITY_SIGNING_KEY", "env-test-secret") + monkeypatch.setenv("PREDICATE_AUTHORITY_MANDATE_TTL_SECONDS", "120") + context = AuthorityClient.from_env() + assert context.policy_file == str(policy) diff --git a/tests/test_daemon_phase2.py b/tests/test_daemon_phase2.py index 0e9ca12..5fc4c3b 100644 --- a/tests/test_daemon_phase2.py +++ b/tests/test_daemon_phase2.py @@ -11,6 +11,7 @@ from typing import Any from urllib.parse import urlsplit +# pylint: disable=import-error from predicate_authority import ( ActionGuard, AuthorityMode, @@ -42,8 +43,6 @@ VerificationEvidence, ) -# pylint: disable=import-error - def _build_sidecar(tmp_path: Path, policy_file: Path) -> PredicateAuthoritySidecar: policy_engine = PolicyEngine( @@ -191,6 +190,20 @@ def test_daemon_supports_policy_reload_and_revoke_endpoints(tmp_path: Path) -> N daemon.start() try: base_url = f"http://127.0.0.1:{daemon.bound_port}" + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:test"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="create order", + ), + state_evidence=StateEvidence(source="test", state_hash="state-1"), + verification_evidence=VerificationEvidence(), + ) + decision = sidecar.issue_mandate(request) + assert decision.allowed is True + assert decision.mandate is not None + reloaded = _post_json(f"{base_url}/policy/reload") revoke_principal = _post_json( f"{base_url}/revoke/principal", {"principal_id": "agent:test-revoked"} @@ -199,13 +212,19 @@ def test_daemon_supports_policy_reload_and_revoke_endpoints(tmp_path: Path) -> N f"{base_url}/revoke/intent", {"intent_hash": "abc123-intent-hash"}, ) + revoke_mandate = _post_json( + f"{base_url}/revoke/mandate", + {"mandate_id": decision.mandate.claims.mandate_id}, + ) status = _fetch_json(f"{base_url}/status") assert "reloaded" in reloaded assert revoke_principal["ok"] is True assert revoke_intent["ok"] is True + assert revoke_mandate["ok"] is True assert int(status["revoked_principal_count"]) >= 1 assert int(status["revoked_intent_count"]) >= 1 + assert int(status["revoked_mandate_count"]) >= 1 finally: daemon.stop() @@ -227,7 +246,8 @@ def do_POST(self) -> None: # noqa: N802 self.end_headers() self.wfile.write(b"{}") - def log_message(self, format: str, *args: Any) -> None: # noqa: A003 + def log_message(self, fmt: str, *args: Any) -> None: # noqa: A003 + _ = fmt return @@ -259,7 +279,8 @@ def do_POST(self) -> None: # noqa: N802 self.end_headers() self.wfile.write(b'{"error":"temporary_failure"}') - def log_message(self, format: str, *args: Any) -> None: # noqa: A003 + def log_message(self, fmt: str, *args: Any) -> None: # noqa: A003 + _ = fmt return server = ThreadingHTTPServer(("127.0.0.1", 0), FailingHandler) diff --git a/tests/test_delegation_simulation.py b/tests/test_delegation_simulation.py new file mode 100644 index 0000000..bbf8e37 --- /dev/null +++ b/tests/test_delegation_simulation.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +def test_delegate_worker_revocation_blocks_worker(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.yaml" + policy_file.write_text( + "\n".join( + [ + "rules:", + " - name: allow-delegate-task", + " effect: allow", + " principals:", + " - agent:root", + " actions:", + " - task.delegate", + " resources:", + " - worker:queue/*", + " max_delegation_depth: 1", + " - name: allow-worker-execute", + " effect: allow", + " principals:", + " - agent:worker", + " actions:", + " - job.execute", + " resources:", + " - queue://jobs/*", + " max_delegation_depth: 1", + ] + ), + encoding="utf-8", + ) + revocation_file = tmp_path / "revocations.json" + + repo_root = Path(__file__).resolve().parents[1] + delegate_script = repo_root / "examples" / "delegation" / "delegate.py" + worker_script = repo_root / "examples" / "delegation" / "worker.py" + + command = [ + sys.executable, + str(delegate_script), + "--policy-file", + str(policy_file), + "--worker-script", + str(worker_script), + "--revocation-file", + str(revocation_file), + "--secret-key", + "delegation-test-secret", + ] + result = subprocess.run(command, check=True, capture_output=True, text=True) # noqa: S603 + payload = json.loads(result.stdout) + + assert payload["root_allowed"] is True + assert payload["root_delegation_depth"] == 0 + assert payload["worker_allowed_before_revoke"] is True + assert payload["worker_delegation_depth_before_revoke"] == 1 + assert payload["worker_chain_verified_before_revoke"] is True + assert payload["worker_allowed_after_revoke"] is False + assert payload["after_reason"] == "revoked_root_token" diff --git a/tests/test_mandate_signer.py b/tests/test_mandate_signer.py index 4ec7a47..1534a9b 100644 --- a/tests/test_mandate_signer.py +++ b/tests/test_mandate_signer.py @@ -1,5 +1,6 @@ from __future__ import annotations +# pylint: disable=import-error from predicate_authority import LocalMandateSigner from predicate_contracts import ( ActionRequest, @@ -27,6 +28,9 @@ def test_mandate_signature_verifies() -> None: assert verified is not None assert verified.claims.mandate_id == signed.claims.mandate_id assert verified.claims.intent_hash == signed.claims.intent_hash + assert verified.claims.delegation_depth == 0 + assert verified.claims.delegated_by is None + assert verified.claims.delegation_chain_hash is not None def test_mandate_tamper_is_rejected() -> None: @@ -43,3 +47,40 @@ def test_mandate_tamper_is_rejected() -> None: tampered = signed.token[:-1] + ("A" if signed.token[-1] != "A" else "B") assert signer.verify(tampered) is None + + +def test_multi_hop_delegation_claims_and_chain_verification() -> None: + signer = LocalMandateSigner(secret_key="test-key", ttl_seconds=60) + root_request = ActionRequest( + principal=PrincipalRef(principal_id="agent:root"), + action_spec=ActionSpec( + action="task.delegate", + resource="worker:queue/main", + intent="delegate job", + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-root"), + verification_evidence=VerificationEvidence(), + ) + root_mandate = signer.issue(root_request) + + child_request = ActionRequest( + principal=PrincipalRef(principal_id="agent:worker"), + action_spec=ActionSpec( + action="job.execute", + resource="queue://jobs/high-priority", + intent="execute delegated job", + ), + state_evidence=StateEvidence(source="non-web", state_hash="state-worker"), + verification_evidence=VerificationEvidence(), + ) + child_mandate = signer.issue(child_request, parent_mandate=root_mandate) + + assert root_mandate.claims.delegation_depth == 0 + assert root_mandate.claims.delegated_by is None + assert child_mandate.claims.delegation_depth == 1 + assert child_mandate.claims.delegated_by == "agent:root" + assert child_mandate.claims.delegation_chain_hash is not None + + assert signer.verify_delegation(root_mandate, parent_mandate=None) is True + assert signer.verify_delegation(child_mandate, parent_mandate=root_mandate) is True + assert signer.verify_delegation(child_mandate, parent_mandate=None) is False diff --git a/tests/test_policy_and_guard.py b/tests/test_policy_and_guard.py index 4f61b1c..82e2cae 100644 --- a/tests/test_policy_and_guard.py +++ b/tests/test_policy_and_guard.py @@ -2,6 +2,7 @@ import pytest +# pylint: disable=import-error from predicate_authority import ( ActionGuard, AuthorizationDeniedError, @@ -102,3 +103,60 @@ def test_enforce_raises_when_denied() -> None: with pytest.raises(AuthorizationDeniedError): guard.enforce(lambda: "should-not-run", request) + + +def test_authorize_denies_when_global_delegation_depth_exceeded() -> None: + rules = ( + PolicyRule( + name="allow-checkout", + effect=PolicyEffect.ALLOW, + principals=("agent:*",), + actions=("http.*",), + resources=("https://api.vendor.com/*",), + ), + ) + guard = ActionGuard( + policy_engine=PolicyEngine(rules=rules, global_max_delegation_depth=0), + mandate_signer=LocalMandateSigner(secret_key="dev-secret", ttl_seconds=120), + proof_ledger=InMemoryProofLedger(), + ) + root_request = _build_request(with_verified_label=True) + root = guard.authorize(root_request) + assert root.allowed is True + assert root.mandate is not None + + child = guard.authorize(_build_request(with_verified_label=True), parent_mandate=root.mandate) + assert child.allowed is False + assert child.reason == AuthorizationReason.MAX_DELEGATION_DEPTH_EXCEEDED + + +def test_authorize_per_rule_depth_cap_overrides_higher_global() -> None: + rules = ( + PolicyRule( + name="allow-checkout", + effect=PolicyEffect.ALLOW, + principals=("agent:*",), + actions=("http.*",), + resources=("https://api.vendor.com/*",), + max_delegation_depth=1, + ), + ) + guard = ActionGuard( + policy_engine=PolicyEngine(rules=rules, global_max_delegation_depth=5), + mandate_signer=LocalMandateSigner(secret_key="dev-secret", ttl_seconds=120), + proof_ledger=InMemoryProofLedger(), + ) + root = guard.authorize(_build_request(with_verified_label=True)) + assert root.allowed is True + assert root.mandate is not None + + child = guard.authorize(_build_request(with_verified_label=True), parent_mandate=root.mandate) + assert child.allowed is True + assert child.mandate is not None + + grandchild = guard.authorize( + _build_request(with_verified_label=True), + parent_mandate=child.mandate, + ) + assert grandchild.allowed is False + assert grandchild.reason == AuthorizationReason.MAX_DELEGATION_DEPTH_EXCEEDED diff --git a/tests/test_sidecar_store.py b/tests/test_sidecar_store.py new file mode 100644 index 0000000..aeb3c43 --- /dev/null +++ b/tests/test_sidecar_store.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import time +from pathlib import Path + +# pylint: disable=import-error +from predicate_authority import CredentialRecord, LocalCredentialStore + + +def test_credential_store_get_returns_none_for_expired_record(tmp_path: Path) -> None: + store = LocalCredentialStore(str(tmp_path / "credentials.json")) + store.save( + CredentialRecord( + principal_id="agent:expired", + refresh_token="expired-token", + expires_at_epoch_s=int(time.time()) - 10, + ) + ) + assert store.get("agent:expired") is None + + +def test_credential_store_handles_corrupt_json_file(tmp_path: Path) -> None: + file_path = tmp_path / "credentials.json" + file_path.write_text("{this-is-not-json}", encoding="utf-8") + store = LocalCredentialStore(str(file_path)) + assert store.get("agent:any") is None