Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/cli/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None:
Args:
result_dict: ``ScenarioResult.to_dict()`` payload from the REST API.
"""
from pyrit.models.scenario_result import ScenarioResult
from pyrit.models import ScenarioResult
from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter

scenario_result = ScenarioResult.from_dict(result_dict)
Expand Down
1 change: 0 additions & 1 deletion pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
SeedSimulatedConversation,
SeedType,
)
from pyrit.models.scenario_result import ScenarioRunState

logger = logging.getLogger(__name__)

Expand Down
14 changes: 7 additions & 7 deletions pyrit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
from typing import TYPE_CHECKING, Any

from pyrit.common.deprecation import print_deprecation_message
from pyrit.models.chat_message import (
ALLOWED_CHAT_MESSAGE_ROLES,
ChatMessage,
ChatMessagesDataset,
)
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models.conversation_stats import ConversationStats
from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation
from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions
Expand Down Expand Up @@ -79,11 +73,17 @@
group_message_pieces_into_conversations,
sort_message_pieces,
)
from pyrit.models.messages.chat_message import (
ALLOWED_CHAT_MESSAGE_ROLES,
ChatMessage,
ChatMessagesDataset,
)
from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType
from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice
from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT
from pyrit.models.results.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState
from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT
from pyrit.models.retry_event import RetryEvent
from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState
from pyrit.models.score import Score, ScoreType, UnvalidatedScore

# Seeds - import from new seeds submodule for forward compatibility
Expand Down
16 changes: 16 additions & 0 deletions pyrit/models/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,20 @@

- MessagePiece: A single piece of a message exchanged with a target.
- Message: One request/response to a target, made up of one or more pieces.
- ChatMessage: OpenAI-style wire shape consumed/emitted by prompt targets.
- Conversation: Conversation-scoped metadata shared by every piece.
- ConversationReference: Immutable reference to a conversation in an attack.
- conversations: Free functions that operate on collections of messages/pieces.
"""

