Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/stackone_defender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
RiskLevel,
Tier1Result,
Tier3Provider,
Tier3TokenUsage,
Tier3Verdict,
)
from .utils.boundary import contains_boundary_patterns, generate_boundary_instructions
Expand All @@ -44,6 +45,7 @@
"SfePreprocessResult",
"Tier1Result",
"Tier3Provider",
"Tier3TokenUsage",
"Tier3Verdict",
"contains_boundary_patterns",
"create_prompt_defense",
Expand Down
17 changes: 17 additions & 0 deletions src/stackone_defender/core/prompt_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Tier3Provider,
Tier3Result,
Tier3Skip,
Tier3TokenUsage,
Tier3Verdict,
)
from .tool_result_sanitizer import ToolResultSanitizer, create_tool_result_sanitizer
Expand Down Expand Up @@ -292,6 +293,21 @@ def is_tier2_ready(self) -> bool:
def _resolve_tier3_provider(self) -> Tier3Provider | None:
return self._tier3_custom_provider or get_default_tier3_provider()

@staticmethod
def _parse_tier3_usage(usage: Any) -> Tier3TokenUsage | None:
if usage is None or not isinstance(usage, dict):
return None
prompt_tokens = usage.get("prompt_tokens", usage.get("promptTokens"))
completion_tokens = usage.get("completion_tokens", usage.get("completionTokens"))
total_tokens = usage.get("total_tokens", usage.get("totalTokens"))
if not all(isinstance(value, int) for value in (prompt_tokens, completion_tokens, total_tokens)):
return None
return Tier3TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)

@staticmethod
def _validate_tier3_verdict(verdict: Any) -> Tier3Verdict | Tier3Skip:
if isinstance(verdict, Tier3Verdict):
Expand All @@ -317,6 +333,7 @@ def _validate_tier3_verdict(verdict: Any) -> Tier3Verdict | Tier3Skip:
score=verdict.get("score"),
raw=verdict.get("raw"),
latency_ms=verdict.get("latency_ms", verdict.get("latencyMs")),
usage=PromptDefense._parse_tier3_usage(verdict.get("usage")),
)

@staticmethod
Expand Down
10 changes: 10 additions & 0 deletions src/stackone_defender/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ class Tier1Result:
latency_ms: float


@dataclass
class Tier3TokenUsage:
"""Token usage reported by a Tier 3 provider (e.g. vLLM or OpenAI ``usage``)."""

prompt_tokens: int
completion_tokens: int
total_tokens: int

Comment thread
hiskudin marked this conversation as resolved.

@dataclass
class Tier3Verdict:
"""Authoritative block/allow decision from a Tier 3 provider."""
Expand All @@ -73,6 +82,7 @@ class Tier3Verdict:
score: float | None = None
raw: Any = None
latency_ms: float | None = None
usage: Tier3TokenUsage | None = None


@dataclass
Expand Down
83 changes: 81 additions & 2 deletions tests/test_tier3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand All @@ -12,7 +12,7 @@
get_default_tier3_provider,
set_default_tier3_provider,
)
from stackone_defender.types import Tier3Skip, Tier3Verdict
from stackone_defender.types import Tier3Skip, Tier3TokenUsage, Tier3Verdict


def _make_provider(decision: str) -> MagicMock:
Expand Down Expand Up @@ -378,3 +378,82 @@ def test_sync_batch_with_tier3_uses_async_path(self):
results = defense.defend_tool_results(items)
assert len(results) == 1
assert isinstance(results[0].tier3, Tier3Verdict)


class TestTier3UsagePropagation:
def test_passes_provider_reported_usage_through_to_result_tier3(self):
provider = MagicMock()
provider.classify = AsyncMock(
return_value={
"decision": "allow",
"latencyMs": 42,
"usage": {
"promptTokens": 311,
"completionTokens": 17,
"totalTokens": 328,
},
}
)
defense = create_prompt_defense(
enable_tier1=False,
enable_tier2=False,
enable_tier3=True,
defender_mode="tier3_only",
tier3={"provider": provider},
)

result = asyncio.run(defense.defend_tool_result_async({"body": "test"}, "test_tool"))

assert isinstance(result.tier3, Tier3Verdict)
assert result.tier3.usage == Tier3TokenUsage(
prompt_tokens=311,
completion_tokens=17,
total_tokens=328,
)
assert result.tier3.latency_ms == 42

def test_passes_snake_case_usage_through_public_api(self):
provider = MagicMock()
provider.classify = AsyncMock(
return_value={
"decision": "allow",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
}
)
defense = create_prompt_defense(
enable_tier1=False,
enable_tier2=False,
enable_tier3=True,
defender_mode="tier3_only",
tier3={"provider": provider},
)

result = asyncio.run(defense.defend_tool_result_async({"body": "test"}, "test_tool"))

assert isinstance(result.tier3, Tier3Verdict)
assert result.tier3.usage == Tier3TokenUsage(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
)

def test_preserves_usage_when_provider_returns_tier3_verdict_instance(self):
usage = Tier3TokenUsage(prompt_tokens=1, completion_tokens=2, total_tokens=3)
provider = MagicMock()
provider.classify = AsyncMock(return_value=Tier3Verdict(decision="allow", usage=usage))
defense = create_prompt_defense(
enable_tier1=False,
enable_tier2=False,
enable_tier3=True,
defender_mode="tier3_only",
tier3={"provider": provider},
)

result = asyncio.run(defense.defend_tool_result_async({"body": "test"}, "test_tool"))

assert isinstance(result.tier3, Tier3Verdict)
assert result.tier3.usage == usage
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading