From 2e182695c20e9b087041e55f7466cca9b659521a Mon Sep 17 00:00:00 2001 From: Hisku Date: Tue, 30 Jun 2026 13:09:51 +0100 Subject: [PATCH 1/2] feat: add Tier3Verdict.usage for TS 0.7.2 parity Let Tier 3 providers report prompt/completion/total token counts so callers can attribute LLM spend per request. Parses camelCase and snake_case usage dicts in verdict validation. Co-authored-by: Cursor --- src/stackone_defender/core/prompt_defense.py | 17 ++++++ src/stackone_defender/types.py | 10 ++++ tests/test_tier3.py | 63 +++++++++++++++++++- uv.lock | 2 +- 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/src/stackone_defender/core/prompt_defense.py b/src/stackone_defender/core/prompt_defense.py index 9f804a6..5e4d29a 100644 --- a/src/stackone_defender/core/prompt_defense.py +++ b/src/stackone_defender/core/prompt_defense.py @@ -30,6 +30,7 @@ Tier3Provider, Tier3Result, Tier3Skip, + Tier3TokenUsage, Tier3Verdict, ) from .tool_result_sanitizer import ToolResultSanitizer, create_tool_result_sanitizer @@ -292,6 +293,21 @@ def is_tier2_ready(self) -> bool: def _resolve_tier3_provider(self) -> Tier3Provider | None: return self._tier3_custom_provider or get_default_tier3_provider() + @staticmethod + def _parse_tier3_usage(usage: Any) -> Tier3TokenUsage | None: + if usage is None or not isinstance(usage, dict): + return None + prompt_tokens = usage.get("prompt_tokens", usage.get("promptTokens")) + completion_tokens = usage.get("completion_tokens", usage.get("completionTokens")) + total_tokens = usage.get("total_tokens", usage.get("totalTokens")) + if not all(isinstance(value, int) for value in (prompt_tokens, completion_tokens, total_tokens)): + return None + return Tier3TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + @staticmethod def _validate_tier3_verdict(verdict: Any) -> Tier3Verdict | Tier3Skip: if isinstance(verdict, Tier3Verdict): @@ -317,6 +333,7 @@ def _validate_tier3_verdict(verdict: Any) -> Tier3Verdict | Tier3Skip: score=verdict.get("score"), raw=verdict.get("raw"), latency_ms=verdict.get("latency_ms", verdict.get("latencyMs")), + usage=PromptDefense._parse_tier3_usage(verdict.get("usage")), ) @staticmethod diff --git a/src/stackone_defender/types.py b/src/stackone_defender/types.py index c131338..cc4434a 100644 --- a/src/stackone_defender/types.py +++ b/src/stackone_defender/types.py @@ -65,6 +65,15 @@ class Tier1Result: latency_ms: float +@dataclass +class Tier3TokenUsage: + """Token usage reported by a Tier 3 provider (e.g. vLLM or OpenAI ``usage``).""" + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + @dataclass class Tier3Verdict: """Authoritative block/allow decision from a Tier 3 provider.""" @@ -73,6 +82,7 @@ class Tier3Verdict: score: float | None = None raw: Any = None latency_ms: float | None = None + usage: Tier3TokenUsage | None = None @dataclass diff --git a/tests/test_tier3.py b/tests/test_tier3.py index e54ef39..6678986 100644 --- a/tests/test_tier3.py +++ b/tests/test_tier3.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -12,7 +12,8 @@ get_default_tier3_provider, set_default_tier3_provider, ) -from stackone_defender.types import Tier3Skip, Tier3Verdict +from stackone_defender.core.prompt_defense import PromptDefense +from stackone_defender.types import Tier3Skip, Tier3TokenUsage, Tier3Verdict def _make_provider(decision: str) -> MagicMock: @@ -378,3 +379,61 @@ def test_sync_batch_with_tier3_uses_async_path(self): results = defense.defend_tool_results(items) assert len(results) == 1 assert isinstance(results[0].tier3, Tier3Verdict) + + +class TestTier3UsagePropagation: + def test_passes_provider_reported_usage_through_to_result_tier3(self): + provider = MagicMock() + provider.classify = AsyncMock( + return_value={ + "decision": "allow", + "latencyMs": 42, + "usage": { + "promptTokens": 311, + "completionTokens": 17, + "totalTokens": 328, + }, + } + ) + defense = create_prompt_defense( + enable_tier1=False, + enable_tier2=False, + enable_tier3=True, + defender_mode="tier3_only", + tier3={"provider": provider}, + ) + + result = asyncio.run(defense.defend_tool_result_async({"body": "test"}, "test_tool")) + + assert isinstance(result.tier3, Tier3Verdict) + assert result.tier3.usage == Tier3TokenUsage( + prompt_tokens=311, + completion_tokens=17, + total_tokens=328, + ) + assert result.tier3.latency_ms == 42 + + def test_validate_tier3_verdict_parses_snake_case_usage(self): + validated = PromptDefense._validate_tier3_verdict( + { + "decision": "allow", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + ) + assert isinstance(validated, Tier3Verdict) + assert validated.usage == Tier3TokenUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + ) + + def test_validate_tier3_verdict_preserves_usage_on_dataclass_instance(self): + verdict = Tier3Verdict( + decision="allow", + usage=Tier3TokenUsage(prompt_tokens=1, completion_tokens=2, total_tokens=3), + ) + assert PromptDefense._validate_tier3_verdict(verdict) is verdict diff --git a/uv.lock b/uv.lock index 1a42afe..858f896 100644 --- a/uv.lock +++ b/uv.lock @@ -493,7 +493,7 @@ wheels = [ [[package]] name = "stackone-defender" -version = "0.7.0" +version = "0.7.1" source = { editable = "." } [package.optional-dependencies] From 56d72a61cbbc55963853c6e393ea32359d26844a Mon Sep 17 00:00:00 2001 From: Hisku Date: Tue, 30 Jun 2026 14:35:00 +0100 Subject: [PATCH 2/2] refactor: address Copilot PR review on Tier3TokenUsage export Export Tier3TokenUsage from the package root and cover usage parsing via the public defend_tool_result_async API instead of private helpers. Co-authored-by: Cursor --- src/stackone_defender/__init__.py | 2 ++ tests/test_tier3.py | 42 +++++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/stackone_defender/__init__.py b/src/stackone_defender/__init__.py index c6089f8..a7a8375 100644 --- a/src/stackone_defender/__init__.py +++ b/src/stackone_defender/__init__.py @@ -29,6 +29,7 @@ RiskLevel, Tier1Result, Tier3Provider, + Tier3TokenUsage, Tier3Verdict, ) from .utils.boundary import contains_boundary_patterns, generate_boundary_instructions @@ -44,6 +45,7 @@ "SfePreprocessResult", "Tier1Result", "Tier3Provider", + "Tier3TokenUsage", "Tier3Verdict", "contains_boundary_patterns", "create_prompt_defense", diff --git a/tests/test_tier3.py b/tests/test_tier3.py index 6678986..589c426 100644 --- a/tests/test_tier3.py +++ b/tests/test_tier3.py @@ -12,7 +12,6 @@ get_default_tier3_provider, set_default_tier3_provider, ) -from stackone_defender.core.prompt_defense import PromptDefense from stackone_defender.types import Tier3Skip, Tier3TokenUsage, Tier3Verdict @@ -413,9 +412,10 @@ def test_passes_provider_reported_usage_through_to_result_tier3(self): ) assert result.tier3.latency_ms == 42 - def test_validate_tier3_verdict_parses_snake_case_usage(self): - validated = PromptDefense._validate_tier3_verdict( - { + def test_passes_snake_case_usage_through_public_api(self): + provider = MagicMock() + provider.classify = AsyncMock( + return_value={ "decision": "allow", "usage": { "prompt_tokens": 10, @@ -424,16 +424,36 @@ def test_validate_tier3_verdict_parses_snake_case_usage(self): }, } ) - assert isinstance(validated, Tier3Verdict) - assert validated.usage == Tier3TokenUsage( + defense = create_prompt_defense( + enable_tier1=False, + enable_tier2=False, + enable_tier3=True, + defender_mode="tier3_only", + tier3={"provider": provider}, + ) + + result = asyncio.run(defense.defend_tool_result_async({"body": "test"}, "test_tool")) + + assert isinstance(result.tier3, Tier3Verdict) + assert result.tier3.usage == Tier3TokenUsage( prompt_tokens=10, completion_tokens=5, total_tokens=15, ) - def test_validate_tier3_verdict_preserves_usage_on_dataclass_instance(self): - verdict = Tier3Verdict( - decision="allow", - usage=Tier3TokenUsage(prompt_tokens=1, completion_tokens=2, total_tokens=3), + def test_preserves_usage_when_provider_returns_tier3_verdict_instance(self): + usage = Tier3TokenUsage(prompt_tokens=1, completion_tokens=2, total_tokens=3) + provider = MagicMock() + provider.classify = AsyncMock(return_value=Tier3Verdict(decision="allow", usage=usage)) + defense = create_prompt_defense( + enable_tier1=False, + enable_tier2=False, + enable_tier3=True, + defender_mode="tier3_only", + tier3={"provider": provider}, ) - assert PromptDefense._validate_tier3_verdict(verdict) is verdict + + result = asyncio.run(defense.defend_tool_result_async({"body": "test"}, "test_tool")) + + assert isinstance(result.tier3, Tier3Verdict) + assert result.tier3.usage == usage