from pyrit.models.messages.chat_message import (
ALLOWED_CHAT_MESSAGE_ROLES,
ChatMessage,
ChatMessagesDataset,
ToolCall,
)
from pyrit.models.messages.conversation import Conversation
from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType
from pyrit.models.messages.conversations import (
construct_response_from_request,
flatten_to_message_pieces,
Expand All @@ -21,9 +31,15 @@
from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces

__all__ = [
"ALLOWED_CHAT_MESSAGE_ROLES",
"ChatMessage",
"ChatMessagesDataset",
"Conversation",
"ConversationReference",
"ConversationType",
"Message",
"MessagePiece",
"ToolCall",
"construct_response_from_request",
"flatten_to_message_pieces",
"get_all_values",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
OpenAI-format chat message types.

``ChatMessage`` is the OpenAI Chat Completions wire shape — a ``role`` plus a
string-or-multipart ``content``, with the OpenAI ``name`` / ``tool_calls`` /
``tool_call_id`` fields. Prompt targets that speak the OpenAI API (and the many
providers that mirror it) consume and emit these objects directly.

It is intentionally distinct from the PyRIT domain ``Message`` / ``MessagePiece``
types in this same package: those model a persisted request/response exchange,
whereas ``ChatMessage`` is the lightweight OpenAI-shaped transport representation
handed to a model API.
"""

from typing import Any

from pydantic import BaseModel, ConfigDict
Expand All @@ -21,9 +35,9 @@ class ToolCall(BaseModel):

class ChatMessage(BaseModel):
"""
Represents a chat message for API consumption.
Represents a single OpenAI Chat Completions message.

The content field can be:
Mirrors the OpenAI message schema. The content field can be:
- A simple string for single-part text messages
- A list of dicts for multipart messages (e.g., text + images)
"""
Expand Down
13 changes: 12 additions & 1 deletion pyrit/models/results/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,31 @@
# Licensed under the MIT license.

"""
Results module - strategy and attack result types for PyRIT.
Results module - strategy, attack, and scenario result types for PyRIT.

- StrategyResult: Base class for all strategy results.
- AttackResult: Result of an attack execution, with conversation/scoring evidence.
- AttackOutcome: Enum of possible attack outcomes.
- ScenarioResult: Aggregate result of a scenario run.
- ScenarioIdentifier: Identifier describing the executed scenario.
- ScenarioRunState: Lifecycle state of a scenario run.
"""

from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT
from pyrit.models.results.scenario_result import (
ScenarioIdentifier,
ScenarioResult,
ScenarioRunState,
)
from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT

__all__ = [
"AttackOutcome",
"AttackResult",
"AttackResultT",
"ScenarioIdentifier",
"ScenarioResult",
"ScenarioRunState",
"StrategyResult",
"StrategyResultT",
]
2 changes: 1 addition & 1 deletion pyrit/models/results/attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from pydantic import AwareDatetime, Field

from pyrit.common.deprecation import print_deprecation_message
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models.identifiers.component_identifier import ComponentIdentifier
from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType
from pyrit.models.messages.message_piece import MessagePiece
from pyrit.models.results.strategy_result import StrategyResult
from pyrit.models.retry_event import RetryEvent
Expand Down
3 changes: 1 addition & 2 deletions pyrit/output/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

import os

from pyrit.models import AttackResult, ComponentIdentifier, Message, Score
from pyrit.models.scenario_result import ScenarioResult
from pyrit.models import AttackResult, ComponentIdentifier, Message, ScenarioResult, Score
from pyrit.output.attack_result.markdown import MarkdownAttackResultMemoryPrinter
from pyrit.output.attack_result.pretty import PrettyAttackResultMemoryPrinter
from pyrit.output.conversation.pretty import PrettyConversationMemoryPrinter
Expand Down
2 changes: 1 addition & 1 deletion pyrit/output/scenario_result/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import abstractmethod

from pyrit.common.deprecation import print_deprecation_message
from pyrit.models.scenario_result import ScenarioResult
from pyrit.models import ScenarioResult
from pyrit.output.base import PrinterBase


Expand Down
3 changes: 1 addition & 2 deletions pyrit/output/scenario_result/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from colorama import Fore, Style

from pyrit.common.deprecation import print_deprecation_message
from pyrit.models import AttackOutcome
from pyrit.models.scenario_result import ScenarioResult
from pyrit.models import AttackOutcome, ScenarioResult
from pyrit.output.scenario_result.base import ScenarioResultPrinterBase
from pyrit.output.scorer.base import ScorerPrinterBase
from pyrit.output.sink import Sink
Expand Down
2 changes: 1 addition & 1 deletion pyrit/scenario/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from types import ModuleType

from pyrit.common.parameter import Parameter
from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult
from pyrit.models import ScenarioIdentifier, ScenarioResult
from pyrit.scenario.core import (
AtomicAttack,
AttackTechnique,
Expand Down
10 changes: 8 additions & 2 deletions pyrit/scenario/core/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@
from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack
from pyrit.memory import CentralMemory
from pyrit.memory.memory_models import ScenarioResultEntry
from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup
from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState
from pyrit.models import (
AttackOutcome,
AttackResult,
ScenarioIdentifier,
ScenarioResult,
ScenarioRunState,
SeedAttackGroup,
)
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_requirements import TargetRequirements
from pyrit.registry import ScorerRegistry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestScenarioResultContract:

def test_scenario_result_importable(self):
"""ScenarioOrchestrator reads ScenarioResult."""
from pyrit.models.scenario_result import ScenarioResult
from pyrit.models import ScenarioResult

assert ScenarioResult is not None

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/backend/test_attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ async def test_conversation_summary_formats_media_preview(self, attack_service,

async def test_returns_main_and_related_conversations(self, attack_service, mock_memory):
"""Should return main and PRUNED conversations sorted by timestamp."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations.add(
Expand Down Expand Up @@ -1910,7 +1910,7 @@ async def test_raises_when_conversation_not_part_of_attack(self, attack_service,

async def test_swaps_main_conversation(self, attack_service, mock_memory):
"""Changing the main to a related conversation should swap it with the main."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand Down Expand Up @@ -1950,7 +1950,7 @@ class TestAddMessageTargetConversation:
async def test_stores_message_in_target_conversation(self, attack_service, mock_memory):
"""When target_conversation_id is set, messages should go to that conversation."""
from pyrit.backend.models.attacks import AttackSummary, ConversationMessagesResponse
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand Down Expand Up @@ -2017,7 +2017,7 @@ class TestConversationCount:

async def test_list_attacks_includes_related_conversation_ids(self, attack_service, mock_memory):
"""Attacks with related conversations should expose them in the summary."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand Down Expand Up @@ -2069,7 +2069,7 @@ async def test_create_conversation_increments_count(self, attack_service, mock_m
async def test_create_second_conversation_preserves_first(self, attack_service, mock_memory):
"""Creating a second related conversation should keep the first one."""
from pyrit.backend.models.attacks import CreateConversationRequest
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand Down Expand Up @@ -2099,7 +2099,7 @@ class TestConversationSorting:

async def test_conversations_sorted_by_created_at_earliest_first(self, attack_service, mock_memory):
"""Conversations should be sorted by created_at with earliest first."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand Down Expand Up @@ -2127,7 +2127,7 @@ async def test_conversations_sorted_by_created_at_earliest_first(self, attack_se

async def test_empty_conversations_sorted_last(self, attack_service, mock_memory):
"""Conversations with no timestamp should appear at the bottom."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand All @@ -2153,7 +2153,7 @@ async def test_empty_conversations_sorted_last(self, attack_service, mock_memory

async def test_empty_conversations_all_sort_last(self, attack_service, mock_memory):
"""Multiple empty conversations should all have created_at=None."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = make_attack_result(conversation_id="attack-1")
ar.related_conversations = {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/backend/test_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ async def test_no_converters_returns_empty_list(self) -> None:

async def test_related_conversation_ids_from_related_conversations(self) -> None:
"""Test that related_conversation_ids includes all related conversation IDs."""
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ConversationReference, ConversationType

ar = _make_attack_result()
ar.related_conversations = {
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/backend/test_response_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from pyrit.models import (
AttackResult,
ComponentIdentifier,
ConversationReference,
ConversationType,
MessagePiece,
RetryEvent,
Score,
build_atomic_attack_identifier,
)
from pyrit.models.conversation_reference import ConversationReference, ConversationType


def _make_score() -> Score:
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/cli/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ async def test_print_scenario_result_async_uses_pretty_printer():
fake_printer.write_async = AsyncMock()

with (
patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock,
patch("pyrit.models.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock,
patch(
"pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer
) as printer_cls,
Expand All @@ -339,8 +339,7 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload():
"""
from datetime import datetime, timezone

from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier
from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult
from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult

identifier = ScenarioIdentifier(name="test.scenario", description="A test")
target_identifier = ComponentIdentifier.from_dict(
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/models/test_attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import pytest

from pyrit.memory.memory_models import AttackResultEntry
from pyrit.models import ComponentIdentifier
from pyrit.models.conversation_reference import ConversationReference, ConversationType
from pyrit.models import ComponentIdentifier, ConversationReference, ConversationType
from pyrit.models.messages.message_piece import MessagePiece
from pyrit.models.results.attack_result import AttackOutcome, AttackResult
from pyrit.models.retry_event import RetryEvent
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/models/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
import pytest
from pydantic import ValidationError

from pyrit.models.chat_message import (
ChatMessage,
ChatMessagesDataset,
ToolCall,
)
from pyrit.models import ChatMessage, ChatMessagesDataset
from pyrit.models.messages.chat_message import ToolCall


def test_tool_call_init():
Expand Down
Loading
Loading