Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions doc/code/memory/1_sqlite_memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
" timestamp: DATETIME NOT NULL\n",
" labels: JSON NOT NULL\n",
" prompt_metadata: JSON NOT NULL\n",
" targeted_harm_categories: JSON NULL\n",
" converter_identifiers: JSON NULL\n",
" prompt_target_identifier: JSON NOT NULL\n",
" attack_identifier: JSON NOT NULL\n",
" response_error: VARCHAR NULL\n",
" original_value_data_type: VARCHAR NOT NULL\n",
" original_value: VARCHAR NOT NULL\n",
Expand All @@ -49,6 +46,12 @@
" original_prompt_id: CHAR(32) NOT NULL\n",
" pyrit_version: VARCHAR NULL\n",
"\n",
"Table: Conversations\n",
"--------------------\n",
" conversation_id: VARCHAR NOT NULL\n",
" target_identifier: JSON NULL\n",
" pyrit_version: VARCHAR NULL\n",
"\n",
"Table: EmbeddingData\n",
"--------------------\n",
" id: CHAR(32) NOT NULL\n",
Expand Down Expand Up @@ -98,7 +101,6 @@
" id: CHAR(32) NOT NULL\n",
" conversation_id: VARCHAR NOT NULL\n",
" objective: VARCHAR NOT NULL\n",
" attack_identifier: JSON NOT NULL\n",
" atomic_attack_identifier: JSON NULL\n",
" objective_sha256: VARCHAR NULL\n",
" last_response_id: CHAR(32) NULL\n",
Expand All @@ -109,6 +111,7 @@
" outcome_reason: VARCHAR NULL\n",
" attack_metadata: JSON NULL\n",
" labels: JSON NULL\n",
" targeted_harm_categories: JSON NULL\n",
" pruned_conversation_ids: JSON NULL\n",
" adversarial_chat_conversation_ids: JSON NULL\n",
" timestamp: DATETIME NOT NULL\n",
Expand Down
5 changes: 1 addition & 4 deletions doc/code/memory/3_memory_data_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../
- **`labels`**: Dictionary of labels for categorization and filtering
- **`prompt_metadata`**: Component-specific metadata (e.g., blob URIs, document types)
- **`converter_identifiers`**: List of converters applied to transform the prompt
- **`scorer_identifier`**: Information about the scorer that evaluated this prompt
- **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`)
- **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`)
- **`scores`**: List of `Score` objects associated with this piece
- **`targeted_harm_categories`**: Harm categories associated with the prompt
- **`timestamp`**: When the piece was created

This rich context allows PyRIT to track the full lifecycle of each interaction, including transformations, targeting, scoring, and error handling.
Expand Down Expand Up @@ -135,6 +131,7 @@ Scores enable automated evaluation of attack success, content harmfulness, and o
- **`outcome_reason`**: Optional explanation for the outcome
- **`related_conversations`**: Set of related conversation references
- **`metadata`**: Arbitrary metadata about the attack execution
- **`targeted_harm_categories`**: Harm categories this attack targeted, auto-populated from the attack's seed group

`AttackResult` objects provide comprehensive reporting on attack campaigns, enabling analysis of red teaming effectiveness and vulnerability identification.

