From 7cd7d47bba0562872af14dc859d5cd782bec328c Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Fri, 12 Jun 2026 12:46:13 +0200 Subject: [PATCH 1/3] feat: add syft-verifuscate package for verify-then-obfuscate of enclave model code (cherry picked from commit 44c1a792a605be7bde8507afe2c0e527a8ab3309) --- packages/syft-verifuscate/README.md | 39 ++ packages/syft-verifuscate/pyproject.toml | 18 + .../src/syft_verifuscate/__init__.py | 29 ++ .../src/syft_verifuscate/errors.py | 24 + .../src/syft_verifuscate/obfuscator.py | 199 ++++++++ .../src/syft_verifuscate/policy.py | 151 +++++++ .../src/syft_verifuscate/runner.py | 80 ++++ .../src/syft_verifuscate/verifier.py | 427 ++++++++++++++++++ .../tests/fixtures/compliant_model.py | 41 ++ .../syft-verifuscate/tests/test_obfuscate.py | 64 +++ packages/syft-verifuscate/tests/test_run.py | 67 +++ .../syft-verifuscate/tests/test_verify.py | 90 ++++ pyproject.toml | 2 + uv.lock | 8 + 14 files changed, 1239 insertions(+) create mode 100644 packages/syft-verifuscate/README.md create mode 100644 packages/syft-verifuscate/pyproject.toml create mode 100644 packages/syft-verifuscate/src/syft_verifuscate/__init__.py create mode 100644 packages/syft-verifuscate/src/syft_verifuscate/errors.py create mode 100644 packages/syft-verifuscate/src/syft_verifuscate/obfuscator.py create mode 100644 packages/syft-verifuscate/src/syft_verifuscate/policy.py create mode 100644 packages/syft-verifuscate/src/syft_verifuscate/runner.py create mode 100644 packages/syft-verifuscate/src/syft_verifuscate/verifier.py create mode 100644 packages/syft-verifuscate/tests/fixtures/compliant_model.py create mode 100644 packages/syft-verifuscate/tests/test_obfuscate.py create mode 100644 packages/syft-verifuscate/tests/test_run.py create mode 100644 packages/syft-verifuscate/tests/test_verify.py diff --git a/packages/syft-verifuscate/README.md b/packages/syft-verifuscate/README.md new file mode 100644 index 00000000000..5212aeb9fea --- /dev/null +++ b/packages/syft-verifuscate/README.md @@ -0,0 +1,39 @@ +# syft-verifuscate + +Verify-then-obfuscate for JAX/Flax model-inference code that runs in an enclave on a second party's +private data. `verifuscate.run(...)`: + +1. **Verifies** (static, before running) that the _private_ model-definition lines only do trusted JAX/ + Flax math — no imports, no file/network, no dynamic-Python escape hatches, no named method calls on + opaque values. This is the no-data-theft guarantee. +2. **Obfuscates** those lines (rename identifiers, blank constants/einsum strings, strip comments) so + the model architecture stays secret. Every non-private line is copied byte-for-byte. + +The design and threat model are documented under `research/verifuscate/` in the repo root +(`SYNTHESIS.md`, `approach-A-ast-obfuscation.md`, `approach-B-whitelist-sandbox.md`). + +## Usage + +```python +import syft_verifuscate as verifuscate + +result = verifuscate.run( + "gemma_inference.py", + private=[[22, 93], [99, 280]], # 1-based inclusive line ranges to hide+verify + allow_functions="jax.*, flax.linen.*", # things callable BY NAME (path-resolved) + allow_methods="arithmetic, indexing, comparison, metadata", # operators allowed ON A VALUE +) +# On success: writes gemma_inference.obfuscated.py and returns result.certificate. +# On a policy violation: raises PolicyViolation naming each offending line (strict=True, the default). +``` + +Use `verifuscate.verify(...)` for the check alone (it returns violations instead of raising), or pass +`strict=False` to `run` to get a `RunResult` with `.ok` / `.violations` and no exception. + +## Scope and honest limits + +The static checker enforces the documented rule set with single-pass alias resolution (no deep +data-flow analysis). It does **not** address the output channel (a model encoding private data into its +legitimate result), timing/side channels, or the general undecidability of Python — see +`research/verifuscate/approach-B-whitelist-sandbox.md` §6. The obfuscated artifact is **display-only** +(the real, unobfuscated code is what runs in the enclave). diff --git a/packages/syft-verifuscate/pyproject.toml b/packages/syft-verifuscate/pyproject.toml new file mode 100644 index 00000000000..1b0f0c84f07 --- /dev/null +++ b/packages/syft-verifuscate/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "syft-verifuscate" +version = "0.1.0" +description = "Verify + obfuscate: prove JAX/Flax inference code only does math (no data theft) while hiding the model architecture" +authors = [{ name = "OpenMined", email = "info@openmined.org" }] +license = { text = "Apache-2.0" } +readme = "README.md" +requires-python = ">=3.10" + +# Pure standard-library tool (ast, tokenize, hashlib, fnmatch) — no runtime dependencies. +dependencies = [] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/syft_verifuscate"] diff --git a/packages/syft-verifuscate/src/syft_verifuscate/__init__.py b/packages/syft-verifuscate/src/syft_verifuscate/__init__.py new file mode 100644 index 00000000000..5abe81e9689 --- /dev/null +++ b/packages/syft-verifuscate/src/syft_verifuscate/__init__.py @@ -0,0 +1,29 @@ +"""syft-verifuscate — verify + obfuscate JAX/Flax inference code. + +`run` is the entry point: it statically proves the private model-definition lines only do trusted math +(no data theft), then obfuscates them so the model architecture stays secret. See README and the design +under `research/verifuscate/`. +""" + +from __future__ import annotations + +__version__ = "0.1.0" + +from .errors import PolicyViolation, VerifuscateError +from .obfuscator import obfuscate +from .policy import Policy +from .runner import RunResult, run +from .verifier import VerifyResult, Violation, verify + +__all__ = [ + "run", + "verify", + "obfuscate", + "Policy", + "RunResult", + "VerifyResult", + "Violation", + "PolicyViolation", + "VerifuscateError", + "__version__", +] diff --git a/packages/syft-verifuscate/src/syft_verifuscate/errors.py b/packages/syft-verifuscate/src/syft_verifuscate/errors.py new file mode 100644 index 00000000000..dd78b2536fa --- /dev/null +++ b/packages/syft-verifuscate/src/syft_verifuscate/errors.py @@ -0,0 +1,24 @@ +"""Exceptions raised by syft-verifuscate.""" + +from __future__ import annotations + + +class VerifuscateError(Exception): + """Base class for all verifuscate errors.""" + + +class PolicyViolation(VerifuscateError): + """Raised by ``run(..., strict=True)`` when the private region fails verification. + + The offending findings are attached as ``.violations`` (a list of ``Violation``). + """ + + def __init__(self, violations): + self.violations = list(violations) + lines = "\n".join( + f" line {v.line}: [{v.code}] {v.message}" for v in self.violations + ) + super().__init__( + f"verifuscate refused: {len(self.violations)} policy violation(s) in the private region:\n" + f"{lines}" + ) diff --git a/packages/syft-verifuscate/src/syft_verifuscate/obfuscator.py b/packages/syft-verifuscate/src/syft_verifuscate/obfuscator.py new file mode 100644 index 00000000000..f81cd39db0f --- /dev/null +++ b/packages/syft-verifuscate/src/syft_verifuscate/obfuscator.py @@ -0,0 +1,199 @@ +"""The display transform (research approach A). + +``obfuscate(source, private, scan)`` returns a copy of ``source`` in which only the lines inside the +private ranges are transformed — identifiers renamed to neutral placeholders, constant values and +einsum-equation strings blanked to ``■``, and comments/docstrings stripped. Every line outside the +private ranges is emitted byte-for-byte, so the data owner can diff it against the original glue. + +It is *display-only*: the obfuscated file is for reading, not running (the real, unobfuscated code is +what runs in the enclave). Renaming is deterministic — same input, same output. +""" + +from __future__ import annotations + +import ast +import io +import keyword +import tokenize + +from .policy import DEFAULT_KEEP, METADATA_ATTRS +from .verifier import FileScan, _dotted, _is_dunder, _normalize_ranges + +_BLANK = "■" # ■ + +# Builtins kept readable (they reveal nothing about the architecture). +_KEEP_BUILTINS = frozenset( + { + "int", + "float", + "bool", + "str", + "bytes", + "len", + "range", + "enumerate", + "zip", + "min", + "max", + "sum", + "abs", + "round", + "all", + "any", + "tuple", + "list", + "dict", + "set", + "sorted", + "reversed", + "isinstance", + "super", + "None", + "True", + "False", + } +) + + +def obfuscate(source: str, private, scan: FileScan) -> str: + ranges = _normalize_ranges(private) + tree = ast.parse(source) + value_map, attr_map = _build_maps(tree, ranges, scan) + + keep_values = ( + DEFAULT_KEEP + | set(scan.bindings) + | set(scan.visible_defs) + | _KEEP_BUILTINS + | set(keyword.kwlist) + | set(getattr(keyword, "softkwlist", [])) + ) + + edits: list[tuple[int, int, int, int, str]] = [] + tokens = list(tokenize.generate_tokens(io.StringIO(source).readline)) + prev_op_dot = False + for tok in tokens: + srow, scol = tok.start + erow, ecol = tok.end + if not _row_in_ranges(srow, ranges): + if tok.type not in ( + tokenize.NL, + tokenize.NEWLINE, + tokenize.INDENT, + tokenize.DEDENT, + ): + prev_op_dot = tok.type == tokenize.OP and tok.string == "." + continue + + if tok.type == tokenize.NAME: + is_attr = prev_op_dot + if is_attr: + new = attr_map.get(tok.string) + else: + new = None if tok.string in keep_values else value_map.get(tok.string) + if new is not None: + edits.append((srow, scol, erow, ecol, new)) + elif tok.type == tokenize.STRING: + edits.append((srow, scol, erow, ecol, f'"{_BLANK}"')) + elif tok.type == tokenize.NUMBER: + edits.append((srow, scol, erow, ecol, _BLANK)) + elif tok.type == tokenize.COMMENT: + edits.append( + (srow, scol, erow, ecol, "") + ) # drop comments (incl. commented-out configs) + + if tok.type not in ( + tokenize.NL, + tokenize.NEWLINE, + tokenize.INDENT, + tokenize.DEDENT, + ): + prev_op_dot = tok.type == tokenize.OP and tok.string == "." + + return _apply_edits(source, edits) + + +# ── build the deterministic rename maps from the AST ───────────────────────────────────── +def _build_maps(tree: ast.Module, ranges, scan: FileScan): + keep_attrs: set[str] = set(METADATA_ATTRS) + mangle_attr_names: set[str] = set() + value_occurrences: list[tuple[tuple[int, int], str]] = [] + private_classes = _names_of(tree, ast.ClassDef, ranges) + private_funcs = _names_of(tree, ast.FunctionDef, ranges) + + for node in ast.walk(tree): + if not _node_in_ranges(node, ranges): + continue + if isinstance(node, ast.Attribute) and not _is_dunder(node.attr): + root = (_dotted(node.value) or "").split(".")[0] + if root in scan.bindings: + keep_attrs.add( + node.attr + ) # public library attr (e.g. jnp.einsum) — stays readable + else: + mangle_attr_names.add(node.attr) + elif isinstance(node, ast.Name): + value_occurrences.append(((node.lineno, node.col_offset), node.id)) + elif isinstance(node, ast.arg): + value_occurrences.append(((node.lineno, node.col_offset), node.arg)) + + # attr placeholders, in sorted name order for determinism + attr_map: dict[str, str] = {} + for i, name in enumerate(sorted(mangle_attr_names - keep_attrs)): + attr_map[name] = f"░a{i}" + + # value placeholders, assigned in source order (first occurrence wins) + keep_values = ( + DEFAULT_KEEP + | set(scan.bindings) + | set(scan.visible_defs) + | _KEEP_BUILTINS + | set(keyword.kwlist) + | set(getattr(keyword, "softkwlist", [])) + ) + value_map: dict[str, str] = {} + counters = {"cls": 0, "fn": 0, "v": 0} + for _pos, name in sorted(value_occurrences): + if name in keep_values or name in value_map: + continue + if name in private_classes: + value_map[name] = f"░Cls{counters['cls']}" + counters["cls"] += 1 + elif name in private_funcs: + value_map[name] = f"░fn{counters['fn']}" + counters["fn"] += 1 + else: + value_map[name] = f"░v{counters['v']}" + counters["v"] += 1 + return value_map, attr_map + + +def _names_of(tree, node_type, ranges) -> set[str]: + return { + n.name + for n in ast.walk(tree) + if isinstance(n, node_type) and _node_in_ranges(n, ranges) + } + + +# ── apply position edits to the source, preserving non-private lines verbatim ───────────── +def _apply_edits(source: str, edits) -> str: + lines = source.splitlines(keepends=True) + # apply bottom-up so earlier edits don't shift later line indices / columns + for srow, scol, erow, ecol, new in sorted(edits, reverse=True): + if srow == erow: + line = lines[srow - 1] + lines[srow - 1] = line[:scol] + new + line[ecol:] + else: + merged = lines[srow - 1][:scol] + new + lines[erow - 1][ecol:] + lines[srow - 1 : erow] = [merged] + return "".join(lines) + + +def _row_in_ranges(row: int, ranges) -> bool: + return any(lo <= row <= hi for lo, hi in ranges) + + +def _node_in_ranges(node: ast.AST, ranges) -> bool: + line = getattr(node, "lineno", None) + return line is not None and _row_in_ranges(line, ranges) diff --git a/packages/syft-verifuscate/src/syft_verifuscate/policy.py b/packages/syft-verifuscate/src/syft_verifuscate/policy.py new file mode 100644 index 00000000000..f63bec9ad6c --- /dev/null +++ b/packages/syft-verifuscate/src/syft_verifuscate/policy.py @@ -0,0 +1,151 @@ +"""Policy: what the hidden region is allowed to call and do. + +Two channels, mirroring the two verification mechanisms (see research approach-B §3.6.5): + +- ``functions`` — dotted paths callable BY NAME (resolved exactly against the import bindings), + e.g. ``jax.*``, ``flax.linen.*``. Checked by glob match, with ``JAX_DENYLIST`` beating the allow. +- ``methods`` — operator *bundles* allowed ON A VALUE, e.g. ``arithmetic``, ``indexing``. These are + language-level operators (``__add__``, ``__getitem__``, …), never named library methods. No named + method may be called on an opaque value at all. +""" + +from __future__ import annotations + +import ast +import fnmatch +from dataclasses import dataclass, field + +# ── Operator bundles: bundle name -> the AST node types it enables on a value ────────────── +# These are generic, type-agnostic-safe operators (not named-method calls), so the format-string +# escape cannot hide among them (research approach-B §3.6.2). +OPERATOR_BUNDLES: dict[str, tuple[type[ast.AST], ...]] = { + "arithmetic": (ast.BinOp, ast.UnaryOp), + "comparison": (ast.Compare, ast.BoolOp), + "indexing": (ast.Subscript, ast.Slice), +} +# The metadata bundle is special: it allows a few pure *metadata reads* on a value (ints/dtype, +# no side effects). Transforms like `.T` are NOT here — they're library-specific and must be wrapped +# (research approach-B §3.6.2/§3.6.3). +METADATA_ATTRS: frozenset[str] = frozenset({"shape", "ndim", "dtype", "size"}) + +ALL_BUNDLES: frozenset[str] = frozenset(OPERATOR_BUNDLES) | {"metadata"} + +# ── Dangerous JAX / serialization surface — denylist BEATS the allow (approach-B §3.2/§3.3) ── +# Host-callback / IO / FFI / serialization escape hatches that can run host code or touch disk. +JAX_DENYLIST: tuple[str, ...] = ( + "jax.experimental.*", + "jax.debug.*", + "jax.pure_callback", + "*.io_callback", + "*.host_callback", + "*.host_callback.*", + "jax.profiler.*", + "jax.monitoring.*", + "jax.distributed.*", + "jax.dlpack.*", + "jax.ffi", + "jax.ffi.*", + "jax.extend.*", + # array <-> file on disk, even though jax.numpy.* is otherwise allowed + "jax.numpy.save", + "jax.numpy.savez", + "jax.numpy.savez_compressed", + "jax.numpy.load", + "jax.numpy.tofile", + "jax.numpy.fromfile", + "jax.numpy.memmap", + "jax.numpy.savetxt", + "jax.numpy.loadtxt", + "jax.numpy.genfromtxt", + "flax.serialization.*", + "flax.training.checkpoints.*", + "orbax.*", +) + +# ── Builtins that are dynamic-escape / IO hatches and may never be called (approach-B §2.2) ── +BANNED_NAMES: frozenset[str] = frozenset( + { + "eval", + "exec", + "compile", + "__import__", + "getattr", + "setattr", + "delattr", + "hasattr", + "vars", + "globals", + "locals", + "dir", + "open", + "input", + "breakpoint", + "memoryview", + } +) + +# Decorators allowed above a def/class in the hidden region (approach-B §3.4 / §3.5.1 #4). +ALLOWED_DECORATORS: frozenset[str] = frozenset( + {"nn.compact", "jax.jit", "jax.named_scope", "flax.linen.compact"} +) + +# The only dunder/hook methods a model class may *define* (approach-B §3.5.1 #6). +ALLOWED_DUNDER_DEFS: frozenset[str] = frozenset({"__call__", "setup", "__post_init__"}) + +# Names always preserved verbatim by the obfuscator and never treated as opaque values. +DEFAULT_KEEP: frozenset[str] = frozenset( + {"self", "cls", "nn", "Module", "setup", "__call__", "__post_init__"} +) + + +@dataclass +class Policy: + """Parsed allow-lists. ``reserved`` is filled in by the verifier from the file's imports.""" + + functions: list[str] = field(default_factory=list) + methods: set[str] = field(default_factory=set) + reserved: set[str] = field(default_factory=set) + + @classmethod + def parse(cls, allow_functions: str = "", allow_methods: str = "") -> "Policy": + functions = _split(allow_functions) + methods = set(_split(allow_methods)) + unknown = methods - ALL_BUNDLES + if unknown: + raise ValueError( + f"unknown method bundle(s): {sorted(unknown)}; allowed: {sorted(ALL_BUNDLES)}" + ) + return cls(functions=functions, methods=methods) + + # ── path matching ────────────────────────────────────────────────────────────────── + def function_allowed(self, dotted: str) -> bool: + """True iff a fully-qualified dotted path is allowed (and not denylisted).""" + if any(fnmatch.fnmatchcase(dotted, pat) for pat in JAX_DENYLIST): + return False + return any(_path_matches(dotted, pat) for pat in self.functions) + + def bundle_enabled(self, name: str) -> bool: + return name in self.methods + + def policy_id(self) -> str: + """A short, stable identifier for the policy (for the certificate).""" + import hashlib + + blob = "|".join(sorted(self.functions)) + "##" + "|".join(sorted(self.methods)) + return hashlib.sha256(blob.encode()).hexdigest()[:16] + + +def _split(spec: str) -> list[str]: + return [part.strip() for part in spec.split(",") if part.strip()] + + +def _path_matches(dotted: str, pattern: str) -> bool: + """Match a dotted path against an allow pattern. + + ``jax.*`` matches ``jax`` and anything beneath it (``jax.numpy.einsum``); an exact pattern + matches only itself. + """ + if pattern.endswith(".*"): + prefix = pattern[:-2] + return dotted == prefix or dotted.startswith(prefix + ".") + return fnmatch.fnmatchcase(dotted, pattern) diff --git a/packages/syft-verifuscate/src/syft_verifuscate/runner.py b/packages/syft-verifuscate/src/syft_verifuscate/runner.py new file mode 100644 index 00000000000..a5eaf99933a --- /dev/null +++ b/packages/syft-verifuscate/src/syft_verifuscate/runner.py @@ -0,0 +1,80 @@ +"""``run`` — orchestrate verify → obfuscate → certificate.""" + +from __future__ import annotations + +import ast +import hashlib +from dataclasses import dataclass, field +from pathlib import Path + +from .errors import PolicyViolation +from .obfuscator import obfuscate +from .policy import Policy +from .verifier import Violation, _normalize_ranges, _scan_file, verify + +__all__ = ["run", "RunResult"] + + +@dataclass +class RunResult: + ok: bool + violations: list[Violation] = field(default_factory=list) + obfuscated_path: str | None = None + certificate: dict | None = None + + +def run( + path: str | Path, + private, + allow_functions: str = "", + allow_methods: str = "", + out: str | Path | None = None, + strict: bool = True, +) -> RunResult: + """Verify the private region, then (on success) write an obfuscated copy. + + Args: + path: the inference source file. + private: list of ``[start, end]`` 1-based inclusive line ranges to hide + verify. + allow_functions: comma-separated dotted-path globs callable by name (e.g. ``"jax.*, flax.linen.*"``). + allow_methods: comma-separated operator bundles allowed on a value + (``arithmetic, indexing, comparison, metadata``). + out: where to write the obfuscated file (default ``.obfuscated.py`` next to the source). + strict: if True (default), raise ``PolicyViolation`` when verification fails; otherwise return + a ``RunResult`` with ``ok=False`` and no output written. + """ + path = Path(path) + source = path.read_text() + policy = Policy.parse(allow_functions, allow_methods) + + result = verify(source, private, policy) + if not result.ok: + if strict: + raise PolicyViolation(result.violations) + return RunResult(ok=False, violations=result.violations) + + scan = _scan_file(ast.parse(source), _normalize_ranges(private)) + obfuscated = obfuscate(source, private, scan) + + out_path = Path(out) if out is not None else path.with_suffix(".obfuscated.py") + out_path.write_text(obfuscated) + + certificate = { + "source_sha256": hashlib.sha256(source.encode()).hexdigest(), + "policy_id": policy.policy_id(), + "verifuscate_version": _version(), + "private_ranges": [list(r) for r in _normalize_ranges(private)], + "n_calls_checked": result.n_calls_checked, + } + return RunResult( + ok=True, + violations=[], + obfuscated_path=str(out_path), + certificate=certificate, + ) + + +def _version() -> str: + from . import __version__ + + return __version__ diff --git a/packages/syft-verifuscate/src/syft_verifuscate/verifier.py b/packages/syft-verifuscate/src/syft_verifuscate/verifier.py new file mode 100644 index 00000000000..06ce3409a12 --- /dev/null +++ b/packages/syft-verifuscate/src/syft_verifuscate/verifier.py @@ -0,0 +1,427 @@ +"""The static checker (research approach B). + +``verify(source, private, policy)`` parses the file, restricts attention to the *private* line ranges +(the hidden model definition), and walks those nodes default-deny: only explicitly-allowed node types, +calls, operators, and attribute reads pass. It never raises on a policy issue — it returns a +``VerifyResult`` with the violations so callers can inspect them. +""" + +from __future__ import annotations + +import ast +from dataclasses import dataclass, field + +from .policy import ( + ALLOWED_DECORATORS, + ALLOWED_DUNDER_DEFS, + BANNED_NAMES, + METADATA_ATTRS, + OPERATOR_BUNDLES, + Policy, +) + +# ── allowed AST node types in the hidden region (approach-B §2.1) ──────────────────────────── +_ALLOWED_NODES: tuple[type[ast.AST], ...] = ( + ast.Module, + ast.Expr, + ast.FunctionDef, + ast.ClassDef, + ast.arguments, + ast.arg, + ast.Return, + ast.Lambda, + ast.Name, + ast.Load, + ast.Store, + ast.Del, + ast.Constant, + ast.Call, + ast.keyword, + ast.Starred, + ast.Attribute, + ast.Subscript, + ast.Slice, + ast.BinOp, + ast.UnaryOp, + ast.BoolOp, + ast.Compare, + ast.List, + ast.Tuple, + ast.Dict, + ast.Set, + ast.ListComp, + ast.SetComp, + ast.DictComp, + ast.GeneratorExp, + ast.comprehension, + ast.If, + ast.For, + ast.While, + ast.Break, + ast.Continue, + ast.Pass, + ast.IfExp, + ast.Assign, + ast.AugAssign, + ast.AnnAssign, + ast.JoinedStr, + ast.FormattedValue, + # operator/cmpop/boolop/unaryop singletons are leaf nodes under the above; always fine. + ast.operator, + ast.cmpop, + ast.boolop, + ast.unaryop, + ast.expr_context, +) + +# Banned statement/expr node types (approach-B §2.2): present => violation. +_BANNED_NODES: tuple[type[ast.AST], ...] = ( + ast.Import, + ast.ImportFrom, + ast.With, + ast.Try, + ast.Raise, + ast.Global, + ast.Nonlocal, + ast.Delete, + ast.Assert, + ast.AsyncFunctionDef, + ast.AsyncFor, + ast.AsyncWith, + ast.Await, + ast.Yield, + ast.YieldFrom, +) + + +@dataclass(frozen=True) +class Violation: + line: int + code: str + message: str + + +@dataclass +class VerifyResult: + ok: bool + violations: list[Violation] = field(default_factory=list) + n_calls_checked: int = 0 + + +@dataclass +class FileScan: + """Names harvested from the whole file, used to classify calls in the hidden region.""" + + bindings: dict[str, str] # alias -> fully-qualified module path (jnp -> jax.numpy) + hidden_defs: set[str] # class/func names defined inside the private region + visible_defs: set[ + str + ] # function names defined in the visible region (the wrappers) + + +def verify(source: str, private, policy: Policy) -> VerifyResult: + ranges = _normalize_ranges(private) + tree = ast.parse(source) + scan = _scan_file(tree, ranges) + policy.reserved = set(scan.bindings) + checker = _Checker(policy, scan, ranges) + checker.visit(tree) + return VerifyResult( + ok=not checker.violations, + violations=checker.violations, + n_calls_checked=checker.n_calls, + ) + + +# ── file scan ──────────────────────────────────────────────────────────────────────────── +def _scan_file(tree: ast.Module, ranges) -> FileScan: + bindings: dict[str, str] = {} + hidden_defs: set[str] = set() + visible_defs: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + bindings[alias.asname or alias.name.split(".")[0]] = alias.name + elif isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + bindings[alias.asname or alias.name] = f"{node.module}.{alias.name}" + elif isinstance(node, (ast.FunctionDef, ast.ClassDef)): + if _in_ranges(node, ranges): + hidden_defs.add(node.name) + elif isinstance(node, ast.FunctionDef): + visible_defs.add(node.name) + return FileScan( + bindings=bindings, hidden_defs=hidden_defs, visible_defs=visible_defs + ) + + +# ── the checker ────────────────────────────────────────────────────────────────────────── +class _Checker: + def __init__(self, policy: Policy, scan: FileScan, ranges): + self.policy = policy + self.scan = scan + self.ranges = ranges + self.violations: list[Violation] = [] + self.n_calls = 0 + self._call_funcs: set[int] = ( + set() + ) # Attribute nodes already judged as a call func + + def add(self, node: ast.AST, code: str, message: str) -> None: + self.violations.append(Violation(getattr(node, "lineno", 0), code, message)) + + def visit(self, node: ast.AST) -> None: + """Walk the tree; enforce only on nodes inside the private ranges, recurse everywhere.""" + if _in_ranges(node, self.ranges): + self._enforce(node) + for child in ast.iter_child_nodes(node): + self.visit(child) + + # — per-node enforcement (recursion is handled by visit) — + def _enforce(self, node: ast.AST) -> None: + if isinstance(node, _BANNED_NODES): + self.add( + node, + "banned-construct", + f"{type(node).__name__} is not allowed in the hidden region", + ) + return + if not isinstance(node, _ALLOWED_NODES): + self.add( + node, + "node-type", + f"{type(node).__name__} is not on the node-type allow-list", + ) + return + + if isinstance(node, ast.FunctionDef): + self._check_def(node) + elif isinstance(node, ast.ClassDef): + self._check_class(node) + elif isinstance(node, ast.Call): + self._check_call(node) + elif isinstance(node, ast.Attribute): + self._check_attribute(node) + elif isinstance(node, (ast.BinOp, ast.UnaryOp)): + self._require_bundle(node, "arithmetic") + elif isinstance(node, (ast.Compare, ast.BoolOp)): + self._require_bundle(node, "comparison") + elif isinstance(node, (ast.Subscript, ast.Slice)): + self._require_bundle(node, "indexing") + elif isinstance(node, (ast.Assign, ast.AugAssign, ast.AnnAssign)): + self._check_assign_targets(node) + elif isinstance(node, ast.For): + self._check_reserved_target(node.target) + elif isinstance(node, ast.comprehension): + self._check_reserved_target(node.target) + elif isinstance(node, ast.arg): + self._check_reserved_name(node, node.arg) + + # — defs / classes — + def _check_def(self, node: ast.FunctionDef) -> None: + self._check_decorators(node) + if _is_dunder(node.name) and node.name not in ALLOWED_DUNDER_DEFS: + self.add( + node, + "dunder-def", + f"defining magic method {node.name!r} is not allowed", + ) + + def _check_class(self, node: ast.ClassDef) -> None: + self._check_decorators(node) + if node.keywords: + self.add( + node, + "class-keyword", + "class keyword arguments (e.g. metaclass=) are not allowed", + ) + for base in node.bases: + dotted = _dotted(base) + ok = (dotted and self._resolved_allowed(dotted)) or ( + isinstance(base, ast.Name) + and base.id in (self.scan.hidden_defs | {"object"}) + ) + if not ok: + self.add( + base, + "class-base", + f"base class {_describe(base)!r} is not allow-listed", + ) + + def _check_decorators(self, node) -> None: + for dec in node.decorator_list: + target = dec.func if isinstance(dec, ast.Call) else dec + dotted = _dotted(target) + resolved = self._resolve(dotted) if dotted else None + if not (resolved in ALLOWED_DECORATORS or dotted in ALLOWED_DECORATORS): + self.add( + dec, + "decorator", + f"decorator {_describe(target)!r} is not allow-listed", + ) + + # — calls — + def _check_call(self, node: ast.Call) -> None: + self.n_calls += 1 + func = node.func + if isinstance(func, ast.Name): + if func.id in BANNED_NAMES: + self.add(node, "banned-call", f"call to {func.id!r} is not allowed") + # Otherwise a bare-name call (local var / hidden or visible def / safe builtin) is allowed; + # nothing dangerous can reach a local name given the other rules. + return + if isinstance(func, ast.Attribute): + self._check_call_attribute(node, func) + return + # func is a Call / Subscript / etc.: calling a *value* (e.g. self.layer[i](...), Block(...)(x)). + # The value's provenance is checked elsewhere; calling it (its __call__) is allowed. + + def _check_call_attribute(self, call: ast.Call, func: ast.Attribute) -> None: + self._call_funcs.add( + id(func) + ) # so _check_attribute doesn't re-flag the same node + dotted = _dotted(func) + if dotted is not None: + root = dotted.split(".")[0] + if root in ("self", "cls"): + return # self.method(...) — receiver type is the module class, not opaque + if root in self.scan.bindings: + if not self._resolved_allowed(dotted): + self.add( + call, + "call-not-allowed", + f"call to {self._resolve(dotted)!r} is not allow-listed", + ) + return + # Attribute on an opaque value: this is a NAMED METHOD ON A VALUE — never allowed (§3.6). + self.add( + call, + "method-on-value", + f"named method {func.attr!r} called on a value whose type is unknown; " + f"route it through a visible wrapper function instead", + ) + + # — attribute reads (not the func of a call) — + def _check_attribute(self, node: ast.Attribute) -> None: + if id(node) in self._call_funcs: + return # already judged as a call's function position by _check_call_attribute + if _is_dunder(node.attr): + self.add( + node, + "dunder-attr", + f"access to dunder attribute {node.attr!r} is not allowed", + ) + return + dotted = _dotted(node) + if dotted is not None: + root = dotted.split(".")[0] + if root in ("self", "cls"): + return + if root in self.scan.bindings: + if not self._resolved_allowed(dotted): + self.add( + node, + "attr-not-allowed", + f"reference to {self._resolve(dotted)!r} is not allow-listed", + ) + return + # Attribute read on an opaque value: allowed only for the metadata bundle (.shape/.ndim/.dtype). + if node.attr in METADATA_ATTRS: + if not self.policy.bundle_enabled("metadata"): + self.add( + node, + "bundle-disabled", + f"attribute read {node.attr!r} needs the 'metadata' bundle", + ) + else: + self.add( + node, + "attr-on-value", + f"attribute {node.attr!r} on a value is not a metadata read; " + f"route it through a visible wrapper function instead", + ) + + # — operators — + def _require_bundle(self, node: ast.AST, bundle: str) -> None: + if not self.policy.bundle_enabled(bundle): + ops = "/".join(t.__name__ for t in OPERATOR_BUNDLES[bundle]) + self.add( + node, "bundle-disabled", f"{ops} needs the {bundle!r} method bundle" + ) + + # — assignment / reserved names — + def _check_assign_targets(self, node) -> None: + targets = node.targets if isinstance(node, ast.Assign) else [node.target] + for t in targets: + self._check_reserved_target(t) + + def _check_reserved_target(self, target: ast.AST) -> None: + for name_node in _iter_names(target): + if isinstance(name_node.ctx, ast.Store): + self._check_reserved_name(name_node, name_node.id) + + def _check_reserved_name(self, node: ast.AST, name: str) -> None: + if name in self.policy.reserved: + self.add( + node, + "reserved-name", + f"{name!r} is a reserved module alias and may not be rebound", + ) + elif name in self.scan.visible_defs: + self.add( + node, + "reserved-name", + f"{name!r} is a visible wrapper name and may not be rebound", + ) + + # — path resolution — + def _resolve(self, dotted: str) -> str: + root, _, rest = dotted.partition(".") + base = self.scan.bindings.get(root, root) + return f"{base}.{rest}" if rest else base + + def _resolved_allowed(self, dotted: str) -> bool: + return self.policy.function_allowed(self._resolve(dotted)) + + +# ── helpers ────────────────────────────────────────────────────────────────────────────── +def _normalize_ranges(private) -> list[tuple[int, int]]: + out = [] + for item in private: + lo, hi = item + out.append((int(lo), int(hi))) + return out + + +def _in_ranges(node: ast.AST, ranges) -> bool: + line = getattr(node, "lineno", None) + if line is None: + return False + return any(lo <= line <= hi for lo, hi in ranges) + + +def _dotted(node: ast.AST) -> str | None: + """Return the dotted path for a pure Name/Attribute chain, else None.""" + parts: list[str] = [] + cur = node + while isinstance(cur, ast.Attribute): + parts.append(cur.attr) + cur = cur.value + if isinstance(cur, ast.Name): + parts.append(cur.id) + return ".".join(reversed(parts)) + return None + + +def _describe(node: ast.AST) -> str: + return _dotted(node) or type(node).__name__ + + +def _is_dunder(name: str) -> bool: + return name.startswith("__") + + +def _iter_names(node: ast.AST): + for n in ast.walk(node): + if isinstance(n, ast.Name): + yield n diff --git a/packages/syft-verifuscate/tests/fixtures/compliant_model.py b/packages/syft-verifuscate/tests/fixtures/compliant_model.py new file mode 100644 index 00000000000..b999f1eb32f --- /dev/null +++ b/packages/syft-verifuscate/tests/fixtures/compliant_model.py @@ -0,0 +1,41 @@ +"""A tiny verifuscate-compliant Flax-style module (green-path test fixture). + +Uses only: allow-listed jax/flax calls by name, self.* calls, operator bundles +(arithmetic/indexing/comparison/metadata), comprehensions, and bare-name calls — and NO named methods +on opaque values, no .T, no .append, no dynamic features. So it passes `verify` cleanly. +""" + +import jax +import jax.numpy as jnp +from flax import linen as nn + +CONFIG = dict(dim=8, layers=2, eps=1e-6) + + +def scale_pattern(n): + base = (1.0,) * n + return base[:n] + + +class RMSNorm(nn.Module): + def setup(self): + self.weight = self.param("w", nn.initializers.ones, (CONFIG["dim"],)) + + def __call__(self, x): + sq = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + return x * jax.lax.rsqrt(sq + CONFIG["eps"]) * (1.0 + self.weight) + + +class Block(nn.Module): + cfg: dict + + def setup(self): + self.norm = RMSNorm() + self.layers = [RMSNorm() for _ in range(self.cfg["layers"])] + + def __call__(self, x): + h = self.norm(x) + out = x + h + if x.shape[-1] > 0: + out = out * 0.5 + return jnp.sum(out, axis=-1) + out[..., 0] diff --git a/packages/syft-verifuscate/tests/test_obfuscate.py b/packages/syft-verifuscate/tests/test_obfuscate.py new file mode 100644 index 00000000000..bf483bc06c1 --- /dev/null +++ b/packages/syft-verifuscate/tests/test_obfuscate.py @@ -0,0 +1,64 @@ +"""Tests for the display transform (approach A), exercised through run().""" + +import shutil +from pathlib import Path + +from syft_verifuscate import run + +FIXTURES = Path(__file__).parent / "fixtures" +ALLOW_FUNCTIONS = "jax.*, flax.linen.*" +ALLOW_METHODS = "arithmetic, indexing, comparison, metadata" + + +def _private_from_config(source: str): + config_line = next( + i for i, ln in enumerate(source.splitlines(), 1) if ln.startswith("CONFIG") + ) + return [[config_line, len(source.splitlines())]], config_line + + +def _obfuscate_fixture(tmp_path: Path): + src_path = tmp_path / "model.py" + shutil.copy(FIXTURES / "compliant_model.py", src_path) + source = src_path.read_text() + private, config_line = _private_from_config(source) + result = run( + src_path, + private=private, + allow_functions=ALLOW_FUNCTIONS, + allow_methods=ALLOW_METHODS, + ) + obf = Path(result.obfuscated_path).read_text() + return source, obf, config_line + + +def test_nonprivate_lines_are_byte_for_byte(tmp_path): + source, obf, config_line = _obfuscate_fixture(tmp_path) + src_lines = source.splitlines() + obf_lines = obf.splitlines() + assert len(src_lines) == len(obf_lines) + for i in range(config_line - 1): # lines before CONFIG are non-private + assert src_lines[i] == obf_lines[i], f"line {i + 1} changed" + # the import lines specifically + assert "import jax.numpy as jnp" in obf + + +def test_private_region_is_mangled_and_blanked(tmp_path): + source, obf, config_line = _obfuscate_fixture(tmp_path) + private_text = "\n".join(obf.splitlines()[config_line - 1 :]) + assert "░" in private_text # identifiers renamed + assert "RMSNorm" not in private_text # a private class name is gone + assert "■" in private_text # numeric/string constants blanked + assert "dim=8" not in private_text # the architecture dim is hidden + # public library names stay readable + assert "jnp" in private_text and "nn.Module" in obf + + +def test_obfuscation_is_deterministic(tmp_path): + dir_a = tmp_path / "a" + dir_b = tmp_path / "b" + dir_a.mkdir() + dir_b.mkdir() + _, obf1, _ = _obfuscate_fixture(dir_a) + _, obf2, _ = _obfuscate_fixture(dir_b) + assert obf1 == obf2 diff --git a/packages/syft-verifuscate/tests/test_run.py b/packages/syft-verifuscate/tests/test_run.py new file mode 100644 index 00000000000..05969cccb5c --- /dev/null +++ b/packages/syft-verifuscate/tests/test_run.py @@ -0,0 +1,67 @@ +"""End-to-end tests for run().""" + +import shutil +from pathlib import Path + +import pytest + +from syft_verifuscate import PolicyViolation, run + +FIXTURES = Path(__file__).parent / "fixtures" +ALLOW_FUNCTIONS = "jax.*, flax.linen.*" +ALLOW_METHODS = "arithmetic, indexing, comparison, metadata" + + +def _private(source: str): + config_line = next( + i for i, ln in enumerate(source.splitlines(), 1) if ln.startswith("CONFIG") + ) + return [[config_line, len(source.splitlines())]] + + +def test_run_success_writes_obfuscated_and_certificate(tmp_path): + src = tmp_path / "model.py" + shutil.copy(FIXTURES / "compliant_model.py", src) + result = run( + src, + private=_private(src.read_text()), + allow_functions=ALLOW_FUNCTIONS, + allow_methods=ALLOW_METHODS, + ) + assert result.ok + out = Path(result.obfuscated_path) + assert out.exists() and out.name == "model.obfuscated.py" + assert result.certificate["source_sha256"] + assert result.certificate["policy_id"] + assert result.certificate["verifuscate_version"] + assert result.certificate["n_calls_checked"] > 0 + + +def test_run_strict_raises_and_writes_nothing(tmp_path): + src = tmp_path / "bad.py" + src.write_text("CONFIG = dict(dim=8)\nimport os\nleak = os.getcwd()\n") + with pytest.raises(PolicyViolation) as exc: + run( + src, + private=[[1, 3]], + allow_functions=ALLOW_FUNCTIONS, + allow_methods=ALLOW_METHODS, + ) + assert exc.value.violations + assert not (tmp_path / "bad.obfuscated.py").exists() + + +def test_run_nonstrict_returns_violations(tmp_path): + src = tmp_path / "bad.py" + src.write_text("CONFIG = dict(dim=8)\nleak = x.reshape(1)\n") + result = run( + src, + private=[[1, 2]], + allow_functions=ALLOW_FUNCTIONS, + allow_methods=ALLOW_METHODS, + strict=False, + ) + assert not result.ok + assert any(v.code == "method-on-value" for v in result.violations) + assert result.obfuscated_path is None + assert not (tmp_path / "bad.obfuscated.py").exists() diff --git a/packages/syft-verifuscate/tests/test_verify.py b/packages/syft-verifuscate/tests/test_verify.py new file mode 100644 index 00000000000..9c505a39e14 --- /dev/null +++ b/packages/syft-verifuscate/tests/test_verify.py @@ -0,0 +1,90 @@ +"""Tests for the static checker (approach B).""" + +from pathlib import Path + +import pytest + +from syft_verifuscate import Policy, verify + +FIXTURES = Path(__file__).parent / "fixtures" +REPO_ROOT = Path(__file__).parents[3] + +ALLOW_FUNCTIONS = "jax.*, flax.linen.*" +ALLOW_METHODS = "arithmetic, indexing, comparison, metadata" + + +def _policy(): + return Policy.parse(ALLOW_FUNCTIONS, ALLOW_METHODS) + + +def _verify_all(source: str): + n = len(source.splitlines()) + return verify(source, [[1, n]], _policy()) + + +def test_compliant_fixture_passes(): + source = (FIXTURES / "compliant_model.py").read_text() + # mark the model definition (everything from CONFIG onward) as private + config_line = next( + i for i, ln in enumerate(source.splitlines(), 1) if ln.startswith("CONFIG") + ) + result = verify(source, [[config_line, len(source.splitlines())]], _policy()) + assert result.ok, [f"L{v.line} {v.code}: {v.message}" for v in result.violations] + assert result.n_calls_checked > 0 + + +@pytest.mark.parametrize( + "code, snippet", + [ + ("banned-construct", "import os\n"), + ("banned-call", "y = eval('1 + 1')\n"), + ("banned-call", "z = getattr(obj, name)\n"), + ("method-on-value", "a = x.reshape(8, -1)\n"), + ("method-on-value", "b = '{0.__class__}'.format(payload)\n"), + ("decorator", "@evil\ndef f():\n return 1\n"), + ("dunder-attr", "c = obj.__class__\n"), + ], +) +def test_rejections(code, snippet): + result = _verify_all(snippet) + assert not result.ok + assert code in {v.code for v in result.violations}, [ + (v.code, v.message) for v in result.violations + ] + + +def test_reserved_module_alias_cannot_be_rebound(): + source = "import jax.numpy as jnp\njnp = make_evil()\n" + # only the rebind line is private (the import is visible glue) + result = verify(source, [[2, 2]], _policy()) + assert "reserved-name" in {v.code for v in result.violations} + + +def test_jax_denylist_beats_allow(): + # io_callback is under an allowed module (jax.*) but on the denylist + source = "import jax\nq = jax.experimental.io_callback(send, x)\n" + result = verify(source, [[2, 2]], _policy()) + assert "call-not-allowed" in {v.code for v in result.violations} + + +def test_operator_bundle_must_be_enabled(): + source = "r = a + b\n" + # arithmetic NOT enabled -> the BinOp is rejected + policy = Policy.parse(ALLOW_FUNCTIONS, "indexing") + result = verify(source, [[1, 1]], policy) + assert "bundle-disabled" in {v.code for v in result.violations} + + +def test_real_gemma_flags_named_methods_on_values(): + """Matches research approach-B §3.6.3: the real file still has method/attr-on-value spots.""" + gemma = REPO_ROOT / "koen" / "gemma_inference.py" + if not gemma.exists(): + pytest.skip("gemma_inference.py not present") + source = gemma.read_text() + result = verify(source, [[22, 280]], _policy()) + assert not result.ok + messages = " | ".join(v.message for v in result.violations) + # the documented spots: `module.variable(...).value` in _get, and `embed_table.T` + assert "'variable'" in messages + assert "'T'" in messages + assert {"method-on-value", "attr-on-value"} <= {v.code for v in result.violations} diff --git a/pyproject.toml b/pyproject.toml index a48f621b002..331f0b85d31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ members = ["packages/*"] "syft-permissions" = { workspace = true } "syft-perms" = { workspace = true } "syft-enclave" = { workspace = true } +"syft-verifuscate" = { workspace = true } [project.urls] Homepage = "https://github.com/OpenMined/syft-client" @@ -176,5 +177,6 @@ test = [ "isort>=5.12.0", "coverage>=7.0.0", "syft-enclave", + "syft-verifuscate", "pyarrow", # only used by parquet test fixtures (tests/unit/utils.py) ] diff --git a/uv.lock b/uv.lock index 5b284bdf009..7ed515e8d74 100644 --- a/uv.lock +++ b/uv.lock @@ -27,6 +27,7 @@ members = [ "syft-notebook-ui", "syft-permissions", "syft-perms", + "syft-verifuscate", ] [[package]] @@ -3709,6 +3710,7 @@ test = [ { name = "pytest-rerunfailures" }, { name = "pytest-xdist" }, { name = "syft-enclave" }, + { name = "syft-verifuscate" }, ] [package.metadata] @@ -3756,6 +3758,7 @@ test = [ { name = "pytest-rerunfailures", specifier = ">=14.0" }, { name = "pytest-xdist", specifier = ">=3.0.0" }, { name = "syft-enclave", editable = "packages/syft-enclave" }, + { name = "syft-verifuscate", editable = "packages/syft-verifuscate" }, ] [[package]] @@ -3889,6 +3892,11 @@ dependencies = [ [package.metadata] requires-dist = [{ name = "syft-permissions", editable = "packages/syft-permissions" }] +[[package]] +name = "syft-verifuscate" +version = "0.1.0" +source = { editable = "packages/syft-verifuscate" } + [[package]] name = "terminado" version = "0.18.1" From b633fb49654dddae1810534f9fa0326ffe1d5974 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Mon, 15 Jun 2026 16:43:28 +0200 Subject: [PATCH 2/3] - --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b55baaa5673..f6f99f17f64 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ # python stuff *__pycache__* *.pyc +*.egg-info/ *.swp *.swo *.ipynb_checkpoints* From b06a0754a54cdf0361242beabd0c062b197ce822 Mon Sep 17 00:00:00 2001 From: Koen van der Veen Date: Mon, 15 Jun 2026 16:57:57 +0200 Subject: [PATCH 3/3] example --- .../examples/gemma_inference.certificate.json | 20 + .../examples/gemma_inference.obfuscated.py | 414 ++++++++++++++++++ .../examples/gemma_inference.py | 414 ++++++++++++++++++ 3 files changed, 848 insertions(+) create mode 100644 packages/syft-verifuscate/examples/gemma_inference.certificate.json create mode 100644 packages/syft-verifuscate/examples/gemma_inference.obfuscated.py create mode 100644 packages/syft-verifuscate/examples/gemma_inference.py diff --git a/packages/syft-verifuscate/examples/gemma_inference.certificate.json b/packages/syft-verifuscate/examples/gemma_inference.certificate.json new file mode 100644 index 00000000000..7d597973666 --- /dev/null +++ b/packages/syft-verifuscate/examples/gemma_inference.certificate.json @@ -0,0 +1,20 @@ +{ + "source_sha256": "6ce403505cf035b65be57fdebc310fba6357ba930775adf40355f1948be94528", + "policy_id": "ee73d8485185eed3", + "verifuscate_version": "0.1.0", + "private_ranges": [ + [ + 22, + 93 + ], + [ + 99, + 130 + ], + [ + 151, + 292 + ] + ], + "n_calls_checked": 73 +} \ No newline at end of file diff --git a/packages/syft-verifuscate/examples/gemma_inference.obfuscated.py b/packages/syft-verifuscate/examples/gemma_inference.obfuscated.py new file mode 100644 index 00000000000..ad99a7d1903 --- /dev/null +++ b/packages/syft-verifuscate/examples/gemma_inference.obfuscated.py @@ -0,0 +1,414 @@ +"""Gemma 3 IT — Flax Inference Module + +Standalone inference engine for Gemma 3 instruction-tuned models using Flax. +Module hierarchy mirrors google-deepmind/gemma so checkpoint param names map +1:1 to Flax sub-module names. + +Supports: 270m, 1b, 4b, 12b, 27b. + +Adapted from: github.com/anthropics/beach-notebooks/gemma/local_PoC.py +""" + +import os +import time + +import jax +import jax.numpy as jnp +import orbax.checkpoint as ocp +import sentencepiece as spm +from flax import linen as nn + + + + +░v0 = dict( + ░v5=■, + embed_dim=■, + hidden_dim=■, + num_heads=■, + num_kv_heads=■, + head_dim=■, + ░v19=■, + kaggle_handle="■", + ckpt_subdir="■", +) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +░v1 = ■ +░v2 = ■ +░v3 = ■ +░v4 = -■ + + +def ░fn1(░v5): + ░v6 = ("■",) * ■ + ("■",) + return (░v6 * ((░v5 + ■) // ■))[:░v5] + + +# ── Standalone helpers ──────────────────────────────────────────────────── + + +def ░fn0(░v7, ░v8, ░v9): + "■" + ░v10 = ░v7.shape[-■] // ■ + ░v11 = (■ / ░v7.shape[-■]) * jnp.arange(░v10, dtype=jnp.float32) + ░v12 = ░v9**░v11 + ░v13 = ░v8[..., None, None] / ░v12 + ░v14, ░v15 = jnp.sin(░v13), jnp.cos(░v13) + ░v16, ░v17 = ░v7[..., :░v10], ░v7[..., ░v10:] + return jnp.concatenate([░v16 * ░v15 - ░v17 * ░v14, ░v17 * ░v15 + ░v16 * ░v14], axis=-■) + + +def ░fn2(░v18, ░v19): + "■" + ░v20 = jnp.tril(jnp.ones((░v18, ░v18), dtype=jnp.bool_)) + ░v21 = jnp.triu( + jnp.ones((░v18, ░v18), dtype=jnp.bool_), ░v32=-(░v19 - ■) + ) + return { + "■": (░v20 & ░v21)[None, None], + "■": ░v20[None, None], + } + + +def ░fn3(░v22, ░v19): + "■" + ░v23 = ░v22 + ■ + ░v8 = jnp.arange(░v23) + return { + "■": (░v8 >= ░v22 - ░v19 + ■)[None, None, None, :], + "■": jnp.ones((■, ■, ■, ░v23), dtype=jnp.bool_), + } + + +# ── Flax modules ─────────────────────────────────────────────────────────── + + +def _get(module, name): + """Read a pre-loaded param without shape checking.""" + return module.variable("params", name, lambda: None).value + + +def transpose(a): + """Visible wrapper: matrix transpose (a value transform, not a metadata read).""" + return a.T + + +def append_to(lst, item): + """Visible wrapper: append to a Python list (a named method on a value).""" + lst.append(item) + return lst + + +class ░Cls0(nn.Module): + def setup(self): + self.░a20 = _get(self, "■") + + def __call__(self, ░v24, ░v7): + return jnp.einsum(░v24, ░v7, self.░a20) + + +class ░Cls1(nn.Module): + def setup(self): + self.░a19 = _get(self, "■") + + def __call__(self, ░v7): + ░v25 = jnp.mean(jnp.square(░v7), axis=-■, keepdims=True) + return ░v7 * jax.lax.rsqrt(░v25 + ■) * (■ + self.░a19) + + +class ░Cls2(nn.Module): + ░v26: dict + + def setup(self): + self.░a18 = ░Cls0() + self.░a10 = ░Cls0() + self.░a1 = ░Cls1() + self.░a0 = ░Cls1() + self.░a4 = ░Cls0() + + def __call__(self, ░v7, ░v8, ░v27, ░v28, ░v29=None): + ░v30 = self.░a18("■", ░v7) + ░v31 = self.░a10("■", ░v7) + ░v32, ░v33 = ░v31[■], ░v31[■] + + ░v30 = self.░a1(░v30) + ░v32 = self.░a0(░v32) + + ░v34 = ░v2 if ░v28 == "■" else ░v3 + ░v30 = ░fn0(░v30, ░v8, ░v34) + ░v32 = ░fn0(░v32, ░v8, ░v34) + + if ░v29 is not None: + ░v35, ░v36 = ░v29 + ░v32 = jnp.concatenate([░v35, ░v32], axis=■) + ░v33 = jnp.concatenate([░v36, ░v33], axis=■) + ░v37 = (░v32, ░v33) + + ░v30 = ░v30 * (self.░a5["■"] ** -■) + + ░v38 = self.░a5["■"] // self.░a5["■"] + ░v39 = jnp.repeat(░v32, ░v38, axis=■) + ░v40 = jnp.repeat(░v33, ░v38, axis=■) + + ░v41 = jnp.einsum("■", ░v30, ░v39) + ░v41 = jnp.where(░v27, ░v41, ░v4) + ░v42 = jax.nn.softmax(░v41, axis=-■) + + ░v43 = jnp.einsum("■", ░v42, ░v40) + return self.░a4("■", ░v43), ░v37 + + +class ░Cls3(nn.Module): + def setup(self): + self.░a8 = ░Cls0() + self.░a12 = ░Cls0() + + def __call__(self, ░v7): + ░v44 = self.░a8("■", ░v7) + ░v45 = jax.nn.gelu(░v44[:, :, ■, :]) * ░v44[:, :, ■, :] + return self.░a12("■", ░v45) + + +class ░Cls5(nn.Module): + ░v26: dict + ░v28: str = "■" + + def setup(self): + self.░a16 = ░Cls1() + self.░a2 = ░Cls2(░v26=self.░a5) + self.░a14 = ░Cls1() + self.░a17 = ░Cls1() + self.░a13 = ░Cls3() + self.░a15 = ░Cls1() + + def __call__(self, ░v7, ░v8, ░v27, ░v29=None): + ░v45 = self.░a16(░v7) + ░v45, ░v37 = self.░a2(░v45, ░v8, ░v27, self.░a3, ░v29) + ░v45 = self.░a14(░v45) + ░v7 = ░v7 + ░v45 + ░v45 = self.░a17(░v7) + ░v45 = self.░a13(░v45) + ░v45 = self.░a15(░v45) + return ░v7 + ░v45, ░v37 + + +class ░Cls4(nn.Module): + ░v26: dict + + def setup(self): + self.░a9 = _get(self, "■") + + def __call__(self, ░v46): + ░v47 = self.░a9 + return ░v47[░v46] * jnp.sqrt(float(self.░a5["■"])), ░v47 + + +class Transformer(nn.Module): + ░v26: dict + + def setup(self): + ░v5 = self.░a5["■"] + ░v48 = ░fn1(░v5) + self.░a6 = ░Cls4(░v26=self.░a5) + self.░a11 = [ + ░Cls5(░v26=self.░a5, ░v28=░v48[░v49]) for ░v49 in range(░v5) + ] + self.░a7 = ░Cls1() + + def __call__(self, ░v50, ░v29=None): + ░v19 = self.░a5["■"] + ░v5 = self.░a5["■"] + ░v48 = ░fn1(░v5) + + ░v7, ░v51 = self.░a6(░v50) + + if ░v29 is None: + ░v18 = ░v50.shape[■] + ░v8 = jnp.arange(░v18)[None, :] + ░v52 = ░fn2(░v18, ░v19) + else: + ░v53 = ░v29[■][■].shape[■] + ░v8 = jnp.array([[░v53]]) + ░v52 = ░fn3(░v53, ░v19) + + ░v37 = [] + for ░v49 in range(░v5): + ░v54 = ░v29[░v49] if ░v29 is not None else None + ░v55 = self.░a11[░v49] + ░v7, ░v56 = ░v55(░v7, ░v8, ░v52[░v48[░v49]], ░v54) + ░v37 = append_to(░v37, ░v56) + + ░v7 = self.░a7(░v7) + ░v41 = ░v7 @ transpose(░v51) + return ░v41, ░v37 + + +# ── Weight loading ───────────────────────────────────────────────────────── + + +def nestify(flat): + """Convert Orbax flat dict to nested dict for Flax.""" + nested = {} + for flat_key, param_dict in flat.items(): + parts = flat_key.split("/") + d = nested + for part in parts[:-1]: + d = d.setdefault(part, {}) + d[parts[-1]] = param_dict + return nested + + +def load_params(weights_dir, cfg): + """Load Orbax checkpoint and return Flax-compatible params dict.""" + ckpt_path = os.path.join(weights_dir, cfg["ckpt_subdir"]) + raw = ocp.PyTreeCheckpointer().restore(ckpt_path) + return {"params": nestify(raw)["transformer"]} + + +# ── Setup (convenience entry point) ─────────────────────────────────────── + + +def setup(weights_dir): + """Configure model, load weights and tokenizer. + + Returns (model, tokenizer, params). + """ + params = load_params(weights_dir, CONFIG) + model = Transformer(cfg=CONFIG) + sp = load_tokenizer(weights_dir) + return model, sp, params + + +# ── Tokenizer + generation ───────────────────────────────────────────────── + + +def load_tokenizer(weights_dir): + """Load SentencePiece tokenizer from weights directory.""" + sp = spm.SentencePieceProcessor() + sp.Load(os.path.join(weights_dir, "tokenizer.model")) + return sp + + +def format_chat(prompt): + """Wrap prompt in Gemma's chat template.""" + return f"user\n{prompt}\nmodel\n" + + +def sample_token(logits, temperature=0.8, top_k=40): + """Temperature-scaled top-k sampling. Greedy when temperature=0.""" + if temperature == 0: + return int(jnp.argmax(logits)) + logits = logits / temperature + top_k_logits, top_k_ids = jax.lax.top_k(logits, top_k) + probs = jax.nn.softmax(top_k_logits) + idx = jax.random.categorical( + jax.random.PRNGKey(int(jnp.sum(logits) * 1e6) % 2**31), + jnp.log(probs), + ) + return int(top_k_ids[idx]) + + +def generate(model, params, sp, prompt, max_new_tokens=200, temperature=0.8, top_k=40): + """Autoregressive generation with KV cache and chat template. + + Returns (response_text, stats_dict). + """ + chat_input = format_chat(prompt) + token_ids = [sp.bos_id()] + sp.EncodeAsIds(chat_input) + prompt_tokens = jnp.array([token_ids], dtype=jnp.int32) + prompt_text = sp.Decode(token_ids) + generated_ids = list(token_ids) + + t_prefill = time.time() + logits, cache = model.apply(params, prompt_tokens) + ttft = time.time() - t_prefill + + t_decode = time.time() + decode_tokens = 0 + + for _ in range(max_new_tokens): + next_id = sample_token(logits[0, -1], temperature, top_k) + if next_id == sp.eos_id(): + break + + sp.Decode(generated_ids) + generated_ids.append(next_id) + new_text = sp.Decode(generated_ids) + + response_so_far = new_text[len(prompt_text) :] + if "" in response_so_far: + break + + decode_tokens += 1 + logits, cache = model.apply( + params, jnp.array([[next_id]], dtype=jnp.int32), cache=cache + ) + + decode_elapsed = time.time() - t_decode + decode_tps = decode_tokens / decode_elapsed if decode_elapsed > 0 else 0 + + full = sp.Decode(generated_ids) + response_start = full.find("model\n") + if response_start != -1: + response = full[response_start + len("model\n") :] + response = response.replace("", "").strip() + else: + response = full + + stats = { + "ttft": ttft, + "decode_tps": decode_tps, + "decode_tokens": decode_tokens, + "decode_elapsed": decode_elapsed, + "prompt_tokens": len(token_ids), + } + return response, stats diff --git a/packages/syft-verifuscate/examples/gemma_inference.py b/packages/syft-verifuscate/examples/gemma_inference.py new file mode 100644 index 00000000000..f451dca40c1 --- /dev/null +++ b/packages/syft-verifuscate/examples/gemma_inference.py @@ -0,0 +1,414 @@ +"""Gemma 3 IT — Flax Inference Module + +Standalone inference engine for Gemma 3 instruction-tuned models using Flax. +Module hierarchy mirrors google-deepmind/gemma so checkpoint param names map +1:1 to Flax sub-module names. + +Supports: 270m, 1b, 4b, 12b, 27b. + +Adapted from: github.com/anthropics/beach-notebooks/gemma/local_PoC.py +""" + +import os +import time + +import jax +import jax.numpy as jnp +import orbax.checkpoint as ocp +import sentencepiece as spm +from flax import linen as nn + + +# ── Model config ───────────────────────────────────────────────────────────── +# Active model config. Comment/uncomment to switch sizes. +CONFIG = dict( + num_layers=18, + embed_dim=640, + hidden_dim=2048, + num_heads=4, + num_kv_heads=1, + head_dim=256, + sliding_window=512, + kaggle_handle="google/gemma-3/flax/gemma-3-270m-it", + ckpt_subdir="gemma-3-270m-it", +) + +# CONFIG = dict( # 1b +# num_layers=26, +# embed_dim=1152, +# hidden_dim=6912, +# num_heads=4, +# num_kv_heads=1, +# head_dim=256, +# sliding_window=512, +# kaggle_handle="google/gemma-3/flax/gemma3-1b-it", +# ckpt_subdir="gemma3-1b-it", +# ) + +# CONFIG = dict( # 4b +# num_layers=34, +# embed_dim=2560, +# hidden_dim=10240, +# num_heads=8, +# num_kv_heads=4, +# head_dim=256, +# sliding_window=1024, +# kaggle_handle="google/gemma-3/flax/gemma3-4b-it", +# ckpt_subdir="gemma3-4b-it", +# ) + +# CONFIG = dict( # 12b +# num_layers=48, +# embed_dim=3840, +# hidden_dim=15360, +# num_heads=16, +# num_kv_heads=8, +# head_dim=256, +# sliding_window=1024, +# kaggle_handle="google/gemma-3/flax/gemma3-12b-it", +# ckpt_subdir="gemma3-12b-it", +# ) + +# CONFIG = dict( # 27b +# num_layers=62, +# embed_dim=5376, +# hidden_dim=21504, +# num_heads=32, +# num_kv_heads=16, +# head_dim=128, +# sliding_window=1024, +# kaggle_handle="google/gemma-3/flax/gemma3-27b-it", +# ckpt_subdir="gemma3-27b-it", +# ) + +# ── Shared constants (identical across all Gemma 3 sizes) ───────────────── +VOCAB_SIZE = 262144 +LOCAL_ROPE_BASE = 10_000 +GLOBAL_ROPE_BASE = 1_000_000 +K_MASK = -2.3819763e38 # Google's masking constant (≈ float32 -inf) + + +def _attn_types(num_layers): + pattern = ("local",) * 5 + ("global",) + return (pattern * ((num_layers + 5) // 6))[:num_layers] + + +# ── Standalone helpers ──────────────────────────────────────────────────── + + +def apply_rope(x, positions, base_freq): + """Rotary position embeddings (split-half rotation).""" + half = x.shape[-1] // 2 + freq_exp = (2.0 / x.shape[-1]) * jnp.arange(half, dtype=jnp.float32) + timescale = base_freq**freq_exp + angles = positions[..., None, None] / timescale + sin, cos = jnp.sin(angles), jnp.cos(angles) + x1, x2 = x[..., :half], x[..., half:] + return jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + + +def make_masks(seq_len, sliding_window): + """Causal masks — local layers also clip to a sliding window.""" + causal = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) + window = jnp.triu( + jnp.ones((seq_len, seq_len), dtype=jnp.bool_), k=-(sliding_window - 1) + ) + return { + "local": (causal & window)[None, None], + "global": causal[None, None], + } + + +def make_decode_masks(pos, sliding_window): + """Masks for single-token decode.""" + total_len = pos + 1 + positions = jnp.arange(total_len) + return { + "local": (positions >= pos - sliding_window + 1)[None, None, None, :], + "global": jnp.ones((1, 1, 1, total_len), dtype=jnp.bool_), + } + + +# ── Flax modules ─────────────────────────────────────────────────────────── + + +def _get(module, name): + """Read a pre-loaded param without shape checking.""" + return module.variable("params", name, lambda: None).value + + +def transpose(a): + """Visible wrapper: matrix transpose (a value transform, not a metadata read).""" + return a.T + + +def append_to(lst, item): + """Visible wrapper: append to a Python list (a named method on a value).""" + lst.append(item) + return lst + + +class Einsum(nn.Module): + def setup(self): + self.w = _get(self, "w") + + def __call__(self, equation, x): + return jnp.einsum(equation, x, self.w) + + +class RMSNorm(nn.Module): + def setup(self): + self.scale = _get(self, "scale") + + def __call__(self, x): + var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + return x * jax.lax.rsqrt(var + 1e-6) * (1 + self.scale) + + +class Attention(nn.Module): + cfg: dict + + def setup(self): + self.q_einsum = Einsum() + self.kv_einsum = Einsum() + self._query_norm = RMSNorm() + self._key_norm = RMSNorm() + self.attn_vec_einsum = Einsum() + + def __call__(self, x, positions, mask, attn_type, cache=None): + q = self.q_einsum("bsd,ndh->bsnh", x) + kv = self.kv_einsum("bsd,ckdh->cbskh", x) + k, v = kv[0], kv[1] + + q = self._query_norm(q) + k = self._key_norm(k) + + base = LOCAL_ROPE_BASE if attn_type == "local" else GLOBAL_ROPE_BASE + q = apply_rope(q, positions, base) + k = apply_rope(k, positions, base) + + if cache is not None: + cached_k, cached_v = cache + k = jnp.concatenate([cached_k, k], axis=1) + v = jnp.concatenate([cached_v, v], axis=1) + new_cache = (k, v) + + q = q * (self.cfg["head_dim"] ** -0.5) + + repeats = self.cfg["num_heads"] // self.cfg["num_kv_heads"] + k_exp = jnp.repeat(k, repeats, axis=2) + v_exp = jnp.repeat(v, repeats, axis=2) + + logits = jnp.einsum("bsnh,btnh->bnst", q, k_exp) + logits = jnp.where(mask, logits, K_MASK) + weights = jax.nn.softmax(logits, axis=-1) + + out = jnp.einsum("bnst,btnh->bsnh", weights, v_exp) + return self.attn_vec_einsum("bsnh,nhd->bsd", out), new_cache + + +class FeedForward(nn.Module): + def setup(self): + self.gating_einsum = Einsum() + self.linear = Einsum() + + def __call__(self, x): + gate = self.gating_einsum("bsf,nhf->bsnh", x) + h = jax.nn.gelu(gate[:, :, 0, :]) * gate[:, :, 1, :] + return self.linear("bsh,hf->bsf", h) + + +class Block(nn.Module): + cfg: dict + attn_type: str = "local" + + def setup(self): + self.pre_attention_norm = RMSNorm() + self.attn = Attention(cfg=self.cfg) + self.post_attention_norm = RMSNorm() + self.pre_ffw_norm = RMSNorm() + self.mlp = FeedForward() + self.post_ffw_norm = RMSNorm() + + def __call__(self, x, positions, mask, cache=None): + h = self.pre_attention_norm(x) + h, new_cache = self.attn(h, positions, mask, self.attn_type, cache) + h = self.post_attention_norm(h) + x = x + h + h = self.pre_ffw_norm(x) + h = self.mlp(h) + h = self.post_ffw_norm(h) + return x + h, new_cache + + +class Embedder(nn.Module): + cfg: dict + + def setup(self): + self.input_embedding = _get(self, "input_embedding") + + def __call__(self, token_ids): + table = self.input_embedding + return table[token_ids] * jnp.sqrt(float(self.cfg["embed_dim"])), table + + +class Transformer(nn.Module): + cfg: dict + + def setup(self): + num_layers = self.cfg["num_layers"] + attn_types = _attn_types(num_layers) + self.embedder = Embedder(cfg=self.cfg) + self.layer = [ + Block(cfg=self.cfg, attn_type=attn_types[i]) for i in range(num_layers) + ] + self.final_norm = RMSNorm() + + def __call__(self, tokens, cache=None): + sliding_window = self.cfg["sliding_window"] + num_layers = self.cfg["num_layers"] + attn_types = _attn_types(num_layers) + + x, embed_table = self.embedder(tokens) + + if cache is None: + seq_len = tokens.shape[1] + positions = jnp.arange(seq_len)[None, :] + masks = make_masks(seq_len, sliding_window) + else: + cache_len = cache[0][0].shape[1] + positions = jnp.array([[cache_len]]) + masks = make_decode_masks(cache_len, sliding_window) + + new_cache = [] + for i in range(num_layers): + layer_cache = cache[i] if cache is not None else None + block = self.layer[i] + x, layer_new_cache = block(x, positions, masks[attn_types[i]], layer_cache) + new_cache = append_to(new_cache, layer_new_cache) + + x = self.final_norm(x) + logits = x @ transpose(embed_table) + return logits, new_cache + + +# ── Weight loading ───────────────────────────────────────────────────────── + + +def nestify(flat): + """Convert Orbax flat dict to nested dict for Flax.""" + nested = {} + for flat_key, param_dict in flat.items(): + parts = flat_key.split("/") + d = nested + for part in parts[:-1]: + d = d.setdefault(part, {}) + d[parts[-1]] = param_dict + return nested + + +def load_params(weights_dir, cfg): + """Load Orbax checkpoint and return Flax-compatible params dict.""" + ckpt_path = os.path.join(weights_dir, cfg["ckpt_subdir"]) + raw = ocp.PyTreeCheckpointer().restore(ckpt_path) + return {"params": nestify(raw)["transformer"]} + + +# ── Setup (convenience entry point) ─────────────────────────────────────── + + +def setup(weights_dir): + """Configure model, load weights and tokenizer. + + Returns (model, tokenizer, params). + """ + params = load_params(weights_dir, CONFIG) + model = Transformer(cfg=CONFIG) + sp = load_tokenizer(weights_dir) + return model, sp, params + + +# ── Tokenizer + generation ───────────────────────────────────────────────── + + +def load_tokenizer(weights_dir): + """Load SentencePiece tokenizer from weights directory.""" + sp = spm.SentencePieceProcessor() + sp.Load(os.path.join(weights_dir, "tokenizer.model")) + return sp + + +def format_chat(prompt): + """Wrap prompt in Gemma's chat template.""" + return f"user\n{prompt}\nmodel\n" + + +def sample_token(logits, temperature=0.8, top_k=40): + """Temperature-scaled top-k sampling. Greedy when temperature=0.""" + if temperature == 0: + return int(jnp.argmax(logits)) + logits = logits / temperature + top_k_logits, top_k_ids = jax.lax.top_k(logits, top_k) + probs = jax.nn.softmax(top_k_logits) + idx = jax.random.categorical( + jax.random.PRNGKey(int(jnp.sum(logits) * 1e6) % 2**31), + jnp.log(probs), + ) + return int(top_k_ids[idx]) + + +def generate(model, params, sp, prompt, max_new_tokens=200, temperature=0.8, top_k=40): + """Autoregressive generation with KV cache and chat template. + + Returns (response_text, stats_dict). + """ + chat_input = format_chat(prompt) + token_ids = [sp.bos_id()] + sp.EncodeAsIds(chat_input) + prompt_tokens = jnp.array([token_ids], dtype=jnp.int32) + prompt_text = sp.Decode(token_ids) + generated_ids = list(token_ids) + + t_prefill = time.time() + logits, cache = model.apply(params, prompt_tokens) + ttft = time.time() - t_prefill + + t_decode = time.time() + decode_tokens = 0 + + for _ in range(max_new_tokens): + next_id = sample_token(logits[0, -1], temperature, top_k) + if next_id == sp.eos_id(): + break + + sp.Decode(generated_ids) + generated_ids.append(next_id) + new_text = sp.Decode(generated_ids) + + response_so_far = new_text[len(prompt_text) :] + if "" in response_so_far: + break + + decode_tokens += 1 + logits, cache = model.apply( + params, jnp.array([[next_id]], dtype=jnp.int32), cache=cache + ) + + decode_elapsed = time.time() - t_decode + decode_tps = decode_tokens / decode_elapsed if decode_elapsed > 0 else 0 + + full = sp.Decode(generated_ids) + response_start = full.find("model\n") + if response_start != -1: + response = full[response_start + len("model\n") :] + response = response.replace("", "").strip() + else: + response = full + + stats = { + "ttft": ttft, + "decode_tps": decode_tps, + "decode_tokens": decode_tokens, + "decode_elapsed": decode_elapsed, + "prompt_tokens": len(token_ids), + } + return response, stats