From d7cba8f217590fe179a087a8c2eb14e244909ede Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 9 May 2026 02:10:10 +0300 Subject: [PATCH 01/11] feat: auto-load HF ONNX artifacts on CPU --- AGENTS.md | 9 + .../test_text_classifier_inference_api.py | 27 + .../text_classifier_inference_api.py | 10 + .../default_inference/nlp/th_hf_model_base.py | 622 +++++++++++++++++- extensions/serving/test_th_hf_model_base.py | 232 ++++++- 5 files changed, 885 insertions(+), 15 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ca6dcbb57..dccc2061c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -686,3 +686,12 @@ Entry format: - Details: Added repo purpose/runtime constraints, ownership table, safe-edit boundaries, required verification matrix, role-based agent cards, A2A-style task contract, mandatory handoff envelope, single-agent loop, actor-critic workflow, reusable lessons-learned section, worked examples, and explicit AGENTS review triggers. Critic concerns addressed in the update: keep single-agent as default to avoid unnecessary delegation, require executable evidence for actor-vs-critic disputes, and keep memory logging critical-only instead of turning the file into an activity log. - Verification: `sed -n '1,260p' AGENTS.md`; `rg -n "Module And File Ownership|Safe-Edit Boundaries|Required Verification Commands|Agent Cards|A2A-Style Task Contract|Actor-Critic|AGENTS Review Triggers|ML-20260317-001" AGENTS.md` - Links: `AGENTS.md` + +- ID: `ML-20260508-001` +- Timestamp: `2026-05-08T23:08:05Z` +- Type: `change` +- Summary: Shared HF text serving now auto-selects ONNX artifacts for CPU-only runtime when the HF repo declares a compatible runtime manifest. +- Criticality: Serving runtime architecture change affecting model loading, artifact downloads, output decoding, and API metadata across generic text-classification deployments. +- Details: `ThHfModelBase` keeps Transformers/PT as the default GPU and fallback path, but CPU-only `HF_RUNTIME=auto` now loads `artifact_manifest.json`, selects a declared ONNX Runtime artifact, downloads only safe allow-patterns, loads schema and contract decoder from HF artifacts, and exposes the decoded artifact contract through the existing text-classifier flow. Business API response shaping now passes through generic model/runtime metadata emitted by serving. +- Verification: `python3 -m unittest extensions.serving.test_th_hf_model_base extensions.serving.test_th_text_classifier extensions.serving.test_th_privacy_filter extensions.business.edge_inference_api.test_text_classifier_inference_api extensions.business.edge_inference_api.test_privacy_filter_inference_api`; `python3 -m py_compile extensions/serving/default_inference/nlp/th_hf_model_base.py extensions/business/edge_inference_api/text_classifier_inference_api.py`; required serving gate `python3 -m unittest extensions.serving.model_testing.test_llm_servings` currently fails at import with `ImportError: cannot import name 'Logger' from 'naeural_core'`. +- Links: `extensions/serving/default_inference/nlp/th_hf_model_base.py`, `extensions/business/edge_inference_api/text_classifier_inference_api.py`, `extensions/serving/test_th_hf_model_base.py` diff --git a/extensions/business/edge_inference_api/test_text_classifier_inference_api.py b/extensions/business/edge_inference_api/test_text_classifier_inference_api.py index 6e3c8085a..aa945ea3f 100644 --- a/extensions/business/edge_inference_api/test_text_classifier_inference_api.py +++ b/extensions/business/edge_inference_api/test_text_classifier_inference_api.py @@ -152,6 +152,33 @@ def test_build_result_from_inference_preserves_classifier_output(self): self.assertEqual(result_payload["model_name"], "openai/privacy-filter") self.assertEqual(result_payload["pipeline_task"], "token-classification") + def test_build_result_from_inference_preserves_runtime_model_metadata(self): + plugin = TextClassifierInferenceApiPlugin() + + result_payload = plugin._build_result_from_inference( # pylint: disable=protected-access + request_id="req-onnx", + inference={ + "REQUEST_ID": "req-onnx", + "TEXT": "example text", + "result": {"prediction": "safe"}, + "MODEL": {"key": "generic_text_classifier", "version": "2026.05.09"}, + "MODEL_VERSION": "2026.05.09", + "HF_RUNTIME": "onnx_fp32", + "RUNTIME": "onnxruntime", + }, + metadata={}, + request_data={"metadata": {}, "parameters": {"text": "example text"}}, + ) + + self.assertEqual(result_payload["classification"], {"prediction": "safe"}) + self.assertEqual( + result_payload["model"], + {"key": "generic_text_classifier", "version": "2026.05.09"}, + ) + self.assertEqual(result_payload["model_version"], "2026.05.09") + self.assertEqual(result_payload["hf_runtime"], "onnx_fp32") + self.assertEqual(result_payload["runtime"], "onnxruntime") + def test_handle_inferences_falls_back_to_payload_request_id(self): plugin = TextClassifierInferenceApiPlugin() plugin._requests = {"req-1": {"status": "pending"}} # pylint: disable=protected-access diff --git a/extensions/business/edge_inference_api/text_classifier_inference_api.py b/extensions/business/edge_inference_api/text_classifier_inference_api.py index 867f1ff4b..2de5aa1ae 100644 --- a/extensions/business/edge_inference_api/text_classifier_inference_api.py +++ b/extensions/business/edge_inference_api/text_classifier_inference_api.py @@ -405,6 +405,16 @@ def _build_result_from_inference( result_payload["tokenizer_name"] = inference["TOKENIZER_NAME"] if "PIPELINE_TASK" in inference: result_payload["pipeline_task"] = inference["PIPELINE_TASK"] + if "MODEL" in inference: + result_payload["model"] = inference["MODEL"] + if "MODEL_VERSION" in inference: + result_payload["model_version"] = inference["MODEL_VERSION"] + if "MODEL_REVISION" in inference: + result_payload["model_revision"] = inference["MODEL_REVISION"] + if "HF_RUNTIME" in inference: + result_payload["hf_runtime"] = inference["HF_RUNTIME"] + if "RUNTIME" in inference: + result_payload["runtime"] = inference["RUNTIME"] return result_payload def handle_inference_for_request( diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 436febf0a..db9350254 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -6,6 +6,11 @@ input/output handling. """ +import importlib.util +import inspect +import json +from pathlib import Path + import torch as th from transformers import BitsAndBytesConfig, pipeline as hf_pipeline @@ -24,6 +29,11 @@ "MODEL_NAME": None, "TOKENIZER_NAME": None, "PIPELINE_TASK": None, + "MODEL_REVISION": None, + "HF_RUNTIME": "auto", + "HF_ARTIFACT_MANIFEST": "artifact_manifest.json", + "HF_ONNX_RUNTIME_KEY": "onnx_fp32", + "HF_ONNX_ALLOW_PATTERNS": None, "TEXT_KEYS": ["text", "email_text", "content", "request", "body"], "REQUEST_ID_KEYS": ["request_id", "REQUEST_ID"], "MAX_LENGTH": 512, @@ -44,6 +54,164 @@ } +class HfOnnxArtifactPipeline: + """Callable adapter that exposes an ONNX artifact as a pipeline-like object.""" + + def __init__( + self, + repo_id, + runtime_key, + runtime_config, + tokenizer, + session, + schema, + decoder, + task=None, + max_length=None, + ): + self.repo_id = repo_id + self.runtime_key = runtime_key + self.runtime_config = runtime_config or {} + self.tokenizer = tokenizer + self.session = session + self.schema = schema or {} + self.decoder = decoder + self.task = task + self.framework = "onnxruntime" + self.max_length = max_length + return + + def __call__(self, texts, **kwargs): + """Run one or more text inputs through the ONNX artifact.""" + is_single_text = isinstance(texts, str) + text_items = [texts] if is_single_text else list(texts or []) + results = [ + self._run_single_text(text=text, inference_kwargs=kwargs) + for text in text_items + ] + return results[0] if is_single_text or len(results) == 1 else results + + def _get_max_length(self, inference_kwargs): + max_length = inference_kwargs.get("max_length") + if max_length is not None: + return max_length + if self.max_length is not None: + return self.max_length + schema_max_length = self.schema.get("max_length") + return schema_max_length if schema_max_length is not None else None + + def _tokenize(self, text, inference_kwargs): + tokenize_kwargs = { + "return_tensors": "np", + "truncation": bool(inference_kwargs.get("truncation", True)), + } + max_length = self._get_max_length(inference_kwargs) + if max_length is not None: + tokenize_kwargs["max_length"] = max_length + if "padding" in inference_kwargs: + tokenize_kwargs["padding"] = inference_kwargs["padding"] + return self.tokenizer(text, **tokenize_kwargs) + + def _input_specs(self): + inputs = self.schema.get("inputs") + if isinstance(inputs, list): + return inputs + if isinstance(inputs, dict): + return [ + {"name": name, **(spec if isinstance(spec, dict) else {})} + for name, spec in inputs.items() + ] + return [ + {"name": "input_ids", "dtype": "int64"}, + {"name": "attention_mask", "dtype": "int64"}, + ] + + def _output_names(self): + output_names = self.runtime_config.get("output_names") + if isinstance(output_names, list) and output_names: + return output_names + output_order = self.schema.get("output_order") + if isinstance(output_order, list) and output_order: + return output_order + outputs = self.schema.get("outputs") + if isinstance(outputs, list): + names = [] + for output in outputs: + if isinstance(output, dict) and output.get("name"): + names.append(output["name"]) + elif isinstance(output, str): + names.append(output) + if names: + return names + if hasattr(self.session, "get_outputs"): + session_output_names = [ + output.name for output in self.session.get_outputs() + if getattr(output, "name", None) + ] + if session_output_names: + return session_output_names + return None + + def _prepare_session_inputs(self, encoded): + session_inputs = {} + for input_spec in self._input_specs(): + if isinstance(input_spec, dict): + input_name = input_spec.get("name") + dtype = input_spec.get("dtype") + else: + input_name = str(input_spec) + dtype = None + if not input_name or input_name not in encoded: + continue + value = encoded[input_name] + if dtype is not None and hasattr(value, "astype"): + value = value.astype(dtype) + session_inputs[input_name] = value + if not session_inputs and hasattr(encoded, "items"): + session_inputs = dict(encoded.items()) + return session_inputs + + def _build_output_map(self, raw_outputs, output_names): + if output_names is None: + output_names = [f"output_{idx}" for idx in range(len(raw_outputs))] + return { + output_name: output_value + for output_name, output_value in zip(output_names, raw_outputs) + } + + def _call_decoder(self, outputs_by_name, text): + if self.decoder is None: + return outputs_by_name + decoder_kwargs = { + "runtime": self.runtime_key, + "runtime_key": self.runtime_key, + "text": text, + "repo_id": self.repo_id, + } + try: + signature = inspect.signature(self.decoder) + accepts_var_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) + if not accepts_var_kwargs: + decoder_kwargs = { + key: value for key, value in decoder_kwargs.items() + if key in signature.parameters + } + except (TypeError, ValueError): + pass + return self.decoder(outputs_by_name, self.schema, **decoder_kwargs) + + def _run_single_text(self, text, inference_kwargs): + encoded = self._tokenize(text=text, inference_kwargs=inference_kwargs) + session_inputs = self._prepare_session_inputs(encoded) + output_names = self._output_names() + raw_outputs = self.session.run(output_names, session_inputs) + outputs_by_name = self._build_output_map(raw_outputs, output_names) + return self._call_decoder(outputs_by_name=outputs_by_name, text=text) + + class ThHfModelBase(BaseServingProcess): CONFIG = _CONFIG @@ -57,6 +225,9 @@ def __init__(self, **kwargs): """ self.classifier = None self.device = None + self.hf_runtime = "pt" + self.hf_runtime_config = {} + self.hf_artifact_manifest = None super(ThHfModelBase, self).__init__(**kwargs) return @@ -102,6 +273,16 @@ def get_pipeline_task(self): """ return self.cfg_pipeline_task + def get_model_revision(self): + """Return the optional Hugging Face model revision. + + Returns + ------- + str or None + Configured `MODEL_REVISION`, or `None` when unset. + """ + return getattr(self, "cfg_model_revision", None) + @property def cache_dir(self): """Return the local cache directory for Hugging Face artifacts. @@ -255,6 +436,346 @@ def _get_model_load_config(self): cache_dir=self.cache_dir, ) + def _requested_hf_runtime(self): + """Return the normalized requested HF runtime selector.""" + requested = getattr(self, "cfg_hf_runtime", "auto") + if requested is None: + return "auto" + requested = str(requested).strip().lower() + if requested in {"", "auto"}: + return "auto" + if requested in {"pt", "torch", "pytorch", "transformers"}: + return "pt" + if requested == "onnx": + return "onnx" + return requested + + def _should_load_hf_artifact_manifest(self, requested_runtime): + """Return whether startup needs the HF artifact manifest.""" + if requested_runtime == "pt": + return False + if requested_runtime == "auto": + return self.device == -1 + return True + + def _download_hf_artifact_file(self, filename): + """Download one HF artifact file and return its local path.""" + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id=self.get_model_name(), + filename=filename, + revision=self.get_model_revision(), + token=self.hf_token, + cache_dir=self.cache_dir, + repo_type="model", + ) + + def _load_hf_artifact_manifest(self): + """Load the optional artifact manifest from the configured HF model repo.""" + manifest_name = getattr(self, "cfg_hf_artifact_manifest", None) + if not manifest_name: + return None + try: + manifest_path = self._download_hf_artifact_file(manifest_name) + return json.loads(Path(manifest_path).read_text(encoding="utf-8")) + except Exception as exc: + if self._requested_hf_runtime() != "auto": + raise + self.P( + f"HF artifact manifest {manifest_name} not available for {self.get_model_name()}: {exc}", + color="y", + ) + return None + + def _get_hf_manifest_runtimes(self, manifest): + """Extract runtime definitions from an artifact manifest.""" + if not isinstance(manifest, dict): + return {} + runtimes = manifest.get("runtimes") + return runtimes if isinstance(runtimes, dict) else {} + + def _runtime_is_onnx(self, runtime_key, runtime_config): + """Return whether a manifest runtime is backed by ONNX Runtime.""" + runtime_config = runtime_config or {} + runtime_name = str(runtime_config.get("runtime", "")).lower() + entrypoint = str(runtime_config.get("entrypoint", "")).lower() + runtime_key = str(runtime_key or "").lower() + return ( + "onnxruntime" in runtime_name + or "onnxruntime" in entrypoint + or runtime_key.startswith("onnx") + ) + + def _resolve_hf_onnx_runtime_key(self, runtimes): + """Find the preferred ONNX runtime key from manifest runtimes.""" + preferred = getattr(self, "cfg_hf_onnx_runtime_key", None) + if preferred in runtimes and self._runtime_is_onnx(preferred, runtimes[preferred]): + return preferred + for runtime_key, runtime_config in runtimes.items(): + if self._runtime_is_onnx(runtime_key, runtime_config): + return runtime_key + return None + + def _select_hf_runtime(self, manifest): + """Select the runtime to load for this startup.""" + requested_runtime = self._requested_hf_runtime() + runtimes = self._get_hf_manifest_runtimes(manifest) + if requested_runtime == "pt": + return "pt", runtimes.get("pt", {}) + if requested_runtime == "auto": + if self.device == -1: + runtime_key = self._resolve_hf_onnx_runtime_key(runtimes) + if runtime_key is not None: + return runtime_key, runtimes[runtime_key] + return "pt", runtimes.get("pt", {}) + if requested_runtime in runtimes: + return requested_runtime, runtimes[requested_runtime] + if requested_runtime == "onnx": + runtime_key = self._resolve_hf_onnx_runtime_key(runtimes) + if runtime_key is not None: + return runtime_key, runtimes[runtime_key] + manifest_name = getattr(self, "cfg_hf_artifact_manifest", "artifact_manifest.json") + raise ValueError( + f"HF runtime {requested_runtime!r} is not declared in {manifest_name!r} for {self.get_model_name()}." + ) + + def _blocked_hf_weight_pattern(self, pattern): + """Return whether a download pattern could pull framework weight files.""" + pattern = str(pattern) + blocked_suffixes = ( + ".safetensors", + "pytorch_model.bin", + "tf_model.h5", + "flax_model.msgpack", + ) + blocked_wildcards = ("*.safetensors", "*.bin", "*.h5", "*.msgpack") + return pattern.endswith(blocked_suffixes) or pattern in blocked_wildcards + + def _build_hf_runtime_allow_patterns(self, runtime_config): + """Build safe HF snapshot allow-patterns for an ONNX runtime.""" + configured_patterns = getattr(self, "cfg_hf_onnx_allow_patterns", None) + if configured_patterns: + patterns = configured_patterns + else: + patterns = runtime_config.get("recommended_allow_patterns") or runtime_config.get("files") + if not patterns: + model_file = runtime_config.get("model") + patterns = [ + model_file, + "*.onnx", + "**/*.onnx", + "onnx/*", + "onnx/**", + "*.json", + "*.py", + "*.txt", + "*.model", + "*.tiktoken", + ] + if isinstance(patterns, str): + patterns = [patterns] + safe_patterns = [] + for pattern in patterns or []: + if not pattern or self._blocked_hf_weight_pattern(pattern): + continue + if pattern not in safe_patterns: + safe_patterns.append(pattern) + if not safe_patterns: + raise ValueError("HF ONNX runtime download has no safe allow patterns.") + return safe_patterns + + def _download_hf_runtime_snapshot(self, runtime_key, runtime_config, allow_patterns): + """Download the minimal HF snapshot needed for a selected runtime.""" + from huggingface_hub import snapshot_download + + self.P( + f"Downloading HF runtime {runtime_key} artifacts for {self.get_model_name()}...", + color="y", + ) + return snapshot_download( + repo_id=self.get_model_name(), + revision=self.get_model_revision(), + token=self.hf_token, + cache_dir=self.cache_dir, + allow_patterns=allow_patterns, + repo_type="model", + ) + + def _runtime_file_list(self, runtime_config): + files = runtime_config.get("files") if isinstance(runtime_config, dict) else None + return files if isinstance(files, list) else [] + + def _first_manifest_file_with_suffix(self, runtime_config, suffixes): + """Return the first exact manifest file path ending with any suffix.""" + for file_path in self._runtime_file_list(runtime_config): + file_path = str(file_path) + if any(file_path.endswith(suffix) for suffix in suffixes): + return file_path + return None + + def _resolve_manifest_file_path(self, model_dir, manifest, runtime_config, keys, suffixes): + """Resolve a model-repo file path declared directly or inferred by suffix.""" + for key in keys: + value = runtime_config.get(key) if isinstance(runtime_config, dict) else None + if value is None and isinstance(manifest, dict): + value = manifest.get(key) + if value: + return Path(model_dir) / str(value) + inferred = self._first_manifest_file_with_suffix(runtime_config, suffixes) + if inferred: + return Path(model_dir) / inferred + return None + + def _load_hf_schema(self, model_dir, manifest, runtime_config): + """Load the JSON schema declared by the selected HF runtime.""" + schema_path = self._resolve_manifest_file_path( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + keys=("schema", "schema_file", "contract_schema"), + suffixes=("_schema.json", "schema.json"), + ) + if schema_path is None or not schema_path.exists(): + raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable schema file.") + return json.loads(schema_path.read_text(encoding="utf-8")) + + def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): + """Load the artifact decoder function declared by the selected HF runtime.""" + decoder_path = self._resolve_manifest_file_path( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + keys=("decoder", "decoder_file", "contract", "contract_file"), + suffixes=("_contract.py", "contract.py"), + ) + if decoder_path is None or not decoder_path.exists(): + raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable contract decoder.") + module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" + spec = importlib.util.spec_from_file_location(module_name, decoder_path) + if spec is None or spec.loader is None: + raise ValueError(f"Could not load HF contract decoder from {decoder_path}.") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + function_name = None + if isinstance(runtime_config, dict): + function_name = runtime_config.get("decoder_function") + if function_name is None and isinstance(manifest, dict): + function_name = manifest.get("decoder_function") + if function_name is None and callable(getattr(module, "decode_outputs", None)): + function_name = "decode_outputs" + if function_name is None: + decode_functions = [ + name for name in dir(module) + if name.startswith("decode_") + and name.endswith("_outputs") + and callable(getattr(module, name, None)) + ] + if len(decode_functions) == 1: + function_name = decode_functions[0] + decoder = getattr(module, function_name, None) if function_name else None + if not callable(decoder): + raise ValueError(f"Could not resolve a decoder function in {decoder_path}.") + return decoder + + def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, schema): + """Resolve the ONNX model file for the selected runtime.""" + for key in ("model", "model_file", "path"): + value = runtime_config.get(key) if isinstance(runtime_config, dict) else None + if value: + return Path(model_dir) / str(value) + models = schema.get("models") if isinstance(schema, dict) else None + if isinstance(models, dict): + candidates = [ + runtime_key, + str(runtime_key).replace("_", "-"), + str(runtime_key).replace("-", "_"), + ] + for candidate in candidates: + value = models.get(candidate) + if value: + if isinstance(value, dict): + value = value.get("path") or value.get("file") or value.get("model") + if not value: + continue + return Path(model_dir) / str(value) + model_file = self._first_manifest_file_with_suffix(runtime_config, (".onnx",)) + if model_file: + return Path(model_dir) / model_file + raise ValueError(f"HF runtime {runtime_key} does not declare an ONNX model file.") + + def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema): + """Resolve tokenizer directory for the selected artifact runtime.""" + tokenizer_dir = None + for source in (runtime_config, schema, manifest): + if isinstance(source, dict) and source.get("tokenizer_dir"): + tokenizer_dir = source["tokenizer_dir"] + break + return Path(model_dir) / str(tokenizer_dir or ".") + + def _load_hf_onnx_tokenizer(self, model_dir, runtime_config): + """Load the tokenizer for an ONNX HF artifact.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained( + str(model_dir), + token=self.hf_token, + trust_remote_code=bool(runtime_config.get("trust_remote_code", False)), + ) + + def _create_hf_onnx_session(self, model_path, providers): + """Create an ONNX Runtime inference session.""" + import onnxruntime as ort + + return ort.InferenceSession(str(model_path), providers=providers) + + def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_config, manifest): + """Build a callable ONNX artifact pipeline from downloaded HF files.""" + schema = self._load_hf_schema( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + decoder = self._load_hf_contract_decoder( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + tokenizer_dir = self._resolve_hf_tokenizer_dir( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + schema=schema, + ) + tokenizer = self._load_hf_onnx_tokenizer( + model_dir=tokenizer_dir, + runtime_config=runtime_config, + ) + model_path = self._resolve_hf_onnx_model_path( + model_dir=model_dir, + runtime_key=runtime_key, + runtime_config=runtime_config, + schema=schema, + ) + provider = runtime_config.get("provider") or "CPUExecutionProvider" + providers = runtime_config.get("providers") or [provider] + session = self._create_hf_onnx_session( + model_path=model_path, + providers=providers, + ) + return HfOnnxArtifactPipeline( + repo_id=self.get_model_name(), + runtime_key=runtime_key, + runtime_config=runtime_config, + tokenizer=tokenizer, + session=session, + schema=schema, + decoder=decoder, + task=runtime_config.get("pipeline_task") or manifest.get("pipeline_task") or self.get_pipeline_task(), + max_length=self.cfg_max_length, + ) + def _normalize_pipeline_runtime_contract(self): """Patch known gaps in custom remote-code pipeline initialization. @@ -299,19 +820,9 @@ def _run_startup_warmup(self): ) return - def startup(self): - """Load the Hugging Face pipeline and prepare it for inference. - - Raises - ------ - ValueError - If `MODEL_NAME` is not configured. - """ + def _startup_transformers_pipeline(self): + """Load the standard Transformers pipeline runtime.""" model_name = self.get_model_name() - if not model_name: - raise ValueError(f"{self.__class__.__name__} serving requires MODEL_NAME.") - - self.device = self._resolve_pipeline_device() model_load_params, quantization_params = self._get_model_load_config() pipeline_kwargs = self.build_pipeline_kwargs() model_kwargs = { @@ -334,12 +845,84 @@ def startup(self): trust_remote_code=bool(self.cfg_trust_remote_code), device=self.device, model_kwargs=model_kwargs, + revision=self.get_model_revision(), **pipeline_kwargs, ) self._normalize_pipeline_runtime_contract() + return + + def _startup_hf_onnx_artifact(self, runtime_key, runtime_config, manifest): + """Load the selected ONNX artifact runtime from the HF repository.""" + allow_patterns = self._build_hf_runtime_allow_patterns(runtime_config) + model_dir = self._download_hf_runtime_snapshot( + runtime_key=runtime_key, + runtime_config=runtime_config, + allow_patterns=allow_patterns, + ) + self.classifier = self._build_hf_onnx_artifact_pipeline( + model_dir=model_dir, + runtime_key=runtime_key, + runtime_config=runtime_config, + manifest=manifest or {}, + ) + return + + def startup(self): + """Load the Hugging Face runtime and prepare it for inference. + + Raises + ------ + ValueError + If `MODEL_NAME` is not configured. + """ + model_name = self.get_model_name() + if not model_name: + raise ValueError(f"{self.__class__.__name__} serving requires MODEL_NAME.") + + self.device = self._resolve_pipeline_device() + requested_runtime = self._requested_hf_runtime() + manifest = None + if self._should_load_hf_artifact_manifest(requested_runtime=requested_runtime): + manifest = self._load_hf_artifact_manifest() + runtime_key, runtime_config = self._select_hf_runtime(manifest=manifest) + self.hf_runtime = runtime_key + self.hf_runtime_config = dict(runtime_config or {}) + self.hf_artifact_manifest = manifest if isinstance(manifest, dict) else None + + if self._runtime_is_onnx(runtime_key=runtime_key, runtime_config=runtime_config): + self._startup_hf_onnx_artifact( + runtime_key=runtime_key, + runtime_config=runtime_config, + manifest=manifest, + ) + else: + self._startup_transformers_pipeline() self._run_startup_warmup() return + def _get_hf_artifact_model_metadata(self): + """Return model metadata declared by the loaded artifact.""" + metadata = {} + has_artifact_metadata = False + for source in (self.hf_artifact_manifest, getattr(self.classifier, "schema", None)): + if not isinstance(source, dict): + continue + for key in ( + "repo_id", + "repo_key", + "model_key", + "model_version", + "release_channel", + "release_alias_of", + "source_repo_id", + ): + if key not in metadata and source.get(key) is not None: + metadata[key] = source[key] + has_artifact_metadata = True + if has_artifact_metadata and self.hf_runtime: + metadata["runtime"] = self.hf_runtime + return metadata + def get_additional_metadata(self): """Return model metadata attached to decoded predictions. @@ -349,11 +932,24 @@ def get_additional_metadata(self): Model name, tokenizer name, and pipeline task metadata. """ pipeline_task = getattr(self.classifier, "task", None) if self.classifier is not None else None - return { + metadata = { "MODEL_NAME": self.get_model_name(), "TOKENIZER_NAME": self.get_tokenizer_name(), "PIPELINE_TASK": pipeline_task or self.get_pipeline_task(), + "HF_RUNTIME": self.hf_runtime, + "RUNTIME": self.hf_runtime_config.get("runtime") or ( + "onnxruntime" if self._runtime_is_onnx(self.hf_runtime, self.hf_runtime_config) else "transformers" + ), } + model_revision = self.get_model_revision() + if model_revision is not None: + metadata["MODEL_REVISION"] = model_revision + artifact_model_metadata = self._get_hf_artifact_model_metadata() + if artifact_model_metadata: + metadata["MODEL"] = artifact_model_metadata + if artifact_model_metadata.get("model_version") is not None: + metadata["MODEL_VERSION"] = artifact_model_metadata["model_version"] + return metadata def _extract_serving_target(self, struct_payload): """Extract the reserved serving-target metadata from a payload. diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index e69c4c0e4..6692bc8e8 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -2,6 +2,7 @@ import unittest from pathlib import Path +from tempfile import TemporaryDirectory ROOT = Path(__file__).resolve().parents[2] @@ -14,6 +15,11 @@ def __init__(self, **kwargs): self.cfg_model_name = kwargs.get("MODEL_NAME") self.cfg_tokenizer_name = kwargs.get("TOKENIZER_NAME") self.cfg_pipeline_task = kwargs.get("PIPELINE_TASK") + self.cfg_model_revision = kwargs.get("MODEL_REVISION") + self.cfg_hf_runtime = kwargs.get("HF_RUNTIME", "auto") + self.cfg_hf_artifact_manifest = kwargs.get("HF_ARTIFACT_MANIFEST", "artifact_manifest.json") + self.cfg_hf_onnx_runtime_key = kwargs.get("HF_ONNX_RUNTIME_KEY", "onnx_fp32") + self.cfg_hf_onnx_allow_patterns = kwargs.get("HF_ONNX_ALLOW_PATTERNS") self.cfg_max_length = kwargs.get("MAX_LENGTH", 512) self.cfg_model_weights_size = kwargs.get("MODEL_WEIGHTS_SIZE") self.cfg_hf_token = kwargs.get("HF_TOKEN") @@ -77,6 +83,7 @@ def __init__(self, **kwargs): class _FakePipeline: def __init__(self, task=None): self.task = task + self.framework = "pt" self.inference_calls = [] def __call__(self, text, **kwargs): @@ -104,6 +111,37 @@ def is_available(): return True +class _FakeEncodedValue: + def __init__(self, value): + self.value = value + self.dtype = None + + def astype(self, dtype): + self.dtype = dtype + return self + + +class _FakeTokenizer: + def __init__(self): + self.calls = [] + + def __call__(self, text, **kwargs): + self.calls.append((text, kwargs)) + return { + "input_ids": _FakeEncodedValue([1, 2, 3]), + "attention_mask": _FakeEncodedValue([1, 1, 1]), + } + + +class _FakeOrtSession: + def __init__(self): + self.calls = [] + + def run(self, output_names, inputs): + self.calls.append((output_names, inputs)) + return [[0.25, 0.75]] + + def _load_base_class(): factory = _PipelineFactory() source_path = ROOT / "extensions" / "serving" / "default_inference" / "nlp" / "th_hf_model_base.py" @@ -125,10 +163,10 @@ def _load_base_class(): "__name__": "loaded_th_hf_model_base", } exec(compile(source, str(source_path), "exec"), namespace) # noqa: S102 - return namespace["ThHfModelBase"], factory + return namespace["ThHfModelBase"], namespace["HfOnnxArtifactPipeline"], factory -ThHfModelBase, _PIPELINE_FACTORY = _load_base_class() +ThHfModelBase, HfOnnxArtifactPipeline, _PIPELINE_FACTORY = _load_base_class() class _ConcreteHfModel(ThHfModelBase): @@ -136,6 +174,11 @@ class _ConcreteHfModel(ThHfModelBase): class ThHfModelBaseTests(unittest.TestCase): + def setUp(self): + _PIPELINE_FACTORY.calls = [] + _PIPELINE_FACTORY.instance.inference_calls = [] + return + def test_hf_serving_raises_default_wait_time_above_generic_base(self): self.assertEqual(_ConcreteHfModel.CONFIG["MAX_WAIT_TIME"], 60) @@ -224,6 +267,191 @@ def test_startup_can_disable_warmup(self): self.assertEqual(after_calls, before_calls) + def test_forced_pt_runtime_passes_model_revision_to_transformers_pipeline(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + MODEL_REVISION="rev-123", + HF_RUNTIME="pt", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + + plugin.startup() + + _args, kwargs = _PIPELINE_FACTORY.calls[-1] + self.assertEqual(kwargs["revision"], "rev-123") + self.assertEqual(plugin.hf_runtime, "pt") + + def test_auto_runtime_uses_onnx_artifact_on_cpu_only(self): + manifest = { + "model_key": "generic_text_classifier", + "model_version": "2026.05.09", + "pipeline_task": "text-classification", + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": [ + "model.onnx", + "tokenizer.json", + "contract.py", + "schema.json", + "model.safetensors", + ], + } + }, + } + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + download_calls = [] + plugin._load_hf_artifact_manifest = lambda: manifest # pylint: disable=protected-access + + def fake_download(runtime_key, runtime_config, allow_patterns): + download_calls.append((runtime_key, runtime_config, allow_patterns)) + return "/tmp/models/test-model" + + plugin._download_hf_runtime_snapshot = fake_download # pylint: disable=protected-access + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "onnx_fp32") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + self.assertEqual(download_calls[0][0], "onnx_fp32") + self.assertIn("model.onnx", download_calls[0][2]) + self.assertNotIn("model.safetensors", download_calls[0][2]) + + def test_auto_runtime_keeps_transformers_pipeline_when_gpu_available(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx"], + } + }, + } + + plugin.startup() + + self.assertEqual(plugin.device, 0) + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + + def test_forced_onnx_runtime_uses_manifest_runtime_without_hardcoded_key(self): + manifest = { + "runtimes": { + "cpu_artifact": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: manifest # pylint: disable=protected-access + plugin._download_hf_runtime_snapshot = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, allow_patterns: "/tmp/models/test-model" + ) + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "cpu_artifact") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + + def test_onnx_artifact_pipeline_uses_hf_contract_decoder(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + fake_tokenizer = _FakeTokenizer() + fake_session = _FakeOrtSession() + created_sessions = [] + plugin._load_hf_onnx_tokenizer = lambda model_dir, runtime_config: fake_tokenizer # pylint: disable=protected-access + + def fake_create_session(model_path, providers): + created_sessions.append((model_path, providers)) + return fake_session + + plugin._create_hf_onnx_session = fake_create_session # pylint: disable=protected-access + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "model.onnx").write_text("fake", encoding="utf-8") + (model_dir / "schema.json").write_text( + ( + '{"inputs":[{"name":"input_ids","dtype":"int64"},' + '{"name":"attention_mask","dtype":"int64"}],' + '"outputs":[{"name":"scores"}],' + '"models":{"onnx_fp32":{"path":"model.onnx"}}}' + ), + encoding="utf-8", + ) + (model_dir / "contract.py").write_text( + ( + "def decode_generic_outputs(outputs, schema, **kwargs):\n" + " return {\n" + " 'contract': 'hf',\n" + " 'outputs': outputs,\n" + " 'repo_id': kwargs.get('repo_id'),\n" + " 'runtime': kwargs.get('runtime_key'),\n" + " }\n" + ), + encoding="utf-8", + ) + manifest = { + "pipeline_task": "text-classification", + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + + pipeline = plugin._build_hf_onnx_artifact_pipeline( # pylint: disable=protected-access + model_dir=str(model_dir), + runtime_key="onnx_fp32", + runtime_config=manifest["runtimes"]["onnx_fp32"], + manifest=manifest, + ) + result = pipeline("hello world") + batched_single_result = pipeline(["hello world"]) + + self.assertIsInstance(pipeline, HfOnnxArtifactPipeline) + self.assertEqual(result["contract"], "hf") + self.assertEqual(batched_single_result["contract"], "hf") + self.assertEqual(result["outputs"], {"scores": [0.25, 0.75]}) + self.assertEqual(result["repo_id"], "test/model") + self.assertEqual(result["runtime"], "onnx_fp32") + self.assertEqual(Path(created_sessions[0][0]).name, "model.onnx") + output_names, inputs = fake_session.calls[-1] + self.assertEqual(output_names, ["scores"]) + self.assertEqual(inputs["input_ids"].dtype, "int64") + self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") + if __name__ == "__main__": unittest.main() From e5441964bdf64fa0babaf0d175ec5b94044ec670 Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Mon, 11 May 2026 12:57:31 +0300 Subject: [PATCH 02/11] fix: harden hf onnx runtime fallback What changed: - Make auto ONNX startup opportunistic and fall back to Transformers/PT on ONNX init or warmup failure. - Keep explicit ONNX runtimes fail-fast while explicit PT skips manifest lookup. - Gate decoder and tokenizer remote code on global and runtime trust flags. - Confine manifest-declared artifact paths to the downloaded HF snapshot and filter broad/framework-weight allow patterns. - Forward runtime metadata consistently for privacy-filter responses and add focused regression coverage. Why: - Preserve seamless CPU ONNX when available without breaking Transformers fallback or weakening remote-code/path safety. --- .../privacy_filter_inference_api.py | 10 + .../test_privacy_filter_inference_api.py | 13 + .../default_inference/nlp/th_hf_model_base.py | 100 ++++-- extensions/serving/test_th_hf_model_base.py | 306 +++++++++++++++++- 4 files changed, 403 insertions(+), 26 deletions(-) diff --git a/extensions/business/edge_inference_api/privacy_filter_inference_api.py b/extensions/business/edge_inference_api/privacy_filter_inference_api.py index 295d8c00e..3cf0bd727 100644 --- a/extensions/business/edge_inference_api/privacy_filter_inference_api.py +++ b/extensions/business/edge_inference_api/privacy_filter_inference_api.py @@ -95,4 +95,14 @@ def _build_result_from_inference( # pylint: disable=arguments-differ result_payload["tokenizer_name"] = inference["TOKENIZER_NAME"] if "PIPELINE_TASK" in inference: result_payload["pipeline_task"] = inference["PIPELINE_TASK"] + if "MODEL" in inference: + result_payload["model"] = inference["MODEL"] + if "MODEL_VERSION" in inference: + result_payload["model_version"] = inference["MODEL_VERSION"] + if "MODEL_REVISION" in inference: + result_payload["model_revision"] = inference["MODEL_REVISION"] + if "HF_RUNTIME" in inference: + result_payload["hf_runtime"] = inference["HF_RUNTIME"] + if "RUNTIME" in inference: + result_payload["runtime"] = inference["RUNTIME"] return result_payload diff --git a/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py b/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py index d5219e83f..ecb3bec75 100644 --- a/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py +++ b/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py @@ -55,6 +55,11 @@ def test_build_result_from_inference_uses_findings_key(self): "FINDINGS_COUNT": 1, "MODEL_NAME": "openai/privacy-filter", "PIPELINE_TASK": "token-classification", + "MODEL": {"model_key": "privacy_filter", "model_version": "2026.05.09"}, + "MODEL_VERSION": "2026.05.09", + "MODEL_REVISION": "rev-privacy", + "HF_RUNTIME": "pt", + "RUNTIME": "transformers", }, metadata={}, request_data={"metadata": {}, "parameters": {"text": "example text"}}, @@ -73,6 +78,14 @@ def test_build_result_from_inference_uses_findings_key(self): self.assertEqual(result_payload["findings_count"], 1) self.assertEqual(result_payload["model_name"], "openai/privacy-filter") self.assertEqual(result_payload["pipeline_task"], "token-classification") + self.assertEqual( + result_payload["model"], + {"model_key": "privacy_filter", "model_version": "2026.05.09"}, + ) + self.assertEqual(result_payload["model_version"], "2026.05.09") + self.assertEqual(result_payload["model_revision"], "rev-privacy") + self.assertEqual(result_payload["hf_runtime"], "pt") + self.assertEqual(result_payload["runtime"], "transformers") if __name__ == "__main__": diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index db9350254..bca48e19d 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -179,14 +179,16 @@ def _build_output_map(self, raw_outputs, output_names): for output_name, output_value in zip(output_names, raw_outputs) } - def _call_decoder(self, outputs_by_name, text): + def _call_decoder(self, outputs_by_name, text, inference_kwargs): if self.decoder is None: return outputs_by_name decoder_kwargs = { + **dict(inference_kwargs or {}), "runtime": self.runtime_key, "runtime_key": self.runtime_key, "text": text, "repo_id": self.repo_id, + "inference_kwargs": dict(inference_kwargs or {}), } try: signature = inspect.signature(self.decoder) @@ -209,7 +211,11 @@ def _run_single_text(self, text, inference_kwargs): output_names = self._output_names() raw_outputs = self.session.run(output_names, session_inputs) outputs_by_name = self._build_output_map(raw_outputs, output_names) - return self._call_decoder(outputs_by_name=outputs_by_name, text=text) + return self._call_decoder( + outputs_by_name=outputs_by_name, + text=text, + inference_kwargs=inference_kwargs, + ) class ThHfModelBase(BaseServingProcess): @@ -545,12 +551,13 @@ def _blocked_hf_weight_pattern(self, pattern): pattern = str(pattern) blocked_suffixes = ( ".safetensors", - "pytorch_model.bin", - "tf_model.h5", - "flax_model.msgpack", + ".bin", + ".h5", + ".msgpack", ) - blocked_wildcards = ("*.safetensors", "*.bin", "*.h5", "*.msgpack") - return pattern.endswith(blocked_suffixes) or pattern in blocked_wildcards + blocked_wildcards = ("*", "**/*", "*.safetensors", "*.bin", "*.h5", "*.msgpack") + blocked_directory_globs = pattern.endswith("/*") or pattern.endswith("/**") + return pattern.endswith(blocked_suffixes) or pattern in blocked_wildcards or blocked_directory_globs def _build_hf_runtime_allow_patterns(self, runtime_config): """Build safe HF snapshot allow-patterns for an ONNX runtime.""" @@ -565,8 +572,6 @@ def _build_hf_runtime_allow_patterns(self, runtime_config): model_file, "*.onnx", "**/*.onnx", - "onnx/*", - "onnx/**", "*.json", "*.py", "*.txt", @@ -606,6 +611,19 @@ def _runtime_file_list(self, runtime_config): files = runtime_config.get("files") if isinstance(runtime_config, dict) else None return files if isinstance(files, list) else [] + def _resolve_hf_snapshot_path(self, model_dir, file_path): + """Resolve a manifest path while keeping it inside the downloaded snapshot.""" + path = Path(str(file_path)) + if path.is_absolute(): + raise ValueError(f"HF artifact path {file_path!r} must be relative to the model snapshot.") + snapshot_dir = Path(model_dir).resolve() + resolved_path = (snapshot_dir / path).resolve() + try: + resolved_path.relative_to(snapshot_dir) + except ValueError as exc: + raise ValueError(f"HF artifact path {file_path!r} escapes the model snapshot.") from exc + return resolved_path + def _first_manifest_file_with_suffix(self, runtime_config, suffixes): """Return the first exact manifest file path ending with any suffix.""" for file_path in self._runtime_file_list(runtime_config): @@ -621,10 +639,10 @@ def _resolve_manifest_file_path(self, model_dir, manifest, runtime_config, keys, if value is None and isinstance(manifest, dict): value = manifest.get(key) if value: - return Path(model_dir) / str(value) + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=value) inferred = self._first_manifest_file_with_suffix(runtime_config, suffixes) if inferred: - return Path(model_dir) / inferred + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=inferred) return None def _load_hf_schema(self, model_dir, manifest, runtime_config): @@ -640,6 +658,13 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable schema file.") return json.loads(schema_path.read_text(encoding="utf-8")) + def _runtime_allows_remote_code(self, manifest, runtime_config): + """Return whether the selected runtime explicitly allows Python artifact code.""" + for source in (runtime_config, manifest): + if isinstance(source, dict) and "trust_remote_code" in source: + return bool(source.get("trust_remote_code")) + return False + def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): """Load the artifact decoder function declared by the selected HF runtime.""" decoder_path = self._resolve_manifest_file_path( @@ -651,6 +676,14 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): ) if decoder_path is None or not decoder_path.exists(): raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable contract decoder.") + if not bool(self.cfg_trust_remote_code) or not self._runtime_allows_remote_code( + manifest=manifest, + runtime_config=runtime_config, + ): + raise ValueError( + "HF ONNX artifact decoder requires global TRUST_REMOTE_CODE=True and runtime " + f"trust_remote_code=True because it executes Python code from {decoder_path}." + ) module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" spec = importlib.util.spec_from_file_location(module_name, decoder_path) if spec is None or spec.loader is None: @@ -684,7 +717,7 @@ def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, sc for key in ("model", "model_file", "path"): value = runtime_config.get(key) if isinstance(runtime_config, dict) else None if value: - return Path(model_dir) / str(value) + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=value) models = schema.get("models") if isinstance(schema, dict) else None if isinstance(models, dict): candidates = [ @@ -699,10 +732,10 @@ def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, sc value = value.get("path") or value.get("file") or value.get("model") if not value: continue - return Path(model_dir) / str(value) + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=value) model_file = self._first_manifest_file_with_suffix(runtime_config, (".onnx",)) if model_file: - return Path(model_dir) / model_file + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=model_file) raise ValueError(f"HF runtime {runtime_key} does not declare an ONNX model file.") def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema): @@ -712,16 +745,19 @@ def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema) if isinstance(source, dict) and source.get("tokenizer_dir"): tokenizer_dir = source["tokenizer_dir"] break - return Path(model_dir) / str(tokenizer_dir or ".") + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=tokenizer_dir or ".") - def _load_hf_onnx_tokenizer(self, model_dir, runtime_config): + def _load_hf_onnx_tokenizer(self, model_dir, runtime_config, manifest=None): """Load the tokenizer for an ONNX HF artifact.""" from transformers import AutoTokenizer return AutoTokenizer.from_pretrained( str(model_dir), token=self.hf_token, - trust_remote_code=bool(runtime_config.get("trust_remote_code", False)), + trust_remote_code=bool(self.cfg_trust_remote_code) and self._runtime_allows_remote_code( + manifest=manifest, + runtime_config=runtime_config, + ), ) def _create_hf_onnx_session(self, model_path, providers): @@ -751,6 +787,7 @@ def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_confi tokenizer = self._load_hf_onnx_tokenizer( model_dir=tokenizer_dir, runtime_config=runtime_config, + manifest=manifest, ) model_path = self._resolve_hf_onnx_model_path( model_dir=model_dir, @@ -889,15 +926,32 @@ def startup(self): self.hf_runtime_config = dict(runtime_config or {}) self.hf_artifact_manifest = manifest if isinstance(manifest, dict) else None + run_warmup = True if self._runtime_is_onnx(runtime_key=runtime_key, runtime_config=runtime_config): - self._startup_hf_onnx_artifact( - runtime_key=runtime_key, - runtime_config=runtime_config, - manifest=manifest, - ) + try: + self._startup_hf_onnx_artifact( + runtime_key=runtime_key, + runtime_config=runtime_config, + manifest=manifest, + ) + self._run_startup_warmup() + run_warmup = False + except Exception as exc: + if requested_runtime != "auto": + raise + self.P( + f"HF auto runtime could not start ONNX artifact {runtime_key!r} for " + f"{self.get_model_name()}: {exc}. Falling back to Transformers/PT.", + color="y", + ) + self.hf_runtime = "pt" + self.hf_runtime_config = {} + self.hf_artifact_manifest = None + self._startup_transformers_pipeline() else: self._startup_transformers_pipeline() - self._run_startup_warmup() + if run_warmup: + self._run_startup_warmup() return def _get_hf_artifact_model_metadata(self): diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 6692bc8e8..f541a9bb8 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -1,3 +1,4 @@ +import sys import types import unittest @@ -282,6 +283,45 @@ def test_forced_pt_runtime_passes_model_revision_to_transformers_pipeline(self): self.assertEqual(kwargs["revision"], "rev-123") self.assertEqual(plugin.hf_runtime, "pt") + def test_forced_pt_runtime_on_cpu_skips_manifest_lookup(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="pt", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = ( # pylint: disable=protected-access + lambda: (_ for _ in ()).throw(AssertionError("manifest should not be loaded")) + ) + + plugin.startup() + + self.assertEqual(plugin.device, -1) + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + + def test_onnx_allow_patterns_reject_framework_weights_and_broad_downloads(self): + plugin = _ConcreteHfModel(MODEL_NAME="test/model") + + allow_patterns = plugin._build_hf_runtime_allow_patterns({ # pylint: disable=protected-access + "files": [ + "*", + "**/*", + "onnx/*", + "onnx/**", + "model.onnx", + "tokenizer.json", + "contract.py", + "pytorch_model-00001-of-00002.bin", + "model.safetensors", + "tf_model.h5", + "flax_model.msgpack", + ], + }) + + self.assertEqual(allow_patterns, ["model.onnx", "tokenizer.json", "contract.py"]) + def test_auto_runtime_uses_onnx_artifact_on_cpu_only(self): manifest = { "model_key": "generic_text_classifier", @@ -379,6 +419,259 @@ def test_forced_onnx_runtime_uses_manifest_runtime_without_hardcoded_key(self): self.assertEqual(plugin.hf_runtime, "cpu_artifact") self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + def test_auto_runtime_falls_back_to_transformers_when_onnx_startup_fails(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + + def fail_onnx_startup(runtime_key, runtime_config, manifest): # pylint: disable=unused-argument + raise RuntimeError("onnxruntime is not installed") + + plugin._startup_hf_onnx_artifact = fail_onnx_startup # pylint: disable=protected-access + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(plugin.hf_runtime_config, {}) + self.assertIsNone(plugin.hf_artifact_manifest) + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + self.assertTrue( + any("Falling back to Transformers/PT" in message[0][0] for message in plugin.logged_messages) + ) + + def test_forced_onnx_runtime_does_not_fallback_after_startup_failure(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + plugin._startup_hf_onnx_artifact = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, manifest: (_ for _ in ()).throw(RuntimeError("bad onnx")) + ) + + with self.assertRaisesRegex(RuntimeError, "bad onnx"): + plugin.startup() + + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + + def test_named_onnx_runtime_does_not_fallback_after_startup_failure(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx_fp32", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + plugin._startup_hf_onnx_artifact = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, manifest: (_ for _ in ()).throw(RuntimeError("bad named onnx")) + ) + + with self.assertRaisesRegex(RuntimeError, "bad named onnx"): + plugin.startup() + + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + + def test_auto_runtime_falls_back_to_transformers_when_onnx_warmup_fails(self): + class _FailingWarmupPipeline: + task = "text-classification" + schema = {} + + def __call__(self, text, **kwargs): # pylint: disable=unused-argument + raise RuntimeError("onnx warmup failed") + + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + + def set_failing_pipeline(runtime_key, runtime_config, manifest): # pylint: disable=unused-argument + plugin.classifier = _FailingWarmupPipeline() + return + + plugin._startup_hf_onnx_artifact = set_failing_pipeline # pylint: disable=protected-access + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + self.assertEqual(_PIPELINE_FACTORY.instance.inference_calls[-1][0], "Warmup request.") + self.assertTrue( + any("Falling back to Transformers/PT" in message[0][0] for message in plugin.logged_messages) + ) + + def test_hf_contract_decoder_requires_global_trust_remote_code(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=False, + ) + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema):\n return outputs\n", + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "TRUST_REMOTE_CODE=True"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py"}, + ) + + def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + ) + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema):\n return outputs\n", + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py", "trust_remote_code": False}, + ) + + def test_hf_artifact_paths_must_stay_inside_snapshot(self): + plugin = _ConcreteHfModel(MODEL_NAME="test/model") + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "schema.json").write_text("{}", encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "escapes the model snapshot"): + plugin._load_hf_schema( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"schema": "../schema.json"}, + ) + + with self.assertRaisesRegex(ValueError, "must be relative"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": str((model_dir / "contract.py").resolve())}, + ) + + with self.assertRaisesRegex(ValueError, "escapes the model snapshot"): + plugin._resolve_hf_onnx_model_path( # pylint: disable=protected-access + model_dir=str(model_dir), + runtime_key="onnx_fp32", + runtime_config={"model": "../model.onnx"}, + schema={}, + ) + + with self.assertRaisesRegex(ValueError, "escapes the model snapshot"): + plugin._resolve_hf_tokenizer_dir( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"tokenizer_dir": "../tokenizer"}, + schema={}, + ) + + def test_onnx_tokenizer_remote_code_requires_global_trust_remote_code(self): + calls = [] + + class _FakeAutoTokenizer: + @staticmethod + def from_pretrained(model_dir, **kwargs): + calls.append((model_dir, kwargs)) + return _FakeTokenizer() + + fake_transformers = types.SimpleNamespace(AutoTokenizer=_FakeAutoTokenizer) + original_transformers = sys.modules.get("transformers") + sys.modules["transformers"] = fake_transformers + try: + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=False, + ) + plugin._load_hf_onnx_tokenizer( # pylint: disable=protected-access + model_dir="/tmp/model", + runtime_config={"trust_remote_code": True}, + ) + + self.assertFalse(calls[-1][1]["trust_remote_code"]) + + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + ) + plugin._load_hf_onnx_tokenizer( # pylint: disable=protected-access + model_dir="/tmp/model", + runtime_config={"trust_remote_code": True}, + ) + + self.assertTrue(calls[-1][1]["trust_remote_code"]) + + plugin._load_hf_onnx_tokenizer( # pylint: disable=protected-access + model_dir="/tmp/model", + runtime_config={"trust_remote_code": False}, + ) + + self.assertFalse(calls[-1][1]["trust_remote_code"]) + finally: + if original_transformers is None: + sys.modules.pop("transformers", None) + else: + sys.modules["transformers"] = original_transformers + def test_onnx_artifact_pipeline_uses_hf_contract_decoder(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", @@ -389,7 +682,9 @@ def test_onnx_artifact_pipeline_uses_hf_contract_decoder(self): fake_tokenizer = _FakeTokenizer() fake_session = _FakeOrtSession() created_sessions = [] - plugin._load_hf_onnx_tokenizer = lambda model_dir, runtime_config: fake_tokenizer # pylint: disable=protected-access + plugin._load_hf_onnx_tokenizer = ( # pylint: disable=protected-access + lambda model_dir, runtime_config, manifest=None: fake_tokenizer + ) def fake_create_session(model_path, providers): created_sessions.append((model_path, providers)) @@ -411,12 +706,14 @@ def fake_create_session(model_path, providers): ) (model_dir / "contract.py").write_text( ( - "def decode_generic_outputs(outputs, schema, **kwargs):\n" + "def decode_generic_outputs(outputs, schema, aggregation_strategy=None, inference_kwargs=None, **kwargs):\n" " return {\n" " 'contract': 'hf',\n" " 'outputs': outputs,\n" " 'repo_id': kwargs.get('repo_id'),\n" " 'runtime': kwargs.get('runtime_key'),\n" + " 'aggregation_strategy': aggregation_strategy,\n" + " 'inference_kwargs': inference_kwargs,\n" " }\n" ), encoding="utf-8", @@ -426,6 +723,7 @@ def fake_create_session(model_path, providers): "runtimes": { "onnx_fp32": { "runtime": "onnxruntime", + "trust_remote_code": True, "files": ["model.onnx", "schema.json", "contract.py"], } }, @@ -437,7 +735,7 @@ def fake_create_session(model_path, providers): runtime_config=manifest["runtimes"]["onnx_fp32"], manifest=manifest, ) - result = pipeline("hello world") + result = pipeline("hello world", aggregation_strategy="simple", threshold=0.7) batched_single_result = pipeline(["hello world"]) self.assertIsInstance(pipeline, HfOnnxArtifactPipeline) @@ -446,6 +744,8 @@ def fake_create_session(model_path, providers): self.assertEqual(result["outputs"], {"scores": [0.25, 0.75]}) self.assertEqual(result["repo_id"], "test/model") self.assertEqual(result["runtime"], "onnx_fp32") + self.assertEqual(result["aggregation_strategy"], "simple") + self.assertEqual(result["inference_kwargs"]["threshold"], 0.7) self.assertEqual(Path(created_sessions[0][0]).name, "model.onnx") output_names, inputs = fake_session.calls[-1] self.assertEqual(output_names, ["scores"]) From 9c317ade87013b0096d013fa7814333c83972dbd Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Mon, 11 May 2026 12:58:48 +0300 Subject: [PATCH 03/11] fix: require runtime trust for hf onnx code What changed: - Require selected ONNX runtime config trust_remote_code=True before executing artifact decoder or tokenizer remote code. - Add regression coverage proving a top-level manifest trust flag cannot enable runtime code execution by itself. Why: - Avoid remote-code trust bypasses from broad manifest metadata; the selected runtime must explicitly opt in. --- .../default_inference/nlp/th_hf_model_base.py | 5 +---- extensions/serving/test_th_hf_model_base.py | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index bca48e19d..aabc9af94 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -660,10 +660,7 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): def _runtime_allows_remote_code(self, manifest, runtime_config): """Return whether the selected runtime explicitly allows Python artifact code.""" - for source in (runtime_config, manifest): - if isinstance(source, dict) and "trust_remote_code" in source: - return bool(source.get("trust_remote_code")) - return False + return isinstance(runtime_config, dict) and bool(runtime_config.get("trust_remote_code")) def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): """Load the artifact decoder function declared by the selected HF runtime.""" diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index f541a9bb8..da73e9a48 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -585,6 +585,28 @@ def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): runtime_config={"decoder": "contract.py", "trust_remote_code": False}, ) + def test_top_level_manifest_trust_remote_code_does_not_enable_runtime_decoder(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + ) + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema):\n return outputs\n", + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={"trust_remote_code": True}, + runtime_config={"decoder": "contract.py"}, + ) + def test_hf_artifact_paths_must_stay_inside_snapshot(self): plugin = _ConcreteHfModel(MODEL_NAME="test/model") plugin.hf_runtime = "onnx_fp32" From 9fd9a16a724e537d83941980d22086479091ca7d Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 17:33:40 +0300 Subject: [PATCH 04/11] feat: support privacy-filter onnx fallback What changed: - Added subclass ONNX fallback hooks in the HF serving base. - Added local privacy-filter ONNX discovery and BIOES/Viterbi span decoding. - Covered fallback runtime selection and privacy-filter decoder behavior with tests. Why: - Allow openai/privacy-filter ONNX artifacts to run without a remote artifact manifest or remote Python decoder code. --- .../default_inference/nlp/th_hf_model_base.py | 51 ++- .../nlp/th_privacy_filter.py | 348 ++++++++++++++++++ extensions/serving/test_th_hf_model_base.py | 70 ++++ extensions/serving/test_th_privacy_filter.py | 131 +++++++ 4 files changed, 596 insertions(+), 4 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index aabc9af94..951469a7b 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -105,6 +105,10 @@ def _tokenize(self, text, inference_kwargs): "return_tensors": "np", "truncation": bool(inference_kwargs.get("truncation", True)), } + for source in (self.schema, self.runtime_config): + extra_tokenize_kwargs = source.get("tokenizer_kwargs") if isinstance(source, dict) else None + if isinstance(extra_tokenize_kwargs, dict): + tokenize_kwargs.update(extra_tokenize_kwargs) max_length = self._get_max_length(inference_kwargs) if max_length is not None: tokenize_kwargs["max_length"] = max_length @@ -179,7 +183,7 @@ def _build_output_map(self, raw_outputs, output_names): for output_name, output_value in zip(output_names, raw_outputs) } - def _call_decoder(self, outputs_by_name, text, inference_kwargs): + def _call_decoder(self, outputs_by_name, text, encoded, inference_kwargs): if self.decoder is None: return outputs_by_name decoder_kwargs = { @@ -188,6 +192,8 @@ def _call_decoder(self, outputs_by_name, text, inference_kwargs): "runtime_key": self.runtime_key, "text": text, "repo_id": self.repo_id, + "tokenizer_output": encoded, + "encoded": encoded, "inference_kwargs": dict(inference_kwargs or {}), } try: @@ -214,6 +220,7 @@ def _run_single_text(self, text, inference_kwargs): return self._call_decoder( outputs_by_name=outputs_by_name, text=text, + encoded=encoded, inference_kwargs=inference_kwargs, ) @@ -477,15 +484,32 @@ def _download_hf_artifact_file(self, filename): repo_type="model", ) + def _get_hf_onnx_fallback_manifest(self): + """Return a subclass-provided ONNX manifest when the repo has no manifest. + + This hook lets dedicated serving classes support standard HF ONNX layouts + without requiring remote Python artifact code or model-specific logic in + the shared base class. + """ + return None + def _load_hf_artifact_manifest(self): """Load the optional artifact manifest from the configured HF model repo.""" manifest_name = getattr(self, "cfg_hf_artifact_manifest", None) if not manifest_name: - return None + return self._get_hf_onnx_fallback_manifest() try: manifest_path = self._download_hf_artifact_file(manifest_name) return json.loads(Path(manifest_path).read_text(encoding="utf-8")) except Exception as exc: + fallback_manifest = self._get_hf_onnx_fallback_manifest() + if isinstance(fallback_manifest, dict): + self.P( + f"HF artifact manifest {manifest_name} not available for {self.get_model_name()}; " + "using subclass ONNX fallback manifest.", + color="y", + ) + return fallback_manifest if self._requested_hf_runtime() != "auto": raise self.P( @@ -647,6 +671,9 @@ def _resolve_manifest_file_path(self, model_dir, manifest, runtime_config, keys, def _load_hf_schema(self, model_dir, manifest, runtime_config): """Load the JSON schema declared by the selected HF runtime.""" + inline_schema = runtime_config.get("inline_schema") if isinstance(runtime_config, dict) else None + if isinstance(inline_schema, dict): + return inline_schema schema_path = self._resolve_manifest_file_path( model_dir=model_dir, manifest=manifest, @@ -709,6 +736,22 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): raise ValueError(f"Could not resolve a decoder function in {decoder_path}.") return decoder + def _get_hf_onnx_artifact_schema(self, model_dir, manifest, runtime_config): + """Return the schema used by an ONNX artifact runtime.""" + return self._load_hf_schema( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + + def _get_hf_onnx_artifact_decoder(self, model_dir, manifest, runtime_config): + """Return the decoder used by an ONNX artifact runtime.""" + return self._load_hf_contract_decoder( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, schema): """Resolve the ONNX model file for the selected runtime.""" for key in ("model", "model_file", "path"): @@ -765,12 +808,12 @@ def _create_hf_onnx_session(self, model_path, providers): def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_config, manifest): """Build a callable ONNX artifact pipeline from downloaded HF files.""" - schema = self._load_hf_schema( + schema = self._get_hf_onnx_artifact_schema( model_dir=model_dir, manifest=manifest, runtime_config=runtime_config, ) - decoder = self._load_hf_contract_decoder( + decoder = self._get_hf_onnx_artifact_decoder( model_dir=model_dir, manifest=manifest, runtime_config=runtime_config, diff --git a/extensions/serving/default_inference/nlp/th_privacy_filter.py b/extensions/serving/default_inference/nlp/th_privacy_filter.py index c4ce806ac..e7a13ea31 100644 --- a/extensions/serving/default_inference/nlp/th_privacy_filter.py +++ b/extensions/serving/default_inference/nlp/th_privacy_filter.py @@ -7,6 +7,9 @@ - redaction-friendly post-processing metadata """ +import json +import math + from extensions.serving.default_inference.nlp.th_hf_model_base import ( _CONFIG as BASE_HF_MODEL_CONFIG, ThHfModelBase, @@ -30,11 +33,356 @@ FIXED_CENSOR_SIZE = 4 +PRIVACY_FILTER_ONNX_RUNTIME_KEY = "onnx_fp32" +PRIVACY_FILTER_ONNX_MODEL_FILE = "onnx/model.onnx" +PRIVACY_FILTER_VITERBI_FILE = "viterbi_calibration.json" class ThPrivacyFilter(ThHfModelBase): CONFIG = _CONFIG + def _get_hf_onnx_fallback_manifest(self): + """Declare the public HF ONNX layout when no artifact manifest exists.""" + if self.get_model_name() != "openai/privacy-filter": + return None + return { + "model_key": "openai_privacy_filter", + "source_repo_id": "openai/privacy-filter", + "pipeline_task": "token-classification", + "runtimes": { + PRIVACY_FILTER_ONNX_RUNTIME_KEY: { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "pipeline_task": "token-classification", + "model": PRIVACY_FILTER_ONNX_MODEL_FILE, + "decoder_type": "privacy_filter_span_decoder", + "files": [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + PRIVACY_FILTER_VITERBI_FILE, + PRIVACY_FILTER_ONNX_MODEL_FILE, + "onnx/model.onnx_data", + "onnx/model.onnx_data_1", + "onnx/model.onnx_data_2", + ], + "recommended_allow_patterns": [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + PRIVACY_FILTER_VITERBI_FILE, + PRIVACY_FILTER_ONNX_MODEL_FILE, + "onnx/model.onnx_data", + "onnx/model.onnx_data_1", + "onnx/model.onnx_data_2", + ], + "providers": ["CPUExecutionProvider"], + }, + }, + } + + def _get_hf_onnx_artifact_schema(self, model_dir, manifest, runtime_config): + """Build a local schema for the privacy-filter ONNX artifacts.""" + if runtime_config.get("decoder_type") != "privacy_filter_span_decoder": + return super()._get_hf_onnx_artifact_schema( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + config_path = self._resolve_hf_snapshot_path(model_dir=model_dir, file_path="config.json") + config = json.loads(config_path.read_text(encoding="utf-8")) + calibration = {} + calibration_path = self._resolve_hf_snapshot_path( + model_dir=model_dir, + file_path=PRIVACY_FILTER_VITERBI_FILE, + ) + if calibration_path.exists(): + calibration = json.loads(calibration_path.read_text(encoding="utf-8")) + return { + "inputs": [ + {"name": "input_ids", "dtype": "int64"}, + {"name": "attention_mask", "dtype": "int64"}, + ], + "outputs": [{"name": "logits"}], + "output_order": ["logits"], + "id2label": config.get("id2label", {}), + "tokenizer_kwargs": {"return_offsets_mapping": True}, + "viterbi_calibration": calibration, + } + + def _get_hf_onnx_artifact_decoder(self, model_dir, manifest, runtime_config): + """Use the local privacy-filter decoder instead of remote Python code.""" + if runtime_config.get("decoder_type") == "privacy_filter_span_decoder": + return self._decode_privacy_filter_onnx_outputs + return super()._get_hf_onnx_artifact_decoder( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + + def _to_plain_list(self, value): + """Convert tensors/arrays to plain Python lists for decoder logic.""" + if hasattr(value, "tolist"): + return value.tolist() + return value + + def _first_batch_item(self, value): + """Return the first batch element from a tensor-like value.""" + value = self._to_plain_list(value) + if isinstance(value, list) and len(value) == 1 and isinstance(value[0], list): + return value[0] + return value + + def _get_tokenizer_field(self, tokenizer_output, field_name): + if not hasattr(tokenizer_output, "get"): + return None + return self._first_batch_item(tokenizer_output.get(field_name)) + + def _get_privacy_filter_id2label(self, schema): + raw_id2label = schema.get("id2label") if isinstance(schema, dict) else None + if not isinstance(raw_id2label, dict) or len(raw_id2label) == 0: + raise ValueError("Privacy-filter ONNX schema must provide id2label.") + labels_by_id = { + int(label_id): label + for label_id, label in raw_id2label.items() + } + return [ + labels_by_id[idx] + for idx in range(max(labels_by_id) + 1) + ] + + def _split_privacy_filter_label(self, label): + if not isinstance(label, str) or label == "O": + return "O", None + if "-" not in label: + return label, None + prefix, entity = label.split("-", 1) + return prefix, entity + + def _get_privacy_filter_transition_biases(self, schema): + calibration = schema.get("viterbi_calibration") if isinstance(schema, dict) else None + operating_points = calibration.get("operating_points") if isinstance(calibration, dict) else None + default_point = operating_points.get("default") if isinstance(operating_points, dict) else None + biases = default_point.get("biases") if isinstance(default_point, dict) else None + return biases if isinstance(biases, dict) else {} + + def _privacy_filter_transition_is_valid(self, previous_label, current_label): + current_prefix, current_entity = self._split_privacy_filter_label(current_label) + previous_prefix, previous_entity = self._split_privacy_filter_label(previous_label) + if previous_label is None: + return current_prefix in {"O", "B", "S"} + if previous_prefix in {"O", "E", "S"}: + return current_prefix in {"O", "B", "S"} + if previous_prefix in {"B", "I"}: + return current_prefix in {"I", "E"} and current_entity == previous_entity + return False + + def _privacy_filter_terminal_is_valid(self, label): + prefix, _entity = self._split_privacy_filter_label(label) + return prefix in {"O", "E", "S"} + + def _privacy_filter_transition_bias(self, previous_label, current_label, biases): + if previous_label is None: + return 0.0 + previous_prefix, previous_entity = self._split_privacy_filter_label(previous_label) + current_prefix, current_entity = self._split_privacy_filter_label(current_label) + if previous_prefix == "O" and current_prefix == "O": + return float(biases.get("transition_bias_background_stay", 0.0)) + if previous_prefix == "O" and current_prefix in {"B", "S"}: + return float(biases.get("transition_bias_background_to_start", 0.0)) + if previous_prefix in {"E", "S"} and current_prefix == "O": + return float(biases.get("transition_bias_end_to_background", 0.0)) + if previous_prefix in {"E", "S"} and current_prefix in {"B", "S"}: + return float(biases.get("transition_bias_end_to_start", 0.0)) + if ( + previous_prefix in {"B", "I"} + and current_prefix == "I" + and current_entity == previous_entity + ): + return float(biases.get("transition_bias_inside_to_continue", 0.0)) + if ( + previous_prefix in {"B", "I"} + and current_prefix == "E" + and current_entity == previous_entity + ): + return float(biases.get("transition_bias_inside_to_end", 0.0)) + return 0.0 + + def _softmax(self, values): + if not values: + return [] + max_value = max(values) + exps = [math.exp(value - max_value) for value in values] + total = sum(exps) + if total == 0: + return [0.0 for _ in values] + return [value / total for value in exps] + + def _decode_privacy_filter_label_ids(self, logits, labels, offsets, attention_mask, schema): + """Run constrained BIOES Viterbi decoding over token logits.""" + o_label_id = labels.index("O") if "O" in labels else 0 + biases = self._get_privacy_filter_transition_biases(schema) + previous_scores = None + backpointers = [] + selected_probabilities = [] + probabilities_by_token = [] + invalid_score = -1e9 + for token_idx, token_logits in enumerate(logits): + token_logits = [float(value) for value in token_logits] + probabilities_by_token.append(self._softmax(token_logits)) + is_content_token = True + if attention_mask is not None and token_idx < len(attention_mask): + is_content_token = bool(attention_mask[token_idx]) + if offsets is not None and token_idx < len(offsets): + start, end = offsets[token_idx] + if int(start) == int(end): + is_content_token = False + if not is_content_token: + token_logits = [ + 0.0 if label_idx == o_label_id else invalid_score + for label_idx, _label in enumerate(labels) + ] + current_scores = [] + current_backpointers = [] + for label_idx, label in enumerate(labels): + emission_score = token_logits[label_idx] + if previous_scores is None: + if self._privacy_filter_transition_is_valid(None, label): + current_scores.append(emission_score) + current_backpointers.append(None) + else: + current_scores.append(invalid_score) + current_backpointers.append(None) + continue + best_score = invalid_score + best_previous_idx = 0 + for previous_idx, previous_label in enumerate(labels): + if not self._privacy_filter_transition_is_valid(previous_label, label): + continue + score = ( + previous_scores[previous_idx] + + self._privacy_filter_transition_bias(previous_label, label, biases) + + emission_score + ) + if score > best_score: + best_score = score + best_previous_idx = previous_idx + current_scores.append(best_score) + current_backpointers.append(best_previous_idx) + previous_scores = current_scores + backpointers.append(current_backpointers) + if not previous_scores: + return [], [] + terminal_scores = [ + score if self._privacy_filter_terminal_is_valid(labels[idx]) else invalid_score + for idx, score in enumerate(previous_scores) + ] + if max(terminal_scores) > invalid_score: + previous_scores = terminal_scores + best_label_idx = max(range(len(previous_scores)), key=lambda idx: previous_scores[idx]) + label_ids = [] + for token_idx in range(len(backpointers) - 1, -1, -1): + label_ids.append(best_label_idx) + previous_idx = backpointers[token_idx][best_label_idx] + best_label_idx = previous_idx if previous_idx is not None else o_label_id + label_ids.reverse() + for token_idx, label_idx in enumerate(label_ids): + probabilities = probabilities_by_token[token_idx] + selected_probabilities.append(probabilities[label_idx] if label_idx < len(probabilities) else 0.0) + return label_ids, selected_probabilities + + def _build_privacy_filter_spans(self, text, labels, label_ids, probabilities, offsets): + spans = [] + current_span = None + for token_idx, label_id in enumerate(label_ids): + if offsets is None or token_idx >= len(offsets): + continue + start, end = offsets[token_idx] + start = int(start) + end = int(end) + if start == end: + continue + label = labels[label_id] + prefix, entity = self._split_privacy_filter_label(label) + token_score = probabilities[token_idx] if token_idx < len(probabilities) else 0.0 + if prefix == "O": + if current_span is not None: + spans.append(current_span) + current_span = None + continue + if prefix == "S": + if current_span is not None: + spans.append(current_span) + current_span = None + spans.append({ + "entity_group": entity, + "entity": entity, + "score": token_score, + "word": text[start:end], + "start": start, + "end": end, + }) + continue + if prefix == "B" or current_span is None or current_span["entity_group"] != entity: + if current_span is not None: + spans.append(current_span) + current_span = { + "entity_group": entity, + "entity": entity, + "score": token_score, + "word": text[start:end], + "start": start, + "end": end, + "_scores": [token_score], + } + if prefix == "E": + current_span["_scores"].append(token_score) + current_span["end"] = end + current_span["word"] = text[current_span["start"]:current_span["end"]] + spans.append(current_span) + current_span = None + continue + current_span["end"] = end + current_span["word"] = text[current_span["start"]:current_span["end"]] + current_span["_scores"].append(token_score) + current_span["score"] = sum(current_span["_scores"]) / len(current_span["_scores"]) + if prefix == "E": + spans.append(current_span) + current_span = None + if current_span is not None: + spans.append(current_span) + for span in spans: + span.pop("_scores", None) + return spans + + def _decode_privacy_filter_onnx_outputs(self, outputs, schema, text=None, tokenizer_output=None, **kwargs): + """Decode ONNX token logits into privacy-filter span dictionaries.""" + logits = outputs.get("logits") if isinstance(outputs, dict) else None + if logits is None and isinstance(outputs, dict) and outputs: + logits = next(iter(outputs.values())) + logits = self._first_batch_item(logits) + if not isinstance(logits, list): + raise ValueError("Privacy-filter ONNX decoder expected logits output.") + offsets = self._get_tokenizer_field(tokenizer_output, "offset_mapping") + if offsets is None: + raise ValueError("Privacy-filter ONNX decoder requires tokenizer offset_mapping.") + attention_mask = self._get_tokenizer_field(tokenizer_output, "attention_mask") + labels = self._get_privacy_filter_id2label(schema) + label_ids, probabilities = self._decode_privacy_filter_label_ids( + logits=logits, + labels=labels, + offsets=offsets, + attention_mask=attention_mask, + schema=schema, + ) + return self._build_privacy_filter_spans( + text=text or "", + labels=labels, + label_ids=label_ids, + probabilities=probabilities, + offsets=offsets, + ) + def _extract_struct_payload(self, payload): """Extract the structured payload used by the privacy filter. diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index da73e9a48..9e9db7ddc 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -174,6 +174,22 @@ class _ConcreteHfModel(ThHfModelBase): pass +class _FallbackManifestHfModel(ThHfModelBase): + def _get_hf_onnx_fallback_manifest(self): + return { + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "files": ["model.onnx"], + "inline_schema": { + "inputs": [{"name": "input_ids", "dtype": "int64"}], + "outputs": [{"name": "scores"}], + }, + }, + }, + } + + class ThHfModelBaseTests(unittest.TestCase): def setUp(self): _PIPELINE_FACTORY.calls = [] @@ -367,6 +383,58 @@ def fake_download(runtime_key, runtime_config, allow_patterns): self.assertIn("model.onnx", download_calls[0][2]) self.assertNotIn("model.safetensors", download_calls[0][2]) + def test_auto_runtime_uses_subclass_onnx_fallback_manifest_when_hf_manifest_missing(self): + plugin = _FallbackManifestHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + download_calls = [] + plugin._download_hf_artifact_file = ( # pylint: disable=protected-access + lambda filename: (_ for _ in ()).throw(RuntimeError("not found")) + ) + plugin._download_hf_runtime_snapshot = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, allow_patterns: download_calls.append( + (runtime_key, runtime_config, allow_patterns) + ) or "/tmp/models/test-model" + ) + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "onnx_fp32") + self.assertEqual(download_calls[0][0], "onnx_fp32") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + self.assertTrue( + any("using subclass ONNX fallback manifest" in message[0][0] for message in plugin.logged_messages) + ) + + def test_forced_onnx_runtime_uses_subclass_fallback_manifest_when_hf_manifest_missing(self): + plugin = _FallbackManifestHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._download_hf_artifact_file = ( # pylint: disable=protected-access + lambda filename: (_ for _ in ()).throw(RuntimeError("not found")) + ) + plugin._download_hf_runtime_snapshot = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, allow_patterns: "/tmp/models/test-model" + ) + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "onnx_fp32") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + def test_auto_runtime_keeps_transformers_pipeline_when_gpu_available(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", @@ -722,6 +790,7 @@ def fake_create_session(model_path, providers): '{"inputs":[{"name":"input_ids","dtype":"int64"},' '{"name":"attention_mask","dtype":"int64"}],' '"outputs":[{"name":"scores"}],' + '"tokenizer_kwargs":{"return_offsets_mapping":true},' '"models":{"onnx_fp32":{"path":"model.onnx"}}}' ), encoding="utf-8", @@ -773,6 +842,7 @@ def fake_create_session(model_path, providers): self.assertEqual(output_names, ["scores"]) self.assertEqual(inputs["input_ids"].dtype, "int64") self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") + self.assertTrue(fake_tokenizer.calls[-1][1]["return_offsets_mapping"]) if __name__ == "__main__": diff --git a/extensions/serving/test_th_privacy_filter.py b/extensions/serving/test_th_privacy_filter.py index 266b308b3..417a4a8ba 100644 --- a/extensions/serving/test_th_privacy_filter.py +++ b/extensions/serving/test_th_privacy_filter.py @@ -92,6 +92,14 @@ def _payload_matches_current_serving(self, struct_payload): return False return True + def _resolve_hf_snapshot_path(self, model_dir, file_path): + path = Path(str(file_path)) + if path.is_absolute(): + raise ValueError("path must be relative") + resolved = (Path(model_dir).resolve() / path).resolve() + resolved.relative_to(Path(model_dir).resolve()) + return resolved + def _load_plugin_class(): source_path = ROOT / "extensions" / "serving" / "default_inference" / "nlp" / "th_privacy_filter.py" @@ -126,6 +134,129 @@ def test_config_pins_privacy_filter_defaults(self): "simple", ) + def test_privacy_filter_declares_local_onnx_fallback_manifest(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + + manifest = plugin._get_hf_onnx_fallback_manifest() # pylint: disable=protected-access + runtime = manifest["runtimes"]["onnx_fp32"] + + self.assertEqual(runtime["runtime"], "onnxruntime") + self.assertEqual(runtime["decoder_type"], "privacy_filter_span_decoder") + self.assertIn("onnx/model.onnx", runtime["files"]) + self.assertIn("onnx/model.onnx_data_2", runtime["files"]) + self.assertIn("viterbi_calibration.json", runtime["recommended_allow_patterns"]) + + def test_privacy_filter_does_not_declare_onnx_fallback_for_other_models(self): + plugin = ThPrivacyFilter(MODEL_NAME="other/privacy-filter") + + self.assertIsNone(plugin._get_hf_onnx_fallback_manifest()) # pylint: disable=protected-access + + def test_privacy_filter_builds_local_onnx_schema_from_hf_files(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + + from tempfile import TemporaryDirectory + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "config.json").write_text( + '{"id2label":{"0":"O","1":"S-private_email"}}', + encoding="utf-8", + ) + (model_dir / "viterbi_calibration.json").write_text( + '{"operating_points":{"default":{"biases":{"transition_bias_background_stay":0.0}}}}', + encoding="utf-8", + ) + + schema = plugin._get_hf_onnx_artifact_schema( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder_type": "privacy_filter_span_decoder"}, + ) + + self.assertEqual(schema["output_order"], ["logits"]) + self.assertEqual(schema["id2label"]["1"], "S-private_email") + self.assertTrue(schema["tokenizer_kwargs"]["return_offsets_mapping"]) + self.assertIn("viterbi_calibration", schema) + + def test_privacy_filter_local_onnx_decoder_emits_spans_from_bioes_logits(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + schema = { + "id2label": { + "0": "O", + "1": "B-private_email", + "2": "I-private_email", + "3": "E-private_email", + "4": "S-private_email", + }, + "viterbi_calibration": { + "operating_points": { + "default": { + "biases": {}, + }, + }, + }, + } + outputs = { + "logits": [[ + [8.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 8.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 8.0, 0.0], + [8.0, 0.0, 0.0, 0.0, 0.0], + ]], + } + tokenizer_output = { + "offset_mapping": [[[0, 0], [0, 5], [5, 17], [0, 0]]], + "attention_mask": [[1, 1, 1, 1]], + } + + spans = plugin._decode_privacy_filter_onnx_outputs( # pylint: disable=protected-access + outputs=outputs, + schema=schema, + text="alice@example.com", + tokenizer_output=tokenizer_output, + ) + + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0]["entity_group"], "private_email") + self.assertEqual(spans[0]["word"], "alice@example.com") + self.assertEqual(spans[0]["start"], 0) + self.assertEqual(spans[0]["end"], 17) + self.assertGreater(spans[0]["score"], 0.9) + + def test_privacy_filter_viterbi_decoder_rejects_invalid_terminal_inside_label(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + schema = { + "id2label": { + "0": "O", + "1": "B-private_email", + "2": "I-private_email", + "3": "E-private_email", + "4": "S-private_email", + }, + "viterbi_calibration": {}, + } + outputs = { + "logits": [[ + [0.0, 8.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 8.0, 7.5, 0.0], + ]], + } + tokenizer_output = { + "offset_mapping": [[[0, 5], [5, 17]]], + "attention_mask": [[1, 1]], + } + + spans = plugin._decode_privacy_filter_onnx_outputs( # pylint: disable=protected-access + outputs=outputs, + schema=schema, + text="alice@example.com", + tokenizer_output=tokenizer_output, + ) + + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0]["entity_group"], "private_email") + self.assertEqual(spans[0]["end"], 17) + def test_post_process_emits_redaction_friendly_fields(self): plugin = ThPrivacyFilter() From fb95019730c45f4f0d6b1bbcc5016bde1da6d104 Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 18:24:51 +0300 Subject: [PATCH 05/11] fix: support hf snapshot symlink artifacts What changed: - Keep HF artifact path traversal checks lexical so valid snapshot symlinks into the cache blob store are accepted. - Merge exact manifest files with recommended ONNX allow patterns after filtering broad or framework-weight downloads. - Add regression coverage for both behaviors. Why: - Live PR image validation showed Sentinel and privacy-filter ONNX startup falling back because valid HF snapshot files were rejected as escaping the snapshot. --- .../default_inference/nlp/th_hf_model_base.py | 32 ++++++++++++------- extensions/serving/test_th_hf_model_base.py | 31 +++++++++++++++++- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 951469a7b..50b4aac5f 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -9,7 +9,7 @@ import importlib.util import inspect import json -from pathlib import Path +from pathlib import Path, PurePosixPath import torch as th @@ -589,11 +589,19 @@ def _build_hf_runtime_allow_patterns(self, runtime_config): if configured_patterns: patterns = configured_patterns else: - patterns = runtime_config.get("recommended_allow_patterns") or runtime_config.get("files") + patterns = [] + for source_patterns in ( + runtime_config.get("recommended_allow_patterns"), + runtime_config.get("files"), + [runtime_config.get("model")] if runtime_config.get("model") else None, + ): + if not source_patterns: + continue + if isinstance(source_patterns, str): + source_patterns = [source_patterns] + patterns.extend(source_patterns) if not patterns: - model_file = runtime_config.get("model") patterns = [ - model_file, "*.onnx", "**/*.onnx", "*.json", @@ -637,16 +645,16 @@ def _runtime_file_list(self, runtime_config): def _resolve_hf_snapshot_path(self, model_dir, file_path): """Resolve a manifest path while keeping it inside the downloaded snapshot.""" - path = Path(str(file_path)) + raw_path = str(file_path) + path = PurePosixPath(raw_path) if path.is_absolute(): raise ValueError(f"HF artifact path {file_path!r} must be relative to the model snapshot.") - snapshot_dir = Path(model_dir).resolve() - resolved_path = (snapshot_dir / path).resolve() - try: - resolved_path.relative_to(snapshot_dir) - except ValueError as exc: - raise ValueError(f"HF artifact path {file_path!r} escapes the model snapshot.") from exc - return resolved_path + if ".." in path.parts: + raise ValueError(f"HF artifact path {file_path!r} escapes the model snapshot.") + # Hugging Face snapshots commonly symlink files into the shared cache + # blob store. A resolved containment check would reject valid snapshots, + # so keep the traversal guard lexical and return the snapshot path itself. + return Path(model_dir) / Path(*path.parts) def _first_manifest_file_with_suffix(self, runtime_config, suffixes): """Return the first exact manifest file path ending with any suffix.""" diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 9e9db7ddc..0d7569a38 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -321,6 +321,10 @@ def test_onnx_allow_patterns_reject_framework_weights_and_broad_downloads(self): plugin = _ConcreteHfModel(MODEL_NAME="test/model") allow_patterns = plugin._build_hf_runtime_allow_patterns({ # pylint: disable=protected-access + "recommended_allow_patterns": [ + "onnx/*", + "schema.json", + ], "files": [ "*", "**/*", @@ -334,9 +338,13 @@ def test_onnx_allow_patterns_reject_framework_weights_and_broad_downloads(self): "tf_model.h5", "flax_model.msgpack", ], + "model": "model.onnx", }) - self.assertEqual(allow_patterns, ["model.onnx", "tokenizer.json", "contract.py"]) + self.assertEqual( + allow_patterns, + ["schema.json", "model.onnx", "tokenizer.json", "contract.py"], + ) def test_auto_runtime_uses_onnx_artifact_on_cpu_only(self): manifest = { @@ -713,6 +721,27 @@ def test_hf_artifact_paths_must_stay_inside_snapshot(self): schema={}, ) + def test_hf_artifact_paths_allow_snapshot_symlink_targets_outside_snapshot(self): + plugin = _ConcreteHfModel(MODEL_NAME="test/model") + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + root_dir = Path(tmpdir) + model_dir = root_dir / "snapshot" + blob_dir = root_dir / "blobs" + model_dir.mkdir() + blob_dir.mkdir() + (blob_dir / "schema.json").write_text('{"inputs": []}', encoding="utf-8") + (model_dir / "schema.json").symlink_to(blob_dir / "schema.json") + + schema = plugin._load_hf_schema( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"schema": "schema.json"}, + ) + + self.assertEqual(schema, {"inputs": []}) + def test_onnx_tokenizer_remote_code_requires_global_trust_remote_code(self): calls = [] From d27b9e8429756293748c1de1f1e73afa1cfd4d1e Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 20:15:58 +0300 Subject: [PATCH 06/11] fix: allow legacy trusted onnx decoders What changed: - Temporarily allow ONNX artifact decoders without runtime-level trust_remote_code to inherit global TRUST_REMOTE_CODE=True. - Keep explicit runtime trust_remote_code=False as a hard block. - Add a TODO documenting the security concern and declarative decoder replacement path. Why: - The current Sentinel ONNX artifact predates runtime-level trust metadata and uses a reviewed contract decoder, so it needs a compatibility path until the artifact moves to declarative decoding. --- .../default_inference/nlp/th_hf_model_base.py | 13 ++++++++++--- extensions/serving/test_th_hf_model_base.py | 17 +++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 50b4aac5f..a0605870c 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -695,7 +695,14 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): def _runtime_allows_remote_code(self, manifest, runtime_config): """Return whether the selected runtime explicitly allows Python artifact code.""" - return isinstance(runtime_config, dict) and bool(runtime_config.get("trust_remote_code")) + if isinstance(runtime_config, dict) and "trust_remote_code" in runtime_config: + return bool(runtime_config.get("trust_remote_code")) + # TODO: replace this temporary compatibility path with declarative ONNX + # decoders (for example multihead_classification_v1) so artifact Python + # does not execute unless each runtime explicitly opts into remote code. + # This currently preserves legacy Sentinel ONNX artifacts whose decoder is + # a reviewed contract file but whose manifest predates runtime-level trust. + return bool(self.cfg_trust_remote_code) def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): """Load the artifact decoder function declared by the selected HF runtime.""" @@ -713,8 +720,8 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): runtime_config=runtime_config, ): raise ValueError( - "HF ONNX artifact decoder requires global TRUST_REMOTE_CODE=True and runtime " - f"trust_remote_code=True because it executes Python code from {decoder_path}." + "HF ONNX artifact decoder requires TRUST_REMOTE_CODE=True and no explicit " + f"runtime trust_remote_code=False because it executes Python code from {decoder_path}." ) module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" spec = importlib.util.spec_from_file_location(module_name, decoder_path) diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 0d7569a38..c5f8d30d2 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -654,14 +654,14 @@ def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): encoding="utf-8", ) - with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): + with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=False"): plugin._load_hf_contract_decoder( # pylint: disable=protected-access model_dir=str(model_dir), manifest={}, runtime_config={"decoder": "contract.py", "trust_remote_code": False}, ) - def test_top_level_manifest_trust_remote_code_does_not_enable_runtime_decoder(self): + def test_missing_runtime_trust_remote_code_temporarily_inherits_global_trust(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", DEVICE="cpu", @@ -676,12 +676,13 @@ def test_top_level_manifest_trust_remote_code_does_not_enable_runtime_decoder(se encoding="utf-8", ) - with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): - plugin._load_hf_contract_decoder( # pylint: disable=protected-access - model_dir=str(model_dir), - manifest={"trust_remote_code": True}, - runtime_config={"decoder": "contract.py"}, - ) + decoder = plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py"}, + ) + + self.assertEqual(decoder({"ok": True}, {}), {"ok": True}) def test_hf_artifact_paths_must_stay_inside_snapshot(self): plugin = _ConcreteHfModel(MODEL_NAME="test/model") From 1026dcc5d189e6fc2265ff0bf28f4a12c2ecf3d3 Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 20:20:52 +0300 Subject: [PATCH 07/11] fix: let global trust gate legacy onnx decoders What changed: - Split ONNX remote-code trust between tokenizer/model loading and decoder execution. - Keep tokenizer/model loading tied to runtime-level trust_remote_code. - Temporarily allow Python decoder execution when global TRUST_REMOTE_CODE=True, even for legacy runtimes that mark ONNX trust_remote_code=False. Why: - Current Sentinel ONNX artifacts use trust_remote_code=False for tokenizer/model loading but still declare a Python contract decoder. This keeps the temporary compatibility path narrow until declarative decoding replaces it. --- .../default_inference/nlp/th_hf_model_base.py | 17 +++++++++++------ extensions/serving/test_th_hf_model_base.py | 15 ++++++++------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index a0605870c..55d4b6d77 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -695,13 +695,18 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): def _runtime_allows_remote_code(self, manifest, runtime_config): """Return whether the selected runtime explicitly allows Python artifact code.""" - if isinstance(runtime_config, dict) and "trust_remote_code" in runtime_config: - return bool(runtime_config.get("trust_remote_code")) + return isinstance(runtime_config, dict) and bool(runtime_config.get("trust_remote_code")) + + def _runtime_allows_decoder_remote_code(self, manifest, runtime_config): + """Return whether the selected runtime may execute Python decoder code.""" # TODO: replace this temporary compatibility path with declarative ONNX # decoders (for example multihead_classification_v1) so artifact Python # does not execute unless each runtime explicitly opts into remote code. # This currently preserves legacy Sentinel ONNX artifacts whose decoder is - # a reviewed contract file but whose manifest predates runtime-level trust. + # a reviewed contract file but whose manifest marks the ONNX runtime as + # trust_remote_code=False because tokenizer/model loading does not need HF + # remote code. The decoder still executes Python, so this is intentionally + # gated by global TRUST_REMOTE_CODE and should be removed after repackaging. return bool(self.cfg_trust_remote_code) def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): @@ -715,13 +720,13 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): ) if decoder_path is None or not decoder_path.exists(): raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable contract decoder.") - if not bool(self.cfg_trust_remote_code) or not self._runtime_allows_remote_code( + if not bool(self.cfg_trust_remote_code) or not self._runtime_allows_decoder_remote_code( manifest=manifest, runtime_config=runtime_config, ): raise ValueError( - "HF ONNX artifact decoder requires TRUST_REMOTE_CODE=True and no explicit " - f"runtime trust_remote_code=False because it executes Python code from {decoder_path}." + "HF ONNX artifact decoder requires TRUST_REMOTE_CODE=True because it executes " + f"Python code from {decoder_path}." ) module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" spec = importlib.util.spec_from_file_location(module_name, decoder_path) diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index c5f8d30d2..ff017b5ed 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -639,7 +639,7 @@ def test_hf_contract_decoder_requires_global_trust_remote_code(self): runtime_config={"decoder": "contract.py"}, ) - def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): + def test_runtime_trust_remote_code_false_temporarily_inherits_global_trust(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", DEVICE="cpu", @@ -654,12 +654,13 @@ def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): encoding="utf-8", ) - with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=False"): - plugin._load_hf_contract_decoder( # pylint: disable=protected-access - model_dir=str(model_dir), - manifest={}, - runtime_config={"decoder": "contract.py", "trust_remote_code": False}, - ) + decoder = plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py", "trust_remote_code": False}, + ) + + self.assertEqual(decoder({"ok": True}, {}), {"ok": True}) def test_missing_runtime_trust_remote_code_temporarily_inherits_global_trust(self): plugin = _ConcreteHfModel( From 550fbadec10057218fcc2d40bbc05dd46c2bb672 Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 20:30:20 +0300 Subject: [PATCH 08/11] fix: materialize hf onnx external data What changed: - Prepare HF ONNX artifacts in an edge-node-owned materialized cache before creating ONNX Runtime sessions. - Hardlink resolved HF cache blobs when possible and copy as fallback. - Preserve runtime relative layout for .onnx and external data sidecars. - Add regression coverage for symlinked external data files. Why: - ONNX Runtime rejects HF snapshot symlinks for external data because resolved sidecars can escape the model directory. --- .../default_inference/nlp/th_hf_model_base.py | 68 ++++++++++++++++++ extensions/serving/test_th_hf_model_base.py | 72 +++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 55d4b6d77..5b530c42c 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -9,6 +9,8 @@ import importlib.util import inspect import json +import os +import shutil from pathlib import Path, PurePosixPath import torch as th @@ -798,6 +800,65 @@ def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, sc return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=model_file) raise ValueError(f"HF runtime {runtime_key} does not declare an ONNX model file.") + def _hf_onnx_materialized_root(self, model_dir, runtime_key): + """Return the local directory used for ORT-compatible ONNX artifacts.""" + snapshot_name = Path(model_dir).name + model_key = str(self.get_model_name()).replace("/", "--") + return Path(self.cache_dir) / "_onnx_materialized" / model_key / snapshot_name / str(runtime_key) + + def _materialize_hf_onnx_file(self, source_path, destination_path): + """Materialize one HF snapshot file as a real local file or hardlink.""" + source_path = Path(source_path) + destination_path = Path(destination_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + resolved_source = source_path.resolve() + if destination_path.exists(): + try: + if not destination_path.is_symlink() and destination_path.stat().st_size == resolved_source.stat().st_size: + return + except OSError: + pass + destination_path.unlink() + try: + os.link(resolved_source, destination_path) + except OSError: + shutil.copy2(resolved_source, destination_path) + return + + def _materialize_hf_onnx_artifact(self, model_dir, runtime_key, runtime_config, schema, model_path): + """Prepare an ONNX artifact outside the HF symlink snapshot for ORT.""" + root_dir = self._hf_onnx_materialized_root(model_dir=model_dir, runtime_key=runtime_key) + materialized_paths = [] + model_snapshot_dir = Path(model_dir) + model_relative_path = Path(model_path).relative_to(model_snapshot_dir) + file_paths = [] + for file_path in self._runtime_file_list(runtime_config): + file_path = str(file_path) + if file_path.endswith(".onnx") or ".onnx_data" in file_path or file_path.endswith(".onnx.data"): + file_paths.append(file_path) + if str(model_relative_path) not in file_paths: + file_paths.append(str(model_relative_path)) + for file_path in file_paths: + source_path = self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=file_path) + if not source_path.exists(): + continue + destination_path = root_dir / Path(file_path) + self._materialize_hf_onnx_file( + source_path=source_path, + destination_path=destination_path, + ) + materialized_paths.append(destination_path) + materialized_model_path = root_dir / model_relative_path + if not materialized_model_path.exists(): + raise ValueError(f"Could not materialize ONNX model file {model_relative_path!s}.") + if materialized_paths: + self.P( + f"Materialized HF ONNX artifact {runtime_key} with {len(materialized_paths)} file(s) " + f"under {root_dir}.", + color="y", + ) + return materialized_model_path + def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema): """Resolve tokenizer directory for the selected artifact runtime.""" tokenizer_dir = None @@ -855,6 +916,13 @@ def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_confi runtime_config=runtime_config, schema=schema, ) + model_path = self._materialize_hf_onnx_artifact( + model_dir=model_dir, + runtime_key=runtime_key, + runtime_config=runtime_config, + schema=schema, + model_path=model_path, + ) provider = runtime_config.get("provider") or "CPUExecutionProvider" providers = runtime_config.get("providers") or [provider] session = self._create_hf_onnx_session( diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index ff017b5ed..5c2077dce 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -875,6 +875,78 @@ def fake_create_session(model_path, providers): self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") self.assertTrue(fake_tokenizer.calls[-1][1]["return_offsets_mapping"]) + def test_onnx_artifact_pipeline_materializes_symlinked_external_data(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + fake_tokenizer = _FakeTokenizer() + fake_session = _FakeOrtSession() + created_sessions = [] + plugin._load_hf_onnx_tokenizer = ( # pylint: disable=protected-access + lambda model_dir, runtime_config, manifest=None: fake_tokenizer + ) + plugin._create_hf_onnx_session = ( # pylint: disable=protected-access + lambda model_path, providers: created_sessions.append((Path(model_path), providers)) or fake_session + ) + + with TemporaryDirectory() as tmpdir: + root_dir = Path(tmpdir) + model_dir = root_dir / "snapshot" + blob_dir = root_dir / "blobs" + cache_dir = root_dir / "models-cache" + onnx_dir = model_dir / "onnx" + blob_dir.mkdir(parents=True) + onnx_dir.mkdir(parents=True) + cache_dir.mkdir() + plugin.log.get_models_folder = lambda: str(cache_dir) + (blob_dir / "model.onnx").write_text("onnx", encoding="utf-8") + (blob_dir / "model.onnx_data").write_text("weights", encoding="utf-8") + (onnx_dir / "model.onnx").symlink_to(blob_dir / "model.onnx") + (onnx_dir / "model.onnx_data").symlink_to(blob_dir / "model.onnx_data") + (model_dir / "schema.json").write_text( + '{"outputs":[{"name":"scores"}],"models":{"onnx_fp32":{"path":"onnx/model.onnx"}}}', + encoding="utf-8", + ) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema, **kwargs):\n return outputs\n", + encoding="utf-8", + ) + manifest = { + "pipeline_task": "text-classification", + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "trust_remote_code": False, + "files": [ + "onnx/model.onnx", + "onnx/model.onnx_data", + "schema.json", + "contract.py", + ], + } + }, + } + + plugin._build_hf_onnx_artifact_pipeline( # pylint: disable=protected-access + model_dir=str(model_dir), + runtime_key="onnx_fp32", + runtime_config=manifest["runtimes"]["onnx_fp32"], + manifest=manifest, + ) + + materialized_model_path = created_sessions[0][0] + materialized_sidecar_path = materialized_model_path.parent / "model.onnx_data" + self.assertTrue(materialized_model_path.exists()) + self.assertTrue(materialized_sidecar_path.exists()) + self.assertFalse(materialized_model_path.is_symlink()) + self.assertFalse(materialized_sidecar_path.is_symlink()) + self.assertEqual(materialized_sidecar_path.read_text(encoding="utf-8"), "weights") + self.assertIn("_onnx_materialized", str(materialized_model_path)) + if __name__ == "__main__": unittest.main() From dd008107ffd231a225b2166688971ee8f857f1b3 Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Wed, 20 May 2026 17:35:11 +0300 Subject: [PATCH 09/11] fix: clone worker app configured branch What changed: - Use the configured WorkerAppRunner branch when building git clone commands. - Add regression coverage for branch-aware and branchless clone behavior. Why: - WorkerAppRunner already monitored VCS_DATA.BRANCH for updates, but initial setup could still clone the repository default branch. --- .../container_apps/test_worker_app_runner.py | 39 +++++++++++++++++++ .../container_apps/worker_app_runner.py | 12 +++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/extensions/business/container_apps/test_worker_app_runner.py b/extensions/business/container_apps/test_worker_app_runner.py index 8512b0b15..c024cd3ba 100644 --- a/extensions/business/container_apps/test_worker_app_runner.py +++ b/extensions/business/container_apps/test_worker_app_runner.py @@ -89,7 +89,23 @@ def _install_dummy_base_plugin(): sys.modules['naeural_core.business.base.web_app.base_tunnel_engine_plugin'] = base_tunnel_mod +def _install_dummy_docker_module(): + docker_mod = types.ModuleType('docker') + docker_types_mod = types.ModuleType('docker.types') + + class _DummyDeviceRequest: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + docker_types_mod.DeviceRequest = _DummyDeviceRequest + docker_mod.types = docker_types_mod + sys.modules.setdefault('docker', docker_mod) + sys.modules.setdefault('docker.types', docker_types_mod) + + _install_dummy_base_plugin() +_install_dummy_docker_module() from extensions.business.container_apps.worker_app_runner import WorkerAppRunnerPlugin from extensions.business.container_apps import container_utils @@ -172,6 +188,29 @@ def test_configure_repo_url_token_only(self): plugin._configure_repo_url() self.assertEqual(plugin.repo_url, "https://token@github.com/ratio1/demo.git") + def test_git_clone_command_uses_configured_branch(self): + """Test repository setup clones the same branch monitored for updates.""" + plugin = self._make_plugin() + plugin.repo_url = "https://github.com/ratio1/demo.git" + plugin.branch = "develop" + + command = plugin._build_git_clone_command("/app") + + self.assertEqual( + command, + "git clone --branch develop --single-branch https://github.com/ratio1/demo.git /app", + ) + + def test_git_clone_command_allows_branchless_clone(self): + """Test repository setup keeps the old default-branch behavior if branch is absent.""" + plugin = self._make_plugin() + plugin.repo_url = "https://github.com/ratio1/demo.git" + plugin.branch = None + + command = plugin._build_git_clone_command("/app") + + self.assertEqual(command, "git clone https://github.com/ratio1/demo.git /app") + def test_check_image_updates_respects_autoupdate_flag(self): """Test that image update checks respect the AUTOUPDATE flag.""" plugin = self._make_plugin() diff --git a/extensions/business/container_apps/worker_app_runner.py b/extensions/business/container_apps/worker_app_runner.py index 04366d2e1..6a0fe1166 100644 --- a/extensions/business/container_apps/worker_app_runner.py +++ b/extensions/business/container_apps/worker_app_runner.py @@ -10,6 +10,8 @@ - Streams logs and manages tunnel lifecycle through the base runner """ +import shlex + import requests from urllib.parse import urlsplit @@ -211,6 +213,14 @@ def _build_git_bootstrap_command(self): inner_block = " ".join(f"{part};" for part in checks) + " fi;" return f"if ! command -v git >/dev/null 2>&1; then {inner_block} fi" + def _build_git_clone_command(self, repo_path): + repo_url = shlex.quote(self.repo_url) + repo_path = shlex.quote(repo_path) + if self.branch: + branch = shlex.quote(self.branch) + return f"git clone --branch {branch} --single-branch {repo_url} {repo_path}" + return f"git clone {repo_url} {repo_path}" + def _collect_exec_commands(self): """ Collect commands to execute inside container. @@ -241,7 +251,7 @@ def _collect_exec_commands(self): ] if self.cfg_setup_repo: commands.append(f"rm -rf {repo_path}") - commands.append(f"git clone {self.repo_url} {repo_path}") + commands.append(self._build_git_clone_command(repo_path)) # endif # last_commit = commit commands.extend([f"cd {repo_path} && {cmd}" for cmd in base_commands]) From 586605a3a1397d59db651a53f1da6a1bfce53357 Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 21 May 2026 12:49:49 +0300 Subject: [PATCH 10/11] fix: prefer quantized privacy filter onnx runtime What changed: - Default privacy-filter ONNX selection to the quantized runtime and declare public ONNX variants. - Expose serving timing metadata through serving outputs and inference API responses. - Add configurable multi-text warmup support and opt-in HF runtime profiling coverage. - Document branch-limited WorkerAppRunner clone behavior. Why: - Existing deployed privacy-filter configs should get the lower-latency CPU runtime without per-instance config changes, while preserving explicit runtime overrides and benchmark reproducibility. --- .../container_apps/worker_app_runner.py | 2 + .../privacy_filter_inference_api.py | 2 + .../test_privacy_filter_inference_api.py | 13 + .../test_text_classifier_inference_api.py | 13 + .../text_classifier_inference_api.py | 2 + .../default_inference/nlp/th_hf_model_base.py | 92 ++- .../nlp/th_privacy_filter.py | 129 +++- .../nlp/th_text_classifier.py | 22 +- extensions/serving/test_th_hf_model_base.py | 20 + .../serving/test_th_hf_runtime_profile.py | 574 ++++++++++++++++++ extensions/serving/test_th_privacy_filter.py | 78 +++ extensions/serving/test_th_text_classifier.py | 16 + 12 files changed, 915 insertions(+), 48 deletions(-) create mode 100644 extensions/serving/test_th_hf_runtime_profile.py diff --git a/extensions/business/container_apps/worker_app_runner.py b/extensions/business/container_apps/worker_app_runner.py index 6a0fe1166..aafacf9eb 100644 --- a/extensions/business/container_apps/worker_app_runner.py +++ b/extensions/business/container_apps/worker_app_runner.py @@ -218,6 +218,8 @@ def _build_git_clone_command(self, repo_path): repo_path = shlex.quote(repo_path) if self.branch: branch = shlex.quote(self.branch) + # Fetch only the configured branch; this keeps startup clones smaller + # and avoids accidentally running code from a different branch. return f"git clone --branch {branch} --single-branch {repo_url} {repo_path}" return f"git clone {repo_url} {repo_path}" diff --git a/extensions/business/edge_inference_api/privacy_filter_inference_api.py b/extensions/business/edge_inference_api/privacy_filter_inference_api.py index 3cf0bd727..eeb836ccb 100644 --- a/extensions/business/edge_inference_api/privacy_filter_inference_api.py +++ b/extensions/business/edge_inference_api/privacy_filter_inference_api.py @@ -105,4 +105,6 @@ def _build_result_from_inference( # pylint: disable=arguments-differ result_payload["hf_runtime"] = inference["HF_RUNTIME"] if "RUNTIME" in inference: result_payload["runtime"] = inference["RUNTIME"] + if "SERVING_TIMINGS" in inference: + result_payload["serving_timings"] = inference["SERVING_TIMINGS"] return result_payload diff --git a/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py b/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py index ecb3bec75..2bc03a721 100644 --- a/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py +++ b/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py @@ -60,6 +60,11 @@ def test_build_result_from_inference_uses_findings_key(self): "MODEL_REVISION": "rev-privacy", "HF_RUNTIME": "pt", "RUNTIME": "transformers", + "SERVING_TIMINGS": { + "model_pipeline_elapsed_s": 0.456, + "active_payloads": 1, + "batch_size": 1, + }, }, metadata={}, request_data={"metadata": {}, "parameters": {"text": "example text"}}, @@ -86,6 +91,14 @@ def test_build_result_from_inference_uses_findings_key(self): self.assertEqual(result_payload["model_revision"], "rev-privacy") self.assertEqual(result_payload["hf_runtime"], "pt") self.assertEqual(result_payload["runtime"], "transformers") + self.assertEqual( + result_payload["serving_timings"], + { + "model_pipeline_elapsed_s": 0.456, + "active_payloads": 1, + "batch_size": 1, + }, + ) if __name__ == "__main__": diff --git a/extensions/business/edge_inference_api/test_text_classifier_inference_api.py b/extensions/business/edge_inference_api/test_text_classifier_inference_api.py index aa945ea3f..68add73c8 100644 --- a/extensions/business/edge_inference_api/test_text_classifier_inference_api.py +++ b/extensions/business/edge_inference_api/test_text_classifier_inference_api.py @@ -165,6 +165,11 @@ def test_build_result_from_inference_preserves_runtime_model_metadata(self): "MODEL_VERSION": "2026.05.09", "HF_RUNTIME": "onnx_fp32", "RUNTIME": "onnxruntime", + "SERVING_TIMINGS": { + "model_pipeline_elapsed_s": 0.123, + "active_payloads": 1, + "batch_size": 1, + }, }, metadata={}, request_data={"metadata": {}, "parameters": {"text": "example text"}}, @@ -178,6 +183,14 @@ def test_build_result_from_inference_preserves_runtime_model_metadata(self): self.assertEqual(result_payload["model_version"], "2026.05.09") self.assertEqual(result_payload["hf_runtime"], "onnx_fp32") self.assertEqual(result_payload["runtime"], "onnxruntime") + self.assertEqual( + result_payload["serving_timings"], + { + "model_pipeline_elapsed_s": 0.123, + "active_payloads": 1, + "batch_size": 1, + }, + ) def test_handle_inferences_falls_back_to_payload_request_id(self): plugin = TextClassifierInferenceApiPlugin() diff --git a/extensions/business/edge_inference_api/text_classifier_inference_api.py b/extensions/business/edge_inference_api/text_classifier_inference_api.py index 2de5aa1ae..106b31b8f 100644 --- a/extensions/business/edge_inference_api/text_classifier_inference_api.py +++ b/extensions/business/edge_inference_api/text_classifier_inference_api.py @@ -415,6 +415,8 @@ def _build_result_from_inference( result_payload["hf_runtime"] = inference["HF_RUNTIME"] if "RUNTIME" in inference: result_payload["runtime"] = inference["RUNTIME"] + if "SERVING_TIMINGS" in inference: + result_payload["serving_timings"] = inference["SERVING_TIMINGS"] return result_payload def handle_inference_for_request( diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 5b530c42c..53ff5c5c1 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -12,6 +12,7 @@ import os import shutil from pathlib import Path, PurePosixPath +from time import perf_counter import torch as th @@ -48,6 +49,7 @@ "INFERENCE_KWARGS": {}, "WARMUP_ENABLED": True, "WARMUP_TEXT": "Warmup request.", + "WARMUP_TEXTS": None, "WARMUP_INFERENCE_KWARGS": {}, "RUNS_ON_EMPTY_INPUT": False, "VALIDATION_RULES": { @@ -81,18 +83,42 @@ def __init__( self.task = task self.framework = "onnxruntime" self.max_length = max_length + self.last_call_timings = None return def __call__(self, texts, **kwargs): """Run one or more text inputs through the ONNX artifact.""" is_single_text = isinstance(texts, str) text_items = [texts] if is_single_text else list(texts or []) - results = [ - self._run_single_text(text=text, inference_kwargs=kwargs) - for text in text_items - ] + call_started = perf_counter() + results = [] + item_timings = [] + for text in text_items: + result, timings = self._run_single_text(text=text, inference_kwargs=kwargs) + results.append(result) + item_timings.append(timings) + self.last_call_timings = self._aggregate_item_timings( + item_timings=item_timings, + total_s=perf_counter() - call_started, + ) return results[0] if is_single_text or len(results) == 1 else results + def _aggregate_item_timings(self, item_timings, total_s): + """Aggregate per-text ONNX timings for the last pipeline call.""" + totals = { + "onnx_pipeline_total_s": total_s, + "onnx_items": len(item_timings), + } + for key in ( + "onnx_tokenize_s", + "onnx_prepare_inputs_s", + "onnx_session_run_s", + "onnx_decode_s", + "onnx_single_total_s", + ): + totals[key] = sum(item.get(key, 0.0) for item in item_timings) + return totals + def _get_max_length(self, inference_kwargs): max_length = inference_kwargs.get("max_length") if max_length is not None: @@ -214,17 +240,33 @@ def _call_decoder(self, outputs_by_name, text, encoded, inference_kwargs): return self.decoder(outputs_by_name, self.schema, **decoder_kwargs) def _run_single_text(self, text, inference_kwargs): + total_started = perf_counter() + started = perf_counter() encoded = self._tokenize(text=text, inference_kwargs=inference_kwargs) + tokenize_s = perf_counter() - started + started = perf_counter() session_inputs = self._prepare_session_inputs(encoded) + prepare_inputs_s = perf_counter() - started output_names = self._output_names() + started = perf_counter() raw_outputs = self.session.run(output_names, session_inputs) + session_run_s = perf_counter() - started + started = perf_counter() outputs_by_name = self._build_output_map(raw_outputs, output_names) - return self._call_decoder( + decoded = self._call_decoder( outputs_by_name=outputs_by_name, text=text, encoded=encoded, inference_kwargs=inference_kwargs, ) + decode_s = perf_counter() - started + return decoded, { + "onnx_tokenize_s": tokenize_s, + "onnx_prepare_inputs_s": prepare_inputs_s, + "onnx_session_run_s": session_run_s, + "onnx_decode_s": decode_s, + "onnx_single_total_s": perf_counter() - total_started, + } class ThHfModelBase(BaseServingProcess): @@ -409,6 +451,22 @@ def get_warmup_text(self): return warmup_text.strip() return None + def get_warmup_texts(self): + """Return startup warmup texts, preserving `WARMUP_TEXT` compatibility.""" + warmup_texts = getattr(self, "cfg_warmup_texts", None) + if isinstance(warmup_texts, str): + warmup_texts = [warmup_texts] + if isinstance(warmup_texts, (list, tuple)): + texts = [ + text.strip() + for text in warmup_texts + if isinstance(text, str) and len(text.strip()) > 0 + ] + if texts: + return texts + warmup_text = self.get_warmup_text() + return [warmup_text] if warmup_text is not None else [] + def build_warmup_inference_kwargs(self): """Build keyword arguments used by the startup warmup call. @@ -423,6 +481,11 @@ def build_warmup_inference_kwargs(self): **dict(self.cfg_warmup_inference_kwargs or {}), } + def get_last_pipeline_timings(self): + """Return optional stage timings exposed by the loaded pipeline adapter.""" + timings = getattr(self.classifier, "last_call_timings", None) + return dict(timings) if isinstance(timings, dict) else {} + def _get_device_map(self): """Return the model-loading device map for helper configuration. @@ -963,22 +1026,25 @@ def _run_startup_warmup(self): Notes ----- Warmup is intentionally skipped when the pipeline is missing, disabled, or - configured with an empty warmup text. + configured without valid warmup texts. """ if not self.cfg_warmup_enabled or self.classifier is None: return - warmup_text = self.get_warmup_text() - if warmup_text is None: + warmup_texts = self.get_warmup_texts() + if not warmup_texts: return warmup_started_at = self.time() self.P( - f"Running startup warmup for {self.get_model_name()} on device {self.device}...", + f"Running startup warmup for {self.get_model_name()} on device {self.device} " + f"with {len(warmup_texts)} text(s)...", color="y", ) - self.classifier( - warmup_text, - **self.build_warmup_inference_kwargs(), - ) + warmup_inference_kwargs = self.build_warmup_inference_kwargs() + for warmup_text in warmup_texts: + self.classifier( + warmup_text, + **warmup_inference_kwargs, + ) self.P( "Startup warmup completed in {:.3f}s".format(self.time() - warmup_started_at), color="g", diff --git a/extensions/serving/default_inference/nlp/th_privacy_filter.py b/extensions/serving/default_inference/nlp/th_privacy_filter.py index e7a13ea31..5c2246a87 100644 --- a/extensions/serving/default_inference/nlp/th_privacy_filter.py +++ b/extensions/serving/default_inference/nlp/th_privacy_filter.py @@ -9,6 +9,7 @@ import json import math +from time import perf_counter from extensions.serving.default_inference.nlp.th_hf_model_base import ( _CONFIG as BASE_HF_MODEL_CONFIG, @@ -26,16 +27,89 @@ "TRUST_REMOTE_CODE": False, "EXPECTED_AI_ENGINES": ["privacy_filter"], "MAX_LENGTH": None, + "HF_ONNX_RUNTIME_KEY": "onnx_quantized", "INFERENCE_KWARGS": { "aggregation_strategy": "simple", }, + # Multi-text privacy-filter warmup is configurable, but should not be the + # built-in default: on CPU the representative warmup corpus took minutes and + # still did not remove varied-input ONNX Runtime shape tails. + "WARMUP_TEXTS": None, } FIXED_CENSOR_SIZE = 4 -PRIVACY_FILTER_ONNX_RUNTIME_KEY = "onnx_fp32" +PRIVACY_FILTER_ONNX_RUNTIME_KEY = "onnx_quantized" PRIVACY_FILTER_ONNX_MODEL_FILE = "onnx/model.onnx" PRIVACY_FILTER_VITERBI_FILE = "viterbi_calibration.json" +PRIVACY_FILTER_ONNX_COMMON_FILES = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + PRIVACY_FILTER_VITERBI_FILE, +] +PRIVACY_FILTER_ONNX_RUNTIME_SPECS = { + # Keep the preferred CPU runtime first as a defensive fallback for configs + # that request generic "onnx" without setting HF_ONNX_RUNTIME_KEY. + "onnx_quantized": { + "model": "onnx/model_quantized.onnx", + "sidecars": ["onnx/model_quantized.onnx_data"], + "precision": "quantized", + "stability": "experimental_cpu", + }, + "onnx_fp32": { + "model": PRIVACY_FILTER_ONNX_MODEL_FILE, + "sidecars": [ + "onnx/model.onnx_data", + "onnx/model.onnx_data_1", + "onnx/model.onnx_data_2", + ], + "precision": "fp32", + "stability": "default", + }, + "onnx_fp16": { + "model": "onnx/model_fp16.onnx", + "sidecars": [ + "onnx/model_fp16.onnx_data", + "onnx/model_fp16.onnx_data_1", + ], + "precision": "fp16", + "stability": "experimental_cpu", + }, + "onnx_q4": { + "model": "onnx/model_q4.onnx", + "sidecars": ["onnx/model_q4.onnx_data"], + "precision": "q4", + "stability": "experimental_cpu", + }, + "onnx_q4f16": { + "model": "onnx/model_q4f16.onnx", + "sidecars": ["onnx/model_q4f16.onnx_data"], + "precision": "q4f16", + "stability": "experimental_cpu", + }, +} + + +def _build_privacy_filter_onnx_runtime_config(spec): + """Build a manifest runtime entry for one public Privacy Filter ONNX file.""" + files = [ + *PRIVACY_FILTER_ONNX_COMMON_FILES, + spec["model"], + *spec.get("sidecars", []), + ] + return { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "pipeline_task": "token-classification", + "model": spec["model"], + "decoder_type": "privacy_filter_span_decoder", + "files": files, + "recommended_allow_patterns": list(files), + "providers": ["CPUExecutionProvider"], + "precision": spec.get("precision"), + "stability": spec.get("stability"), + } class ThPrivacyFilter(ThHfModelBase): @@ -50,34 +124,8 @@ def _get_hf_onnx_fallback_manifest(self): "source_repo_id": "openai/privacy-filter", "pipeline_task": "token-classification", "runtimes": { - PRIVACY_FILTER_ONNX_RUNTIME_KEY: { - "runtime": "onnxruntime", - "entrypoint": "onnxruntime.InferenceSession", - "pipeline_task": "token-classification", - "model": PRIVACY_FILTER_ONNX_MODEL_FILE, - "decoder_type": "privacy_filter_span_decoder", - "files": [ - "config.json", - "tokenizer.json", - "tokenizer_config.json", - PRIVACY_FILTER_VITERBI_FILE, - PRIVACY_FILTER_ONNX_MODEL_FILE, - "onnx/model.onnx_data", - "onnx/model.onnx_data_1", - "onnx/model.onnx_data_2", - ], - "recommended_allow_patterns": [ - "config.json", - "tokenizer.json", - "tokenizer_config.json", - PRIVACY_FILTER_VITERBI_FILE, - PRIVACY_FILTER_ONNX_MODEL_FILE, - "onnx/model.onnx_data", - "onnx/model.onnx_data_1", - "onnx/model.onnx_data_2", - ], - "providers": ["CPUExecutionProvider"], - }, + runtime_key: _build_privacy_filter_onnx_runtime_config(spec) + for runtime_key, spec in PRIVACY_FILTER_ONNX_RUNTIME_SPECS.items() }, } @@ -546,10 +594,23 @@ def predict(self, preprocessed_inputs): "max_length": self.cfg_max_length, **inference_kwargs, } - outputs = [] if not texts else self.classifier(texts, **inference_kwargs) + model_pipeline_elapsed_s = 0.0 + if texts: + model_pipeline_started = perf_counter() + outputs = self.classifier(texts, **inference_kwargs) + model_pipeline_elapsed_s = perf_counter() - model_pipeline_started + else: + outputs = [] + serving_timings = { + "model_pipeline_elapsed_s": model_pipeline_elapsed_s, + "active_payloads": len(texts), + "batch_size": len(texts), + **self.get_last_pipeline_timings(), + } return { "payloads": preprocessed_inputs, "outputs": outputs, + "serving_timings": serving_timings, } def _is_privacy_span(self, item): @@ -717,6 +778,7 @@ def post_process(self, predictions): output_iter = iter(normalized_outputs) decoded = [] additional_metadata = self.get_additional_metadata() + serving_timings = predictions.get("serving_timings") for payload_info in predictions["payloads"]: if payload_info.get("ignored"): decoded.append([]) @@ -731,7 +793,7 @@ def post_process(self, predictions): label = self._extract_span_label(span) if label is not None and label not in detected_labels: detected_labels.append(label) - decoded.append({ + decoded_output = { "REQUEST_ID": payload_info["request_id"], "TEXT": payload_info["text"], "result": findings, @@ -741,5 +803,8 @@ def post_process(self, predictions): "DETECTED_ENTITY_GROUPS": detected_labels, "FINDINGS_COUNT": len(findings), **additional_metadata, - }) + } + if isinstance(serving_timings, dict): + decoded_output["SERVING_TIMINGS"] = dict(serving_timings) + decoded.append(decoded_output) return decoded diff --git a/extensions/serving/default_inference/nlp/th_text_classifier.py b/extensions/serving/default_inference/nlp/th_text_classifier.py index 43dabffdb..2b204917a 100644 --- a/extensions/serving/default_inference/nlp/th_text_classifier.py +++ b/extensions/serving/default_inference/nlp/th_text_classifier.py @@ -6,6 +6,8 @@ remote-code models usable by specifying only the Hugging Face model id. """ +from time import perf_counter + from extensions.serving.default_inference.nlp.th_hf_model_base import ( _CONFIG as BASE_HF_MODEL_CONFIG, ThHfModelBase, @@ -199,7 +201,9 @@ def predict(self, preprocessed_inputs): **dict(self.cfg_inference_kwargs or {}), } outputs = [] + model_pipeline_elapsed_s = 0.0 if texts: + model_pipeline_started = perf_counter() try: outputs = self.classifier(texts, **inference_kwargs) except AttributeError as exc: @@ -209,9 +213,17 @@ def predict(self, preprocessed_inputs): self.classifier(text, **inference_kwargs) for text in texts ] + model_pipeline_elapsed_s = perf_counter() - model_pipeline_started + serving_timings = { + "model_pipeline_elapsed_s": model_pipeline_elapsed_s, + "active_payloads": len(texts), + "batch_size": len(texts), + **self.get_last_pipeline_timings(), + } return { "payloads": preprocessed_inputs, "outputs": outputs, + "serving_timings": serving_timings, } def _normalize_outputs(self, outputs, expected_count): @@ -249,7 +261,7 @@ def _normalize_outputs(self, outputs, expected_count): f"Pipeline returned a scalar output for {expected_count} payloads." ) - def _default_decode_outputs(self, outputs, payloads): + def _default_decode_outputs(self, outputs, payloads, serving_timings=None): """Decode raw model outputs into the serving response contract. Parameters @@ -277,13 +289,16 @@ def _default_decode_outputs(self, outputs, payloads): serving_target = None if isinstance(payload_info.get("struct_payload"), dict): serving_target = payload_info["struct_payload"].get("__SERVING_TARGET__") - decoded.append({ + decoded_output = { "REQUEST_ID": payload_info["request_id"], "TEXT": payload_info["text"], "result": model_output, "SERVING_TARGET": serving_target, **additional_metadata, - }) + } + if isinstance(serving_timings, dict): + decoded_output["SERVING_TIMINGS"] = dict(serving_timings) + decoded.append(decoded_output) return decoded def post_process(self, predictions): @@ -304,4 +319,5 @@ def post_process(self, predictions): return self._default_decode_outputs( outputs=predictions["outputs"], payloads=predictions["payloads"], + serving_timings=predictions.get("serving_timings"), ) diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 5c2077dce..ade651b3e 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -31,6 +31,7 @@ def __init__(self, **kwargs): self.cfg_inference_kwargs = kwargs.get("INFERENCE_KWARGS", {}) self.cfg_warmup_enabled = kwargs.get("WARMUP_ENABLED", True) self.cfg_warmup_text = kwargs.get("WARMUP_TEXT", "Warmup request.") + self.cfg_warmup_texts = kwargs.get("WARMUP_TEXTS") self.cfg_warmup_inference_kwargs = kwargs.get("WARMUP_INFERENCE_KWARGS", {}) self.cfg_model_instance_id = kwargs.get("MODEL_INSTANCE_ID") self.os_environ = {} @@ -217,6 +218,21 @@ def test_startup_runs_default_warmup(self): True, ) + def test_startup_runs_configured_warmup_texts(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + PIPELINE_TASK="text-classification", + WARMUP_TEXTS=["short warmup", "longer warmup text"], + ) + + plugin.startup() + + calls = _PIPELINE_FACTORY.instance.inference_calls + self.assertEqual([call[0] for call in calls[-2:]], ["short warmup", "longer warmup text"]) + for _text, kwargs in calls[-2:]: + self.assertEqual(kwargs["max_length"], 512) + self.assertEqual(kwargs["truncation"], True) + def test_startup_adds_4bit_quantization_config(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", @@ -874,6 +890,10 @@ def fake_create_session(model_path, providers): self.assertEqual(inputs["input_ids"].dtype, "int64") self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") self.assertTrue(fake_tokenizer.calls[-1][1]["return_offsets_mapping"]) + self.assertEqual(pipeline.last_call_timings["onnx_items"], 1) + self.assertIn("onnx_tokenize_s", pipeline.last_call_timings) + self.assertIn("onnx_session_run_s", pipeline.last_call_timings) + self.assertIn("onnx_decode_s", pipeline.last_call_timings) def test_onnx_artifact_pipeline_materializes_symlinked_external_data(self): plugin = _ConcreteHfModel( diff --git a/extensions/serving/test_th_hf_runtime_profile.py b/extensions/serving/test_th_hf_runtime_profile.py new file mode 100644 index 000000000..04638e529 --- /dev/null +++ b/extensions/serving/test_th_hf_runtime_profile.py @@ -0,0 +1,574 @@ +import json +import os +import random +import shutil +import statistics +import tempfile +import time +import unittest + +from pathlib import Path + + +RUN_PROFILE_TESTS = os.getenv("EE_RUN_HF_PROFILE_TESTS") == "1" + + +DEFAULT_PROFILE_TEXTS = [ + "Short status: all clear.", + "Please classify this short customer support message about a delayed shipment.", + "The robot cell stopped after station 3 reported a missing safety acknowledgement at 06:41 UTC.", + ( + "A technician reports intermittent failures in the packaging line. " + "The PLC shows normal voltage, but the camera trigger sometimes arrives " + "late and causes the downstream reject gate to miss the product window." + ), + ( + "Contact john.doe@example.com about order RF-2026-05-11. " + "The payload includes an address, a phone number, and an internal ticket id." + ), + ( + "We need a concise operational summary for a factory shift handoff. " + "Include the machine status, recent alerts, inferred severity, and whether " + "the next operator should escalate to maintenance before restarting." + ), + ( + "Long diagnostic note: the API gateway accepted a request from the local " + "RobotFactory client, forwarded it to the text classifier, then waited for " + "the asynchronous result. During that interval the conveyor supervisor " + "published two state changes, one warning, and one final recovery message." + ), + ( + "Privacy review text with mixed content: Maria Ionescu, phone +40 721 123 456, " + "email maria.ionescu@example.org, visited the Bucharest site on 2026-05-11 " + "and reported that badge number BC-4421 failed twice at the entrance." + ), + ( + "Factory note: camera CAM-07 saw three rejected parts after the gripper changed " + "speed from 70 percent to 92 percent during the night shift." + ), + ( + "Billing escalation for Acme Robotics: invoice INV-2026-0512 was disputed by " + "accounts-payable@example.net after the VAT identifier changed." + ), + ( + "Health and safety report: worker badge BC-4421 entered Zone C at 2026-05-11 " + "09:32, then left before the evacuation drill started." + ), + ( + "Plain operational prose with no obvious private entities. The scheduler queued " + "five maintenance jobs and left two optional inspections for tomorrow." + ), + ( + "Longer diagnostic paragraph: the API gateway accepted a request, submitted it " + "to the privacy filter, polled for completion, and returned a redacted response " + "after the downstream inference service completed its model call." + ), + ( + "Mixed identifier sample: user alex.smith@example.org, phone 555-0188, customer " + "ID CUST-88019, device serial RF-AX-0091, and address 24 Industrial Way." + ), + ( + "Legal-style note: the contractor named Jordan Blake signed amendment A-17 on " + "May 9, 2026, but requested that their home address remain confidential." + ), + ( + "Small multilingual-looking ASCII text: Ana Popescu said bonjour to Carlos " + "before sending order number FR-2026-771 to logistics." + ), + ( + "Stack trace excerpt: ValueError at worker.py line 184 while handling request " + "req_01HX9, no email address or person name is expected in this text." + ), + ( + "Customer support message: Priya asked support to call +1 415 555 0134 after " + "her access token expired during login from office IP 203.0.113.42." + ), + "Very short PII sample: Sam, 555-0101.", + ( + "A medium privacy-heavy example lists Jane Roe, employee E-3812, jane.roe@corp.example, " + "passport P1234567, and a meeting room booking for Floor 14." + ), + ( + "A longer privacy-heavy example mentions Michael Turner, phone +49 30 1234 5678, " + "email michael.turner@example.de, bank reference DE89 3704 0044 0532 0130 00, " + "and shipment route Berlin to Cluj for audit review." + ), + ( + "A long non-private factory narrative describes conveyor speeds, motor current, " + "temperature drift, retry counters, watchdog resets, operator acknowledgements, " + "and a planned calibration window with no customer or employee identifiers." + ), +] + + +def _env_int(name, default): + value = os.getenv(name) + if value is None: + return default + return int(value) + + +def _env_float(name, default=None): + value = os.getenv(name) + if value is None or value == "": + return default + return float(value) + + +def _env_bool(name, default=False): + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _latency_summary_ms(latencies): + ordered = sorted(latencies) + p95_index = min(len(ordered) - 1, int(len(ordered) * 0.95)) + return { + "runs": len(latencies), + "mean_ms": statistics.fmean(latencies) * 1000.0, + "median_ms": statistics.median(latencies) * 1000.0, + "p95_ms": ordered[p95_index] * 1000.0, + "min_ms": ordered[0] * 1000.0, + "max_ms": ordered[-1] * 1000.0, + } + + +def _print_profile_summary(label, load_seconds, stage_latencies, input_count=None, seed=None): + input_info = "" + if input_count is not None: + input_info = f" inputs={input_count}" + if seed is not None: + input_info = f"{input_info} seed={seed}" + print(f"{label} profile: load={load_seconds:.3f}s{input_info}") + + summaries = {} + for stage_name, latencies in stage_latencies.items(): + if not latencies: + continue + summary = _latency_summary_ms(latencies) + summaries[stage_name] = summary + print( + " {}: runs={} mean={:.3f}ms median={:.3f}ms p95={:.3f}ms min={:.3f}ms max={:.3f}ms".format( + stage_name, + summary["runs"], + summary["mean_ms"], + summary["median_ms"], + summary["p95_ms"], + summary["min_ms"], + summary["max_ms"], + ) + ) + return summaries + + +def _comma_list_env(name, default): + value = os.getenv(name) + if value is None: + return list(default) + return [item.strip() for item in value.split(",") if item.strip()] + + +def _split_text_env(value): + if not value: + return [] + return [item.strip() for item in value.split("|||") if item.strip()] + + +def _profile_texts_from_env(): + text_file = os.getenv("EE_HF_PROFILE_TEXT_FILE") + if text_file: + return [ + line.strip() + for line in Path(text_file).read_text(encoding="utf-8").splitlines() + if line.strip() + ] + texts = _split_text_env(os.getenv("EE_HF_PROFILE_TEXTS")) + if texts: + return texts + text = os.getenv("EE_HF_PROFILE_TEXT") + if text: + return [text] + return list(DEFAULT_PROFILE_TEXTS) + + +def _build_profile_input_sequence(runs, seed): + texts = _profile_texts_from_env() + if not texts: + raise ValueError("At least one profiling text must be configured.") + allow_repeats = _env_bool("EE_HF_PROFILE_ALLOW_REPEATS", default=False) + if not allow_repeats and runs > len(texts): + raise ValueError( + f"EE_HF_PROFILE_RUNS={runs} requires {runs} unique texts, but only " + f"{len(texts)} are available. Provide EE_HF_PROFILE_TEXTS, " + "EE_HF_PROFILE_TEXT_FILE, or set EE_HF_PROFILE_ALLOW_REPEATS=1." + ) + if allow_repeats: + repeats = (runs + len(texts) - 1) // len(texts) + sequence = (texts * repeats)[:runs] + else: + sequence = list(texts) + random.Random(seed).shuffle(sequence) + return sequence[:runs] + + +def _warmup_texts_from_env(default_sequence, warmup_runs): + warmup_texts = _split_text_env(os.getenv("EE_HF_PROFILE_WARMUP_TEXTS")) + if warmup_texts: + return warmup_texts + if warmup_runs <= 0: + return [] + return default_sequence[:warmup_runs] or default_sequence[:1] + + +def _onnx_allow_patterns(model_file): + model_path = Path(model_file) + model_dir = model_path.parent.as_posix() + model_name = model_path.name + sidecar_prefix = f"{model_file}_data" + if model_dir == ".": + sidecar_prefix = f"{model_name}_data" + return _comma_list_env( + "EE_HF_PROFILE_ONNX_ALLOW_PATTERNS", + [ + model_file, + f"{sidecar_prefix}*", + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.txt", + "merges.txt", + "sentencepiece.bpe.model", + "spiece.model", + "viterbi_calibration.json", + ], + ) + + +def _copy_or_link(source, destination): + destination.parent.mkdir(parents=True, exist_ok=True) + source = source.resolve() + if destination.exists(): + return + try: + os.link(source, destination) + except OSError: + shutil.copy2(source, destination) + return + + +def _materialize_onnx_for_runtime(snapshot_dir, model_file, tmpdir): + snapshot_dir = Path(snapshot_dir) + source_model = snapshot_dir / model_file + materialized_root = Path(tmpdir) / "materialized_onnx" + materialized_model = materialized_root / model_file + _copy_or_link(source_model, materialized_model) + + sidecar_globs = [ + f"{source_model.name}_data*", + f"{source_model.name}.data*", + ] + for sidecar_glob in sidecar_globs: + for sidecar in source_model.parent.glob(sidecar_glob): + if sidecar.is_file() or sidecar.is_symlink(): + relative_sidecar = sidecar.relative_to(snapshot_dir) + _copy_or_link(sidecar, materialized_root / relative_sidecar) + return str(materialized_model) + + +def _session_inputs(session, encoded): + inputs = {} + encoded_items = dict(encoded.items()) if hasattr(encoded, "items") else dict(encoded) + for input_meta in session.get_inputs(): + input_name = input_meta.name + if input_name not in encoded_items: + continue + value = encoded_items[input_name] + input_type = getattr(input_meta, "type", "") + if "int64" in input_type and hasattr(value, "astype"): + value = value.astype("int64") + elif "int32" in input_type and hasattr(value, "astype"): + value = value.astype("int32") + inputs[input_name] = value + if inputs: + return inputs + return encoded_items + + +def _resolve_decoder_kind(model_name): + decoder_kind = os.getenv("EE_HF_PROFILE_ONNX_DECODER", "auto").strip().lower() + if decoder_kind in {"", "none", "null", "off", "false", "0"}: + return None + if decoder_kind == "auto": + if model_name == "openai/privacy-filter": + return "privacy_filter" + return None + return decoder_kind + + +def _build_privacy_filter_decoder(snapshot_dir): + from extensions.serving.default_inference.nlp.th_privacy_filter import ThPrivacyFilter + + snapshot_dir = Path(snapshot_dir) + config = json.loads((snapshot_dir / "config.json").read_text(encoding="utf-8")) + calibration_path = snapshot_dir / "viterbi_calibration.json" + calibration = {} + if calibration_path.exists(): + calibration = json.loads(calibration_path.read_text(encoding="utf-8")) + schema = { + "id2label": config.get("id2label", {}), + "viterbi_calibration": calibration, + } + decoder_owner = object.__new__(ThPrivacyFilter) + + def decode(outputs_by_name, text, encoded): + return decoder_owner._decode_privacy_filter_onnx_outputs( + outputs_by_name, + schema, + text=text, + tokenizer_output=encoded, + ) + + return decode + + +def _build_onnx_decoder(model_name, snapshot_dir): + decoder_kind = _resolve_decoder_kind(model_name=model_name) + if decoder_kind is None: + return None + if decoder_kind == "privacy_filter": + return _build_privacy_filter_decoder(snapshot_dir=snapshot_dir) + raise ValueError(f"Unsupported EE_HF_PROFILE_ONNX_DECODER={decoder_kind!r}.") + + +def _run_onnx_once(tokenizer, session, output_names, tokenize_kwargs, text, decoder=None): + total_started = time.perf_counter() + + started = time.perf_counter() + encoded = tokenizer(text, **tokenize_kwargs) + tokenize_seconds = time.perf_counter() - started + + started = time.perf_counter() + inputs = _session_inputs(session=session, encoded=encoded) + prepare_seconds = time.perf_counter() - started + + started = time.perf_counter() + raw_outputs = session.run(output_names, inputs) + session_seconds = time.perf_counter() - started + + started = time.perf_counter() + outputs_by_name = { + output_name: output_value + for output_name, output_value in zip(output_names, raw_outputs) + } + if decoder is None: + decoded = outputs_by_name + else: + decoded = decoder(outputs_by_name=outputs_by_name, text=text, encoded=encoded) + decode_seconds = time.perf_counter() - started + + return { + "decoded": decoded, + "total": time.perf_counter() - total_started, + "tokenize": tokenize_seconds, + "prepare_inputs": prepare_seconds, + "session_run": session_seconds, + "decode": decode_seconds, + } + + +@unittest.skipUnless( + RUN_PROFILE_TESTS, + "Set EE_RUN_HF_PROFILE_TESTS=1 to run real HF runtime profiling tests.", +) +class ThHfRuntimeProfileTests(unittest.TestCase): + """Opt-in profiling checks for real Transformers/PyTorch and ONNX Runtime paths.""" + + def setUp(self): + self.runs = _env_int("EE_HF_PROFILE_RUNS", 10) + self.warmup_runs = _env_int("EE_HF_PROFILE_WARMUP_RUNS", 2) + self.seed = _env_int("EE_HF_PROFILE_SHUFFLE_SEED", 12345) + if self.runs <= 0: + raise ValueError("EE_HF_PROFILE_RUNS must be greater than 0.") + if self.warmup_runs < 0: + raise ValueError("EE_HF_PROFILE_WARMUP_RUNS must not be negative.") + self.profile_texts = _build_profile_input_sequence(runs=self.runs, seed=self.seed) + self.warmup_texts = _warmup_texts_from_env( + default_sequence=self.profile_texts, + warmup_runs=self.warmup_runs, + ) + return + + def _assert_optional_threshold(self, label, summary, env_name): + threshold_ms = _env_float(env_name) + if threshold_ms is not None: + self.assertLessEqual( + summary["mean_ms"], + threshold_ms, + f"{label} mean latency exceeded {env_name}={threshold_ms}ms", + ) + return + + def test_profile_transformers_torch_pipeline(self): + model_name = os.getenv("EE_HF_PROFILE_TORCH_MODEL_NAME") + if not model_name: + self.skipTest("Set EE_HF_PROFILE_TORCH_MODEL_NAME to profile the PyTorch/Transformers runtime.") + + try: + from transformers import pipeline + except ImportError as exc: + self.skipTest(f"transformers is not installed: {exc}") + + task = os.getenv("EE_HF_PROFILE_TORCH_TASK") + if task is None and model_name == "openai/privacy-filter": + task = "token-classification" + if task is None: + task = "text-classification" + if task.strip().lower() in {"", "none", "null"}: + task = None + device = _env_int("EE_HF_PROFILE_TORCH_DEVICE", -1) + trust_remote_code = os.getenv("EE_HF_PROFILE_TORCH_TRUST_REMOTE_CODE") == "1" + + load_started = time.perf_counter() + classifier = pipeline( + task=task, + model=model_name, + tokenizer=os.getenv("EE_HF_PROFILE_TORCH_TOKENIZER_NAME") or model_name, + device=device, + trust_remote_code=trust_remote_code, + token=os.getenv("EE_HF_TOKEN") or os.getenv("HF_TOKEN"), + ) + if getattr(classifier, "framework", None) is None: + classifier.framework = "pt" + load_seconds = time.perf_counter() - load_started + + for warmup_text in self.warmup_texts: + classifier(warmup_text) + + stage_latencies = {"pipeline_total": []} + result = None + for text in self.profile_texts: + started = time.perf_counter() + result = classifier(text) + stage_latencies["pipeline_total"].append(time.perf_counter() - started) + + self.assertIsNotNone(result) + summaries = _print_profile_summary( + "torch", + load_seconds, + stage_latencies, + input_count=len(set(self.profile_texts)), + seed=self.seed, + ) + self._assert_optional_threshold( + label="torch.pipeline_total", + summary=summaries["pipeline_total"], + env_name="EE_HF_PROFILE_TORCH_MAX_MEAN_MS", + ) + return + + def test_profile_onnx_runtime_pipeline(self): + model_name = os.getenv("EE_HF_PROFILE_ONNX_MODEL_NAME") + if not model_name: + self.skipTest("Set EE_HF_PROFILE_ONNX_MODEL_NAME to profile the ONNX Runtime path.") + + try: + import onnxruntime as ort + from huggingface_hub import snapshot_download + from transformers import AutoTokenizer + except ImportError as exc: + self.skipTest(f"ONNX profiling dependencies are not installed: {exc}") + + model_file = os.getenv("EE_HF_PROFILE_ONNX_MODEL_FILE", "onnx/model.onnx") + trust_remote_code = os.getenv("EE_HF_PROFILE_ONNX_TRUST_REMOTE_CODE") == "1" + providers = _comma_list_env( + "EE_HF_PROFILE_ONNX_PROVIDERS", + ["CPUExecutionProvider"], + ) + + with tempfile.TemporaryDirectory(prefix="hf_onnx_profile_") as tmpdir: + snapshot_dir = snapshot_download( + repo_id=model_name, + revision=os.getenv("EE_HF_PROFILE_ONNX_REVISION") or None, + token=os.getenv("EE_HF_TOKEN") or os.getenv("HF_TOKEN"), + cache_dir=os.getenv("EE_HF_PROFILE_CACHE_DIR") or None, + allow_patterns=_onnx_allow_patterns(model_file=model_file), + repo_type="model", + ) + materialized_model = _materialize_onnx_for_runtime( + snapshot_dir=snapshot_dir, + model_file=model_file, + tmpdir=tmpdir, + ) + + load_started = time.perf_counter() + tokenizer = AutoTokenizer.from_pretrained( + snapshot_dir, + trust_remote_code=trust_remote_code, + ) + session = ort.InferenceSession(materialized_model, providers=providers) + decoder = _build_onnx_decoder(model_name=model_name, snapshot_dir=snapshot_dir) + load_seconds = time.perf_counter() - load_started + + tokenize_kwargs = { + "return_tensors": "np", + "truncation": True, + } + if decoder is not None: + tokenize_kwargs["return_offsets_mapping"] = True + max_length = os.getenv("EE_HF_PROFILE_ONNX_MAX_LENGTH") + if max_length: + tokenize_kwargs["max_length"] = int(max_length) + output_names = [output.name for output in session.get_outputs()] + + for warmup_text in self.warmup_texts: + _run_onnx_once( + tokenizer=tokenizer, + session=session, + output_names=output_names, + tokenize_kwargs=tokenize_kwargs, + text=warmup_text, + decoder=decoder, + ) + + stage_latencies = { + "pipeline_total": [], + "tokenize": [], + "prepare_inputs": [], + "session_run": [], + "decode": [], + } + result = None + for text in self.profile_texts: + result = _run_onnx_once( + tokenizer=tokenizer, + session=session, + output_names=output_names, + tokenize_kwargs=tokenize_kwargs, + text=text, + decoder=decoder, + ) + for stage_name in stage_latencies: + stage_latencies[stage_name].append(result[stage_name]) + + self.assertIsNotNone(result) + self.assertIsNotNone(result["decoded"]) + summaries = _print_profile_summary( + "onnx", + load_seconds, + stage_latencies, + input_count=len(set(self.profile_texts)), + seed=self.seed, + ) + self._assert_optional_threshold( + label="onnx.pipeline_total", + summary=summaries["pipeline_total"], + env_name="EE_HF_PROFILE_ONNX_MAX_MEAN_MS", + ) + return + + +if __name__ == "__main__": + unittest.main() diff --git a/extensions/serving/test_th_privacy_filter.py b/extensions/serving/test_th_privacy_filter.py index 417a4a8ba..9388a1a37 100644 --- a/extensions/serving/test_th_privacy_filter.py +++ b/extensions/serving/test_th_privacy_filter.py @@ -22,6 +22,9 @@ def __init__(self, **kwargs): self.cfg_expected_ai_engines = kwargs.get("EXPECTED_AI_ENGINES", getattr(self, "CONFIG", {}).get("EXPECTED_AI_ENGINES")) self.cfg_model_instance_id = kwargs.get("MODEL_INSTANCE_ID", getattr(self, "CONFIG", {}).get("MODEL_INSTANCE_ID")) self.cfg_model_name = kwargs.get("MODEL_NAME", getattr(self, "CONFIG", {}).get("MODEL_NAME")) + self.cfg_max_length = kwargs.get("MAX_LENGTH", getattr(self, "CONFIG", {}).get("MAX_LENGTH")) + self.cfg_inference_kwargs = kwargs.get("INFERENCE_KWARGS", getattr(self, "CONFIG", {}).get("INFERENCE_KWARGS", {})) + self.classifier = None self.logged_messages = [] def P(self, *args, **kwargs): @@ -44,6 +47,9 @@ def get_additional_metadata(self): "PIPELINE_TASK": self.get_pipeline_task(), } + def get_last_pipeline_timings(self): + return {} + def get_tokenizer_name(self): return self.cfg_model_name @@ -123,17 +129,46 @@ def _load_plugin_class(): ThPrivacyFilter = _load_plugin_class() +class _FakePrivacyPipeline: + def __init__(self): + self.calls = [] + + def __call__(self, texts, **kwargs): + self.calls.append((texts, kwargs)) + return [ + [{ + "entity_group": "private_person", + "score": 0.99, + "word": text, + "start": 0, + "end": len(text), + }] + for text in texts + ] + + class ThPrivacyFilterTests(unittest.TestCase): def test_config_pins_privacy_filter_defaults(self): self.assertEqual(ThPrivacyFilter.CONFIG["MODEL_NAME"], "openai/privacy-filter") self.assertEqual(ThPrivacyFilter.CONFIG["PIPELINE_TASK"], "token-classification") self.assertFalse(ThPrivacyFilter.CONFIG["TRUST_REMOTE_CODE"]) self.assertIsNone(ThPrivacyFilter.CONFIG["MAX_LENGTH"]) + self.assertEqual(ThPrivacyFilter.CONFIG["HF_ONNX_RUNTIME_KEY"], "onnx_quantized") self.assertEqual( ThPrivacyFilter.CONFIG["INFERENCE_KWARGS"]["aggregation_strategy"], "simple", ) + def test_privacy_filter_prefers_quantized_onnx_runtime_by_default(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + + manifest = plugin._get_hf_onnx_fallback_manifest() # pylint: disable=protected-access + runtime_key = next(iter(manifest["runtimes"])) + runtime_config = manifest["runtimes"][runtime_key] + + self.assertEqual(runtime_key, "onnx_quantized") + self.assertEqual(runtime_config["model"], "onnx/model_quantized.onnx") + def test_privacy_filter_declares_local_onnx_fallback_manifest(self): plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") @@ -146,6 +181,32 @@ def test_privacy_filter_declares_local_onnx_fallback_manifest(self): self.assertIn("onnx/model.onnx_data_2", runtime["files"]) self.assertIn("viterbi_calibration.json", runtime["recommended_allow_patterns"]) + def test_privacy_filter_declares_additional_onnx_runtime_variants(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + + manifest = plugin._get_hf_onnx_fallback_manifest() # pylint: disable=protected-access + runtimes = manifest["runtimes"] + + expected_models = { + "onnx_fp32": "onnx/model.onnx", + "onnx_fp16": "onnx/model_fp16.onnx", + "onnx_q4": "onnx/model_q4.onnx", + "onnx_q4f16": "onnx/model_q4f16.onnx", + "onnx_quantized": "onnx/model_quantized.onnx", + } + self.assertEqual(set(expected_models), set(runtimes)) + for runtime_key, model_file in expected_models.items(): + runtime = runtimes[runtime_key] + self.assertEqual(runtime["runtime"], "onnxruntime") + self.assertEqual(runtime["decoder_type"], "privacy_filter_span_decoder") + self.assertEqual(runtime["model"], model_file) + self.assertIn(model_file, runtime["files"]) + self.assertIn(model_file, runtime["recommended_allow_patterns"]) + self.assertIn("config.json", runtime["files"]) + self.assertIn("tokenizer.json", runtime["files"]) + self.assertIn("viterbi_calibration.json", runtime["files"]) + self.assertEqual(runtime["providers"], ["CPUExecutionProvider"]) + def test_privacy_filter_does_not_declare_onnx_fallback_for_other_models(self): plugin = ThPrivacyFilter(MODEL_NAME="other/privacy-filter") @@ -257,6 +318,23 @@ def test_privacy_filter_viterbi_decoder_rejects_invalid_terminal_inside_label(se self.assertEqual(spans[0]["entity_group"], "private_email") self.assertEqual(spans[0]["end"], 17) + def test_predict_adds_serving_timings_to_post_processed_output(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + plugin.classifier = _FakePrivacyPipeline() + prepared = [{ + "request_id": "req-a", + "text": "Alice", + "ignored": False, + }] + + predictions = plugin.predict(prepared) + decoded = plugin.post_process(predictions) + + self.assertEqual(predictions["serving_timings"]["active_payloads"], 1) + self.assertEqual(predictions["serving_timings"]["batch_size"], 1) + self.assertGreaterEqual(predictions["serving_timings"]["model_pipeline_elapsed_s"], 0.0) + self.assertEqual(decoded[0]["SERVING_TIMINGS"], predictions["serving_timings"]) + def test_post_process_emits_redaction_friendly_fields(self): plugin = ThPrivacyFilter() diff --git a/extensions/serving/test_th_text_classifier.py b/extensions/serving/test_th_text_classifier.py index 08f5892af..6731ece46 100644 --- a/extensions/serving/test_th_text_classifier.py +++ b/extensions/serving/test_th_text_classifier.py @@ -61,6 +61,9 @@ def get_additional_metadata(self): "PIPELINE_TASK": pipeline_task or self.get_pipeline_task(), } + def get_last_pipeline_timings(self): + return {} + def get_expected_ai_engines(self): expected = self.cfg_expected_ai_engines if expected is None: @@ -210,6 +213,16 @@ def test_predict_uses_pipeline_with_inference_kwargs(self): self.assertEqual(kwargs["max_length"], 512) self.assertEqual(kwargs["batch_size"], 4) self.assertEqual(predictions["outputs"][0]["label"], "ok") + self.assertEqual(predictions["serving_timings"]["active_payloads"], 1) + self.assertEqual(predictions["serving_timings"]["batch_size"], 1) + self.assertGreaterEqual(predictions["serving_timings"]["model_pipeline_elapsed_s"], 0.0) + + decoded = plugin.post_process(predictions) + + self.assertEqual( + decoded[0]["SERVING_TIMINGS"], + predictions["serving_timings"], + ) def test_predict_falls_back_to_sequential_for_broken_custom_batch_pipeline(self): plugin = ThTextClassifier(MODEL_NAME="org/generic-text-classifier") @@ -224,6 +237,9 @@ def test_predict_falls_back_to_sequential_for_broken_custom_batch_pipeline(self) self.assertEqual(len(predictions["outputs"]), 2) self.assertEqual(predictions["outputs"][0]["label"], "ok") self.assertEqual(predictions["outputs"][1]["label"], "ok") + self.assertEqual(predictions["serving_timings"]["active_payloads"], 2) + self.assertEqual(predictions["serving_timings"]["batch_size"], 2) + self.assertGreaterEqual(predictions["serving_timings"]["model_pipeline_elapsed_s"], 0.0) self.assertEqual(plugin.classifier.calls[0][0], ["hello", "world"]) self.assertEqual(plugin.classifier.calls[1][0], "hello") self.assertEqual(plugin.classifier.calls[2][0], "world") From e30099488da06e9644698991ac9ba31f2ae1cf7b Mon Sep 17 00:00:00 2001 From: Codex Date: Thu, 21 May 2026 14:50:20 +0300 Subject: [PATCH 11/11] chore: increment version What changed: - Bumped edge_node version from 2.10.240 to 2.10.241. Why: - Keep the PR one version ahead of develop before merge. --- ver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ver.py b/ver.py index 4dca92aa1..1154d2db1 100644 --- a/ver.py +++ b/ver.py @@ -1 +1 @@ -__VER__ = '2.10.240' +__VER__ = '2.10.241'