Expand Down
4 changes: 4 additions & 0 deletions pyrit/analytics/technique_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def compute_technique_stats(
*,
technique_eval_hashes: Sequence[str],
scenario_result_id: str | None = None,
targeted_harm_categories: Sequence[str] | None = None,
memory: MemoryInterface | None = None,
) -> dict[str, AttackStats]:
"""
Expand All @@ -39,6 +40,8 @@ def compute_technique_stats(
Returned dict is keyed by these.
scenario_result_id (str | None): Restrict to a single scenario run.
Defaults to ``None`` (aggregate across all runs).
targeted_harm_categories (Sequence[str] | None): Restrict to results
whose attack targeted these harm categories. Defaults to ``None``.
memory (MemoryInterface | None): Memory backend to query. Defaults to
``CentralMemory.get_memory_instance()``.

Expand All @@ -54,6 +57,7 @@ def compute_technique_stats(
results = memory.get_attack_results(
atomic_attack_eval_hashes=list(technique_eval_hashes),
scenario_result_id=scenario_result_id,
targeted_harm_categories=targeted_harm_categories,
)

requested = set(technique_eval_hashes)
Expand Down
10 changes: 9 additions & 1 deletion pyrit/executor/attack/core/attack_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyrit.models import Message, SeedAttackGroup, SeedGroup

if TYPE_CHECKING:
from pyrit.models import SeedUnion
from pyrit.prompt_target import PromptTarget
from pyrit.score import TrueFalseScorer

Expand Down Expand Up @@ -41,6 +42,10 @@ class AttackParameters:
# Additional labels that can be applied to the prompts throughout the attack
memory_labels: dict[str, str] | None = field(default_factory=dict)

# Harm categories targeted by this attack, derived from the seed group's
# seeds. Stamped onto the produced AttackResult.
targeted_harm_categories: list[str] = field(default_factory=list)

def __str__(self) -> str:
"""Return a nicely formatted string representation of the attack parameters."""
lines = [f"{self.__class__.__name__}:"]
Expand Down Expand Up @@ -138,6 +143,9 @@ async def from_seed_group_async(
if "memory_labels" in valid_fields:
params["memory_labels"] = {}

if "targeted_harm_categories" in valid_fields:
params["targeted_harm_categories"] = list(seed_group.harm_categories)

# Determine which group to use for extracting prepended_conversation/next_message
extraction_group: SeedGroup = seed_group

Expand All @@ -164,7 +172,7 @@ async def from_seed_group_async(
)

# Merge simulated prompts with existing static prompts from the seed_group
all_prompts = list(seed_group.prompts) + simulated_prompts
all_prompts: list[SeedUnion] = [*seed_group.prompts, *simulated_prompts]

# Create a temporary prompts-only SeedGroup for extraction
# This group contains only prompts (no objective, no simulated config)
Expand Down
26 changes: 26 additions & 0 deletions pyrit/executor/attack/core/attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ async def _on_post_execute_async(
# AttackResultEntry row records its lineage. Outside an orchestrator
Comment thread
rlundeen2 marked this conversation as resolved.
# _attribution is None and both attribution fields stay None.
self._apply_attribution(context=event_data.context, result=event_data.result)
self._apply_targeted_harm_categories(context=event_data.context, result=event_data.result)

self._logger.debug(f"Attack execution completed in {execution_time_ms}ms")

Expand Down Expand Up @@ -275,6 +276,30 @@ def _apply_attribution(
attribution_data["parent_eval_hash"] = attribution.parent_eval_hash
result.attribution_data = attribution_data

@staticmethod
def _apply_targeted_harm_categories(
*,
context: AttackStrategyContextT,
result: AttackResult,
) -> None:
"""
Copy the attack's targeted harm categories from its parameters onto the result.

Reads ``context.params.targeted_harm_categories`` (populated in
``AttackParameters.from_seed_group_async`` from the SeedGroup's
deduplicated harm categories) and stamps it onto the result so it
round-trips into ``AttackResultEntry``. The read is defensive because
some ``AttackParameters`` subclasses may exclude the field.

