From b5837844767372af2e7893f19f611ca8467d2d7c Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Mon, 16 Feb 2026 21:34:25 -0800 Subject: [PATCH 1/2] local IdP --- README.md | 47 +++++ docs/authorityd-operations.md | 128 ++++++++++++++ predicate_authority/README.md | 3 +- predicate_authority/__init__.py | 16 ++ predicate_authority/bridge.py | 94 ++++++++++ predicate_authority/control_plane.py | 197 +++++++++++++++++++++ predicate_authority/daemon.py | 218 +++++++++++++++++++++++- predicate_authority/sidecar.py | 31 ++++ tests/test_control_plane_integration.py | 169 ++++++++++++++++++ tests/test_daemon_phase2.py | 148 +++++++++++++++- tests/test_identity_bridge_phase2.py | 42 +++++ 11 files changed, 1085 insertions(+), 8 deletions(-) create mode 100644 docs/authorityd-operations.md create mode 100644 predicate_authority/control_plane.py create mode 100644 tests/test_control_plane_integration.py diff --git a/README.md b/README.md index 445c4fd..ba3325c 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,53 @@ predicate-authority revoke intent --host 127.0.0.1 --port 8787 --hash `/v1/audit/events:batch` +- usage credits -> `/v1/metering/usage:batch` + +Expected startup output: + +```text +predicate-authorityd listening on http://127.0.0.1:8787 (mode=local_only) +``` + +## 3) Endpoint checks + +### Health + +```bash +curl -s http://127.0.0.1:8787/health | jq +``` + +Example response: + +```json +{ + "status": "ok", + "mode": "local_only", + "uptime_s": 12 +} +``` + +### Status + +```bash +curl -s http://127.0.0.1:8787/status | jq +``` + +Example response: + +```json +{ + "mode": "local_only", + "policy_hot_reload_enabled": true, + "revoked_principal_count": 0, + "revoked_intent_count": 0, + "revoked_mandate_count": 0, + "proof_event_count": 0, + "daemon_running": true, + "policy_reload_count": 1, + "policy_poll_error_count": 0, + "last_policy_reload_epoch_s": 1700000000.0, + "last_policy_poll_error": null +} +``` + +## 4) Verify policy hot-reload + +1. Update `examples/authorityd/policy.json`. +2. Wait for at most `--policy-poll-interval-s`. +3. Check `/status` and confirm `policy_reload_count` increases. + +## 5) Stop daemon + +Press `Ctrl+C` in the daemon terminal. diff --git a/predicate_authority/README.md b/predicate_authority/README.md index e933530..ebc5797 100644 --- a/predicate_authority/README.md +++ b/predicate_authority/README.md @@ -8,4 +8,5 @@ Core pieces: - `ActionGuard` for pre-action `authorize` / `enforce`, - `LocalMandateSigner` for signed short-lived mandates, - `InMemoryProofLedger` and optional `OpenTelemetryTraceEmitter`, -- typed integration adapters (including `sdk-python` mapping helpers). +- typed integration adapters (including `sdk-python` mapping helpers), +- control-plane client primitives for shipping proof and usage batches to hosted APIs. diff --git a/predicate_authority/__init__.py b/predicate_authority/__init__.py index b431480..1e569ac 100644 --- a/predicate_authority/__init__.py +++ b/predicate_authority/__init__.py @@ -3,10 +3,19 @@ EntraIdentityBridge, IdentityBridge, IdentityProviderType, + LocalIdPBridge, + LocalIdPBridgeConfig, OIDCBridgeConfig, OIDCIdentityBridge, TokenExchangeResult, ) +from predicate_authority.control_plane import ( + AuditEventEnvelope, + ControlPlaneClient, + ControlPlaneClientConfig, + ControlPlaneTraceEmitter, + UsageCreditRecord, +) from predicate_authority.daemon import DaemonConfig, PredicateAuthorityDaemon from predicate_authority.errors import AuthorizationDeniedError from predicate_authority.guard import ActionExecutionResult, ActionGuard @@ -30,6 +39,10 @@ "ActionGuard", "AuthorityMode", "AuthorizationDeniedError", + "AuditEventEnvelope", + "ControlPlaneClient", + "ControlPlaneClientConfig", + "ControlPlaneTraceEmitter", "CredentialRecord", "DaemonConfig", "EntraBridgeConfig", @@ -37,6 +50,8 @@ "IdentityBridge", "IdentityProviderType", "InMemoryProofLedger", + "LocalIdPBridge", + "LocalIdPBridgeConfig", "LocalCredentialStore", "LocalMandateSigner", "LocalRevocationCache", @@ -53,4 +68,5 @@ "SidecarError", "SidecarStatus", "TokenExchangeResult", + "UsageCreditRecord", ] diff --git a/predicate_authority/bridge.py b/predicate_authority/bridge.py index 8460e2a..0786666 100644 --- a/predicate_authority/bridge.py +++ b/predicate_authority/bridge.py @@ -1,7 +1,11 @@ from __future__ import annotations +import base64 import hashlib +import hmac +import json import time +from collections.abc import Mapping from dataclasses import dataclass from enum import Enum @@ -10,6 +14,7 @@ class IdentityProviderType(str, Enum): LOCAL = "local" + LOCAL_IDP = "local_idp" OIDC = "oidc" ENTRA = "entra" OKTA = "okta" @@ -39,6 +44,14 @@ class EntraBridgeConfig: token_ttl_seconds: int = 300 +@dataclass(frozen=True) +class LocalIdPBridgeConfig: + issuer: str = "http://localhost/predicate-local-idp" + audience: str = "api://predicate-authority" + signing_key: str = "predicate-local-idp-dev-key" + token_ttl_seconds: int = 300 + + class IdentityBridge: """Local bridge implementation for development/local-only mode.""" @@ -120,3 +133,84 @@ def exchange_token( token_type=result.token_type, provider=IdentityProviderType.ENTRA, ) + + +class LocalIdPBridge: + """Local IdP emulator for dev/offline/air-gapped workflows.""" + + def __init__(self, config: LocalIdPBridgeConfig) -> None: + self._config = config + + def exchange_token( + self, subject: PrincipalRef, state_evidence: StateEvidence + ) -> TokenExchangeResult: + expires_at = int(time.time()) + self._config.token_ttl_seconds + token = self._mint_token( + subject=subject, + state_evidence=state_evidence, + expires_at_epoch_s=expires_at, + grant_kind="access", + refresh_token=None, + ) + return TokenExchangeResult( + access_token=token, + expires_at_epoch_s=expires_at, + provider=IdentityProviderType.LOCAL_IDP, + ) + + def refresh_token( + self, refresh_token: str, subject: PrincipalRef, state_evidence: StateEvidence + ) -> TokenExchangeResult: + expires_at = int(time.time()) + self._config.token_ttl_seconds + token = self._mint_token( + subject=subject, + state_evidence=state_evidence, + expires_at_epoch_s=expires_at, + grant_kind="refresh_access", + refresh_token=refresh_token, + ) + return TokenExchangeResult( + access_token=token, + expires_at_epoch_s=expires_at, + provider=IdentityProviderType.LOCAL_IDP, + ) + + def _mint_token( + self, + subject: PrincipalRef, + state_evidence: StateEvidence, + expires_at_epoch_s: int, + grant_kind: str, + refresh_token: str | None, + ) -> str: + header = {"alg": "HS256", "typ": "JWT", "kid": "predicate-local-idp-dev"} + payload: dict[str, str | int | None] = { + "iss": self._config.issuer, + "aud": self._config.audience, + "sub": subject.principal_id, + "state_hash": state_evidence.state_hash, + "state_source": state_evidence.source, + "token_kind": grant_kind, + "exp": expires_at_epoch_s, + "iat": int(time.time()), + "tenant_id": subject.tenant_id, + "session_id": subject.session_id, + "refresh_token_hash": ( + hashlib.sha256(refresh_token.encode("utf-8")).hexdigest() + if refresh_token is not None + else None + ), + } + header_b64 = _b64url_json(header) + payload_b64 = _b64url_json(payload) + signing_input = f"{header_b64}.{payload_b64}".encode() + signature = hmac.new( + self._config.signing_key.encode("utf-8"), signing_input, hashlib.sha256 + ).digest() + signature_b64 = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("utf-8") + return f"{header_b64}.{payload_b64}.{signature_b64}" + + +def _b64url_json(value: Mapping[str, str | int | None]) -> str: + encoded = json.dumps(value, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(encoded).rstrip(b"=").decode("utf-8") diff --git a/predicate_authority/control_plane.py b/predicate_authority/control_plane.py new file mode 100644 index 0000000..21abd4c --- /dev/null +++ b/predicate_authority/control_plane.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import hashlib +import http.client +import json +import time +from collections.abc import Mapping +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from urllib.parse import urlsplit + +from predicate_contracts import ProofEvent, TraceEmitter + + +@dataclass(frozen=True) +class ControlPlaneClientConfig: + base_url: str + tenant_id: str + project_id: str + auth_token: str | None = None + timeout_s: float = 2.0 + max_retries: int = 2 + backoff_initial_s: float = 0.2 + fail_open: bool = True + + +@dataclass(frozen=True) +class AuditEventEnvelope: + event_id: str + tenant_id: str + principal_id: str + action: str + resource: str + allowed: bool + reason: str + mandate_id: str | None = None + timestamp: str = "" + trace_id: str | None = None + + @staticmethod + def from_proof_event( + event: ProofEvent, tenant_id: str, trace_id: str | None = None + ) -> AuditEventEnvelope: + timestamp = datetime.fromtimestamp(event.emitted_at_epoch_s, tz=timezone.utc).isoformat() + event_id_seed = ( + f"{event.principal_id}|{event.action}|{event.resource}|" + f"{event.emitted_at_epoch_s}|{event.allowed}|{event.reason.value}" + ) + event_id = "evt_" + hashlib.sha256(event_id_seed.encode("utf-8")).hexdigest()[:16] + return AuditEventEnvelope( + event_id=event_id, + tenant_id=tenant_id, + principal_id=event.principal_id, + action=event.action, + resource=event.resource, + allowed=event.allowed, + reason=event.reason.value, + mandate_id=event.mandate_id, + timestamp=timestamp, + trace_id=trace_id, + ) + + +@dataclass(frozen=True) +class UsageCreditRecord: + tenant_id: str + project_id: str + action_type: str + credits: int + timestamp: str + + @staticmethod + def authority_check(tenant_id: str, project_id: str, credits: int = 1) -> UsageCreditRecord: + return UsageCreditRecord( + tenant_id=tenant_id, + project_id=project_id, + action_type="authority_check", + credits=credits, + timestamp=datetime.now(tz=timezone.utc).isoformat(), + ) + + +class ControlPlaneClient: + def __init__(self, config: ControlPlaneClientConfig) -> None: + self.config = config + self._base = urlsplit(config.base_url) + if self._base.scheme not in {"http", "https"}: + raise ValueError("base_url must use http or https scheme") + if self._base.netloc == "": + raise ValueError("base_url must include host:port") + + def send_audit_events(self, events: tuple[AuditEventEnvelope, ...]) -> bool: + payload = {"events": [asdict(event) for event in events]} + return self._post_json("/v1/audit/events:batch", payload) + + def send_usage_records(self, records: tuple[UsageCreditRecord, ...]) -> bool: + payload = {"records": [asdict(record) for record in records]} + return self._post_json("/v1/metering/usage:batch", payload) + + def _post_json(self, path: str, payload: Mapping[str, object]) -> bool: + attempts = self.config.max_retries + 1 + for attempt in range(attempts): + try: + self._post_json_once(path, payload) + return True + except Exception as exc: + is_last_attempt = attempt == attempts - 1 + if is_last_attempt: + if self.config.fail_open: + return False + raise RuntimeError(f"control-plane request failed: {path}") from exc + time.sleep(self.config.backoff_initial_s * (2**attempt)) + return False + + def _post_json_once(self, path: str, payload: Mapping[str, object]) -> None: + target_path = path if path.startswith("/") else f"/{path}" + connection = self._new_connection() + headers = {"Content-Type": "application/json"} + if self.config.auth_token: + headers["Authorization"] = f"Bearer {self.config.auth_token}" + body = json.dumps(payload) + try: + connection.request("POST", target_path, body=body, headers=headers) + response = connection.getresponse() + content = response.read().decode("utf-8") + finally: + connection.close() + if response.status >= 400: + raise RuntimeError(f"HTTP {response.status}: {content}") + + def _new_connection(self) -> http.client.HTTPConnection: + if self._base.scheme == "https": + return http.client.HTTPSConnection(self._base.netloc, timeout=self.config.timeout_s) + return http.client.HTTPConnection(self._base.netloc, timeout=self.config.timeout_s) + + +@dataclass +class ControlPlaneTraceEmitter(TraceEmitter): + client: ControlPlaneClient + trace_id: str | None = None + emit_usage_credits: bool = True + usage_credits_per_decision: int = 1 + audit_push_success_count: int = 0 + audit_push_failure_count: int = 0 + usage_push_success_count: int = 0 + usage_push_failure_count: int = 0 + last_push_error: str | None = None + + def emit(self, event: ProofEvent) -> None: + audit_event = AuditEventEnvelope.from_proof_event( + event=event, tenant_id=self.client.config.tenant_id, trace_id=self.trace_id + ) + self._send_audit_event(audit_event) + if self.emit_usage_credits: + usage = UsageCreditRecord.authority_check( + tenant_id=self.client.config.tenant_id, + project_id=self.client.config.project_id, + credits=self.usage_credits_per_decision, + ) + self._send_usage_record(usage) + + def status_payload(self) -> dict[str, int | str | None]: + return { + "control_plane_audit_push_success_count": self.audit_push_success_count, + "control_plane_audit_push_failure_count": self.audit_push_failure_count, + "control_plane_usage_push_success_count": self.usage_push_success_count, + "control_plane_usage_push_failure_count": self.usage_push_failure_count, + "control_plane_last_push_error": self.last_push_error, + } + + def _send_audit_event(self, audit_event: AuditEventEnvelope) -> None: + try: + sent = self.client.send_audit_events((audit_event,)) + if sent: + self.audit_push_success_count += 1 + self.last_push_error = None + else: + self.audit_push_failure_count += 1 + self.last_push_error = "audit_push_failed" + except Exception as exc: + self.audit_push_failure_count += 1 + self.last_push_error = str(exc) + raise + + def _send_usage_record(self, usage: UsageCreditRecord) -> None: + try: + sent = self.client.send_usage_records((usage,)) + if sent: + self.usage_push_success_count += 1 + self.last_push_error = None + else: + self.usage_push_failure_count += 1 + self.last_push_error = "usage_push_failed" + except Exception as exc: + self.usage_push_failure_count += 1 + self.last_push_error = str(exc) + raise diff --git a/predicate_authority/daemon.py b/predicate_authority/daemon.py index 4bea636..fab4f56 100644 --- a/predicate_authority/daemon.py +++ b/predicate_authority/daemon.py @@ -2,6 +2,7 @@ import argparse import json +import os import secrets import threading import time @@ -11,14 +12,32 @@ from typing import Any from urllib.parse import urlparse -from predicate_authority.bridge import IdentityBridge +from predicate_authority.bridge import ( + EntraBridgeConfig, + EntraIdentityBridge, + IdentityBridge, + LocalIdPBridge, + LocalIdPBridgeConfig, + OIDCBridgeConfig, + OIDCIdentityBridge, +) +from predicate_authority.control_plane import ( + ControlPlaneClient, + ControlPlaneClientConfig, + ControlPlaneTraceEmitter, +) from predicate_authority.guard import ActionGuard from predicate_authority.mandate import LocalMandateSigner from predicate_authority.policy import PolicyEngine from predicate_authority.policy_source import PolicyFileSource from predicate_authority.proof import InMemoryProofLedger from predicate_authority.revocation import LocalRevocationCache -from predicate_authority.sidecar import AuthorityMode, PredicateAuthoritySidecar, SidecarConfig +from predicate_authority.sidecar import ( + AuthorityMode, + ExchangeTokenBridge, + PredicateAuthoritySidecar, + SidecarConfig, +) from predicate_authority.sidecar_store import LocalCredentialStore from predicate_contracts import PolicyRule @@ -30,6 +49,20 @@ class DaemonConfig: policy_poll_interval_s: float = 2.0 +@dataclass(frozen=True) +class ControlPlaneBootstrapConfig: + enabled: bool = False + base_url: str | None = None + tenant_id: str = "dev-tenant" + project_id: str = "dev-project" + auth_token: str | None = None + timeout_s: float = 2.0 + max_retries: int = 2 + backoff_initial_s: float = 0.2 + fail_open: bool = True + usage_credits_per_decision: int = 1 + + @dataclass class DaemonRuntime: started_at_epoch_s: float @@ -213,13 +246,42 @@ def _policy_poll_loop(self) -> None: def _build_default_sidecar( - mode: AuthorityMode, policy_file: str | None, credential_store_file: str + mode: AuthorityMode, + policy_file: str | None, + credential_store_file: str, + control_plane_config: ControlPlaneBootstrapConfig | None = None, + identity_bridge: ExchangeTokenBridge | None = None, ) -> PredicateAuthoritySidecar: policy_rules: tuple[PolicyRule, ...] = () if policy_file is not None and Path(policy_file).exists(): policy_rules = PolicyFileSource(policy_file).load_rules() policy_engine = PolicyEngine(rules=policy_rules) - proof_ledger = InMemoryProofLedger() + + trace_emitter = None + if ( + control_plane_config is not None + and control_plane_config.enabled + and control_plane_config.base_url is not None + ): + control_plane_client = ControlPlaneClient( + config=ControlPlaneClientConfig( + base_url=control_plane_config.base_url, + tenant_id=control_plane_config.tenant_id, + project_id=control_plane_config.project_id, + auth_token=control_plane_config.auth_token, + timeout_s=control_plane_config.timeout_s, + max_retries=control_plane_config.max_retries, + backoff_initial_s=control_plane_config.backoff_initial_s, + fail_open=control_plane_config.fail_open, + ) + ) + trace_emitter = ControlPlaneTraceEmitter( + client=control_plane_client, + emit_usage_credits=True, + usage_credits_per_decision=control_plane_config.usage_credits_per_decision, + ) + proof_ledger = InMemoryProofLedger(trace_emitter=trace_emitter) + guard = ActionGuard( policy_engine=policy_engine, mandate_signer=LocalMandateSigner(secret_key=secrets.token_hex(32)), @@ -229,13 +291,60 @@ def _build_default_sidecar( config=SidecarConfig(mode=mode, policy_file_path=policy_file), action_guard=guard, proof_ledger=proof_ledger, - identity_bridge=IdentityBridge(), + identity_bridge=identity_bridge or IdentityBridge(), credential_store=LocalCredentialStore(credential_store_file), revocation_cache=LocalRevocationCache(), policy_engine=policy_engine, ) +def _build_identity_bridge_from_args(args: argparse.Namespace) -> ExchangeTokenBridge: + mode = str(args.identity_mode) + if mode == "local": + return IdentityBridge(token_ttl_seconds=int(args.idp_token_ttl_s)) + if mode == "local-idp": + signing_key = os.getenv(args.local_idp_signing_key_env, "predicate-local-idp-dev-key") + return LocalIdPBridge( + LocalIdPBridgeConfig( + issuer=str(args.local_idp_issuer), + audience=str(args.local_idp_audience), + signing_key=signing_key, + token_ttl_seconds=int(args.idp_token_ttl_s), + ) + ) + if mode == "oidc": + if args.oidc_issuer is None or args.oidc_client_id is None or args.oidc_audience is None: + raise SystemExit( + "identity-mode=oidc requires --oidc-issuer, --oidc-client-id, and --oidc-audience." + ) + return OIDCIdentityBridge( + OIDCBridgeConfig( + issuer=str(args.oidc_issuer), + client_id=str(args.oidc_client_id), + audience=str(args.oidc_audience), + token_ttl_seconds=int(args.idp_token_ttl_s), + ) + ) + if mode == "entra": + if ( + args.entra_tenant_id is None + or args.entra_client_id is None + or args.entra_audience is None + ): + raise SystemExit( + "identity-mode=entra requires --entra-tenant-id, --entra-client-id, and --entra-audience." + ) + return EntraIdentityBridge( + EntraBridgeConfig( + tenant_id=str(args.entra_tenant_id), + client_id=str(args.entra_client_id), + audience=str(args.entra_audience), + token_ttl_seconds=int(args.idp_token_ttl_s), + ) + ) + raise SystemExit(f"Unsupported identity mode: {mode}") + + def main() -> None: parser = argparse.ArgumentParser(description="predicate-authorityd sidecar daemon") parser.add_argument("--host", default="127.0.0.1") @@ -251,13 +360,109 @@ def main() -> None: "--credential-store-file", default=str(Path.home() / ".predicate-authorityd" / "credentials.json"), ) + parser.add_argument( + "--identity-mode", + choices=["local", "local-idp", "oidc", "entra"], + default="local", + help="Identity source for token exchange: local, local-idp, oidc, or entra.", + ) + parser.add_argument("--idp-token-ttl-s", type=int, default=300) + parser.add_argument( + "--local-idp-issuer", + default=os.getenv("LOCAL_IDP_ISSUER", "http://localhost/predicate-local-idp"), + ) + parser.add_argument( + "--local-idp-audience", + default=os.getenv("LOCAL_IDP_AUDIENCE", "api://predicate-authority"), + ) + parser.add_argument( + "--local-idp-signing-key-env", + default="LOCAL_IDP_SIGNING_KEY", + help="Env var name for Local IdP signing key.", + ) + parser.add_argument("--oidc-issuer", default=os.getenv("OIDC_ISSUER")) + parser.add_argument("--oidc-client-id", default=os.getenv("OIDC_CLIENT_ID")) + parser.add_argument("--oidc-audience", default=os.getenv("OIDC_AUDIENCE")) + parser.add_argument("--entra-tenant-id", default=os.getenv("ENTRA_TENANT_ID")) + parser.add_argument("--entra-client-id", default=os.getenv("ENTRA_CLIENT_ID")) + parser.add_argument("--entra-audience", default=os.getenv("ENTRA_AUDIENCE")) + parser.add_argument( + "--control-plane-enabled", + action="store_true", + help="Enable control-plane audit/usage shipping via trace emitter.", + ) + parser.add_argument( + "--control-plane-url", + default=None, + help="Control plane base URL (e.g. https://authority.example.com).", + ) + parser.add_argument( + "--control-plane-tenant-id", + default=None, + help="Tenant ID for emitted audit/usage records.", + ) + parser.add_argument( + "--control-plane-project-id", + default=None, + help="Project ID for emitted usage records.", + ) + parser.add_argument( + "--control-plane-auth-token-env", + default="CONTROL_PLANE_AUTH_TOKEN", + help="Env var name that stores Bearer token for control-plane APIs.", + ) + parser.add_argument("--control-plane-timeout-s", type=float, default=2.0) + parser.add_argument("--control-plane-max-retries", type=int, default=2) + parser.add_argument("--control-plane-backoff-initial-s", type=float, default=0.2) + parser.add_argument( + "--control-plane-fail-open", + action="store_true", + help="If true, local authorization continues when control-plane push fails.", + ) + parser.add_argument( + "--control-plane-fail-closed", + dest="control_plane_fail_open", + action="store_false", + help="If set, control-plane push failures become hard errors.", + ) + parser.set_defaults(control_plane_fail_open=True) + parser.add_argument("--control-plane-usage-credits-per-decision", type=int, default=1) args = parser.parse_args() mode = AuthorityMode(args.mode) + control_plane_auth_token = os.getenv(args.control_plane_auth_token_env) + control_plane_url = args.control_plane_url or os.getenv("CONTROL_PLANE_URL") + control_plane_tenant = args.control_plane_tenant_id or os.getenv( + "CONTROL_PLANE_TENANT_ID", "dev-tenant" + ) + control_plane_project = args.control_plane_project_id or os.getenv( + "CONTROL_PLANE_PROJECT_ID", "dev-project" + ) + control_plane_enabled = bool(args.control_plane_enabled) + if control_plane_enabled and (control_plane_url is None or control_plane_url.strip() == ""): + raise SystemExit( + "control-plane is enabled but no URL provided. " + "Set --control-plane-url or CONTROL_PLANE_URL." + ) + control_plane_bootstrap = ControlPlaneBootstrapConfig( + enabled=control_plane_enabled, + base_url=control_plane_url, + tenant_id=control_plane_tenant, + project_id=control_plane_project, + auth_token=control_plane_auth_token, + timeout_s=args.control_plane_timeout_s, + max_retries=args.control_plane_max_retries, + backoff_initial_s=args.control_plane_backoff_initial_s, + fail_open=bool(args.control_plane_fail_open), + usage_credits_per_decision=max(0, int(args.control_plane_usage_credits_per_decision)), + ) + identity_bridge = _build_identity_bridge_from_args(args) sidecar = _build_default_sidecar( mode=mode, policy_file=args.policy_file, credential_store_file=args.credential_store_file, + control_plane_config=control_plane_bootstrap, + identity_bridge=identity_bridge, ) daemon = PredicateAuthorityDaemon( sidecar=sidecar, @@ -270,7 +475,8 @@ def main() -> None: daemon.start() print( f"predicate-authorityd listening on http://{args.host}:{daemon.bound_port} " - f"(mode={mode.value})" + f"(mode={mode.value}, identity_mode={args.identity_mode}, " + f"control_plane_enabled={control_plane_enabled})" ) try: while True: diff --git a/predicate_authority/sidecar.py b/predicate_authority/sidecar.py index a1fa77a..020cec7 100644 --- a/predicate_authority/sidecar.py +++ b/predicate_authority/sidecar.py @@ -5,6 +5,7 @@ from typing import Protocol, cast from predicate_authority.bridge import TokenExchangeResult +from predicate_authority.control_plane import ControlPlaneTraceEmitter from predicate_authority.guard import ActionGuard from predicate_authority.policy import PolicyEngine from predicate_authority.policy_source import PolicyFileSource @@ -39,6 +40,12 @@ class SidecarStatus: revoked_intent_count: int revoked_mandate_count: int proof_event_count: int + control_plane_emitter_attached: bool + control_plane_audit_push_success_count: int = 0 + control_plane_audit_push_failure_count: int = 0 + control_plane_usage_push_success_count: int = 0 + control_plane_usage_push_failure_count: int = 0 + control_plane_last_push_error: str | None = None class SidecarError(RuntimeError): @@ -144,6 +151,12 @@ def hot_reload_policy(self) -> bool: return False def status(self) -> SidecarStatus: + trace_emitter = self._proof_ledger.trace_emitter + control_plane_payload: dict[str, int | str | None] = {} + control_plane_attached = False + if isinstance(trace_emitter, ControlPlaneTraceEmitter): + control_plane_attached = True + control_plane_payload = trace_emitter.status_payload() return SidecarStatus( mode=self._config.mode, policy_hot_reload_enabled=self._policy_source is not None, @@ -151,4 +164,22 @@ def status(self) -> SidecarStatus: revoked_intent_count=len(self._revocation_cache.revoked_intent_hashes), revoked_mandate_count=len(self._revocation_cache.revoked_mandate_ids), proof_event_count=len(self._proof_ledger.events), + control_plane_emitter_attached=control_plane_attached, + control_plane_audit_push_success_count=int( + control_plane_payload.get("control_plane_audit_push_success_count", 0) + ), + control_plane_audit_push_failure_count=int( + control_plane_payload.get("control_plane_audit_push_failure_count", 0) + ), + control_plane_usage_push_success_count=int( + control_plane_payload.get("control_plane_usage_push_success_count", 0) + ), + control_plane_usage_push_failure_count=int( + control_plane_payload.get("control_plane_usage_push_failure_count", 0) + ), + control_plane_last_push_error=( + str(control_plane_payload["control_plane_last_push_error"]) + if control_plane_payload.get("control_plane_last_push_error") is not None + else None + ), ) diff --git a/tests/test_control_plane_integration.py b/tests/test_control_plane_integration.py new file mode 100644 index 0000000..44fe588 --- /dev/null +++ b/tests/test_control_plane_integration.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import json +import threading +from dataclasses import dataclass, field +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any + +from predicate_authority.control_plane import ( + AuditEventEnvelope, + ControlPlaneClient, + ControlPlaneClientConfig, + ControlPlaneTraceEmitter, + UsageCreditRecord, +) +from predicate_contracts import AuthorizationReason, ProofEvent + + +@dataclass +class Recorder: + paths: list[str] = field(default_factory=list) + payloads: list[dict[str, Any]] = field(default_factory=list) + headers: list[dict[str, str]] = field(default_factory=list) + + +class _Handler(BaseHTTPRequestHandler): + recorder: Recorder + + def do_POST(self) -> None: # noqa: N802 + raw_length = self.headers.get("Content-Length", "0") + content_length = int(raw_length) if raw_length.isdigit() else 0 + content = self.rfile.read(content_length).decode("utf-8") if content_length > 0 else "{}" + payload = json.loads(content) + assert isinstance(payload, dict) + self.recorder.paths.append(self.path) + self.recorder.payloads.append(payload) + self.recorder.headers.append(dict(self.headers.items())) + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b"{}") + + def log_message(self, format: str, *args: Any) -> None: # noqa: A003 + return + + +def _start_server(recorder: Recorder) -> tuple[ThreadingHTTPServer, threading.Thread]: + class BoundHandler(_Handler): + pass + + BoundHandler.recorder = recorder + server = ThreadingHTTPServer(("127.0.0.1", 0), BoundHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, thread + + +def test_control_plane_client_posts_audit_and_usage() -> None: + recorder = Recorder() + server, _ = _start_server(recorder) + try: + base_url = f"http://127.0.0.1:{server.server_port}" + client = ControlPlaneClient( + ControlPlaneClientConfig( + base_url=base_url, + tenant_id="tenant-a", + project_id="project-a", + auth_token="token-123", + fail_open=False, + ) + ) + sent_audit = client.send_audit_events( + ( + AuditEventEnvelope( + event_id="evt_1", + tenant_id="tenant-a", + principal_id="agent:orders-1", + action="http.post", + resource="https://api.vendor.com/orders", + allowed=True, + reason="allowed", + timestamp="2026-01-01T00:00:00+00:00", + ), + ) + ) + sent_usage = client.send_usage_records( + ( + UsageCreditRecord( + tenant_id="tenant-a", + project_id="project-a", + action_type="authority_check", + credits=1, + timestamp="2026-01-01T00:00:00+00:00", + ), + ) + ) + assert sent_audit is True + assert sent_usage is True + assert "/v1/audit/events:batch" in recorder.paths + assert "/v1/metering/usage:batch" in recorder.paths + assert any( + headers.get("Authorization") == "Bearer token-123" for headers in recorder.headers + ) + finally: + server.shutdown() + server.server_close() + + +def test_control_plane_trace_emitter_sends_from_proof_event() -> None: + recorder = Recorder() + server, _ = _start_server(recorder) + try: + client = ControlPlaneClient( + ControlPlaneClientConfig( + base_url=f"http://127.0.0.1:{server.server_port}", + tenant_id="tenant-z", + project_id="project-z", + fail_open=False, + ) + ) + emitter = ControlPlaneTraceEmitter(client=client, trace_id="trace-1") + event = ProofEvent( + event_type="authority.decision", + principal_id="agent:test", + action="http.post", + resource="https://api.vendor.com/orders", + reason=AuthorizationReason.ALLOWED, + allowed=True, + mandate_id="mandate-1", + emitted_at_epoch_s=1_700_000_000, + ) + emitter.emit(event) + assert len(recorder.paths) == 2 + assert recorder.paths[0] == "/v1/audit/events:batch" + assert recorder.paths[1] == "/v1/metering/usage:batch" + events_payload = recorder.payloads[0]["events"] + assert isinstance(events_payload, list) + assert events_payload[0]["tenant_id"] == "tenant-z" + finally: + server.shutdown() + server.server_close() + + +def test_control_plane_client_fail_open_returns_false() -> None: + client = ControlPlaneClient( + ControlPlaneClientConfig( + base_url="http://127.0.0.1:65531", + tenant_id="tenant-a", + project_id="project-a", + max_retries=0, + fail_open=True, + ) + ) + result = client.send_audit_events( + ( + AuditEventEnvelope( + event_id="evt_1", + tenant_id="tenant-a", + principal_id="agent:1", + action="http.post", + resource="https://api.vendor.com/orders", + allowed=True, + reason="allowed", + timestamp="2026-01-01T00:00:00+00:00", + ), + ) + ) + assert result is False diff --git a/tests/test_daemon_phase2.py b/tests/test_daemon_phase2.py index e815529..153593d 100644 --- a/tests/test_daemon_phase2.py +++ b/tests/test_daemon_phase2.py @@ -2,8 +2,13 @@ import http.client import json +import os +import threading import time +from argparse import Namespace +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path +from typing import Any from urllib.parse import urlsplit from predicate_authority import ( @@ -20,7 +25,20 @@ PredicateAuthoritySidecar, SidecarConfig, ) -from predicate_contracts import PolicyEffect, PolicyRule +from predicate_authority.daemon import ( + ControlPlaneBootstrapConfig, + _build_default_sidecar, + _build_identity_bridge_from_args, +) +from predicate_contracts import ( + ActionRequest, + ActionSpec, + PolicyEffect, + PolicyRule, + PrincipalRef, + StateEvidence, + VerificationEvidence, +) # pylint: disable=import-error @@ -188,3 +206,131 @@ def test_daemon_supports_policy_reload_and_revoke_endpoints(tmp_path: Path) -> N assert int(status["revoked_intent_count"]) >= 1 finally: daemon.stop() + + +class _ControlPlaneHandler(BaseHTTPRequestHandler): + requests: list[tuple[str, dict[str, object], dict[str, str]]] + + def do_POST(self) -> None: # noqa: N802 + raw_length = self.headers.get("Content-Length", "0") + content_length = int(raw_length) if raw_length.isdigit() else 0 + payload_raw = ( + self.rfile.read(content_length).decode("utf-8") if content_length > 0 else "{}" + ) + loaded = json.loads(payload_raw) + assert isinstance(loaded, dict) + self.requests.append((self.path, loaded, dict(self.headers.items()))) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b"{}") + + def log_message(self, format: str, *args: Any) -> None: # noqa: A003 + return + + +def _start_control_plane_server() -> tuple[ThreadingHTTPServer, threading.Thread]: + class BoundHandler(_ControlPlaneHandler): + requests = [] + + server = ThreadingHTTPServer(("127.0.0.1", 0), BoundHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, thread + + +def test_daemon_bootstrap_wires_control_plane_emitter(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text( + json.dumps( + { + "rules": [ + { + "name": "allow-any-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + } + ), + encoding="utf-8", + ) + server, _ = _start_control_plane_server() + daemon: PredicateAuthorityDaemon | None = None + try: + base_url = f"http://127.0.0.1:{server.server_port}" + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + control_plane_config=ControlPlaneBootstrapConfig( + enabled=True, + base_url=base_url, + tenant_id="tenant-a", + project_id="project-a", + auth_token="test-token", + fail_open=False, + ), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=0.2), + ) + daemon.start() + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:test"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="create order", + ), + state_evidence=StateEvidence(source="test", state_hash="abc123"), + verification_evidence=VerificationEvidence(), + ) + decision = sidecar.issue_mandate(request) + assert decision.allowed is True + # Emitter sends both audit and usage payloads. + handler_cls = server.RequestHandlerClass + requests = getattr(handler_cls, "requests") + assert isinstance(requests, list) + paths = [item[0] for item in requests] + assert "/v1/audit/events:batch" in paths + assert "/v1/metering/usage:batch" in paths + assert any(item[2].get("Authorization") == "Bearer test-token" for item in requests) + daemon_status = _fetch_json(f"http://127.0.0.1:{daemon.bound_port}/status") + assert daemon_status["control_plane_emitter_attached"] is True + assert int(daemon_status["control_plane_audit_push_success_count"]) >= 1 + assert int(daemon_status["control_plane_usage_push_success_count"]) >= 1 + assert int(daemon_status["control_plane_audit_push_failure_count"]) == 0 + assert int(daemon_status["control_plane_usage_push_failure_count"]) == 0 + finally: + if daemon is not None: + daemon.stop() + server.shutdown() + server.server_close() + + +def test_daemon_identity_mode_local_idp_builder() -> None: + os.environ["LOCAL_IDP_SIGNING_KEY"] = "daemon-local-idp-key" + args = Namespace( + identity_mode="local-idp", + idp_token_ttl_s=120, + local_idp_issuer="http://localhost/local-idp", + local_idp_audience="api://predicate-authority", + local_idp_signing_key_env="LOCAL_IDP_SIGNING_KEY", + oidc_issuer=None, + oidc_client_id=None, + oidc_audience=None, + entra_tenant_id=None, + entra_client_id=None, + entra_audience=None, + ) + bridge = _build_identity_bridge_from_args(args) + token = bridge.exchange_token( + PrincipalRef(principal_id="agent:test"), + StateEvidence(source="test", state_hash="state-1"), + ) + assert token.provider.value == "local_idp" + assert len(token.access_token.split(".")) == 3 diff --git a/tests/test_identity_bridge_phase2.py b/tests/test_identity_bridge_phase2.py index 5192ada..fa4d11c 100644 --- a/tests/test_identity_bridge_phase2.py +++ b/tests/test_identity_bridge_phase2.py @@ -1,8 +1,13 @@ from __future__ import annotations +import base64 +import json + from predicate_authority import ( EntraBridgeConfig, EntraIdentityBridge, + LocalIdPBridge, + LocalIdPBridgeConfig, OIDCBridgeConfig, OIDCIdentityBridge, ) @@ -44,3 +49,40 @@ def test_entra_bridge_marks_provider() -> None: result = bridge.exchange_token(subject, state) assert result.provider.value == "entra" + + +def test_local_idp_bridge_issues_jwt_like_token() -> None: + bridge = LocalIdPBridge( + LocalIdPBridgeConfig( + issuer="http://localhost/local-idp", + audience="api://predicate-authority", + signing_key="dev-signing-key", + token_ttl_seconds=120, + ) + ) + subject = PrincipalRef(principal_id="agent:local", tenant_id="tenant-a") + state = StateEvidence(source="backend", state_hash="state-abc") + + token_result = bridge.exchange_token(subject, state) + token = token_result.access_token + segments = token.split(".") + assert len(segments) == 3 + payload = _decode_jwt_payload(segments[1]) + assert payload["iss"] == "http://localhost/local-idp" + assert payload["aud"] == "api://predicate-authority" + assert payload["sub"] == "agent:local" + assert payload["state_hash"] == "state-abc" + assert token_result.provider.value == "local_idp" + + refreshed = bridge.refresh_token("refresh-123", subject, state) + refreshed_payload = _decode_jwt_payload(refreshed.access_token.split(".")[1]) + assert refreshed_payload["token_kind"] == "refresh_access" + + +def _decode_jwt_payload(payload_segment: str) -> dict[str, object]: + # Pad URL-safe base64 to standard length. + padding = "=" * (-len(payload_segment) % 4) + decoded = base64.urlsafe_b64decode((payload_segment + padding).encode("utf-8")) + loaded = json.loads(decoded.decode("utf-8")) + assert isinstance(loaded, dict) + return loaded From 347f22459eb7559c66b7d14940446ff5bdcdc22e Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Mon, 16 Feb 2026 22:08:14 -0800 Subject: [PATCH 2/2] local identity --- README.md | 40 +++ docs/authorityd-operations.md | 63 ++++ predicate_authority/README.md | 3 +- predicate_authority/__init__.py | 14 + predicate_authority/control_plane.py | 3 + predicate_authority/daemon.py | 451 ++++++++++++++++++++++++-- predicate_authority/local_identity.py | 397 +++++++++++++++++++++++ predicate_authority/sidecar.py | 31 ++ tests/test_daemon_phase2.py | 345 +++++++++++++++++++- tests/test_local_identity_registry.py | 63 ++++ 10 files changed, 1381 insertions(+), 29 deletions(-) create mode 100644 predicate_authority/local_identity.py create mode 100644 tests/test_local_identity_registry.py diff --git a/README.md b/README.md index ba3325c..4479e40 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,46 @@ predicate-authorityd \ --local-idp-audience "api://predicate-authority" ``` +### Local identity registry (ephemeral + TTL + flush queue) + +Enable sidecar-managed local task identities and local ledger queue: + +```bash +PYTHONPATH=. predicate-authorityd \ + --host 127.0.0.1 \ + --port 8787 \ + --mode local_only \ + --policy-file examples/authorityd/policy.json \ + --identity-mode local-idp \ + --local-identity-enabled \ + --local-identity-registry-file ./.predicate-authorityd/local-identities.json \ + --local-identity-default-ttl-s 900 \ + --flush-worker-enabled \ + --flush-worker-interval-s 2.0 \ + --flush-worker-max-batch-size 50 \ + --flush-worker-dead-letter-max-attempts 5 +``` + +Runtime endpoints: + +- `POST /identity/task` (issue ephemeral task identity) +- `GET /identity/list` (list identities) +- `POST /identity/revoke` (revoke identity) +- `GET /ledger/flush-queue` (inspect pending local ledger queue) +- `GET /ledger/dead-letter` (list quarantined queue items only) +- `POST /ledger/flush-ack` (mark queue item as flushed) +- `POST /ledger/flush-now` (manually trigger immediate queue flush) +- `POST /ledger/requeue` (requeue quarantined item for retry) + +Background flush worker status fields: + +- `flush_cycle_count` +- `flush_sent_count` +- `flush_failed_count` +- `flush_quarantined_count` +- `last_flush_epoch_s` +- `last_flush_error` + ### How to run with control-plane shipping (out-of-the-box) ```bash diff --git a/docs/authorityd-operations.md b/docs/authorityd-operations.md index 6d1d149..f991ce9 100644 --- a/docs/authorityd-operations.md +++ b/docs/authorityd-operations.md @@ -69,6 +69,69 @@ authority decision pushes: - audit events -> `/v1/audit/events:batch` - usage credits -> `/v1/metering/usage:batch` +## 3b) Optional local identity registry (ephemeral task identities) + +Enable local identity support: + +```bash +PYTHONPATH=. predicate-authorityd \ + --host 127.0.0.1 \ + --port 8787 \ + --mode local_only \ + --policy-file examples/authorityd/policy.json \ + --identity-mode local-idp \ + --local-identity-enabled \ + --local-identity-registry-file ./.predicate-authorityd/local-identities.json \ + --local-identity-default-ttl-s 900 \ + --flush-worker-enabled \ + --flush-worker-interval-s 2.0 \ + --flush-worker-max-batch-size 50 \ + --flush-worker-dead-letter-max-attempts 5 +``` + +Issue an ephemeral identity: + +```bash +curl -s -X POST http://127.0.0.1:8787/identity/task \ + -H "Content-Type: application/json" \ + -d '{"principal_id":"agent:backend","task_id":"refactor-pr-102","ttl_seconds":120}' +``` + +Inspect pending local ledger flush queue: + +```bash +curl -s http://127.0.0.1:8787/ledger/flush-queue | jq +``` + +List quarantined dead-letter items only: + +```bash +curl -s http://127.0.0.1:8787/ledger/dead-letter | jq +``` + +Manually trigger an immediate flush cycle: + +```bash +curl -s -X POST http://127.0.0.1:8787/ledger/flush-now \ + -H "Content-Type: application/json" \ + -d '{"max_items":50}' | jq +``` + +Requeue a quarantined item for retry: + +```bash +curl -s -X POST http://127.0.0.1:8787/ledger/requeue \ + -H "Content-Type: application/json" \ + -d '{"queue_item_id":"q_abc123"}' | jq +``` + +Flush worker behavior: + +- reuses control-plane client retry policy (`--control-plane-max-retries`, `--control-plane-backoff-initial-s`), +- drains up to `--flush-worker-max-batch-size` queue items per cycle, +- quarantines entries after `--flush-worker-dead-letter-max-attempts` failed sends, +- sleeps `--flush-worker-interval-s` between flush cycles. + Expected startup output: ```text diff --git a/predicate_authority/README.md b/predicate_authority/README.md index ebc5797..d7ee07b 100644 --- a/predicate_authority/README.md +++ b/predicate_authority/README.md @@ -9,4 +9,5 @@ Core pieces: - `LocalMandateSigner` for signed short-lived mandates, - `InMemoryProofLedger` and optional `OpenTelemetryTraceEmitter`, - typed integration adapters (including `sdk-python` mapping helpers), -- control-plane client primitives for shipping proof and usage batches to hosted APIs. +- control-plane client primitives for shipping proof and usage batches to hosted APIs, +- local identity registry primitives (ephemeral task identities + local flush queue). diff --git a/predicate_authority/__init__.py b/predicate_authority/__init__.py index 1e569ac..64800e9 100644 --- a/predicate_authority/__init__.py +++ b/predicate_authority/__init__.py @@ -19,6 +19,14 @@ from predicate_authority.daemon import DaemonConfig, PredicateAuthorityDaemon from predicate_authority.errors import AuthorizationDeniedError from predicate_authority.guard import ActionExecutionResult, ActionGuard +from predicate_authority.local_identity import ( + CompositeTraceEmitter, + LedgerQueueItem, + LocalIdentityRegistry, + LocalIdentityRegistryStats, + LocalLedgerQueueEmitter, + TaskIdentityRecord, +) from predicate_authority.mandate import LocalMandateSigner from predicate_authority.policy import PolicyEngine, PolicyMatchResult from predicate_authority.policy_source import PolicyFileSource, PolicyReloadResult @@ -53,6 +61,9 @@ "LocalIdPBridge", "LocalIdPBridgeConfig", "LocalCredentialStore", + "LocalIdentityRegistry", + "LocalIdentityRegistryStats", + "LocalLedgerQueueEmitter", "LocalMandateSigner", "LocalRevocationCache", "OIDCBridgeConfig", @@ -68,5 +79,8 @@ "SidecarError", "SidecarStatus", "TokenExchangeResult", + "CompositeTraceEmitter", + "LedgerQueueItem", + "TaskIdentityRecord", "UsageCreditRecord", ] diff --git a/predicate_authority/control_plane.py b/predicate_authority/control_plane.py index 21abd4c..9b04909 100644 --- a/predicate_authority/control_plane.py +++ b/predicate_authority/control_plane.py @@ -97,6 +97,9 @@ def send_usage_records(self, records: tuple[UsageCreditRecord, ...]) -> bool: payload = {"records": [asdict(record) for record in records]} return self._post_json("/v1/metering/usage:batch", payload) + def send_audit_payload(self, payload: Mapping[str, object]) -> bool: + return self._post_json("/v1/audit/events:batch", payload) + def _post_json(self, path: str, payload: Mapping[str, object]) -> bool: attempts = self.config.max_retries + 1 for attempt in range(attempts): diff --git a/predicate_authority/daemon.py b/predicate_authority/daemon.py index fab4f56..405146f 100644 --- a/predicate_authority/daemon.py +++ b/predicate_authority/daemon.py @@ -27,6 +27,12 @@ ControlPlaneTraceEmitter, ) from predicate_authority.guard import ActionGuard +from predicate_authority.local_identity import ( + CompositeTraceEmitter, + LedgerQueueItem, + LocalIdentityRegistry, + LocalLedgerQueueEmitter, +) from predicate_authority.mandate import LocalMandateSigner from predicate_authority.policy import PolicyEngine from predicate_authority.policy_source import PolicyFileSource @@ -39,7 +45,7 @@ SidecarConfig, ) from predicate_authority.sidecar_store import LocalCredentialStore -from predicate_contracts import PolicyRule +from predicate_contracts import PolicyRule, TraceEmitter @dataclass(frozen=True) @@ -63,6 +69,21 @@ class ControlPlaneBootstrapConfig: usage_credits_per_decision: int = 1 +@dataclass(frozen=True) +class LocalIdentityBootstrapConfig: + enabled: bool = False + registry_file_path: str | None = None + default_ttl_seconds: int = 900 + + +@dataclass(frozen=True) +class FlushWorkerConfig: + enabled: bool = True + interval_s: float = 2.0 + max_batch_size: int = 50 + dead_letter_max_attempts: int = 5 + + @dataclass class DaemonRuntime: started_at_epoch_s: float @@ -71,6 +92,20 @@ class DaemonRuntime: policy_poll_error_count: int = 0 last_policy_reload_epoch_s: float | None = None last_policy_poll_error: str | None = None + flush_cycle_count: int = 0 + flush_sent_count: int = 0 + flush_failed_count: int = 0 + flush_quarantined_count: int = 0 + last_flush_epoch_s: float | None = None + last_flush_error: str | None = None + + +@dataclass(frozen=True) +class FlushCycleResult: + scanned_count: int = 0 + sent_count: int = 0 + failed_count: int = 0 + quarantined_count: int = 0 class _DaemonHTTPServer(ThreadingHTTPServer): @@ -95,33 +130,136 @@ def do_GET(self) -> None: # noqa: N802 if parsed.path == "/status": self._send_json(200, self.server.daemon_ref.status_payload()) # type: ignore[attr-defined] return + if parsed.path == "/identity/list": + active_only = True + query = urlparse(self.path).query + if "active_only=false" in query: + active_only = False + payload = self.server.daemon_ref.list_task_identities(active_only=active_only) # type: ignore[attr-defined] + self._send_json(200, {"items": payload, "active_only": active_only}) + return + if parsed.path == "/ledger/flush-queue": + query = parsed.query + include_flushed = "include_flushed=true" in query + include_quarantined = "include_quarantined=true" in query + payload = self.server.daemon_ref.list_flush_queue( # type: ignore[attr-defined] + include_flushed=include_flushed, + include_quarantined=include_quarantined, + ) + self._send_json( + 200, + { + "items": payload, + "include_flushed": include_flushed, + "include_quarantined": include_quarantined, + }, + ) + return + if parsed.path == "/ledger/dead-letter": + payload = self.server.daemon_ref.list_dead_letter_queue() # type: ignore[attr-defined] + self._send_json(200, {"items": payload}) + return self._send_json(404, {"error": "not_found"}) def do_POST(self) -> None: # noqa: N802 parsed = urlparse(self.path) - if parsed.path == "/policy/reload": - reloaded = self.server.daemon_ref.reload_policy_now() # type: ignore[attr-defined] - self._send_json(200, {"reloaded": reloaded}) + handlers: dict[str, Any] = { + "/policy/reload": self._handle_policy_reload, + "/revoke/principal": self._handle_revoke_principal, + "/revoke/intent": self._handle_revoke_intent, + "/identity/task": self._handle_identity_task, + "/identity/revoke": self._handle_identity_revoke, + "/ledger/flush-ack": self._handle_ledger_flush_ack, + "/ledger/flush-now": self._handle_ledger_flush_now, + "/ledger/requeue": self._handle_ledger_requeue, + } + handler = handlers.get(parsed.path) + if handler is None: + self._send_json(404, {"error": "not_found"}) return - if parsed.path == "/revoke/principal": - payload = self._read_json_body() - principal_id = payload.get("principal_id") - if not isinstance(principal_id, str) or principal_id.strip() == "": - self._send_json(400, {"error": "principal_id is required"}) - return - self.server.daemon_ref.revoke_principal(principal_id.strip()) # type: ignore[attr-defined] - self._send_json(200, {"ok": True, "principal_id": principal_id.strip()}) + handler() + + def _handle_policy_reload(self) -> None: + reloaded = self.server.daemon_ref.reload_policy_now() # type: ignore[attr-defined] + self._send_json(200, {"reloaded": reloaded}) + + def _handle_revoke_principal(self) -> None: + payload = self._read_json_body() + principal_id = payload.get("principal_id") + if not isinstance(principal_id, str) or principal_id.strip() == "": + self._send_json(400, {"error": "principal_id is required"}) return - if parsed.path == "/revoke/intent": - payload = self._read_json_body() - intent_hash = payload.get("intent_hash") - if not isinstance(intent_hash, str) or intent_hash.strip() == "": - self._send_json(400, {"error": "intent_hash is required"}) - return - self.server.daemon_ref.revoke_intent(intent_hash.strip()) # type: ignore[attr-defined] - self._send_json(200, {"ok": True, "intent_hash": intent_hash.strip()}) + self.server.daemon_ref.revoke_principal(principal_id.strip()) # type: ignore[attr-defined] + self._send_json(200, {"ok": True, "principal_id": principal_id.strip()}) + + def _handle_revoke_intent(self) -> None: + payload = self._read_json_body() + intent_hash = payload.get("intent_hash") + if not isinstance(intent_hash, str) or intent_hash.strip() == "": + self._send_json(400, {"error": "intent_hash is required"}) return - self._send_json(404, {"error": "not_found"}) + self.server.daemon_ref.revoke_intent(intent_hash.strip()) # type: ignore[attr-defined] + self._send_json(200, {"ok": True, "intent_hash": intent_hash.strip()}) + + def _handle_identity_task(self) -> None: + payload = self._read_json_body() + principal_id = payload.get("principal_id") + task_id = payload.get("task_id") + ttl = payload.get("ttl_seconds") + metadata = payload.get("metadata") + if not isinstance(principal_id, str) or principal_id.strip() == "": + self._send_json(400, {"error": "principal_id is required"}) + return + if not isinstance(task_id, str) or task_id.strip() == "": + self._send_json(400, {"error": "task_id is required"}) + return + ttl_value = int(ttl) if isinstance(ttl, (int, str)) else None + metadata_dict = metadata if isinstance(metadata, dict) else None + try: + created = self.server.daemon_ref.issue_task_identity( # type: ignore[attr-defined] + principal_id=principal_id.strip(), + task_id=task_id.strip(), + ttl_seconds=ttl_value, + metadata=metadata_dict, + ) + except RuntimeError as exc: + self._send_json(400, {"error": str(exc)}) + return + self._send_json(200, created) + + def _handle_identity_revoke(self) -> None: + payload = self._read_json_body() + identity_id = payload.get("identity_id") + if not isinstance(identity_id, str) or identity_id.strip() == "": + self._send_json(400, {"error": "identity_id is required"}) + return + ok = self.server.daemon_ref.revoke_task_identity(identity_id.strip()) # type: ignore[attr-defined] + self._send_json(200, {"ok": ok, "identity_id": identity_id.strip()}) + + def _handle_ledger_flush_ack(self) -> None: + payload = self._read_json_body() + queue_item_id = payload.get("queue_item_id") + if not isinstance(queue_item_id, str) or queue_item_id.strip() == "": + self._send_json(400, {"error": "queue_item_id is required"}) + return + ok = self.server.daemon_ref.ack_flush_queue_item(queue_item_id.strip()) # type: ignore[attr-defined] + self._send_json(200, {"ok": ok, "queue_item_id": queue_item_id.strip()}) + + def _handle_ledger_flush_now(self) -> None: + payload = self._read_json_body() + max_items_raw = payload.get("max_items") + max_items = int(max_items_raw) if isinstance(max_items_raw, (int, str)) else None + result = self.server.daemon_ref.flush_queue_now(max_items=max_items) # type: ignore[attr-defined] + self._send_json(200, result) + + def _handle_ledger_requeue(self) -> None: + payload = self._read_json_body() + queue_item_id = payload.get("queue_item_id") + if not isinstance(queue_item_id, str) or queue_item_id.strip() == "": + self._send_json(400, {"error": "queue_item_id is required"}) + return + ok = self.server.daemon_ref.requeue_dead_letter_item(queue_item_id.strip()) # type: ignore[attr-defined] + self._send_json(200, {"ok": ok, "queue_item_id": queue_item_id.strip()}) def log_message(self, format: str, *args: Any) -> None: # noqa: A003 # Keep daemon output deterministic and quiet by default. @@ -154,14 +292,21 @@ def _send_json(self, code: int, payload: dict[str, Any]) -> None: class PredicateAuthorityDaemon: - def __init__(self, sidecar: PredicateAuthoritySidecar, config: DaemonConfig) -> None: + def __init__( + self, + sidecar: PredicateAuthoritySidecar, + config: DaemonConfig, + flush_worker: FlushWorkerConfig | None = None, + ) -> None: self._sidecar = sidecar self._config = config + self._flush_worker = flush_worker or FlushWorkerConfig() self._runtime = DaemonRuntime(started_at_epoch_s=time.time()) self._stop_event = threading.Event() self._http_server: _DaemonHTTPServer | None = None self._server_thread: threading.Thread | None = None self._poll_thread: threading.Thread | None = None + self._flush_thread: threading.Thread | None = None @property def bound_port(self) -> int: @@ -180,8 +325,10 @@ def start(self) -> None: ) self._server_thread = threading.Thread(target=self._http_server.serve_forever, daemon=True) self._poll_thread = threading.Thread(target=self._policy_poll_loop, daemon=True) + self._flush_thread = threading.Thread(target=self._flush_queue_loop, daemon=True) self._server_thread.start() self._poll_thread.start() + self._flush_thread.start() def stop(self) -> None: if not self._runtime.is_running: @@ -195,6 +342,8 @@ def stop(self) -> None: self._server_thread.join(timeout=3.0) if self._poll_thread is not None: self._poll_thread.join(timeout=3.0) + if self._flush_thread is not None: + self._flush_thread.join(timeout=3.0) def health_payload(self) -> dict[str, Any]: uptime_s = int(max(0, time.time() - self._runtime.started_at_epoch_s)) @@ -215,6 +364,13 @@ def status_payload(self) -> dict[str, Any]: "policy_poll_error_count": self._runtime.policy_poll_error_count, "last_policy_reload_epoch_s": self._runtime.last_policy_reload_epoch_s, "last_policy_poll_error": self._runtime.last_policy_poll_error, + "flush_cycle_count": self._runtime.flush_cycle_count, + "flush_sent_count": self._runtime.flush_sent_count, + "flush_failed_count": self._runtime.flush_failed_count, + "flush_quarantined_count": self._runtime.flush_quarantined_count, + "last_flush_epoch_s": self._runtime.last_flush_epoch_s, + "last_flush_error": self._runtime.last_flush_error, + "dead_letter_max_attempts": self._flush_worker.dead_letter_max_attempts, } ) return payload @@ -232,6 +388,80 @@ def revoke_principal(self, principal_id: str) -> None: def revoke_intent(self, intent_hash: str) -> None: self._sidecar.revoke_intent_hash(intent_hash) + def issue_task_identity( + self, + principal_id: str, + task_id: str, + ttl_seconds: int | None = None, + metadata: dict[str, str] | None = None, + ) -> dict[str, object]: + registry = self._sidecar.local_identity_registry() + if registry is None: + raise RuntimeError("local identity registry is not enabled") + issued = registry.issue_task_identity( + principal_id=principal_id, + task_id=task_id, + ttl_seconds=ttl_seconds, + metadata=metadata, + ) + return asdict(issued) + + def revoke_task_identity(self, identity_id: str) -> bool: + registry = self._sidecar.local_identity_registry() + if registry is None: + return False + return registry.revoke_identity(identity_id) + + def list_task_identities(self, active_only: bool = True) -> list[dict[str, object]]: + registry = self._sidecar.local_identity_registry() + if registry is None: + return [] + return [asdict(item) for item in registry.list_identities(active_only=active_only)] + + def list_flush_queue( + self, include_flushed: bool = False, include_quarantined: bool = False + ) -> list[dict[str, object]]: + registry = self._sidecar.local_identity_registry() + if registry is None: + return [] + return [ + asdict(item) + for item in registry.list_flush_queue( + include_flushed=include_flushed, + include_quarantined=include_quarantined, + ) + ] + + def ack_flush_queue_item(self, queue_item_id: str) -> bool: + registry = self._sidecar.local_identity_registry() + if registry is None: + return False + return registry.mark_flush_ack(queue_item_id) + + def list_dead_letter_queue(self) -> list[dict[str, object]]: + registry = self._sidecar.local_identity_registry() + if registry is None: + return [] + return [asdict(item) for item in registry.list_dead_letter_queue() if item.quarantined] + + def requeue_dead_letter_item(self, queue_item_id: str) -> bool: + registry = self._sidecar.local_identity_registry() + if registry is None: + return False + return registry.requeue_item(queue_item_id=queue_item_id, reset_attempts=True) + + def flush_queue_now(self, max_items: int | None = None) -> dict[str, object]: + cycle = self._flush_once(max_items=max_items, force=True) + return { + "ok": True, + "scanned_count": cycle.scanned_count, + "sent_count": cycle.sent_count, + "failed_count": cycle.failed_count, + "quarantined_count": cycle.quarantined_count, + "dead_letter_max_attempts": self._flush_worker.dead_letter_max_attempts, + "last_flush_error": self._runtime.last_flush_error, + } + def _policy_poll_loop(self) -> None: while not self._stop_event.is_set(): try: @@ -244,12 +474,123 @@ def _policy_poll_loop(self) -> None: self._runtime.last_policy_poll_error = str(exc) self._stop_event.wait(timeout=self._config.policy_poll_interval_s) + def _flush_queue_loop(self) -> None: + while not self._stop_event.is_set(): + try: + self._flush_once() + except Exception as exc: # noqa: BLE001 + self._runtime.flush_failed_count += 1 + self._runtime.last_flush_error = str(exc) + self._stop_event.wait(timeout=self._flush_worker.interval_s) + + def _flush_once( + self, + max_items: int | None = None, + force: bool = False, + ) -> FlushCycleResult: + result = FlushCycleResult() + if not self._flush_worker.enabled and not force: + return result + registry = self._sidecar.local_identity_registry() + if registry is None: + return result + client = self._resolve_control_plane_client() + if client is None: + return result + batch_size = self._flush_worker.max_batch_size if max_items is None else max(0, max_items) + queue_items = registry.list_flush_queue(limit=max(0, batch_size)) + if len(queue_items) == 0: + return result + self._runtime.flush_cycle_count += 1 + self._runtime.last_flush_epoch_s = time.time() + scanned_count = 0 + sent_count = 0 + failed_count = 0 + quarantined_count = 0 + for item in queue_items: + scanned_count += 1 + if item.flush_attempts >= self._flush_worker.dead_letter_max_attempts: + registry.quarantine_queue_item( + item.queue_item_id, + "dead_letter_max_attempts_exceeded", + ) + quarantined_count += 1 + self._runtime.flush_quarantined_count += 1 + self._runtime.last_flush_error = "dead_letter_max_attempts_exceeded" + continue + sent = self._send_queue_item_to_control_plane(item=item, client=client) + if sent: + registry.mark_flush_ack(item.queue_item_id) + self._runtime.flush_sent_count += 1 + sent_count += 1 + self._runtime.last_flush_error = None + else: + registry.mark_flush_failed(item.queue_item_id, "control_plane_flush_failed") + self._runtime.flush_failed_count += 1 + failed_count += 1 + self._runtime.last_flush_error = "control_plane_flush_failed" + if item.flush_attempts + 1 >= self._flush_worker.dead_letter_max_attempts: + registry.quarantine_queue_item( + item.queue_item_id, + "dead_letter_max_attempts_exceeded", + ) + quarantined_count += 1 + self._runtime.flush_quarantined_count += 1 + self._runtime.last_flush_error = "dead_letter_max_attempts_exceeded" + return FlushCycleResult( + scanned_count=scanned_count, + sent_count=sent_count, + failed_count=failed_count, + quarantined_count=quarantined_count, + ) + + def _send_queue_item_to_control_plane( + self, item: LedgerQueueItem, client: ControlPlaneClient + ) -> bool: + payload = item.payload if isinstance(item.payload, dict) else {} + principal_id = str(payload.get("principal_id", "unknown-principal")) + action = str(payload.get("action", "unknown-action")) + resource = str(payload.get("resource", "unknown-resource")) + reason = str(payload.get("reason", "unknown")) + allowed = bool(payload.get("allowed", False)) + mandate_id_raw = payload.get("mandate_id") + mandate_id = str(mandate_id_raw) if isinstance(mandate_id_raw, str) else None + emitted_at_raw = payload.get("emitted_at_epoch_s") + emitted_at = ( + int(emitted_at_raw) if isinstance(emitted_at_raw, (int, str)) else int(time.time()) + ) + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(emitted_at)) + audit_envelope = { + "event_id": f"qevt_{item.queue_item_id}", + "tenant_id": client.config.tenant_id, + "principal_id": principal_id, + "action": action, + "resource": resource, + "allowed": allowed, + "reason": reason, + "mandate_id": mandate_id, + "timestamp": timestamp, + "trace_id": None, + } + return client.send_audit_payload({"events": [audit_envelope]}) + + def _resolve_control_plane_client(self) -> ControlPlaneClient | None: + trace_emitter = self._sidecar.trace_emitter() + if isinstance(trace_emitter, ControlPlaneTraceEmitter): + return trace_emitter.client + if isinstance(trace_emitter, CompositeTraceEmitter): + for emitter in trace_emitter.emitters: + if isinstance(emitter, ControlPlaneTraceEmitter): + return emitter.client + return None + def _build_default_sidecar( mode: AuthorityMode, policy_file: str | None, credential_store_file: str, control_plane_config: ControlPlaneBootstrapConfig | None = None, + local_identity_config: LocalIdentityBootstrapConfig | None = None, identity_bridge: ExchangeTokenBridge | None = None, ) -> PredicateAuthoritySidecar: policy_rules: tuple[PolicyRule, ...] = () @@ -257,7 +598,7 @@ def _build_default_sidecar( policy_rules = PolicyFileSource(policy_file).load_rules() policy_engine = PolicyEngine(rules=policy_rules) - trace_emitter = None + trace_emitters: list[TraceEmitter] = [] if ( control_plane_config is not None and control_plane_config.enabled @@ -275,11 +616,29 @@ def _build_default_sidecar( fail_open=control_plane_config.fail_open, ) ) - trace_emitter = ControlPlaneTraceEmitter( - client=control_plane_client, - emit_usage_credits=True, - usage_credits_per_decision=control_plane_config.usage_credits_per_decision, + trace_emitters.append( + ControlPlaneTraceEmitter( + client=control_plane_client, + emit_usage_credits=True, + usage_credits_per_decision=control_plane_config.usage_credits_per_decision, + ) ) + local_identity_registry: LocalIdentityRegistry | None = None + if ( + local_identity_config is not None + and local_identity_config.enabled + and local_identity_config.registry_file_path is not None + ): + local_identity_registry = LocalIdentityRegistry( + file_path=local_identity_config.registry_file_path, + default_ttl_seconds=local_identity_config.default_ttl_seconds, + ) + trace_emitters.append(LocalLedgerQueueEmitter(registry=local_identity_registry)) + trace_emitter = ( + CompositeTraceEmitter(tuple(trace_emitters)) + if len(trace_emitters) > 1 + else (trace_emitters[0] if len(trace_emitters) == 1 else None) + ) proof_ledger = InMemoryProofLedger(trace_emitter=trace_emitter) guard = ActionGuard( @@ -295,6 +654,7 @@ def _build_default_sidecar( credential_store=LocalCredentialStore(credential_store_file), revocation_cache=LocalRevocationCache(), policy_engine=policy_engine, + local_identity_registry=local_identity_registry, ) @@ -360,6 +720,16 @@ def main() -> None: "--credential-store-file", default=str(Path.home() / ".predicate-authorityd" / "credentials.json"), ) + parser.add_argument( + "--local-identity-enabled", + action="store_true", + help="Enable local ephemeral task identity registry and flush queue.", + ) + parser.add_argument( + "--local-identity-registry-file", + default=str(Path.home() / ".predicate-authorityd" / "local-identities.json"), + ) + parser.add_argument("--local-identity-default-ttl-s", type=int, default=900) parser.add_argument( "--identity-mode", choices=["local", "local-idp", "oidc", "entra"], @@ -414,6 +784,21 @@ def main() -> None: parser.add_argument("--control-plane-timeout-s", type=float, default=2.0) parser.add_argument("--control-plane-max-retries", type=int, default=2) parser.add_argument("--control-plane-backoff-initial-s", type=float, default=0.2) + parser.add_argument( + "--flush-worker-enabled", + action="store_true", + help="Enable background local queue flush worker.", + ) + parser.add_argument( + "--flush-worker-disabled", + dest="flush_worker_enabled", + action="store_false", + help="Disable background local queue flush worker.", + ) + parser.set_defaults(flush_worker_enabled=True) + parser.add_argument("--flush-worker-interval-s", type=float, default=2.0) + parser.add_argument("--flush-worker-max-batch-size", type=int, default=50) + parser.add_argument("--flush-worker-dead-letter-max-attempts", type=int, default=5) parser.add_argument( "--control-plane-fail-open", action="store_true", @@ -456,12 +841,18 @@ def main() -> None: fail_open=bool(args.control_plane_fail_open), usage_credits_per_decision=max(0, int(args.control_plane_usage_credits_per_decision)), ) + local_identity_bootstrap = LocalIdentityBootstrapConfig( + enabled=bool(args.local_identity_enabled), + registry_file_path=str(args.local_identity_registry_file), + default_ttl_seconds=max(1, int(args.local_identity_default_ttl_s)), + ) identity_bridge = _build_identity_bridge_from_args(args) sidecar = _build_default_sidecar( mode=mode, policy_file=args.policy_file, credential_store_file=args.credential_store_file, control_plane_config=control_plane_bootstrap, + local_identity_config=local_identity_bootstrap, identity_bridge=identity_bridge, ) daemon = PredicateAuthorityDaemon( @@ -471,6 +862,12 @@ def main() -> None: port=args.port, policy_poll_interval_s=args.policy_poll_interval_s, ), + flush_worker=FlushWorkerConfig( + enabled=bool(args.flush_worker_enabled), + interval_s=max(0.1, float(args.flush_worker_interval_s)), + max_batch_size=max(1, int(args.flush_worker_max_batch_size)), + dead_letter_max_attempts=max(1, int(args.flush_worker_dead_letter_max_attempts)), + ), ) daemon.start() print( diff --git a/predicate_authority/local_identity.py b/predicate_authority/local_identity.py new file mode 100644 index 0000000..86380c9 --- /dev/null +++ b/predicate_authority/local_identity.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import json +import os +import time +import uuid +from dataclasses import asdict, dataclass, field +from pathlib import Path +from threading import Lock +from typing import Any + +from predicate_contracts import ProofEvent, TraceEmitter + + +@dataclass(frozen=True) +class TaskIdentityRecord: + identity_id: str + principal_id: str + task_id: str + issued_at_epoch_s: int + expires_at_epoch_s: int + revoked: bool = False + metadata: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class LedgerQueueItem: + queue_item_id: str + enqueued_at_epoch_s: int + payload: dict[str, object] + flushed: bool = False + flush_attempts: int = 0 + last_error: str | None = None + flushed_at_epoch_s: int | None = None + quarantined: bool = False + quarantine_reason: str | None = None + quarantined_at_epoch_s: int | None = None + + +@dataclass(frozen=True) +class LocalIdentityRegistryStats: + total_identity_count: int + active_identity_count: int + pending_flush_queue_count: int + flushed_queue_count: int + failed_queue_count: int + quarantined_queue_count: int + + +class LocalIdentityRegistry: + def __init__(self, file_path: str, default_ttl_seconds: int = 900) -> None: + if default_ttl_seconds <= 0: + raise ValueError("default_ttl_seconds must be > 0") + self._file_path = Path(file_path) + self._default_ttl_seconds = default_ttl_seconds + self._lock = Lock() + self._ensure_store_path() + + def issue_task_identity( + self, + principal_id: str, + task_id: str, + ttl_seconds: int | None = None, + metadata: dict[str, str] | None = None, + ) -> TaskIdentityRecord: + ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl_seconds + if ttl <= 0: + raise ValueError("ttl_seconds must be > 0") + now = int(time.time()) + record = TaskIdentityRecord( + identity_id="lid_" + uuid.uuid4().hex[:16], + principal_id=principal_id, + task_id=task_id, + issued_at_epoch_s=now, + expires_at_epoch_s=now + ttl, + revoked=False, + metadata=metadata or {}, + ) + with self._lock: + payload = self._read_all_unlocked() + identities = payload.setdefault("identities", {}) + identities[record.identity_id] = asdict(record) + self._write_all_unlocked(payload) + return record + + def revoke_identity(self, identity_id: str) -> bool: + with self._lock: + payload = self._read_all_unlocked() + identities = payload.setdefault("identities", {}) + item = identities.get(identity_id) + if not isinstance(item, dict): + return False + item["revoked"] = True + identities[identity_id] = item + self._write_all_unlocked(payload) + return True + + def is_identity_active(self, identity_id: str, now_epoch_s: int | None = None) -> bool: + now = now_epoch_s if now_epoch_s is not None else int(time.time()) + record = self.get_identity(identity_id) + if record is None: + return False + if record.revoked: + return False + return now < record.expires_at_epoch_s + + def get_identity(self, identity_id: str) -> TaskIdentityRecord | None: + with self._lock: + payload = self._read_all_unlocked() + identities = payload.setdefault("identities", {}) + raw = identities.get(identity_id) + if not isinstance(raw, dict): + return None + try: + return TaskIdentityRecord( + identity_id=str(raw["identity_id"]), + principal_id=str(raw["principal_id"]), + task_id=str(raw["task_id"]), + issued_at_epoch_s=int(raw["issued_at_epoch_s"]), + expires_at_epoch_s=int(raw["expires_at_epoch_s"]), + revoked=bool(raw.get("revoked", False)), + metadata={ + str(k): str(v) + for k, v in dict(raw.get("metadata", {})).items() + if isinstance(k, str) + }, + ) + except Exception: + return None + + def list_identities(self, active_only: bool = True) -> list[TaskIdentityRecord]: + with self._lock: + payload = self._read_all_unlocked() + identities = payload.setdefault("identities", {}) + raw_items = list(identities.values()) + result: list[TaskIdentityRecord] = [] + for raw in raw_items: + if not isinstance(raw, dict): + continue + record = self.get_identity(str(raw.get("identity_id", ""))) + if record is None: + continue + if active_only and not self.is_identity_active(record.identity_id): + continue + result.append(record) + return sorted(result, key=lambda item: item.issued_at_epoch_s, reverse=True) + + def expire_identities(self, now_epoch_s: int | None = None) -> int: + now = now_epoch_s if now_epoch_s is not None else int(time.time()) + expired_count = 0 + with self._lock: + payload = self._read_all_unlocked() + identities = payload.setdefault("identities", {}) + for identity_id, raw in list(identities.items()): + if not isinstance(raw, dict): + continue + expires_at = int(raw.get("expires_at_epoch_s", now + 1)) + revoked = bool(raw.get("revoked", False)) + if not revoked and expires_at <= now: + raw["revoked"] = True + identities[identity_id] = raw + expired_count += 1 + if expired_count > 0: + self._write_all_unlocked(payload) + return expired_count + + def enqueue_proof_event( + self, event: ProofEvent, source: str = "predicate-authorityd" + ) -> LedgerQueueItem: + item = LedgerQueueItem( + queue_item_id="q_" + uuid.uuid4().hex[:16], + enqueued_at_epoch_s=int(time.time()), + payload={ + "source": source, + "event_type": event.event_type, + "principal_id": event.principal_id, + "action": event.action, + "resource": event.resource, + "reason": event.reason.value, + "allowed": event.allowed, + "mandate_id": event.mandate_id, + "emitted_at_epoch_s": event.emitted_at_epoch_s, + }, + ) + with self._lock: + payload = self._read_all_unlocked() + queue = payload.setdefault("flush_queue", {}) + queue[item.queue_item_id] = asdict(item) + self._write_all_unlocked(payload) + return item + + def list_flush_queue( + self, + include_flushed: bool = False, + include_quarantined: bool = False, + limit: int | None = None, + ) -> list[LedgerQueueItem]: + with self._lock: + payload = self._read_all_unlocked() + queue = payload.setdefault("flush_queue", {}) + raw_items = list(queue.values()) + result: list[LedgerQueueItem] = [] + for raw in raw_items: + if not isinstance(raw, dict): + continue + item = self._parse_queue_item(raw) + if item is None: + continue + if not include_flushed and item.flushed: + continue + if not include_quarantined and item.quarantined: + continue + result.append(item) + result = sorted(result, key=lambda item: item.enqueued_at_epoch_s) + if limit is not None and limit >= 0: + return result[:limit] + return result + + def mark_flush_ack(self, queue_item_id: str) -> bool: + with self._lock: + payload = self._read_all_unlocked() + queue = payload.setdefault("flush_queue", {}) + raw = queue.get(queue_item_id) + if not isinstance(raw, dict): + return False + raw["flushed"] = True + raw["flush_attempts"] = int(raw.get("flush_attempts", 0)) + 1 + raw["last_error"] = None + raw["flushed_at_epoch_s"] = int(time.time()) + queue[queue_item_id] = raw + self._write_all_unlocked(payload) + return True + + def mark_flush_failed(self, queue_item_id: str, error: str) -> bool: + with self._lock: + payload = self._read_all_unlocked() + queue = payload.setdefault("flush_queue", {}) + raw = queue.get(queue_item_id) + if not isinstance(raw, dict): + return False + raw["flush_attempts"] = int(raw.get("flush_attempts", 0)) + 1 + raw["last_error"] = error + queue[queue_item_id] = raw + self._write_all_unlocked(payload) + return True + + def quarantine_queue_item(self, queue_item_id: str, reason: str) -> bool: + with self._lock: + payload = self._read_all_unlocked() + queue = payload.setdefault("flush_queue", {}) + raw = queue.get(queue_item_id) + if not isinstance(raw, dict): + return False + raw["quarantined"] = True + raw["quarantine_reason"] = reason + raw["quarantined_at_epoch_s"] = int(time.time()) + queue[queue_item_id] = raw + self._write_all_unlocked(payload) + return True + + def list_dead_letter_queue(self, limit: int | None = None) -> list[LedgerQueueItem]: + return self.list_flush_queue( + include_flushed=True, + include_quarantined=True, + limit=limit, + ) + + def requeue_item(self, queue_item_id: str, reset_attempts: bool = True) -> bool: + with self._lock: + payload = self._read_all_unlocked() + queue = payload.setdefault("flush_queue", {}) + raw = queue.get(queue_item_id) + if not isinstance(raw, dict): + return False + if not bool(raw.get("quarantined", False)): + return False + raw["quarantined"] = False + raw["quarantine_reason"] = None + raw["quarantined_at_epoch_s"] = None + raw["flushed"] = False + raw["flushed_at_epoch_s"] = None + raw["last_error"] = None + if reset_attempts: + raw["flush_attempts"] = 0 + queue[queue_item_id] = raw + self._write_all_unlocked(payload) + return True + + def stats(self) -> LocalIdentityRegistryStats: + active = len(self.list_identities(active_only=True)) + all_identities = len(self.list_identities(active_only=False)) + queue_all = self.list_flush_queue(include_flushed=True, include_quarantined=True) + queue_pending = len( + [item for item in queue_all if not item.flushed and not item.quarantined] + ) + queue_flushed = len([item for item in queue_all if item.flushed]) + queue_failed = len( + [ + item + for item in queue_all + if item.last_error is not None and not item.flushed and not item.quarantined + ] + ) + queue_quarantined = len([item for item in queue_all if item.quarantined]) + return LocalIdentityRegistryStats( + total_identity_count=all_identities, + active_identity_count=active, + pending_flush_queue_count=queue_pending, + flushed_queue_count=queue_flushed, + failed_queue_count=queue_failed, + quarantined_queue_count=queue_quarantined, + ) + + def _read_all_unlocked(self) -> dict[str, Any]: + if not self._file_path.exists(): + return {"identities": {}, "flush_queue": {}} + content = self._file_path.read_text(encoding="utf-8").strip() + if content == "": + return {"identities": {}, "flush_queue": {}} + loaded = json.loads(content) + if isinstance(loaded, dict): + if "identities" not in loaded: + loaded["identities"] = {} + if "flush_queue" not in loaded: + loaded["flush_queue"] = {} + return loaded + return {"identities": {}, "flush_queue": {}} + + def _parse_queue_item(self, raw: dict[str, Any]) -> LedgerQueueItem | None: + try: + return LedgerQueueItem( + queue_item_id=str(raw["queue_item_id"]), + enqueued_at_epoch_s=int(raw["enqueued_at_epoch_s"]), + payload=dict(raw.get("payload", {})), + flushed=bool(raw.get("flushed", False)), + flush_attempts=int(raw.get("flush_attempts", 0)), + last_error=(str(raw["last_error"]) if raw.get("last_error") is not None else None), + flushed_at_epoch_s=( + int(raw["flushed_at_epoch_s"]) + if raw.get("flushed_at_epoch_s") is not None + else None + ), + quarantined=bool(raw.get("quarantined", False)), + quarantine_reason=( + str(raw["quarantine_reason"]) + if raw.get("quarantine_reason") is not None + else None + ), + quarantined_at_epoch_s=( + int(raw["quarantined_at_epoch_s"]) + if raw.get("quarantined_at_epoch_s") is not None + else None + ), + ) + except Exception: + return None + + def _write_all_unlocked(self, payload: dict[str, Any]) -> None: + self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + self._chmod_file_safe() + + def _ensure_store_path(self) -> None: + self._file_path.parent.mkdir(parents=True, exist_ok=True) + try: + os.chmod(self._file_path.parent, 0o700) + except OSError: + pass + if not self._file_path.exists(): + self._file_path.write_text( + json.dumps({"identities": {}, "flush_queue": {}}, indent=2), + encoding="utf-8", + ) + self._chmod_file_safe() + + def _chmod_file_safe(self) -> None: + try: + os.chmod(self._file_path, 0o600) + except OSError: + pass + + +@dataclass +class LocalLedgerQueueEmitter(TraceEmitter): + registry: LocalIdentityRegistry + source: str = "predicate-authorityd" + + def emit(self, event: ProofEvent) -> None: + self.registry.enqueue_proof_event(event, source=self.source) + + +@dataclass +class CompositeTraceEmitter(TraceEmitter): + emitters: tuple[TraceEmitter, ...] + + def emit(self, event: ProofEvent) -> None: + for emitter in self.emitters: + emitter.emit(event) diff --git a/predicate_authority/sidecar.py b/predicate_authority/sidecar.py index 020cec7..b96c4fd 100644 --- a/predicate_authority/sidecar.py +++ b/predicate_authority/sidecar.py @@ -7,6 +7,7 @@ from predicate_authority.bridge import TokenExchangeResult from predicate_authority.control_plane import ControlPlaneTraceEmitter from predicate_authority.guard import ActionGuard +from predicate_authority.local_identity import LocalIdentityRegistry from predicate_authority.policy import PolicyEngine from predicate_authority.policy_source import PolicyFileSource from predicate_authority.proof import InMemoryProofLedger @@ -18,6 +19,7 @@ AuthorizationReason, PrincipalRef, StateEvidence, + TraceEmitter, ) @@ -46,6 +48,13 @@ class SidecarStatus: control_plane_usage_push_success_count: int = 0 control_plane_usage_push_failure_count: int = 0 control_plane_last_push_error: str | None = None + local_identity_registry_enabled: bool = False + local_identity_total_count: int = 0 + local_identity_active_count: int = 0 + local_flush_queue_pending_count: int = 0 + local_flush_queue_flushed_count: int = 0 + local_flush_queue_failed_count: int = 0 + local_flush_queue_quarantined_count: int = 0 class SidecarError(RuntimeError): @@ -74,6 +83,7 @@ def __init__( credential_store: LocalCredentialStore, revocation_cache: LocalRevocationCache, policy_engine: PolicyEngine, + local_identity_registry: LocalIdentityRegistry | None = None, ) -> None: self._config = config self._action_guard = action_guard @@ -82,6 +92,7 @@ def __init__( self._credential_store = credential_store self._revocation_cache = revocation_cache self._policy_engine = policy_engine + self._local_identity_registry = local_identity_registry self._policy_source = ( PolicyFileSource(config.policy_file_path) if config.policy_file_path is not None @@ -157,6 +168,9 @@ def status(self) -> SidecarStatus: if isinstance(trace_emitter, ControlPlaneTraceEmitter): control_plane_attached = True control_plane_payload = trace_emitter.status_payload() + local_stats = ( + self._local_identity_registry.stats() if self._local_identity_registry else None + ) return SidecarStatus( mode=self._config.mode, policy_hot_reload_enabled=self._policy_source is not None, @@ -182,4 +196,21 @@ def status(self) -> SidecarStatus: if control_plane_payload.get("control_plane_last_push_error") is not None else None ), + local_identity_registry_enabled=self._local_identity_registry is not None, + local_identity_total_count=(local_stats.total_identity_count if local_stats else 0), + local_identity_active_count=(local_stats.active_identity_count if local_stats else 0), + local_flush_queue_pending_count=( + local_stats.pending_flush_queue_count if local_stats else 0 + ), + local_flush_queue_flushed_count=(local_stats.flushed_queue_count if local_stats else 0), + local_flush_queue_failed_count=(local_stats.failed_queue_count if local_stats else 0), + local_flush_queue_quarantined_count=( + local_stats.quarantined_queue_count if local_stats else 0 + ), ) + + def local_identity_registry(self) -> LocalIdentityRegistry | None: + return self._local_identity_registry + + def trace_emitter(self) -> TraceEmitter | None: + return self._proof_ledger.trace_emitter diff --git a/tests/test_daemon_phase2.py b/tests/test_daemon_phase2.py index 153593d..0e9ca12 100644 --- a/tests/test_daemon_phase2.py +++ b/tests/test_daemon_phase2.py @@ -27,6 +27,8 @@ ) from predicate_authority.daemon import ( ControlPlaneBootstrapConfig, + FlushWorkerConfig, + LocalIdentityBootstrapConfig, _build_default_sidecar, _build_identity_bridge_from_args, ) @@ -94,7 +96,7 @@ def _fetch_json(url: str) -> dict[str, object]: return loaded -def _post_json(url: str, body: dict[str, str] | None = None) -> dict[str, object]: +def _post_json(url: str, body: dict[str, object] | None = None) -> dict[str, object]: parsed = urlsplit(url) path = parsed.path or "/" payload = json.dumps(body or {}) @@ -239,6 +241,33 @@ class BoundHandler(_ControlPlaneHandler): return server, thread +def _start_failing_control_plane_server() -> tuple[ThreadingHTTPServer, threading.Thread]: + class FailingHandler(BaseHTTPRequestHandler): + requests: list[tuple[str, dict[str, object], dict[str, str]]] = [] + + def do_POST(self) -> None: # noqa: N802 + raw_length = self.headers.get("Content-Length", "0") + content_length = int(raw_length) if raw_length.isdigit() else 0 + payload_raw = ( + self.rfile.read(content_length).decode("utf-8") if content_length > 0 else "{}" + ) + loaded = json.loads(payload_raw) + assert isinstance(loaded, dict) + self.requests.append((self.path, loaded, dict(self.headers.items()))) + self.send_response(500) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"error":"temporary_failure"}') + + def log_message(self, format: str, *args: Any) -> None: # noqa: A003 + return + + server = ThreadingHTTPServer(("127.0.0.1", 0), FailingHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, thread + + def test_daemon_bootstrap_wires_control_plane_emitter(tmp_path: Path) -> None: policy_file = tmp_path / "policy.json" policy_file.write_text( @@ -334,3 +363,317 @@ def test_daemon_identity_mode_local_idp_builder() -> None: ) assert token.provider.value == "local_idp" assert len(token.access_token.split(".")) == 3 + + +def test_daemon_local_identity_registry_endpoints(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text(json.dumps({"rules": []}), encoding="utf-8") + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + local_identity_config=LocalIdentityBootstrapConfig( + enabled=True, + registry_file_path=str(tmp_path / "local-identities.json"), + default_ttl_seconds=60, + ), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + ) + daemon.start() + try: + base_url = f"http://127.0.0.1:{daemon.bound_port}" + created = _post_json( + f"{base_url}/identity/task", + {"principal_id": "agent:local", "task_id": "task-abc", "ttl_seconds": "60"}, + ) + listed = _fetch_json(f"{base_url}/identity/list") + status = _fetch_json(f"{base_url}/status") + identity_id = str(created["identity_id"]) + revoked = _post_json(f"{base_url}/identity/revoke", {"identity_id": identity_id}) + + assert created["principal_id"] == "agent:local" + assert isinstance(listed.get("items"), list) + assert status["local_identity_registry_enabled"] is True + assert int(status["local_identity_total_count"]) >= 1 + assert revoked["ok"] is True + finally: + daemon.stop() + + +def test_daemon_background_flush_worker_drains_local_queue(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text( + json.dumps( + { + "rules": [ + { + "name": "allow-any-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + } + ), + encoding="utf-8", + ) + server, _ = _start_control_plane_server() + daemon: PredicateAuthorityDaemon | None = None + try: + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + control_plane_config=ControlPlaneBootstrapConfig( + enabled=True, + base_url=f"http://127.0.0.1:{server.server_port}", + tenant_id="tenant-a", + project_id="project-a", + auth_token="token-a", + fail_open=False, + ), + local_identity_config=LocalIdentityBootstrapConfig( + enabled=True, + registry_file_path=str(tmp_path / "local-identities.json"), + default_ttl_seconds=60, + ), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=0.1), + flush_worker=FlushWorkerConfig(enabled=True, interval_s=0.1, max_batch_size=20), + ) + daemon.start() + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:flush"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="create order", + ), + state_evidence=StateEvidence(source="test", state_hash="flush-state"), + verification_evidence=VerificationEvidence(), + ) + decision = sidecar.issue_mandate(request) + assert decision.allowed is True + + base_url = f"http://127.0.0.1:{daemon.bound_port}" + deadline = time.time() + 3.0 + while time.time() < deadline: + status = _fetch_json(f"{base_url}/status") + if int(status["local_flush_queue_flushed_count"]) >= 1: + break + time.sleep(0.05) + else: + raise AssertionError("Flush worker did not flush local queue within timeout.") + + status = _fetch_json(f"{base_url}/status") + assert int(status["flush_sent_count"]) >= 1 + assert int(status["local_flush_queue_pending_count"]) == 0 + assert int(status["local_flush_queue_flushed_count"]) >= 1 + # Immediate control-plane push + queue-flush push should both hit audit endpoint. + handler_cls = server.RequestHandlerClass + requests = getattr(handler_cls, "requests") + assert isinstance(requests, list) + audit_posts = [item for item in requests if item[0] == "/v1/audit/events:batch"] + assert len(audit_posts) >= 2 + finally: + if daemon is not None: + daemon.stop() + server.shutdown() + server.server_close() + + +def test_daemon_manual_flush_endpoint_drains_queue(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text( + json.dumps( + { + "rules": [ + { + "name": "allow-any-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + } + ), + encoding="utf-8", + ) + server, _ = _start_control_plane_server() + daemon: PredicateAuthorityDaemon | None = None + try: + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + control_plane_config=ControlPlaneBootstrapConfig( + enabled=True, + base_url=f"http://127.0.0.1:{server.server_port}", + tenant_id="tenant-a", + project_id="project-a", + auth_token="token-a", + fail_open=False, + ), + local_identity_config=LocalIdentityBootstrapConfig( + enabled=True, + registry_file_path=str(tmp_path / "local-identities.json"), + default_ttl_seconds=60, + ), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + flush_worker=FlushWorkerConfig(enabled=False, interval_s=5.0, max_batch_size=20), + ) + daemon.start() + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:manual-flush"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="create order", + ), + state_evidence=StateEvidence(source="test", state_hash="manual-flush"), + verification_evidence=VerificationEvidence(), + ) + decision = sidecar.issue_mandate(request) + assert decision.allowed is True + + base_url = f"http://127.0.0.1:{daemon.bound_port}" + before = _fetch_json(f"{base_url}/ledger/flush-queue") + assert len(before.get("items", [])) == 1 + + result = _post_json(f"{base_url}/ledger/flush-now", {"max_items": 5}) + assert result["ok"] is True + assert int(result["sent_count"]) >= 1 + + after = _fetch_json(f"{base_url}/ledger/flush-queue") + assert len(after.get("items", [])) == 0 + finally: + if daemon is not None: + daemon.stop() + server.shutdown() + server.server_close() + + +def test_dead_letter_threshold_quarantines_queue_items(tmp_path: Path) -> None: + policy_file = tmp_path / "policy.json" + policy_file.write_text( + json.dumps( + { + "rules": [ + { + "name": "allow-any-http", + "effect": "allow", + "principals": ["agent:*"], + "actions": ["http.*"], + "resources": ["https://*/*"], + } + ] + } + ), + encoding="utf-8", + ) + server, _ = _start_failing_control_plane_server() + daemon: PredicateAuthorityDaemon | None = None + try: + sidecar = _build_default_sidecar( + mode=AuthorityMode.LOCAL_ONLY, + policy_file=str(policy_file), + credential_store_file=str(tmp_path / "credentials.json"), + control_plane_config=ControlPlaneBootstrapConfig( + enabled=True, + base_url=f"http://127.0.0.1:{server.server_port}", + tenant_id="tenant-a", + project_id="project-a", + auth_token="token-a", + fail_open=True, + ), + local_identity_config=LocalIdentityBootstrapConfig( + enabled=True, + registry_file_path=str(tmp_path / "local-identities.json"), + default_ttl_seconds=60, + ), + ) + daemon = PredicateAuthorityDaemon( + sidecar=sidecar, + config=DaemonConfig(host="127.0.0.1", port=0, policy_poll_interval_s=10.0), + flush_worker=FlushWorkerConfig( + enabled=False, + interval_s=5.0, + max_batch_size=20, + dead_letter_max_attempts=1, + ), + ) + daemon.start() + request = ActionRequest( + principal=PrincipalRef(principal_id="agent:dead-letter"), + action_spec=ActionSpec( + action="http.post", + resource="https://api.vendor.com/orders", + intent="create order", + ), + state_evidence=StateEvidence(source="test", state_hash="dead-letter-state"), + verification_evidence=VerificationEvidence(), + ) + decision = sidecar.issue_mandate(request) + assert decision.allowed is True + + base_url = f"http://127.0.0.1:{daemon.bound_port}" + flush_result = _post_json(f"{base_url}/ledger/flush-now", {"max_items": 10}) + assert int(flush_result["failed_count"]) >= 1 + assert int(flush_result["quarantined_count"]) >= 1 + + queue_default = _fetch_json(f"{base_url}/ledger/flush-queue") + assert len(queue_default.get("items", [])) == 0 + queue_with_quarantine = _fetch_json( + f"{base_url}/ledger/flush-queue?include_quarantined=true" + ) + items = queue_with_quarantine.get("items", []) + assert isinstance(items, list) + assert len(items) >= 1 + first_item = items[0] + assert isinstance(first_item, dict) + assert first_item.get("quarantined") is True + queue_item_id = str(first_item["queue_item_id"]) + + dead_letter = _fetch_json(f"{base_url}/ledger/dead-letter") + dead_letter_items = dead_letter.get("items", []) + assert isinstance(dead_letter_items, list) + assert len(dead_letter_items) >= 1 + assert all( + bool(item.get("quarantined", False)) + for item in dead_letter_items + if isinstance(item, dict) + ) + status_before_requeue = _fetch_json(f"{base_url}/status") + assert int(status_before_requeue["flush_quarantined_count"]) >= 1 + assert int(status_before_requeue["local_flush_queue_quarantined_count"]) >= 1 + + requeued = _post_json(f"{base_url}/ledger/requeue", {"queue_item_id": queue_item_id}) + assert requeued["ok"] is True + + dead_letter_after = _fetch_json(f"{base_url}/ledger/dead-letter") + dead_letter_after_items = dead_letter_after.get("items", []) + assert isinstance(dead_letter_after_items, list) + assert len(dead_letter_after_items) == 0 + pending_after = _fetch_json(f"{base_url}/ledger/flush-queue") + pending_items_after = pending_after.get("items", []) + assert isinstance(pending_items_after, list) + assert len(pending_items_after) >= 1 + + status = _fetch_json(f"{base_url}/status") + assert int(status["flush_quarantined_count"]) >= 1 + assert int(status["local_flush_queue_quarantined_count"]) == 0 + finally: + if daemon is not None: + daemon.stop() + server.shutdown() + server.server_close() diff --git a/tests/test_local_identity_registry.py b/tests/test_local_identity_registry.py new file mode 100644 index 0000000..baa582d --- /dev/null +++ b/tests/test_local_identity_registry.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from pathlib import Path + +from predicate_authority import LocalIdentityRegistry +from predicate_contracts import AuthorizationReason, ProofEvent + + +def test_local_identity_registry_issue_revoke_and_expire(tmp_path: Path) -> None: + registry = LocalIdentityRegistry(str(tmp_path / "local-identities.json"), default_ttl_seconds=2) + issued = registry.issue_task_identity( + principal_id="agent:test", + task_id="task-123", + ttl_seconds=1, + metadata={"kind": "codegen"}, + ) + assert registry.is_identity_active(issued.identity_id, now_epoch_s=issued.issued_at_epoch_s) + assert issued.metadata["kind"] == "codegen" + + expired = registry.expire_identities(now_epoch_s=issued.expires_at_epoch_s) + assert expired == 1 + assert ( + registry.is_identity_active(issued.identity_id, now_epoch_s=issued.expires_at_epoch_s) + is False + ) + + issued_2 = registry.issue_task_identity(principal_id="agent:test", task_id="task-456") + assert registry.revoke_identity(issued_2.identity_id) is True + assert registry.is_identity_active(issued_2.identity_id) is False + + +def test_local_identity_registry_flush_queue_lifecycle(tmp_path: Path) -> None: + registry = LocalIdentityRegistry(str(tmp_path / "local-identities.json")) + event = ProofEvent( + event_type="authority.decision", + principal_id="agent:test", + action="http.post", + resource="https://api.vendor.com/orders", + reason=AuthorizationReason.ALLOWED, + allowed=True, + mandate_id="mandate-1", + emitted_at_epoch_s=1_700_000_000, + ) + item = registry.enqueue_proof_event(event) + pending = registry.list_flush_queue() + assert len(pending) == 1 + assert pending[0].queue_item_id == item.queue_item_id + + assert registry.mark_flush_failed(item.queue_item_id, "temporary outage") is True + pending_after_fail = registry.list_flush_queue() + assert pending_after_fail[0].last_error == "temporary outage" + + assert registry.mark_flush_ack(item.queue_item_id) is True + pending_after_ack = registry.list_flush_queue() + assert len(pending_after_ack) == 0 + all_items = registry.list_flush_queue(include_flushed=True) + assert len(all_items) == 1 + assert all_items[0].flushed is True + + stats = registry.stats() + assert stats.total_identity_count == 0 + assert stats.pending_flush_queue_count == 0 + assert stats.flushed_queue_count == 1