diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py index 3580c8e5b6..507b14629f 100644 --- a/pyrit/cli/_output.py +++ b/pyrit/cli/_output.py @@ -284,7 +284,7 @@ async def print_scenario_result_async(*, result_dict: dict[str, Any]) -> None: Args: result_dict: ``ScenarioResult.to_dict()`` payload from the REST API. """ - from pyrit.models.scenario_result import ScenarioResult + from pyrit.models import ScenarioResult from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter scenario_result = ScenarioResult.from_dict(result_dict) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index ab467c31a7..f3a20abdd4 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -55,7 +55,6 @@ SeedSimulatedConversation, SeedType, ) -from pyrit.models.scenario_result import ScenarioRunState logger = logging.getLogger(__name__) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index b5ffe6be7b..6e3cb93a0a 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -21,12 +21,6 @@ from typing import TYPE_CHECKING, Any from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.chat_message import ( - ALLOWED_CHAT_MESSAGE_ROLES, - ChatMessage, - ChatMessagesDataset, -) -from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.conversation_stats import ConversationStats from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.harm_definition import HarmDefinition, ScaleDescription, get_all_harm_definitions @@ -79,11 +73,17 @@ group_message_pieces_into_conversations, sort_message_pieces, ) +from pyrit.models.messages.chat_message import ( + ALLOWED_CHAT_MESSAGE_ROLES, + ChatMessage, + ChatMessagesDataset, +) +from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT +from pyrit.models.results.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT from pyrit.models.retry_event import RetryEvent -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState from pyrit.models.score import Score, ScoreType, UnvalidatedScore # Seeds - import from new seeds submodule for forward compatibility diff --git a/pyrit/models/messages/__init__.py b/pyrit/models/messages/__init__.py index 58c9f1a63e..fc0817579c 100644 --- a/pyrit/models/messages/__init__.py +++ b/pyrit/models/messages/__init__.py @@ -6,10 +6,20 @@ - MessagePiece: A single piece of a message exchanged with a target. - Message: One request/response to a target, made up of one or more pieces. +- ChatMessage: OpenAI-style wire shape consumed/emitted by prompt targets. +- Conversation: Conversation-scoped metadata shared by every piece. +- ConversationReference: Immutable reference to a conversation in an attack. - conversations: Free functions that operate on collections of messages/pieces. """ +from pyrit.models.messages.chat_message import ( + ALLOWED_CHAT_MESSAGE_ROLES, + ChatMessage, + ChatMessagesDataset, + ToolCall, +) from pyrit.models.messages.conversation import Conversation +from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType from pyrit.models.messages.conversations import ( construct_response_from_request, flatten_to_message_pieces, @@ -21,9 +31,15 @@ from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces __all__ = [ + "ALLOWED_CHAT_MESSAGE_ROLES", + "ChatMessage", + "ChatMessagesDataset", "Conversation", + "ConversationReference", + "ConversationType", "Message", "MessagePiece", + "ToolCall", "construct_response_from_request", "flatten_to_message_pieces", "get_all_values", diff --git a/pyrit/models/chat_message.py b/pyrit/models/messages/chat_message.py similarity index 65% rename from pyrit/models/chat_message.py rename to pyrit/models/messages/chat_message.py index 99939704a5..57e2b0dbab 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/messages/chat_message.py @@ -1,6 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +""" +OpenAI-format chat message types. + +``ChatMessage`` is the OpenAI Chat Completions wire shape — a ``role`` plus a +string-or-multipart ``content``, with the OpenAI ``name`` / ``tool_calls`` / +``tool_call_id`` fields. Prompt targets that speak the OpenAI API (and the many +providers that mirror it) consume and emit these objects directly. + +It is intentionally distinct from the PyRIT domain ``Message`` / ``MessagePiece`` +types in this same package: those model a persisted request/response exchange, +whereas ``ChatMessage`` is the lightweight OpenAI-shaped transport representation +handed to a model API. +""" + from typing import Any from pydantic import BaseModel, ConfigDict @@ -21,9 +35,9 @@ class ToolCall(BaseModel): class ChatMessage(BaseModel): """ - Represents a chat message for API consumption. + Represents a single OpenAI Chat Completions message. - The content field can be: + Mirrors the OpenAI message schema. The content field can be: - A simple string for single-part text messages - A list of dicts for multipart messages (e.g., text + images) """ diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/messages/conversation_reference.py similarity index 100% rename from pyrit/models/conversation_reference.py rename to pyrit/models/messages/conversation_reference.py diff --git a/pyrit/models/results/__init__.py b/pyrit/models/results/__init__.py index 4bcc2f8848..b57cb1ef37 100644 --- a/pyrit/models/results/__init__.py +++ b/pyrit/models/results/__init__.py @@ -2,20 +2,31 @@ # Licensed under the MIT license. """ -Results module - strategy and attack result types for PyRIT. +Results module - strategy, attack, and scenario result types for PyRIT. - StrategyResult: Base class for all strategy results. - AttackResult: Result of an attack execution, with conversation/scoring evidence. - AttackOutcome: Enum of possible attack outcomes. +- ScenarioResult: Aggregate result of a scenario run. +- ScenarioIdentifier: Identifier describing the executed scenario. +- ScenarioRunState: Lifecycle state of a scenario run. """ from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT +from pyrit.models.results.scenario_result import ( + ScenarioIdentifier, + ScenarioResult, + ScenarioRunState, +) from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT __all__ = [ "AttackOutcome", "AttackResult", "AttackResultT", + "ScenarioIdentifier", + "ScenarioResult", + "ScenarioRunState", "StrategyResult", "StrategyResultT", ] diff --git a/pyrit/models/results/attack_result.py b/pyrit/models/results/attack_result.py index 648c837214..138d3dd38e 100644 --- a/pyrit/models/results/attack_result.py +++ b/pyrit/models/results/attack_result.py @@ -11,8 +11,8 @@ from pydantic import AwareDatetime, Field from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.messages.conversation_reference import ConversationReference, ConversationType from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.results.strategy_result import StrategyResult from pyrit.models.retry_event import RetryEvent diff --git a/pyrit/models/scenario_result.py b/pyrit/models/results/scenario_result.py similarity index 100% rename from pyrit/models/scenario_result.py rename to pyrit/models/results/scenario_result.py diff --git a/pyrit/output/helpers.py b/pyrit/output/helpers.py index 87923d5862..c1a4b21b15 100644 --- a/pyrit/output/helpers.py +++ b/pyrit/output/helpers.py @@ -11,8 +11,7 @@ import os -from pyrit.models import AttackResult, ComponentIdentifier, Message, Score -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import AttackResult, ComponentIdentifier, Message, ScenarioResult, Score from pyrit.output.attack_result.markdown import MarkdownAttackResultMemoryPrinter from pyrit.output.attack_result.pretty import PrettyAttackResultMemoryPrinter from pyrit.output.conversation.pretty import PrettyConversationMemoryPrinter diff --git a/pyrit/output/scenario_result/base.py b/pyrit/output/scenario_result/base.py index 579d480acc..13972d9ac5 100644 --- a/pyrit/output/scenario_result/base.py +++ b/pyrit/output/scenario_result/base.py @@ -4,7 +4,7 @@ from abc import abstractmethod from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import ScenarioResult from pyrit.output.base import PrinterBase diff --git a/pyrit/output/scenario_result/pretty.py b/pyrit/output/scenario_result/pretty.py index 5abbc807ec..d8654c0bfd 100644 --- a/pyrit/output/scenario_result/pretty.py +++ b/pyrit/output/scenario_result/pretty.py @@ -6,8 +6,7 @@ from colorama import Fore, Style from pyrit.common.deprecation import print_deprecation_message -from pyrit.models import AttackOutcome -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import AttackOutcome, ScenarioResult from pyrit.output.scenario_result.base import ScenarioResultPrinterBase from pyrit.output.scorer.base import ScorerPrinterBase from pyrit.output.sink import Sink diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index a93f3098c1..2ef485abdf 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -19,7 +19,7 @@ from types import ModuleType from pyrit.common.parameter import Parameter -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult +from pyrit.models import ScenarioIdentifier, ScenarioResult from pyrit.scenario.core import ( AtomicAttack, AttackTechnique, diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 4611128eda..dce77c3bbd 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -37,8 +37,14 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry -from pyrit.models import AttackOutcome, AttackResult, SeedAttackGroup -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState +from pyrit.models import ( + AttackOutcome, + AttackResult, + ScenarioIdentifier, + ScenarioResult, + ScenarioRunState, + SeedAttackGroup, +) from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.registry import ScorerRegistry diff --git a/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py b/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py index 3e806f03ef..4ff4d77b57 100644 --- a/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py +++ b/tests/partner_integration/azure_ai_evaluation/test_foundry_scenario_contract.py @@ -57,7 +57,7 @@ class TestScenarioResultContract: def test_scenario_result_importable(self): """ScenarioOrchestrator reads ScenarioResult.""" - from pyrit.models.scenario_result import ScenarioResult + from pyrit.models import ScenarioResult assert ScenarioResult is not None diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 2749c6fd67..5f7bee1809 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -1757,7 +1757,7 @@ async def test_conversation_summary_formats_media_preview(self, attack_service, async def test_returns_main_and_related_conversations(self, attack_service, mock_memory): """Should return main and PRUNED conversations sorted by timestamp.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations.add( @@ -1910,7 +1910,7 @@ async def test_raises_when_conversation_not_part_of_attack(self, attack_service, async def test_swaps_main_conversation(self, attack_service, mock_memory): """Changing the main to a related conversation should swap it with the main.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -1950,7 +1950,7 @@ class TestAddMessageTargetConversation: async def test_stores_message_in_target_conversation(self, attack_service, mock_memory): """When target_conversation_id is set, messages should go to that conversation.""" from pyrit.backend.models.attacks import AttackSummary, ConversationMessagesResponse - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2017,7 +2017,7 @@ class TestConversationCount: async def test_list_attacks_includes_related_conversation_ids(self, attack_service, mock_memory): """Attacks with related conversations should expose them in the summary.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2069,7 +2069,7 @@ async def test_create_conversation_increments_count(self, attack_service, mock_m async def test_create_second_conversation_preserves_first(self, attack_service, mock_memory): """Creating a second related conversation should keep the first one.""" from pyrit.backend.models.attacks import CreateConversationRequest - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2099,7 +2099,7 @@ class TestConversationSorting: async def test_conversations_sorted_by_created_at_earliest_first(self, attack_service, mock_memory): """Conversations should be sorted by created_at with earliest first.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2127,7 +2127,7 @@ async def test_conversations_sorted_by_created_at_earliest_first(self, attack_se async def test_empty_conversations_sorted_last(self, attack_service, mock_memory): """Conversations with no timestamp should appear at the bottom.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { @@ -2153,7 +2153,7 @@ async def test_empty_conversations_sorted_last(self, attack_service, mock_memory async def test_empty_conversations_all_sort_last(self, attack_service, mock_memory): """Multiple empty conversations should all have created_at=None.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = make_attack_result(conversation_id="attack-1") ar.related_conversations = { diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 9170f031df..1aba27d4b5 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -322,7 +322,7 @@ async def test_no_converters_returns_empty_list(self) -> None: async def test_related_conversation_ids_from_related_conversations(self) -> None: """Test that related_conversation_ids includes all related conversation IDs.""" - from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models import ConversationReference, ConversationType ar = _make_attack_result() ar.related_conversations = { diff --git a/tests/unit/backend/test_response_contracts.py b/tests/unit/backend/test_response_contracts.py index bb7f403df3..15ff045b6a 100644 --- a/tests/unit/backend/test_response_contracts.py +++ b/tests/unit/backend/test_response_contracts.py @@ -24,12 +24,13 @@ from pyrit.models import ( AttackResult, ComponentIdentifier, + ConversationReference, + ConversationType, MessagePiece, RetryEvent, Score, build_atomic_attack_identifier, ) -from pyrit.models.conversation_reference import ConversationReference, ConversationType def _make_score() -> Score: diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py index 26f41cccc2..b1c67cbbc6 100644 --- a/tests/unit/cli/test_output.py +++ b/tests/unit/cli/test_output.py @@ -319,7 +319,7 @@ async def test_print_scenario_result_async_uses_pretty_printer(): fake_printer.write_async = AsyncMock() with ( - patch("pyrit.models.scenario_result.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, + patch("pyrit.models.ScenarioResult.from_dict", return_value=fake_scenario) as from_dict_mock, patch( "pyrit.output.scenario_result.pretty.PrettyScenarioResultMemoryPrinter", return_value=fake_printer ) as printer_cls, @@ -339,8 +339,7 @@ async def test_print_scenario_result_async_roundtrip_with_real_payload(): """ from datetime import datetime, timezone - from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier - from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult identifier = ScenarioIdentifier(name="test.scenario", description="A test") target_identifier = ComponentIdentifier.from_dict( diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 43e7522052..2b76191152 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,8 +8,7 @@ import pytest from pyrit.memory.memory_models import AttackResultEntry -from pyrit.models import ComponentIdentifier -from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models import ComponentIdentifier, ConversationReference, ConversationType from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.retry_event import RetryEvent diff --git a/tests/unit/models/test_chat_message.py b/tests/unit/models/test_chat_message.py index c3475285b7..889ca9d6f6 100644 --- a/tests/unit/models/test_chat_message.py +++ b/tests/unit/models/test_chat_message.py @@ -6,11 +6,8 @@ import pytest from pydantic import ValidationError -from pyrit.models.chat_message import ( - ChatMessage, - ChatMessagesDataset, - ToolCall, -) +from pyrit.models import ChatMessage, ChatMessagesDataset +from pyrit.models.messages.chat_message import ToolCall def test_tool_call_init(): diff --git a/tests/unit/models/test_conversation.py b/tests/unit/models/test_conversation.py new file mode 100644 index 0000000000..b2576cd2e9 --- /dev/null +++ b/tests/unit/models/test_conversation.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +from pydantic import ValidationError + +from pyrit.models import ComponentIdentifier, Conversation + + +def test_init_requires_conversation_id(): + with pytest.raises(ValidationError): + Conversation() # type: ignore[call-arg] + + +def test_init_defaults_target_identifier_to_none(): + conversation = Conversation(conversation_id="conv-1") + assert conversation.conversation_id == "conv-1" + assert conversation.target_identifier is None + + +def test_init_forbids_extra_fields(): + with pytest.raises(ValidationError): + Conversation(conversation_id="conv-1", unexpected="value") # type: ignore[call-arg] + + +def test_init_accepts_component_identifier(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier) + assert conversation.target_identifier == identifier + + +def test_target_identifier_accepts_flat_dict(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier.model_dump()) + assert isinstance(conversation.target_identifier, ComponentIdentifier) + assert conversation.target_identifier.class_name == "OpenAIChatTarget" + + +def test_model_dump_serializes_target_identifier_to_flat_dict(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier) + + dumped = conversation.model_dump() + + assert dumped["conversation_id"] == "conv-1" + assert dumped["target_identifier"]["class_name"] == "OpenAIChatTarget" + assert dumped["target_identifier"]["class_module"] == "pyrit.prompt_target" + + +def test_model_dump_with_no_target_identifier(): + conversation = Conversation(conversation_id="conv-1") + assert conversation.model_dump()["target_identifier"] is None + + +def test_round_trips_through_model_validate(): + identifier = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") + conversation = Conversation(conversation_id="conv-1", target_identifier=identifier) + + restored = Conversation.model_validate(conversation.model_dump()) + + assert restored.conversation_id == "conv-1" + assert restored.target_identifier == identifier diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py index de6263cd95..b229bc27eb 100644 --- a/tests/unit/models/test_conversation_reference.py +++ b/tests/unit/models/test_conversation_reference.py @@ -4,7 +4,7 @@ import pytest from pydantic import ValidationError -from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models import ConversationReference, ConversationType def test_conversation_type_values(): diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index e15abc53e7..ab2da79e8c 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -7,11 +7,15 @@ import pytest import pyrit -from pyrit.models import ComponentIdentifier -from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models import ( + ComponentIdentifier, + ConversationReference, + ConversationType, + ScenarioIdentifier, + ScenarioResult, +) from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.retry_event import RetryEvent -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult def _make_scenario_identifier(**kwargs): diff --git a/tests/unit/output/attack_result/test_markdown.py b/tests/unit/output/attack_result/test_markdown.py index b61081867a..7ce0c8b265 100644 --- a/tests/unit/output/attack_result/test_markdown.py +++ b/tests/unit/output/attack_result/test_markdown.py @@ -11,13 +11,13 @@ AttackOutcome, AttackResult, ComponentIdentifier, + ConversationReference, ConversationType, Message, MessagePiece, Score, build_atomic_attack_identifier, ) -from pyrit.models.conversation_reference import ConversationReference from pyrit.output.attack_result.markdown import MarkdownAttackResultMemoryPrinter diff --git a/tests/unit/output/attack_result/test_pretty.py b/tests/unit/output/attack_result/test_pretty.py index dd7d02c6d7..7b9e79737d 100644 --- a/tests/unit/output/attack_result/test_pretty.py +++ b/tests/unit/output/attack_result/test_pretty.py @@ -10,13 +10,13 @@ AttackOutcome, AttackResult, ComponentIdentifier, + ConversationReference, ConversationType, Message, MessagePiece, Score, build_atomic_attack_identifier, ) -from pyrit.models.conversation_reference import ConversationReference from pyrit.output.attack_result.pretty import PrettyAttackResultMemoryPrinter diff --git a/tests/unit/output/scenario_result/test_base.py b/tests/unit/output/scenario_result/test_base.py index fe64b39b2d..0d3bb06413 100644 --- a/tests/unit/output/scenario_result/test_base.py +++ b/tests/unit/output/scenario_result/test_base.py @@ -5,7 +5,7 @@ import pytest -from pyrit.models.scenario_result import ScenarioResult +from pyrit.models import ScenarioResult from pyrit.output.scenario_result.base import ScenarioResultPrinterBase diff --git a/tests/unit/output/scenario_result/test_pretty.py b/tests/unit/output/scenario_result/test_pretty.py index b2f8cced9c..f1ba89c431 100644 --- a/tests/unit/output/scenario_result/test_pretty.py +++ b/tests/unit/output/scenario_result/test_pretty.py @@ -5,8 +5,7 @@ import pytest -from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier -from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult +from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter diff --git a/tests/unit/scenario/core/test_scenario_parameters.py b/tests/unit/scenario/core/test_scenario_parameters.py index 4a013f4365..17fee40f67 100644 --- a/tests/unit/scenario/core/test_scenario_parameters.py +++ b/tests/unit/scenario/core/test_scenario_parameters.py @@ -419,7 +419,7 @@ class TestResumeParameterValidation: @staticmethod def _make_stored_result(*, scenario_name: str, version: int, init_data): """Build a minimal ScenarioResult with a controlled identifier for resume tests.""" - from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult + from pyrit.models import ScenarioIdentifier, ScenarioResult identifier = ScenarioIdentifier( name=scenario_name,