Args:
context: The per-task AttackContext.
result: The AttackResult that is about to be persisted.
"""
params = getattr(context, "params", None)
harm_categories = getattr(params, "targeted_harm_categories", None)
if harm_categories:
result.targeted_harm_categories = list(harm_categories)

def _log_attack_outcome(self, result: AttackResult) -> None:
"""
Log the outcome of the attack.
Expand Down Expand Up @@ -342,6 +367,7 @@ async def _on_error_async(
# Stamp attribution onto the error result so it is locatable via the
# attribution_parent_id foreign key on resume.
self._apply_attribution(context=context, result=error_result)
self._apply_targeted_harm_categories(context=context, result=error_result)

self._memory.add_attack_results_to_memory(attack_results=[error_result])

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Add targeted_harm_categories to Attack Results.

Adds a nullable JSON ``targeted_harm_categories`` column to the
``AttackResultEntries`` table. No backfill.

Revision ID: c3d5e7f9a1b2
Revises: b2f4c6a8d1e3
Create Date: 2026-06-11 17:55:00.000000
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c3d5e7f9a1b2"
down_revision: str | None = "b2f4c6a8d1e3"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Apply this schema upgrade."""
op.add_column("AttackResultEntries", sa.Column("targeted_harm_categories", sa.JSON(), nullable=True))


def downgrade() -> None:
"""Revert this schema upgrade."""
op.drop_column("AttackResultEntries", "targeted_harm_categories")
19 changes: 19 additions & 0 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,7 @@ def get_attack_results(
converter_classes_match: Literal["all", "any"] = "all",
has_converters: bool | None = None,
labels: dict[str, str | Sequence[str]] | None = None,
targeted_harm_categories: Sequence[str] | None = None,
identifier_filters: Sequence[IdentifierFilter] | None = None,
scenario_result_id: str | None = None,
) -> Sequence[AttackResult]:
Expand Down Expand Up @@ -1933,6 +1934,12 @@ def get_attack_results(
["roakey_op_a", "roakey_op_b"]}`` matches attacks where ``operator ==
"roakey"`` AND (``operation == "roakey_op_a"`` OR ``operation ==
"roakey_op_b"``). Defaults to None.
targeted_harm_categories (Sequence[str] | None, optional): Filter results by the
harm categories targeted by the attack (stored on
``AttackResultEntry.targeted_harm_categories``, auto-populated from the
attack's SeedGroup). Returns attacks targeting ANY of the listed categories
(OR logic, case-insensitive). An empty sequence applies no filter. Defaults
to None.
identifier_filters (Sequence[IdentifierFilter] | None, optional):
A sequence of IdentifierFilter objects that allows filtering by various attack identifier
JSON properties. Defaults to None.
Expand Down Expand Up @@ -2049,6 +2056,18 @@ def get_attack_results(
# Use database-specific JSON query method
conditions.append(self._get_attack_result_label_condition(labels=effective_labels))

if targeted_harm_categories:
# Match attacks whose targeted_harm_categories array contains ANY of the
# requested categories.
conditions.append(
self._get_condition_json_array_match(
json_column=AttackResultEntry.targeted_harm_categories,
property_path="$",
array_to_match=list(targeted_harm_categories),
match_mode="any",
)
)

if identifier_filters:
conditions.extend(
self._build_identifier_filter_conditions(
Expand Down
4 changes: 4 additions & 0 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ class AttackResultEntry(Base):
outcome_reason (str): Optional reason for the outcome, providing additional context.
attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context.
labels (dict[str, str]): Optional labels associated with the attack result entry.
targeted_harm_categories (list[str]): Harm categories this attack targeted.
pruned_conversation_ids (list[str]): List of conversation IDs that were pruned from the attack.
adversarial_chat_conversation_ids (list[str]): List of conversation IDs used for adversarial chat.
timestamp (DateTime): The timestamp of the attack result entry.
Expand Down Expand Up @@ -883,6 +884,7 @@ class AttackResultEntry(Base):
outcome_reason = mapped_column(String, nullable=True)
attack_metadata: Mapped[dict[str, str | int | float | bool] | None] = mapped_column(JSON, nullable=True)
labels: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True)
targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
pruned_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
adversarial_chat_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
timestamp = mapped_column(UTCDateTime, nullable=False)
Expand Down Expand Up @@ -949,6 +951,7 @@ def __init__(self, *, entry: AttackResult) -> None:
self.outcome_reason = entry.outcome_reason
self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata)
self.labels = entry.labels or {}
self.targeted_harm_categories = entry.targeted_harm_categories or None

# Persist conversation references by type
self.pruned_conversation_ids = [
Expand Down Expand Up @@ -1076,6 +1079,7 @@ def get_attack_result(self) -> AttackResult:
metadata=self.attack_metadata or {},
timestamp=self.timestamp or datetime.now(tz=timezone.utc),
labels=self.labels or {},
targeted_harm_categories=self.targeted_harm_categories or [],
error_message=self.error_message,
error_type=self.error_type,
error_traceback=self.error_traceback,
Expand Down
7 changes: 7 additions & 0 deletions pyrit/models/results/attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class AttackResult(StrategyResult):
# labels associated with this attack result
labels: dict[str, str] = Field(default_factory=dict)

# Harm categories this attack targeted. Auto-populated from the attack's
# SeedGroup (the deduplicated union of its seeds' harm_categories) when the
# result is produced by an attack strategy.
targeted_harm_categories: list[str] = Field(default_factory=list)

# Error information (populated when attack fails with exception)
error_message: str | None = None
error_type: str | None = None
Expand Down Expand Up @@ -244,6 +249,7 @@ def to_dict(self) -> dict[str, Any]:
),
"metadata": self.metadata,
"labels": self.labels,
"targeted_harm_categories": self.targeted_harm_categories,
"error_message": self.error_message,
"error_type": self.error_type,
"error_traceback": self.error_traceback,
Expand Down Expand Up @@ -294,6 +300,7 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult:
},
metadata=data.get("metadata", {}),
labels=data.get("labels", {}),
targeted_harm_categories=data.get("targeted_harm_categories", []),
error_message=data.get("error_message"),
error_type=data.get("error_type"),
error_traceback=data.get("error_traceback"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def select_async(
stats = compute_technique_stats(
technique_eval_hashes=technique_list,
scenario_result_id=effective_run_id,
targeted_harm_categories=self._scope.targeted_harm_categories,
)

chosen: list[str] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ class SelectorScope:
queries when estimating technique success rates.

All fields default to "no restriction"; combine fields to narrow the
scope (e.g. current run only). Filter values flow through
``compute_technique_stats`` to ``MemoryInterface.get_attack_results``.
scope (e.g. current run only, same harm category). Filter values flow
through ``compute_technique_stats`` to
``MemoryInterface.get_attack_results``.

The scope is held by the selector at construction time. The per-call
``scenario_result_id`` is supplied by the dispatcher and is forwarded
Expand All @@ -37,6 +38,10 @@ class SelectorScope:
"""Restrict to the dispatcher-supplied ``scenario_result_id`` for the
in-flight run. When ``False`` (default), query across all runs."""

targeted_harm_categories: Sequence[str] | None = None
"""Filter to results whose attack targeted these harm categories.
``None`` means no harm-category filter."""

@classmethod
def all_runs(cls) -> SelectorScope:
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/analytics/test_technique_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_passes_eval_hashes_to_memory_query(self, _patch_memory):
call_kwargs = _patch_memory.get_attack_results.call_args[1]
assert call_kwargs["atomic_attack_eval_hashes"] == ["x", "y"]
assert call_kwargs["scenario_result_id"] is None
assert call_kwargs["targeted_harm_categories"] is None

def test_passes_scenario_result_id_to_memory_query(self, _patch_memory):
compute_technique_stats(technique_eval_hashes=["x"], scenario_result_id="run-123")
Expand Down Expand Up @@ -123,6 +124,15 @@ def test_success_rate_computed(self, _patch_memory):

assert stats["a"].success_rate == pytest.approx(0.5)

def test_passes_harm_categories_to_memory_query(self, _patch_memory):
compute_technique_stats(
technique_eval_hashes=["x"],
targeted_harm_categories=["misinformation", "hate"],
)

call_kwargs = _patch_memory.get_attack_results.call_args[1]
assert call_kwargs["targeted_harm_categories"] == ["misinformation", "hate"]

def test_injected_memory_bypasses_central_memory(self, _patch_memory):
injected = MagicMock()
injected.get_attack_results.return_value = [
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/executor/attack/core/test_attack_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ async def test_extracts_objective_from_seed_group(self, seed_group_with_objectiv

assert params.objective == "Test objective"

async def test_extracts_targeted_harm_categories_from_seed_group(self) -> None:
"""Harm categories from the seed group's seeds are captured onto the parameters."""
objective = SeedObjective(value="Test objective", harm_categories=["violence"])
prompt = SeedPrompt(value="Test prompt", data_type="text", role="user", harm_categories=["hate", "violence"])
seed_group = SeedAttackGroup(seeds=[objective, prompt])

params = await AttackParameters.from_seed_group_async(seed_group=seed_group)

assert sorted(params.targeted_harm_categories) == ["hate", "violence"]

async def test_targeted_harm_categories_empty_when_seed_group_has_none(
self, seed_group_with_objective: SeedAttackGroup
) -> None:
"""When no seed declares harm categories, the parameters list is empty."""
params = await AttackParameters.from_seed_group_async(seed_group=seed_group_with_objective)

assert params.targeted_harm_categories == []

async def test_raises_when_no_objective(self) -> None:
"""Test that ValueError is raised when SeedAttackGroup has no objective."""
# SeedAttackGroup now validates exactly one objective at construction
Expand Down
Loading
Loading