From c82f8ed8e25b85fa19afe4d99e41371debf8be34 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 14:08:54 -0800 Subject: [PATCH 01/10] first_commit --- doc/api.rst | 1 + pyrit/executor/attack/multi_turn/crescendo.py | 2 +- pyrit/identifiers/__init__.py | 2 + pyrit/identifiers/identifier.py | 2 +- pyrit/identifiers/target_identifier.py | 60 ++++++++++++ pyrit/memory/memory_models.py | 21 +++- pyrit/models/message_piece.py | 36 ++++--- pyrit/models/scenario_result.py | 10 +- .../azure_blob_storage_target.py | 8 ++ pyrit/prompt_target/azure_ml_chat_target.py | 11 +++ .../common/prompt_chat_target.py | 2 +- pyrit/prompt_target/common/prompt_target.py | 97 +++++++++++++------ pyrit/prompt_target/gandalf_target.py | 8 ++ .../prompt_target/http_target/http_target.py | 9 ++ .../hugging_face/hugging_face_chat_target.py | 12 +++ .../hugging_face_endpoint_target.py | 10 ++ .../openai/openai_chat_target.py | 15 +++ .../openai/openai_completion_target.py | 13 +++ .../openai/openai_image_target.py | 10 ++ .../openai/openai_realtime_target.py | 8 ++ .../openai/openai_response_target.py | 10 ++ .../prompt_target/openai/openai_tts_target.py | 11 +++ .../openai/openai_video_target.py | 9 ++ .../playwright_copilot_target.py | 8 ++ pyrit/prompt_target/prompt_shield_target.py | 9 ++ pyrit/scenario/core/scenario.py | 3 +- pyrit/scenario/printer/console_printer.py | 7 +- pyrit/score/scorer.py | 20 ++-- .../unit/converter/test_denylist_converter.py | 3 +- .../converter/test_generic_llm_converter.py | 2 + .../converter/test_math_prompt_converter.py | 5 + .../test_random_translation_converter.py | 2 + ...test_toxic_sentence_generator_converter.py | 3 +- .../component/test_conversation_manager.py | 17 +++- .../component/test_simulated_conversation.py | 25 ++++- .../attack/core/test_attack_strategy.py | 13 ++- .../attack/multi_turn/test_crescendo.py | 14 ++- .../multi_turn/test_multi_prompt_sending.py | 14 ++- .../attack/multi_turn/test_red_teaming.py | 21 ++-- .../attack/multi_turn/test_tree_of_attacks.py | 16 ++- .../single_turn/test_context_compliance.py | 15 ++- .../attack/single_turn/test_flip_attack.py | 13 ++- .../single_turn/test_many_shot_jailbreak.py | 13 ++- .../attack/single_turn/test_prompt_sending.py | 4 +- .../attack/single_turn/test_role_play.py | 6 +- .../attack/single_turn/test_skeleton_key.py | 4 +- .../test_attack_parameter_consistency.py | 18 +++- .../attack/test_error_skip_scoring.py | 31 +++++- .../executor/promptgen/fuzzer/test_fuzzer.py | 2 +- .../test_interface_scenario_results.py | 6 +- tests/unit/memory/test_azure_sql_memory.py | 2 +- tests/unit/memory/test_sqlite_memory.py | 2 +- tests/unit/mocks.py | 22 ++++- tests/unit/models/test_message_piece.py | 6 +- .../test_prompt_normalizer.py | 7 +- tests/unit/scenarios/test_content_harms.py | 16 ++- tests/unit/scenarios/test_cyber.py | 18 +++- tests/unit/scenarios/test_encoding.py | 14 ++- tests/unit/scenarios/test_foundry.py | 16 ++- tests/unit/scenarios/test_leakage_scenario.py | 18 +++- tests/unit/scenarios/test_scam.py | 18 +++- tests/unit/scenarios/test_scenario.py | 13 ++- .../test_scenario_partial_results.py | 9 +- tests/unit/scenarios/test_scenario_retry.py | 9 +- tests/unit/score/test_gandalf_scorer.py | 5 + .../score/test_general_float_scale_scorer.py | 6 ++ .../score/test_general_true_false_scorer.py | 4 + tests/unit/score/test_scorer.py | 10 ++ tests/unit/score/test_self_ask_category.py | 9 ++ tests/unit/score/test_self_ask_likert.py | 6 ++ tests/unit/score/test_self_ask_refusal.py | 10 ++ tests/unit/score/test_self_ask_scale.py | 10 +- tests/unit/score/test_self_ask_true_false.py | 9 ++ tests/unit/score/test_video_scorer.py | 2 +- tests/unit/target/test_crucible_target.py | 2 +- tests/unit/target/test_gandalf_target.py | 2 +- tests/unit/target/test_http_target.py | 14 +-- .../test_hugging_face_endpoint_target.py | 2 +- .../target/test_huggingface_chat_target.py | 6 +- tests/unit/target/test_openai_chat_target.py | 32 +++--- 80 files changed, 777 insertions(+), 183 deletions(-) create mode 100644 pyrit/identifiers/target_identifier.py diff --git a/doc/api.rst b/doc/api.rst index 09bae3a1a..bf604f8a3 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -276,6 +276,7 @@ API Reference IdentifierType ScorerIdentifier snake_case_to_class_name + TargetIdentifier :py:mod:`pyrit.memory` ====================== diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index ac937fb54..dd521c194 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -576,7 +576,7 @@ async def _send_prompt_to_objective_target_async( Raises: ValueError: If no response is received from the objective target. """ - objective_target_type = self._objective_target.get_identifier()["__type__"] + objective_target_type = self._objective_target.get_identifier().class_name # Send the generated prompt to the objective target prompt_preview = attack_message.get_value()[:100] if attack_message.get_value() else "" diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index 8ca875ca3..644e03829 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -13,6 +13,7 @@ IdentifierType, ) from pyrit.identifiers.scorer_identifier import ScorerIdentifier +from pyrit.identifiers.target_identifier import TargetIdentifier __all__ = [ "class_name_to_snake_case", @@ -23,4 +24,5 @@ "LegacyIdentifiable", "ScorerIdentifier", "snake_case_to_class_name", + "TargetIdentifier", ] diff --git a/pyrit/identifiers/identifier.py b/pyrit/identifiers/identifier.py index 8e5265d4b..17dedc171 100644 --- a/pyrit/identifiers/identifier.py +++ b/pyrit/identifiers/identifier.py @@ -208,7 +208,7 @@ def normalize(cls: Type[T], value: T | dict[str, Any]) -> T: print_deprecation_message( old_item=f"dict for {cls.__name__}", new_item=cls.__name__, - removed_in="0.14.0", + removed_in="0.13.0", ) return cls.from_dict(value) diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py new file mode 100644 index 000000000..9d4bc9330 --- /dev/null +++ b/pyrit/identifiers/target_identifier.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Type, cast + +from pyrit.identifiers.identifier import Identifier + + +@dataclass(frozen=True) +class TargetIdentifier(Identifier): + """ + Identifier for PromptTarget instances. + + This frozen dataclass extends Identifier with target-specific fields. + It provides a stable, hashable identifier for prompt targets that can be + used for scorer evaluation, registry tracking, and memory storage. + + Attributes: + endpoint: The target endpoint URL, if applicable. + model_name: The model or deployment name. Uses underlying_model if specified, + otherwise falls back to the deployment name. + temperature: The temperature parameter for generation, if applicable. + top_p: The top_p parameter for generation, if applicable. + target_specific_params: Additional target-specific parameters. + """ + + endpoint: str = "" + """The target endpoint URL.""" + + model_name: str = "" + """The model or deployment name.""" + + temperature: Optional[float] = None + """The temperature parameter for generation.""" + + top_p: Optional[float] = None + """The top_p parameter for generation.""" + + target_specific_params: Optional[Dict[str, Any]] = None + """Additional target-specific parameters.""" + + @classmethod + def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIdentifier": + """ + Create a TargetIdentifier from a dictionary (e.g., retrieved from database). + + Extends the base Identifier.from_dict() to handle legacy key mappings. + + Args: + data: The dictionary representation. + + Returns: + TargetIdentifier: A new TargetIdentifier instance. + """ + # Delegate to parent class for standard processing + result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] + return cast(TargetIdentifier, result) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 1401c2e24..b653ee20e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -31,7 +31,7 @@ from sqlalchemy.types import Uuid from pyrit.common.utils import to_sha256 -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -208,7 +208,12 @@ def __init__(self, *, entry: MessagePiece): self.prompt_metadata = entry.prompt_metadata self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = entry.converter_identifiers - self.prompt_target_identifier = entry.prompt_target_identifier + # Normalize prompt_target_identifier and convert to dict for JSON serialization + self.prompt_target_identifier = ( + TargetIdentifier.normalize(entry.prompt_target_identifier).to_dict() + if entry.prompt_target_identifier + else {} + ) self.attack_identifier = entry.attack_identifier self.original_value = entry.original_value @@ -262,7 +267,9 @@ def __str__(self) -> str: str: Formatted string representation of the memory entry. """ if self.prompt_target_identifier: - return f"{self.prompt_target_identifier['__type__']}: {self.role}: {self.converted_value}" + # prompt_target_identifier is stored as dict in the database + class_name = self.prompt_target_identifier.get("class_name") or self.prompt_target_identifier.get("__type__", "Unknown") + return f"{class_name}: {self.role}: {self.converted_value}" return f": {self.role}: {self.converted_value}" @@ -874,7 +881,8 @@ def __init__(self, *, entry: ScenarioResult): self.scenario_version = entry.scenario_identifier.version self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data - self.objective_target_identifier = entry.objective_target_identifier + # Convert TargetIdentifier to dict for JSON storage + self.objective_target_identifier = entry.objective_target_identifier.to_dict() # Convert ScorerIdentifier to dict for JSON storage self.objective_scorer_identifier = ( entry.objective_scorer_identifier.to_dict() if entry.objective_scorer_identifier else None @@ -921,10 +929,13 @@ def get_scenario_result(self) -> ScenarioResult: ScorerIdentifier.from_dict(self.objective_scorer_identifier) if self.objective_scorer_identifier else None ) + # Convert dict back to TargetIdentifier for reconstruction + target_identifier = TargetIdentifier.from_dict(self.objective_target_identifier) + return ScenarioResult( id=self.id, scenario_identifier=scenario_identifier, - objective_target_identifier=self.objective_target_identifier, + objective_target_identifier=target_identifier, attack_results=attack_results, objective_scorer_identifier=scorer_identifier, scenario_run_state=self.scenario_run_state, diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index d0dc3e6df..9ce7ab9be 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,11 +5,10 @@ import uuid from datetime import datetime -from typing import Dict, List, Literal, Optional, Union, get_args +from typing import Any, Dict, List, Literal, Optional, Union, get_args from uuid import uuid4 -from pyrit.common.deprecation import print_deprecation_message -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.score import Score @@ -39,7 +38,7 @@ def __init__( labels: Optional[Dict[str, str]] = None, prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, converter_identifiers: Optional[List[Dict[str, str]]] = None, - prompt_target_identifier: Optional[Dict[str, str]] = None, + prompt_target_identifier: Optional[Union[TargetIdentifier, Dict[str, Any]]] = None, attack_identifier: Optional[Dict[str, str]] = None, scorer_identifier: Optional[Union[ScorerIdentifier, Dict[str, str]]] = None, original_value_data_type: PromptDataType = "text", @@ -108,21 +107,17 @@ def __init__( self.converter_identifiers = converter_identifiers if converter_identifiers else [] - self.prompt_target_identifier = prompt_target_identifier or {} + # Handle prompt_target_identifier: normalize to TargetIdentifier (handles dict with deprecation warning) + self.prompt_target_identifier: Optional[TargetIdentifier] = ( + TargetIdentifier.normalize(prompt_target_identifier) if prompt_target_identifier else None + ) + self.attack_identifier = attack_identifier or {} - # Handle scorer_identifier: convert dict to ScorerIdentifier with deprecation warning - if scorer_identifier is None: - self.scorer_identifier: Optional[ScorerIdentifier] = None - elif isinstance(scorer_identifier, dict): - print_deprecation_message( - old_item="dict for scorer_identifier", - new_item="ScorerIdentifier", - removed_in="0.13.0", - ) - self.scorer_identifier = ScorerIdentifier.from_dict(scorer_identifier) - else: - self.scorer_identifier = scorer_identifier + # Handle scorer_identifier: normalize to ScorerIdentifier (handles dict with deprecation warning) + self.scorer_identifier: Optional[ScorerIdentifier] = ( + ScorerIdentifier.normalize(scorer_identifier) if scorer_identifier else None + ) self.original_value = original_value @@ -279,7 +274,9 @@ def to_dict(self) -> dict[str, object]: "targeted_harm_categories": self.targeted_harm_categories if self.targeted_harm_categories else None, "prompt_metadata": self.prompt_metadata, "converter_identifiers": self.converter_identifiers, - "prompt_target_identifier": self.prompt_target_identifier, + "prompt_target_identifier": ( + self.prompt_target_identifier.to_dict() if self.prompt_target_identifier else None + ), "attack_identifier": self.attack_identifier, "scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None, "original_value_data_type": self.original_value_data_type, @@ -295,7 +292,8 @@ def to_dict(self) -> dict[str, object]: } def __str__(self) -> str: - return f"{self.prompt_target_identifier}: {self._role}: {self.converted_value}" + target_str = self.prompt_target_identifier.class_name if self.prompt_target_identifier else "Unknown" + return f"{target_str}: {self._role}: {self.converted_value}" __repr__ = __str__ diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index a052ecaec..ab0ba0e5e 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -10,7 +10,7 @@ from pyrit.models import AttackOutcome, AttackResult if TYPE_CHECKING: - from pyrit.identifiers import ScorerIdentifier + from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.score import Scorer from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics @@ -59,7 +59,7 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: dict[str, str], + objective_target_identifier: Union[Dict[str, Any], "TargetIdentifier"], attack_results: dict[str, List[AttackResult]], objective_scorer_identifier: Union[Dict[str, Any], "ScorerIdentifier"], scenario_run_state: ScenarioRunState = "CREATED", @@ -71,11 +71,13 @@ def __init__( objective_scorer: Optional["Scorer"] = None, ) -> None: from pyrit.common import print_deprecation_message - from pyrit.identifiers import ScorerIdentifier + from pyrit.identifiers import ScorerIdentifier, TargetIdentifier self.id = id if id is not None else uuid.uuid4() self.scenario_identifier = scenario_identifier - self.objective_target_identifier = objective_target_identifier + + # Normalize objective_target_identifier to TargetIdentifier + self.objective_target_identifier = TargetIdentifier.normalize(objective_target_identifier) # Handle deprecated objective_scorer parameter if objective_scorer is not None: diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index edbb0419b..505c463d5 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -79,6 +79,14 @@ def __init__( super().__init__(endpoint=self._container_url, max_requests_per_minute=max_requests_per_minute) + def _build_identifier(self) -> None: + """Build the identifier with Azure Blob Storage-specific parameters.""" + self._set_identifier( + target_specific_params={ + "blob_content_type": self._blob_content_type, + }, + ) + async def _create_container_client_async(self) -> None: """ Create an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 7734acdd0..2aed246a3 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -103,6 +103,17 @@ def __init__( self._repetition_penalty = repetition_penalty self._extra_parameters = param_kwargs + def _build_identifier(self) -> None: + """Build the identifier with Azure ML-specific parameters.""" + self._set_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_new_tokens": self._max_new_tokens, + "repetition_penalty": self._repetition_penalty, + }, + ) + def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str] = None) -> None: """ Set the endpoint and key for accessing the Azure ML model. Use this function to manually diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index fbfa3f93f..5eac0209f 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -121,7 +121,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) if config.enabled and not self.is_json_response_supported(): - target_name = self.get_identifier()["__type__"] + target_name = self.get_identifier().class_name raise ValueError(f"This target {target_name} does not support JSON response format.") return config diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index cbc0cc079..7095e6c47 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -5,14 +5,14 @@ import logging from typing import Any, Dict, List, Optional -from pyrit.identifiers import LegacyIdentifiable +from pyrit.identifiers import Identifiable, TargetIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import Message logger = logging.getLogger(__name__) -class PromptTarget(LegacyIdentifiable): +class PromptTarget(Identifiable[TargetIdentifier]): """ Abstract base class for prompt targets. @@ -26,6 +26,8 @@ class PromptTarget(LegacyIdentifiable): #: An empty list implies that the prompt target supports all converters. supported_converters: List[Any] + _identifier: Optional[TargetIdentifier] = None + def __init__( self, verbose: bool = False, @@ -91,36 +93,73 @@ def dispose_db_engine(self) -> None: """ self._memory.dispose_engine() - def get_identifier(self) -> Dict[str, Any]: + def _set_identifier( + self, + *, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + target_specific_params: Optional[dict[str, Any]] = None, + ) -> None: + """ + Construct the target identifier. + + Subclasses should call this method in their _build_identifier() implementation + to set the identifier with their specific parameters. + + Args: + temperature (Optional[float]): The temperature parameter for generation. Defaults to None. + top_p (Optional[float]): The top_p parameter for generation. Defaults to None. + target_specific_params (Optional[dict[str, Any]]): Additional target-specific parameters + that should be included in the identifier. Defaults to None. + """ + # Determine the model name to use + model_name = "" + if self._underlying_model: + model_name = self._underlying_model + elif self._model_name: + model_name = self._model_name + + self._identifier = TargetIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", + identifier_type="instance", + endpoint=self._endpoint, + model_name=model_name, + temperature=temperature, + top_p=top_p, + target_specific_params=target_specific_params, + ) + + def _build_identifier(self) -> None: + """ + Build the identifier for this target. + + Subclasses should override this method to call _set_identifier() with + their specific parameters (temperature, top_p, target_specific_params). + + The base implementation calls _set_identifier() with no parameters, + which works for targets that don't have model-specific settings. """ - Get an identifier dictionary for this prompt target. + self._set_identifier() - This includes essential attributes needed for scorer evaluation and registry tracking. - Subclasses should override this method to include additional relevant attributes - (e.g., temperature, top_p) when available. + def get_identifier(self) -> TargetIdentifier: + """ + Get the target identifier. Built lazily on first access. Returns: - Dict[str, Any]: A dictionary containing identification attributes. + TargetIdentifier: The identifier containing all configuration parameters. Note: - If the `self._underlying_model` is specified, either passed in during instantiation - or via environment variable, it is used as the "model_name" for the identifier. - Otherwise, `self._model_name` (which is often the deployment name in Azure) is used. - """ - public_attributes: Dict[str, Any] = {} - public_attributes["__type__"] = self.__class__.__name__ - public_attributes["__module__"] = self.__class__.__module__ - if self._endpoint: - public_attributes["endpoint"] = self._endpoint - # if the underlying model is specified, use it as the model name for identification - # otherwise, use self._model_name (which is often the deployment name in Azure) - if self._underlying_model: - public_attributes["model_name"] = self._underlying_model - elif self._model_name: - public_attributes["model_name"] = self._model_name - # Include temperature and top_p if available (set by subclasses) - if hasattr(self, "_temperature") and self._temperature is not None: - public_attributes["temperature"] = self._temperature - if hasattr(self, "_top_p") and self._top_p is not None: - public_attributes["top_p"] = self._top_p - return public_attributes + If `self._underlying_model` is specified (via instantiation or environment + variable), it is used as the "model_name". Otherwise, `self._model_name` + (which is often the deployment name in Azure) is used. + + For storage in memory/database, call `.to_dict()` on the returned + identifier to get a dictionary suitable for JSON serialization. + """ + if self._identifier is None: + self._build_identifier() + if self._identifier is None: + raise RuntimeError("_build_identifier must set _identifier") + return self._identifier diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 4819b6de9..2a18e3919 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -57,6 +57,14 @@ def __init__( self._defender = level.value + def _build_identifier(self) -> None: + """Build the identifier with Gandalf-specific parameters.""" + self._set_identifier( + target_specific_params={ + "level": self._defender, + }, + ) + @limit_requests_per_minute async def send_prompt_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 419f22633..e7361ad61 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -82,6 +82,15 @@ def __init__( if client and httpx_client_kwargs: raise ValueError("Cannot provide both a pre-configured client and additional httpx client kwargs.") + def _build_identifier(self) -> None: + """Build the identifier with HTTP target-specific parameters.""" + self._set_identifier( + target_specific_params={ + "use_tls": self.use_tls, + "prompt_regex_string": self.prompt_regex_string, + }, + ) + @classmethod def with_client( cls, diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index c38f5858b..950a4def8 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -135,6 +135,18 @@ def __init__( self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer()) + def _build_identifier(self) -> None: + """Build the identifier with HuggingFace chat-specific parameters.""" + self._set_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_new_tokens": self.max_new_tokens, + "skip_special_tokens": self.skip_special_tokens, + "use_cuda": self.use_cuda, + }, + ) + def _load_from_path(self, path: str, **kwargs: Any) -> None: """ Load the model and tokenizer from a given path. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index ebaf10c3e..f4aa7cd55 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -61,6 +61,16 @@ def __init__( self._temperature = temperature self._top_p = top_p + def _build_identifier(self) -> None: + """Build the identifier with HuggingFace endpoint-specific parameters.""" + self._set_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_tokens": self.max_tokens, + }, + ) + @limit_requests_per_minute async def send_prompt_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 544046b59..3e4a9c16b 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -163,6 +163,21 @@ def __init__( self._extra_body_parameters = extra_body_parameters + def _build_identifier(self) -> None: + """Build the identifier with OpenAI chat-specific parameters.""" + self._set_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_completion_tokens": self._max_completion_tokens, + "max_tokens": self._max_tokens, + "frequency_penalty": self._frequency_penalty, + "presence_penalty": self._presence_penalty, + "seed": self._seed, + "n": self._n, + }, + ) + def _set_openai_env_configuration_vars(self) -> None: """ Set deployment_environment_variable, endpoint_environment_variable, diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 9180466e1..376b82bb2 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -72,6 +72,19 @@ def __init__( self._presence_penalty = presence_penalty self._n = n + def _build_identifier(self) -> None: + """Build the identifier with OpenAI completion-specific parameters.""" + self._set_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_tokens": self._max_tokens, + "frequency_penalty": self._frequency_penalty, + "presence_penalty": self._presence_penalty, + "n": self._n, + }, + ) + def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "OPENAI_COMPLETION_MODEL" self.endpoint_environment_variable = "OPENAI_COMPLETION_ENDPOINT" diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index a010d27db..774555790 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -85,6 +85,16 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } + def _build_identifier(self) -> None: + """Build the identifier with image generation-specific parameters.""" + self._set_identifier( + target_specific_params={ + "image_size": self.image_size, + "quality": self.quality, + "style": self.style, + }, + ) + @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async( diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index faea8c56b..17c89f8c8 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -118,6 +118,14 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "wss://api.openai.com/v1", } + def _build_identifier(self) -> None: + """Build the identifier with Realtime API-specific parameters.""" + self._set_identifier( + target_specific_params={ + "voice": self.voice, + }, + ) + def _validate_url_for_target(self, endpoint_url: str) -> None: """ Validate URL for Realtime API with websocket-specific checks. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 9f1868f5c..6cefa23e3 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -156,6 +156,16 @@ def __init__( logger.debug("Detected grammar tool: %s", tool_name) self._grammar_name = tool_name + def _build_identifier(self) -> None: + """Build the identifier with OpenAI response-specific parameters.""" + self._set_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_output_tokens": self._max_output_tokens, + }, + ) + def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "OPENAI_RESPONSES_MODEL" self.endpoint_environment_variable = "OPENAI_RESPONSES_ENDPOINT" diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index d745f3e26..6d7f27d07 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -81,6 +81,17 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } + def _build_identifier(self) -> None: + """Build the identifier with TTS-specific parameters.""" + self._set_identifier( + target_specific_params={ + "voice": self._voice, + "response_format": self._response_format, + "language": self._language, + "speed": self._speed, + }, + ) + @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> list[Message]: diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index def844cb5..8b3ba4eb5 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -95,6 +95,15 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } + def _build_identifier(self) -> None: + """Build the identifier with video generation-specific parameters.""" + self._set_identifier( + target_specific_params={ + "resolution": self._size, + "n_seconds": self._n_seconds, + }, + ) + def _validate_resolution(self, *, resolution_dimensions: str) -> str: """ Validate resolution dimensions. diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 2610e78d8..aa6a7327c 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -128,6 +128,14 @@ def __init__(self, *, page: "Page", copilot_type: CopilotType = CopilotType.CONS if page and self.M365_URL_IDENTIFIER not in page.url and copilot_type == CopilotType.M365: raise ValueError("The provided page URL does not indicate M365 Copilot, but the type is set to m365.") + def _build_identifier(self) -> None: + """Build the identifier with Copilot-specific parameters.""" + self._set_identifier( + target_specific_params={ + "copilot_type": self._type.value, + }, + ) + def _get_selectors(self) -> CopilotSelectors: """ Get the appropriate selectors for the current Copilot type. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 35e30f358..e64ae7576 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -93,6 +93,15 @@ def __init__( self._force_entry_field: PromptShieldEntryField = field + def _build_identifier(self) -> None: + """Build the identifier with Prompt Shield-specific parameters.""" + self._set_identifier( + target_specific_params={ + "api_version": self._api_version, + "force_entry_field": self._force_entry_field if self._force_entry_field else None, + }, + ) + @limit_requests_per_minute async def send_prompt_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 8446f5535..804647574 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -19,6 +19,7 @@ from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import TargetIdentifier from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry from pyrit.models import AttackResult @@ -90,7 +91,7 @@ def __init__( # These will be set in initialize_async self._objective_target: Optional[PromptTarget] = None - self._objective_target_identifier: Optional[Dict[str, str]] = None + self._objective_target_identifier: Optional[TargetIdentifier] = None self._memory_labels: Dict[str, str] = {} self._max_concurrency: int = 1 self._max_retries: int = 0 diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 3a248e0e0..45787916e 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -115,9 +115,10 @@ async def print_summary_async(self, result: ScenarioResult) -> None: # Target information print() self._print_colored(f"{self._indent}🎯 Target Information", Style.BRIGHT) - target_type = result.objective_target_identifier.get("__type__", "Unknown") - target_model = result.objective_target_identifier.get("model_name", "Unknown") - target_endpoint = result.objective_target_identifier.get("endpoint", "Unknown") + target_id = result.objective_target_identifier + target_type = target_id.class_name if target_id else "Unknown" + target_model = target_id.model_name if target_id else "Unknown" + target_endpoint = target_id.endpoint if target_id else "Unknown" self._print_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN) self._print_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN) diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index c9fe2a9d5..9e22d1268 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -104,7 +104,8 @@ def get_identifier(self) -> ScorerIdentifier: """ if self._identifier is None: self._build_identifier() - assert self._identifier is not None, "_build_identifier must set _identifier" + if self._identifier is None: + raise RuntimeError("_build_identifier must set _identifier") return self._identifier @property @@ -122,7 +123,7 @@ def _set_identifier( prompt_target: Optional[PromptTarget] = None, ) -> None: """ - Construct the scorer evaluation identifier. + Construct the scorer identifier. Args: system_prompt_template (Optional[str]): The system prompt template used by this scorer. Defaults to None. @@ -141,16 +142,19 @@ def _set_identifier( target_info: Optional[Dict[str, Any]] = None if prompt_target: target_id = prompt_target.get_identifier() - # Extract standard fields for scorer evaluation - target_info = {} - for key in ["__type__", "model_name", "temperature", "top_p"]: - if key in target_id: - target_info[key] = target_id[key] + # Extract standard fields for scorer evaluation, excluding None values + target_info = {"class_name": target_id.class_name} + if target_id.model_name: + target_info["model_name"] = target_id.model_name + if target_id.temperature is not None: + target_info["temperature"] = target_id.temperature + if target_id.top_p is not None: + target_info["top_p"] = target_id.top_p self._identifier = ScorerIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, - class_description=self.__class__.__doc__ or "", + class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", identifier_type="instance", scorer_type=self.scorer_type, system_prompt_template=system_prompt_template, diff --git a/tests/unit/converter/test_denylist_converter.py b/tests/unit/converter/test_denylist_converter.py index bc7afb896..8bbf06c58 100644 --- a/tests/unit/converter/test_denylist_converter.py +++ b/tests/unit/converter/test_denylist_converter.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from unit.mocks import MockPromptTarget +from unit.mocks import get_mock_target_identifier, MockPromptTarget from pyrit.models import Message, MessagePiece, SeedPrompt from pyrit.prompt_converter import DenylistConverter @@ -28,6 +28,7 @@ def mock_target() -> MockPromptTarget: ] ) target.send_prompt_async = AsyncMock(return_value=[response]) + target.get_identifier.return_value = get_mock_target_identifier("MockDenylistTarget") return target diff --git a/tests/unit/converter/test_generic_llm_converter.py b/tests/unit/converter/test_generic_llm_converter.py index 8c2a912ab..403294cab 100644 --- a/tests/unit/converter/test_generic_llm_converter.py +++ b/tests/unit/converter/test_generic_llm_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import ( @@ -28,6 +29,7 @@ def mock_target() -> PromptTarget: ] ) target.send_prompt_async = AsyncMock(return_value=[response]) + target.get_identifier.return_value = get_mock_target_identifier("MockLLMTarget") return target diff --git a/tests/unit/converter/test_math_prompt_converter.py b/tests/unit/converter/test_math_prompt_converter.py index 9203f5950..703000a73 100644 --- a/tests/unit/converter/test_math_prompt_converter.py +++ b/tests/unit/converter/test_math_prompt_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece, SeedPrompt from pyrit.prompt_converter import ConverterResult @@ -15,6 +16,7 @@ async def test_math_prompt_converter_convert_async(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Solve the following problem: {{ prompt }}" dataset_name = "dataset_1" @@ -72,6 +74,7 @@ async def test_math_prompt_converter_handles_disallowed_content(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Encode this instruction: {{ prompt }}" dataset_name = "dataset_1" @@ -126,6 +129,7 @@ async def test_math_prompt_converter_invalid_input_type(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Encode this instruction: {{ prompt }}" dataset_name = "dataset_1" @@ -148,6 +152,7 @@ async def test_math_prompt_converter_error_handling(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Encode this instruction: {{ prompt }}" dataset_name = "dataset_1" diff --git a/tests/unit/converter/test_random_translation_converter.py b/tests/unit/converter/test_random_translation_converter.py index e9c9ed4e9..a71fb86fe 100644 --- a/tests/unit/converter/test_random_translation_converter.py +++ b/tests/unit/converter/test_random_translation_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import RandomTranslationConverter @@ -27,6 +28,7 @@ def mock_target() -> PromptTarget: ] ) target.send_prompt_async = AsyncMock(return_value=[response]) + target.get_identifier.return_value = get_mock_target_identifier("MockTranslationTarget") return target diff --git a/tests/unit/converter/test_toxic_sentence_generator_converter.py b/tests/unit/converter/test_toxic_sentence_generator_converter.py index ed0d1c867..1e75ef89b 100644 --- a/tests/unit/converter/test_toxic_sentence_generator_converter.py +++ b/tests/unit/converter/test_toxic_sentence_generator_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import MessagePiece, SeedPrompt from pyrit.prompt_converter import ToxicSentenceGeneratorConverter @@ -20,7 +21,7 @@ def mock_target(): conversation_id="test-conversation", ).to_message() mock.send_prompt_async = AsyncMock(return_value=[response]) - mock.get_identifier = MagicMock(return_value="mock_target") + mock.get_identifier.return_value = get_mock_target_identifier("MockToxicTarget") return mock diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index a3bc0474f..5c7f6c50c 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -33,10 +33,21 @@ ) from pyrit.executor.attack.core import AttackContext from pyrit.executor.attack.core.attack_parameters import AttackParameters +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget, PromptTarget -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) # ============================================================================= # Test Context Class @@ -78,7 +89,7 @@ def mock_chat_target() -> MagicMock: """Create a mock chat target for testing.""" target = MagicMock(spec=PromptChatTarget) target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"id": "mock_chat_target_id"} + target.get_identifier.return_value = _mock_target_id("MockChatTarget") return target @@ -86,7 +97,7 @@ def mock_chat_target() -> MagicMock: def mock_prompt_target() -> MagicMock: """Create a mock prompt target (non-chat) for testing.""" target = MagicMock(spec=PromptTarget) - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/component/test_simulated_conversation.py b/tests/unit/executor/attack/component/test_simulated_conversation.py index 2e59fb386..201692f9f 100644 --- a/tests/unit/executor/attack/component/test_simulated_conversation.py +++ b/tests/unit/executor/attack/component/test_simulated_conversation.py @@ -11,6 +11,7 @@ from pyrit.executor.attack.multi_turn.simulated_conversation import ( generate_simulated_conversation_async, ) +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -25,13 +26,33 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_adversarial_chat() -> MagicMock: """Create a mock adversarial chat target for testing.""" chat = MagicMock(spec=PromptChatTarget) chat.send_prompt_async = AsyncMock() chat.set_system_prompt = MagicMock() - chat.get_identifier.return_value = {"__type__": "MockAdversarialChat", "__module__": "test_module"} + chat.get_identifier.return_value = _mock_target_id("MockAdversarialChat") return chat @@ -40,7 +61,7 @@ def mock_objective_scorer() -> MagicMock: """Create a mock objective scorer for testing.""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_async = AsyncMock() - scorer.get_identifier.return_value = {"__type__": "MockScorer", "__module__": "test_module"} + scorer.get_identifier.return_value = _mock_scorer_id("MockScorer") return scorer diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index dcb9db2d6..918b9c753 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -13,6 +13,7 @@ _DefaultAttackStrategyEventHandler, ) from pyrit.executor.core import StrategyEvent, StrategyEventData +from pyrit.identifiers import TargetIdentifier from pyrit.memory.central_memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -22,6 +23,16 @@ from pyrit.prompt_target import PromptTarget +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory(): """Mock CentralMemory instance""" @@ -34,7 +45,7 @@ def mock_memory(): def mock_objective_target(): """Mock PromptTarget instance""" target = MagicMock(spec=PromptTarget) - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/multi_turn/test_crescendo.py b/tests/unit/executor/attack/multi_turn/test_crescendo.py index 1ba71af60..73d4909b3 100644 --- a/tests/unit/executor/attack/multi_turn/test_crescendo.py +++ b/tests/unit/executor/attack/multi_turn/test_crescendo.py @@ -24,7 +24,7 @@ CrescendoAttackContext, CrescendoAttackResult, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, ChatMessageRole, @@ -50,6 +50,16 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + def create_mock_chat_target(*, name: str = "MockChatTarget") -> MagicMock: """Create a mock chat target with common setup. @@ -59,7 +69,7 @@ def create_mock_chat_target(*, name: str = "MockChatTarget") -> MagicMock: target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"__type__": name, "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id(name) return target diff --git a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py index b11a7a61e..bdf70a77d 100644 --- a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py +++ b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py @@ -15,7 +15,7 @@ MultiPromptSendingAttackParameters, MultiTurnAttackContext, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -39,12 +39,22 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 7e1b7b269..340dc10c6 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -19,7 +19,7 @@ RedTeamingAttack, RTASystemPromptPaths, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -45,11 +45,21 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target() -> MagicMock: target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -58,7 +68,7 @@ def mock_adversarial_chat() -> MagicMock: chat = MagicMock(spec=PromptChatTarget) chat.send_prompt_async = AsyncMock() chat.set_system_prompt = MagicMock() - chat.get_identifier.return_value = {"__type__": "MockChatTarget", "__module__": "test_module"} + chat.get_identifier.return_value = _mock_target_id("MockChatTarget") return chat @@ -534,10 +544,7 @@ async def test_max_turns_validation_with_prepended_conversation( mock_chat_objective_target = MagicMock(spec=PromptChatTarget) mock_chat_objective_target.send_prompt_async = AsyncMock() mock_chat_objective_target.set_system_prompt = MagicMock() - mock_chat_objective_target.get_identifier.return_value = { - "__type__": "MockChatTarget", - "__module__": "test_module", - } + mock_chat_objective_target.get_identifier.return_value = _mock_target_id("MockChatTarget") adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer) diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index c134abc3f..11a29f993 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -26,7 +26,7 @@ TAPAttackScoringConfig, _TreeOfAttacksNode, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, ConversationReference, @@ -223,7 +223,12 @@ def build(self) -> TreeOfAttacksWithPruningAttack: def _create_mock_target() -> PromptTarget: target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock(return_value=None) - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test_module"} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module="test_module", + class_description="", + identifier_type="instance", + ) return cast(PromptTarget, target) @staticmethod @@ -231,7 +236,12 @@ def _create_mock_chat() -> PromptChatTarget: chat = MagicMock(spec=PromptChatTarget) chat.send_prompt_async = AsyncMock(return_value=None) chat.set_system_prompt = MagicMock() - chat.get_identifier.return_value = {"__type__": "MockChatTarget", "__module__": "test_module"} + chat.get_identifier.return_value = TargetIdentifier( + class_name="MockChatTarget", + class_module="test_module", + class_description="", + identifier_type="instance", + ) return cast(PromptChatTarget, chat) @staticmethod diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index 7871e4483..3908cb89f 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -15,6 +15,7 @@ ContextComplianceAttack, SingleTurnAttackContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -26,12 +27,22 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -40,7 +51,7 @@ def mock_adversarial_chat(): """Create a mock adversarial chat target for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_adversarial_id"} + target.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_flip_attack.py b/tests/unit/executor/attack/single_turn/test_flip_attack.py index 51982a86e..faffd7709 100644 --- a/tests/unit/executor/attack/single_turn/test_flip_attack.py +++ b/tests/unit/executor/attack/single_turn/test_flip_attack.py @@ -13,6 +13,7 @@ FlipAttack, SingleTurnAttackContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -23,12 +24,22 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py index 428918fa6..9f9877c32 100644 --- a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py +++ b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py @@ -13,6 +13,7 @@ ManyShotJailbreakAttack, SingleTurnAttackContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -24,12 +25,22 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptTarget for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index 6553fd8a4..b34e1c27d 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -27,7 +27,7 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import Scorer, TrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier @pytest.fixture @@ -35,7 +35,7 @@ def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_role_play.py b/tests/unit/executor/attack/single_turn/test_role_play.py index 79b04853f..c88037a57 100644 --- a/tests/unit/executor/attack/single_turn/test_role_play.py +++ b/tests/unit/executor/attack/single_turn/test_role_play.py @@ -25,7 +25,7 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration from pyrit.prompt_target import PromptChatTarget from pyrit.score import Scorer, TrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier @pytest.fixture @@ -33,7 +33,7 @@ def mock_objective_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockTarget") return target @@ -42,7 +42,7 @@ def mock_adversarial_chat_target(): """Create a mock adversarial chat target for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_adversarial_chat_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockAdversarialChat") return target diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index 963645e43..7560b6be7 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -24,7 +24,7 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import TrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier @pytest.fixture @@ -32,7 +32,7 @@ def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockTarget") return target diff --git a/tests/unit/executor/attack/test_attack_parameter_consistency.py b/tests/unit/executor/attack/test_attack_parameter_consistency.py index b0d545076..deaa91cb1 100644 --- a/tests/unit/executor/attack/test_attack_parameter_consistency.py +++ b/tests/unit/executor/attack/test_attack_parameter_consistency.py @@ -23,7 +23,7 @@ TreeOfAttacksWithPruningAttack, ) from pyrit.executor.attack.multi_turn.tree_of_attacks import TAPAttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import ( ChatMessageRole, @@ -47,6 +47,16 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + # ============================================================================= # Multi-Modal Message Fixtures # ============================================================================= @@ -137,7 +147,7 @@ def mock_chat_target() -> MagicMock: target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"__type__": "MockChatTarget", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockChatTarget") return target @@ -146,7 +156,7 @@ def mock_non_chat_target() -> MagicMock: """Create a mock PromptTarget (non-chat) with common setup.""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -156,7 +166,7 @@ def mock_adversarial_chat() -> MagicMock: target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"__type__": "MockAdversarialChat", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockAdversarialChat") return target diff --git a/tests/unit/executor/attack/test_error_skip_scoring.py b/tests/unit/executor/attack/test_error_skip_scoring.py index 15825baeb..781406b46 100644 --- a/tests/unit/executor/attack/test_error_skip_scoring.py +++ b/tests/unit/executor/attack/test_error_skip_scoring.py @@ -20,18 +20,39 @@ ) from pyrit.executor.attack.core import AttackAdversarialConfig, AttackScoringConfig from pyrit.executor.attack.multi_turn.tree_of_attacks import TAPAttackScoringConfig +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece, SeedGroup, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.score import FloatScaleThresholdScorer, TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -40,7 +61,7 @@ def mock_scorer(): """Create a mock scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_async = AsyncMock() - scorer.get_identifier.return_value = {"id": "mock_scorer_id"} + scorer.get_identifier.return_value = _mock_scorer_id("MockScorer") return scorer @@ -150,7 +171,7 @@ async def test_attack_executor_skips_scoring_on_error( if attack_class == TreeOfAttacksWithPruningAttack: tap_scorer = MagicMock(spec=FloatScaleThresholdScorer) tap_scorer.score_async = AsyncMock() - tap_scorer.get_identifier.return_value = {"id": "mock_tap_scorer_id"} + tap_scorer.get_identifier.return_value = _mock_scorer_id("MockTapScorer") tap_scorer.threshold = 0.7 attack_scoring_config = TAPAttackScoringConfig( objective_scorer=tap_scorer, @@ -171,7 +192,7 @@ async def test_attack_executor_skips_scoring_on_error( adversarial_target = MagicMock(spec=PromptTarget) adversarial_target.send_prompt_async = AsyncMock() - adversarial_target.get_identifier.return_value = {"id": "adversarial_target_id"} + adversarial_target.get_identifier.return_value = _mock_target_id("AdversarialTarget") attack_adversarial_config = AttackAdversarialConfig( target=adversarial_target, @@ -182,7 +203,7 @@ async def test_attack_executor_skips_scoring_on_error( if attack_class == CrescendoAttack: refusal_scorer = MagicMock(spec=TrueFalseScorer) refusal_scorer.score_async = AsyncMock(return_value=[]) - refusal_scorer.get_identifier.return_value = {"id": "refusal_scorer_id"} + refusal_scorer.get_identifier.return_value = _mock_scorer_id("RefusalScorer") attack_scoring_config.refusal_scorer = refusal_scorer # Create attack with proper configuration diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py index 12028448d..9095a9af5 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py @@ -29,7 +29,7 @@ Scorer, TrueFalseScorer, ) -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 171c47eba..970980ea8 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -14,7 +14,7 @@ ScenarioIdentifier, ScenarioResult, ) -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier @pytest.fixture @@ -294,7 +294,7 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): assert retrieved.scenario_identifier.description == "A test scenario with metadata" assert retrieved.scenario_identifier.version == 3 assert retrieved.scenario_identifier.init_data == {"param1": "value1", "param2": 42} - assert retrieved.objective_target_identifier == {"target": "test_target", "endpoint": "https://example.com"} + assert retrieved.objective_target_identifier.endpoint == "https://example.com" # objective_scorer_identifier is now a ScorerIdentifier, check its properties assert retrieved.objective_scorer_identifier.class_name == "TestScorer" assert retrieved.objective_scorer_identifier.class_module == "test.module" @@ -630,4 +630,4 @@ def test_combined_filters(sqlite_instance: MemoryInterface): ) assert len(results) == 1 assert results[0].scenario_identifier.pyrit_version == "0.5.0" - assert "gpt-4" in results[0].objective_target_identifier["model_name"] + assert "gpt-4" in results[0].objective_target_identifier.model_name diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 0f6033edb..42be1f244 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -207,7 +207,7 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): assert converter_identifiers[0]["__type__"] == "Base64Converter" prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target["__type__"] == "TextTarget" + assert prompt_target.class_name == "TextTarget" labels = retrieved_entry.labels assert labels["normalizer_id"] == "id1" diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index a18e84e7e..ec825feff 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -376,7 +376,7 @@ def test_get_memories_with_json_properties(sqlite_instance): assert converter_identifiers[0]["__type__"] == "Base64Converter" prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target["__type__"] == "TextTarget" + assert prompt_target.class_name == "TextTarget" labels = retrieved_entry.labels assert labels["normalizer_id"] == "id1" diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 928bc0fba..ee1fdd6be 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,7 +9,7 @@ from typing import Generator, MutableSequence, Optional, Sequence from unittest.mock import MagicMock, patch -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -28,6 +28,26 @@ def get_mock_scorer_identifier() -> ScorerIdentifier: ) +def get_mock_target_identifier(name: str = "MockTarget", module: str = "tests.unit.mocks") -> TargetIdentifier: + """ + Returns a mock TargetIdentifier for use in tests where the specific + target identity doesn't matter. + + Args: + name: The class name for the mock target. Defaults to "MockTarget". + module: The module path for the mock target. Defaults to "tests.unit.mocks". + + Returns: + A TargetIdentifier configured with the provided name and module. + """ + return TargetIdentifier( + class_name=name, + class_module=module, + class_description="Mock target for testing", + identifier_type="instance", + ) + + class MockHttpPostAsync(AbstractAsyncContextManager): def __init__(self, url, headers=None, json=None, params=None, ssl=None): self.status = 200 diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index c4aa7d81e..cfc906d2d 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -78,8 +78,8 @@ def test_prompt_targets_serialize(patch_central_database): prompt_target_identifier=target.get_identifier(), ) assert patch_central_database.called - assert entry.prompt_target_identifier["__type__"] == "MockPromptTarget" - assert entry.prompt_target_identifier["__module__"] == "unit.mocks" + assert entry.prompt_target_identifier.class_name == "MockPromptTarget" + assert entry.prompt_target_identifier.class_module == "unit.mocks" def test_executors_serialize(): @@ -745,7 +745,7 @@ def test_message_piece_to_dict(): assert result["targeted_harm_categories"] == entry.targeted_harm_categories assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == entry.converter_identifiers - assert result["prompt_target_identifier"] == entry.prompt_target_identifier + assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() assert result["attack_identifier"] == entry.attack_identifier assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 07319440b..4ff2914f4 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -7,7 +7,7 @@ from uuid import uuid4 import pytest -from unit.mocks import MockPromptTarget, get_image_message_piece +from unit.mocks import MockPromptTarget, get_image_message_piece, get_mock_target_identifier from pyrit.exceptions import EmptyResponseException from pyrit.memory import CentralMemory @@ -130,6 +130,7 @@ async def test_send_prompt_async_empty_response_exception_handled(mock_memory_in # Use MagicMock with send_prompt_async as AsyncMock to avoid coroutine warnings on other methods prompt_target = MagicMock() prompt_target.send_prompt_async = AsyncMock(side_effect=EmptyResponseException(message="Empty response")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") @@ -209,8 +210,9 @@ async def test_send_prompt_async_exception(mock_memory_instance, seed_group): @pytest.mark.asyncio async def test_send_prompt_async_empty_exception(mock_memory_instance, seed_group): - prompt_target = AsyncMock() + prompt_target = MagicMock() prompt_target.send_prompt_async = AsyncMock(side_effect=Exception("")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") @@ -412,6 +414,7 @@ async def test_convert_response_values_type(mock_memory_instance, response: Mess async def test_send_prompt_async_exception_conv_id(mock_memory_instance, seed_group): prompt_target = MagicMock(PromptTarget) prompt_target.send_prompt_async = AsyncMock(side_effect=Exception("Test Exception")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") diff --git a/tests/unit/scenarios/test_content_harms.py b/tests/unit/scenarios/test_content_harms.py index 66598c204..bfd0fbf8c 100644 --- a/tests/unit/scenarios/test_content_harms.py +++ b/tests/unit/scenarios/test_content_harms.py @@ -9,7 +9,7 @@ import pytest from pyrit.common.path import DATASETS_PATH -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget @@ -34,11 +34,21 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -46,7 +56,7 @@ def mock_objective_target(): def mock_adversarial_target(): """Create a mock adversarial target for testing.""" mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock diff --git a/tests/unit/scenarios/test_cyber.py b/tests/unit/scenarios/test_cyber.py index 1730e7036..a19065ff8 100644 --- a/tests/unit/scenarios/test_cyber.py +++ b/tests/unit/scenarios/test_cyber.py @@ -12,7 +12,7 @@ from pyrit.common.path import DATASETS_PATH from pyrit.executor.attack import PromptSendingAttack, RedTeamingAttack from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget, PromptTarget from pyrit.scenario.airt import Cyber, CyberStrategy @@ -29,6 +29,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seed_groups(): """Create mock seed groups that _get_default_seed_groups() would return.""" @@ -75,7 +85,7 @@ def mock_runtime_env(): def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -91,7 +101,7 @@ def mock_objective_scorer(): def mock_adversarial_target(): """Create a mock adversarial target for testing.""" mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock @@ -158,7 +168,7 @@ def test_init_default_adversarial_chat(self, mock_objective_scorer, mock_memory_ def test_init_with_adversarial_chat(self, mock_objective_scorer, mock_memory_seed_groups): """Test initialization with adversarial chat (for red teaming attack variation).""" adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") with patch.object(Cyber, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Cyber( diff --git a/tests/unit/scenarios/test_encoding.py b/tests/unit/scenarios/test_encoding.py index a3624029f..7e08da6db 100644 --- a/tests/unit/scenarios/test_encoding.py +++ b/tests/unit/scenarios/test_encoding.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack import PromptSendingAttack -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter import Base64Converter from pyrit.prompt_target import PromptTarget @@ -26,6 +26,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seeds(): """Create mock seed prompts that memory.get_seeds() would return.""" @@ -41,7 +51,7 @@ def mock_memory_seeds(): def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock diff --git a/tests/unit/scenarios/test_foundry.py b/tests/unit/scenarios/test_foundry.py index 9dc960fcc..d67ace2f7 100644 --- a/tests/unit/scenarios/test_foundry.py +++ b/tests/unit/scenarios/test_foundry.py @@ -10,7 +10,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedObjective from pyrit.prompt_converter import Base64Converter from pyrit.prompt_target import PromptTarget @@ -30,6 +30,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seed_groups(): """Create mock seed groups that _get_default_seed_groups() would return.""" @@ -46,7 +56,7 @@ def mock_memory_seed_groups(): def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -54,7 +64,7 @@ def mock_objective_target(): def mock_adversarial_target(): """Create a mock adversarial target for testing.""" mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock diff --git a/tests/unit/scenarios/test_leakage_scenario.py b/tests/unit/scenarios/test_leakage_scenario.py index 3de049232..553187f21 100644 --- a/tests/unit/scenarios/test_leakage_scenario.py +++ b/tests/unit/scenarios/test_leakage_scenario.py @@ -12,7 +12,7 @@ from pyrit.common.path import DATASETS_PATH from pyrit.executor.attack import CrescendoAttack, PromptSendingAttack, RolePlayAttack from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget, PromptTarget from pyrit.scenario.airt import LeakageScenario, LeakageStrategy @@ -29,6 +29,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seeds(): leakage_path = pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" @@ -83,7 +93,7 @@ def mock_runtime_env(): @pytest.fixture def mock_objective_target(): mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -97,7 +107,7 @@ def mock_objective_scorer(): @pytest.fixture def mock_adversarial_target(): mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock @@ -170,7 +180,7 @@ def test_init_default_adversarial_chat(self, mock_objective_scorer, mock_memory_ def test_init_with_adversarial_chat(self, mock_objective_scorer, mock_memory_seeds): """Test initialization with adversarial chat (for multi-turn attack variations).""" adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") with patch.object( LeakageScenario, "_get_default_objectives", return_value=[seed.value for seed in mock_memory_seeds] diff --git a/tests/unit/scenarios/test_scam.py b/tests/unit/scenarios/test_scam.py index 8d0aa2f93..99201c4cd 100644 --- a/tests/unit/scenarios/test_scam.py +++ b/tests/unit/scenarios/test_scam.py @@ -16,7 +16,7 @@ RolePlayAttack, ) from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedDataset, SeedGroup, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget, PromptTarget from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy @@ -36,6 +36,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seed_groups() -> List[SeedGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" @@ -76,7 +86,7 @@ def mock_runtime_env(): @pytest.fixture def mock_objective_target() -> PromptTarget: mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -90,7 +100,7 @@ def mock_objective_scorer() -> TrueFalseCompositeScorer: @pytest.fixture def mock_adversarial_target() -> PromptChatTarget: mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock @@ -163,7 +173,7 @@ def test_init_with_adversarial_chat( self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: List[SeedGroup] ) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam( diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 909038cdf..20777bd68 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.core import AttackExecutorResult -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioIdentifier, ScenarioResult @@ -71,7 +71,12 @@ def mock_atomic_attacks(): def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test"} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module="test", + class_description="", + identifier_type="instance", + ) return target @@ -223,7 +228,9 @@ async def test_initialize_async_sets_objective_target(self, mock_objective_targe await scenario.initialize_async(objective_target=mock_objective_target) assert scenario._objective_target == mock_objective_target - assert scenario._objective_target_identifier == {"__type__": "MockTarget", "__module__": "test"} + # Verify it's a TargetIdentifier with the expected class_name + assert scenario._objective_target_identifier.class_name == "MockTarget" + assert scenario._objective_target_identifier.class_module == "test" @pytest.mark.asyncio async def test_initialize_async_requires_objective_target(self): diff --git a/tests/unit/scenarios/test_scenario_partial_results.py b/tests/unit/scenarios/test_scenario_partial_results.py index 5c0c67934..1886c3639 100644 --- a/tests/unit/scenarios/test_scenario_partial_results.py +++ b/tests/unit/scenarios/test_scenario_partial_results.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.core import AttackExecutorResult -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult @@ -29,7 +29,12 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test"} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module="test", + class_description="", + identifier_type="instance", + ) return target diff --git a/tests/unit/scenarios/test_scenario_retry.py b/tests/unit/scenarios/test_scenario_retry.py index 8ffc8c6de..69bcecef0 100644 --- a/tests/unit/scenarios/test_scenario_retry.py +++ b/tests/unit/scenarios/test_scenario_retry.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.core import AttackExecutorResult -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult @@ -195,7 +195,12 @@ def mock_atomic_attacks(): def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": TEST_MODULE} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module=TEST_MODULE, + class_description="", + identifier_type="instance", + ) return target diff --git a/tests/unit/score/test_gandalf_scorer.py b/tests/unit/score/test_gandalf_scorer.py index dd4e62115..728bec9c9 100644 --- a/tests/unit/score/test_gandalf_scorer.py +++ b/tests/unit/score/test_gandalf_scorer.py @@ -7,6 +7,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import PyritException from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece @@ -52,6 +53,7 @@ async def test_gandalf_scorer_score( mocked_post, sqlite_instance: MemoryInterface, level: GandalfLevel, password_correct: bool ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") conversation_id = str(uuid.uuid4()) sqlite_instance.add_message_to_memory(request=generate_request(conversation_id=conversation_id)) @@ -95,6 +97,7 @@ async def test_gandalf_scorer_set_system_prompt( sqlite_instance.add_message_to_memory(request=response) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[response]) scorer = GandalfScorer(chat_target=chat_target, level=level) @@ -119,6 +122,7 @@ async def test_gandalf_scorer_adds_to_memory(mocked_post, level: GandalfLevel, s sqlite_instance.add_message_to_memory(request=response) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[response]) # Mock the requests.post call to return a successful response @@ -139,6 +143,7 @@ async def test_gandalf_scorer_runtime_error_retries(level: GandalfLevel, sqlite_ sqlite_instance.add_message_to_memory(request=response) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(side_effect=[RuntimeError("Error"), response]) scorer = GandalfScorer(level=level, chat_target=chat_target) diff --git a/tests/unit/score/test_general_float_scale_scorer.py b/tests/unit/score/test_general_float_scale_scorer.py index 0933646d7..86f750726 100644 --- a/tests/unit/score/test_general_float_scale_scorer.py +++ b/tests/unit/score/test_general_float_scale_scorer.py @@ -6,6 +6,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.score.float_scale.self_ask_general_float_scale_scorer import ( SelfAskGeneralFloatScaleScorer, @@ -31,6 +32,7 @@ def general_float_scorer_response() -> Message: @pytest.mark.asyncio async def test_general_float_scorer_score_async(patch_central_database, general_float_scorer_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_float_scorer_response]) scorer = SelfAskGeneralFloatScaleScorer( @@ -54,6 +56,7 @@ async def test_general_float_scorer_score_async_with_prompt_f_string( general_float_scorer_response: Message, patch_central_database ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_float_scorer_response]) scorer = SelfAskGeneralFloatScaleScorer( @@ -77,6 +80,7 @@ async def test_general_float_scorer_score_async_with_prompt_f_string( @pytest.mark.asyncio async def test_general_float_scorer_score_async_handles_custom_keys(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") assert chat_target json_response = ( @@ -114,6 +118,7 @@ async def test_general_float_scorer_score_async_handles_custom_keys(patch_centra @pytest.mark.asyncio async def test_general_float_scorer_score_async_min_max_scale(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( """ @@ -145,6 +150,7 @@ async def test_general_float_scorer_score_async_min_max_scale(patch_central_data def test_general_float_scorer_init_invalid_min_max(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") with pytest.raises(ValueError): SelfAskGeneralFloatScaleScorer( chat_target=chat_target, diff --git a/tests/unit/score/test_general_true_false_scorer.py b/tests/unit/score/test_general_true_false_scorer.py index 85deb92e5..8b963ad68 100644 --- a/tests/unit/score/test_general_true_false_scorer.py +++ b/tests/unit/score/test_general_true_false_scorer.py @@ -6,6 +6,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.score import SelfAskGeneralTrueFalseScorer @@ -30,6 +31,7 @@ def general_scorer_response() -> Message: @pytest.mark.asyncio async def test_general_scorer_score_async(patch_central_database, general_scorer_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_scorer_response]) scorer = SelfAskGeneralTrueFalseScorer( @@ -54,6 +56,7 @@ async def test_general_scorer_score_async_with_prompt_f_string( general_scorer_response: Message, patch_central_database ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_scorer_response]) scorer = SelfAskGeneralTrueFalseScorer( @@ -79,6 +82,7 @@ async def test_general_scorer_score_async_with_prompt_f_string( @pytest.mark.asyncio async def test_general_scorer_score_async_handles_custom_keys(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") assert chat_target json_response = ( diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index fcd441aa2..39dd60bfb 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -8,6 +8,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions import InvalidJsonException, remove_markdown_json from pyrit.memory import CentralMemory from pyrit.models import Message, MessagePiece, Score @@ -150,6 +151,7 @@ def get_scorer_metrics(self): @pytest.mark.parametrize("bad_json", [BAD_JSON, KEY_ERROR_JSON, KEY_ERROR2_JSON]) async def test_scorer_send_chat_target_async_bad_json_exception_retries(bad_json: str): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=bad_json, conversation_id="test-convo")] ) @@ -173,6 +175,7 @@ async def test_scorer_send_chat_target_async_bad_json_exception_retries(bad_json @pytest.mark.asyncio async def test_scorer_score_value_with_llm_exception_display_prompt_id(): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(side_effect=Exception("Test exception")) scorer = MockScorer() @@ -197,6 +200,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[message]) chat_target.set_system_prompt = MagicMock() @@ -232,6 +236,7 @@ async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empt message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[message]) chat_target.set_system_prompt = MagicMock() @@ -258,6 +263,7 @@ async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empt @pytest.mark.asyncio async def test_scorer_send_chat_target_async_good_response(good_json): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] @@ -282,6 +288,7 @@ async def test_scorer_send_chat_target_async_good_response(good_json): @pytest.mark.asyncio async def test_scorer_remove_markdown_json_called(good_json): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) @@ -307,6 +314,7 @@ async def test_scorer_remove_markdown_json_called(good_json): async def test_score_value_with_llm_prepended_text_message_piece_creates_multipiece_message(good_json): """Test that prepended_text_message_piece creates a multi-piece message (text context + main content).""" chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) @@ -350,6 +358,7 @@ async def test_score_value_with_llm_prepended_text_message_piece_creates_multipi async def test_score_value_with_llm_no_prepended_text_creates_single_piece_message(good_json): """Test that without prepended_text_message_piece, only a single piece message is created.""" chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) @@ -385,6 +394,7 @@ async def test_score_value_with_llm_no_prepended_text_creates_single_piece_messa async def test_score_value_with_llm_prepended_text_works_with_audio(good_json): """Test that prepended_text_message_piece works with audio content (type-independent).""" chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) diff --git a/tests/unit/score/test_self_ask_category.py b/tests/unit/score/test_self_ask_category.py index 68038ffac..795e930fc 100644 --- a/tests/unit/score/test_self_ask_category.py +++ b/tests/unit/score/test_self_ask_category.py @@ -6,6 +6,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory from pyrit.memory.memory_interface import MemoryInterface @@ -50,6 +51,7 @@ def scorer_category_response_false() -> Message: def test_category_scorer_set_no_category_found(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskCategoryScorer( chat_target=chat_target, content_classifier_path=ContentClassifierPaths.HARMFUL_CONTENT_CLASSIFIER.value, @@ -63,6 +65,7 @@ def test_category_scorer_set_no_category_found(): @pytest.mark.asyncio async def test_category_scorer_set_system_prompt(scorer_category_response_bullying: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_bullying]) scorer = SelfAskCategoryScorer( @@ -78,6 +81,7 @@ async def test_category_scorer_set_system_prompt(scorer_category_response_bullyi @pytest.mark.asyncio async def test_category_scorer_score(scorer_category_response_bullying: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_bullying]) @@ -100,6 +104,7 @@ async def test_category_scorer_score(scorer_category_response_bullying: Message, @pytest.mark.asyncio async def test_category_scorer_score_false(scorer_category_response_false: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_false]) @@ -122,6 +127,7 @@ async def test_category_scorer_score_false(scorer_category_response_false: Messa async def test_category_scorer_adds_to_memory(scorer_category_response_false: Message, patch_central_database): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_false]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskCategoryScorer( @@ -137,6 +143,7 @@ async def test_category_scorer_adds_to_memory(scorer_category_response_false: Me @pytest.mark.asyncio async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -155,6 +162,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra @pytest.mark.asyncio async def test_self_ask_objective_scorer_json_missing_key_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( @@ -192,6 +200,7 @@ async def test_score_prompts_batch_async( patch_central_database, ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock() chat_target._max_requests_per_minute = max_requests_per_minute with patch.object(CentralMemory, "get_memory_instance", return_value=MagicMock()): diff --git a/tests/unit/score/test_self_ask_likert.py b/tests/unit/score/test_self_ask_likert.py index e194f5c2b..f907f7951 100644 --- a/tests/unit/score/test_self_ask_likert.py +++ b/tests/unit/score/test_self_ask_likert.py @@ -6,6 +6,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import Message, MessagePiece @@ -36,6 +37,7 @@ async def test_likert_scorer_set_system_prompt(scorer_likert_response: Message): memory = MagicMock(MemoryInterface) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_likert_response]) scorer = SelfAskLikertScorer(chat_target=chat_target, likert_scale=LikertScalePaths.CYBER_SCALE) @@ -58,6 +60,7 @@ async def test_likert_scorer_set_system_prompt(scorer_likert_response: Message): async def test_likert_scorer_adds_to_memory(scorer_likert_response: Message): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_likert_response]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskLikertScorer(chat_target=chat_target, likert_scale=LikertScalePaths.CYBER_SCALE) @@ -70,6 +73,7 @@ async def test_likert_scorer_adds_to_memory(scorer_likert_response: Message): @pytest.mark.asyncio async def test_likert_scorer_score(patch_central_database, scorer_likert_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_likert_response]) @@ -91,6 +95,7 @@ async def test_likert_scorer_score(patch_central_database, scorer_likert_respons @pytest.mark.asyncio async def test_self_ask_scorer_bad_json_exception_retries(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -105,6 +110,7 @@ async def test_self_ask_scorer_bad_json_exception_retries(): @pytest.mark.asyncio async def test_self_ask_likert_scorer_json_missing_key_exception_retries(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( diff --git a/tests/unit/score/test_self_ask_refusal.py b/tests/unit/score/test_self_ask_refusal.py index 4d99e1263..8e0cbe5aa 100644 --- a/tests/unit/score/test_self_ask_refusal.py +++ b/tests/unit/score/test_self_ask_refusal.py @@ -7,6 +7,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory from pyrit.memory.memory_interface import MemoryInterface @@ -32,6 +33,7 @@ def scorer_true_false_response() -> Message: @pytest.mark.asyncio async def test_refusal_scorer_score(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -48,6 +50,7 @@ async def test_refusal_scorer_score(scorer_true_false_response: Message, patch_c @pytest.mark.asyncio async def test_refusal_scorer_set_system_prompt(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -59,6 +62,7 @@ async def test_refusal_scorer_set_system_prompt(scorer_true_false_response: Mess @pytest.mark.asyncio async def test_refusal_scorer_no_task(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -71,6 +75,7 @@ async def test_refusal_scorer_no_task(scorer_true_false_response: Message, patch @pytest.mark.asyncio async def test_refusal_scorer_with_task(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -85,6 +90,7 @@ async def test_refusal_scorer_with_task(scorer_true_false_response: Message, pat @pytest.mark.asyncio async def test_refusal_scorer_image_non_block(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -101,6 +107,7 @@ async def test_refusal_scorer_image_non_block(scorer_true_false_response: Messag async def test_refusal_scorer_adds_to_memory(scorer_true_false_response: Message): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -112,6 +119,7 @@ async def test_refusal_scorer_adds_to_memory(scorer_true_false_response: Message @pytest.mark.asyncio async def test_refusal_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -127,6 +135,7 @@ async def test_refusal_scorer_bad_json_exception_retries(patch_central_database) @pytest.mark.asyncio async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( @@ -155,6 +164,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra async def test_score_async_filtered_response(patch_central_database): memory = CentralMemory.get_memory_instance() chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskRefusalScorer(chat_target=chat_target) request = MessagePiece(role="assistant", original_value="blocked response", response_error="blocked").to_message() diff --git a/tests/unit/score/test_self_ask_scale.py b/tests/unit/score/test_self_ask_scale.py index 07c485056..e5b6d723f 100644 --- a/tests/unit/score/test_self_ask_scale.py +++ b/tests/unit/score/test_self_ask_scale.py @@ -8,6 +8,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.identifiers import ScorerIdentifier from pyrit.models import Message, MessagePiece, UnvalidatedScore from pyrit.score import ContentClassifierPaths, SelfAskScaleScorer @@ -39,8 +40,10 @@ def scorer_scale_response() -> Message: @pytest.fixture def scale_scorer(patch_central_database) -> SelfAskScaleScorer: + chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") return SelfAskScaleScorer( - chat_target=MagicMock(), + chat_target=chat_target, scale_arguments_path=SelfAskScaleScorer.ScalePaths.TREE_OF_ATTACKS_SCALE.value, system_prompt_path=SelfAskScaleScorer.SystemPaths.GENERAL_SYSTEM_PROMPT.value, ) @@ -65,6 +68,7 @@ async def test_scale_scorer_set_system_prompt( patch_central_database, ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_scale_response]) scorer = SelfAskScaleScorer( @@ -86,6 +90,7 @@ async def test_scale_scorer_set_system_prompt( def test_scale_scorer_invalid_scale_file_contents(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") # When using a YAML with wrong keys the Scale constructor will raise an exception. with pytest.raises(ValueError, match="Missing key in scale_args:"): SelfAskScaleScorer( @@ -135,6 +140,7 @@ def test_validate_scale_arguments_missing_args_raises_value_error(scale_args, sc @pytest.mark.asyncio async def test_scale_scorer_score(scorer_scale_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_scale_response]) @@ -161,6 +167,7 @@ async def test_scale_scorer_score(scorer_scale_response: Message, patch_central_ @pytest.mark.asyncio async def test_scale_scorer_score_custom_scale(scorer_scale_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") # set a higher score to test the scaling scorer_scale_response.message_pieces[0].original_value = scorer_scale_response.message_pieces[ @@ -197,6 +204,7 @@ async def test_scale_scorer_score_custom_scale(scorer_scale_response: Message, p @pytest.mark.asyncio async def test_scale_scorer_score_calls_send_chat(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskScaleScorer( chat_target=chat_target, diff --git a/tests/unit/score/test_self_ask_true_false.py b/tests/unit/score/test_self_ask_true_false.py index 70747285f..6b381d05d 100644 --- a/tests/unit/score/test_self_ask_true_false.py +++ b/tests/unit/score/test_self_ask_true_false.py @@ -6,6 +6,7 @@ import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory.central_memory import CentralMemory from pyrit.memory.memory_interface import MemoryInterface @@ -31,6 +32,7 @@ def scorer_true_false_response() -> Message: @pytest.mark.asyncio async def test_true_false_scorer_score(patch_central_database, scorer_true_false_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskTrueFalseScorer( @@ -49,6 +51,7 @@ async def test_true_false_scorer_score(patch_central_database, scorer_true_false @pytest.mark.asyncio async def test_true_false_scorer_set_system_prompt(patch_central_database, scorer_true_false_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskTrueFalseScorer( @@ -68,6 +71,7 @@ async def test_true_false_scorer_set_system_prompt(patch_central_database, score async def test_true_false_scorer_adds_to_memory(scorer_true_false_response: Message): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskTrueFalseScorer( @@ -82,6 +86,7 @@ async def test_true_false_scorer_adds_to_memory(scorer_true_false_response: Mess @pytest.mark.asyncio async def test_self_ask_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -111,6 +116,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra ) bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value=json_response)]) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) scorer = SelfAskTrueFalseScorer( @@ -127,6 +133,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra def test_self_ask_true_false_scorer_identifier_has_system_prompt_template(patch_central_database): """Test that identifier includes system_prompt_template.""" chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskTrueFalseScorer( chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value @@ -143,6 +150,7 @@ def test_self_ask_true_false_scorer_identifier_has_system_prompt_template(patch_ def test_self_ask_true_false_get_identifier_type(patch_central_database): """Test that get_identifier returns correct class_name.""" chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskTrueFalseScorer( chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value @@ -158,6 +166,7 @@ def test_self_ask_true_false_get_identifier_type(patch_central_database): def test_self_ask_true_false_get_identifier_long_prompt_hashed(patch_central_database): """Test that long system prompts are truncated when serialized via to_dict().""" chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskTrueFalseScorer( chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index acab4b162..0429dbc4a 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -15,7 +15,7 @@ from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_scorer import TrueFalseScorer from pyrit.score.true_false.video_true_false_scorer import VideoTrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier def is_opencv_installed(): diff --git a/tests/unit/target/test_crucible_target.py b/tests/unit/target/test_crucible_target.py index a2e71e82e..9469d8e7f 100644 --- a/tests/unit/target/test_crucible_target.py +++ b/tests/unit/target/test_crucible_target.py @@ -20,7 +20,7 @@ def test_crucible_initializes(crucible_target: CrucibleTarget): def test_crucible_sets_endpoint_and_rate_limit(): target = CrucibleTarget(endpoint="https://crucible", api_key="abc", max_requests_per_minute=10) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://crucible" + assert identifier.endpoint == "https://crucible" assert target._max_requests_per_minute == 10 diff --git a/tests/unit/target/test_gandalf_target.py b/tests/unit/target/test_gandalf_target.py index d8e3ffdb9..ffd7a8274 100644 --- a/tests/unit/target/test_gandalf_target.py +++ b/tests/unit/target/test_gandalf_target.py @@ -20,7 +20,7 @@ def test_gandalf_initializes(gandalf_target: GandalfTarget): def test_gandalf_sets_endpoint_and_rate_limit(): target = GandalfTarget(level=GandalfLevel.LEVEL_1, max_requests_per_minute=15) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://gandalf-api.lakera.ai/api/send-message" + assert identifier.endpoint == "https://gandalf-api.lakera.ai/api/send-message" assert target._max_requests_per_minute == 15 diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index 6f5b5842d..5d49702b0 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -59,7 +59,7 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite max_requests_per_minute=25, ) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://example.com/" + assert identifier.endpoint == "https://example.com/" assert target._max_requests_per_minute == 25 @@ -67,7 +67,7 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite @patch("httpx.AsyncClient.request") async def test_send_prompt_async(mock_request, mock_http_target, mock_http_response): message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] mock_request.return_value = mock_http_response response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -113,7 +113,7 @@ async def test_send_prompt_async_client_kwargs(): # Use **httpx_client_kwargs to pass them as keyword arguments http_target = HTTPTarget(http_request=sample_request, **httpx_client_kwargs) message = MagicMock() - message.message_pieces = [MagicMock(converted_value="")] + message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None)] mock_response = MagicMock() mock_response.content = b"Response content" mock_request.return_value = mock_response @@ -148,7 +148,7 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): mock_http_target.callback_function = callback_function message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] mock_response = MagicMock() mock_response.content = b"Match: 1234" @@ -175,7 +175,7 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send first prompt message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -193,7 +193,7 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send second prompt second_message = MagicMock() - second_message.message_pieces = [MagicMock(converted_value="second_test_prompt")] + second_message.message_pieces = [MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None)] await mock_http_target.send_prompt_async(message=second_message) # Assert that the original template is still the same @@ -241,7 +241,7 @@ async def test_http_target_with_injected_client(): mock_request.return_value = mock_response message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] response = await target.send_prompt_async(message=message) diff --git a/tests/unit/target/test_hugging_face_endpoint_target.py b/tests/unit/target/test_hugging_face_endpoint_target.py index 790654204..2bc4de352 100644 --- a/tests/unit/target/test_hugging_face_endpoint_target.py +++ b/tests/unit/target/test_hugging_face_endpoint_target.py @@ -29,7 +29,7 @@ def test_hugging_face_endpoint_sets_endpoint_and_rate_limit(): max_requests_per_minute=30, ) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://api-inference.huggingface.co/models/test-model" + assert identifier.endpoint == "https://api-inference.huggingface.co/models/test-model" assert target._max_requests_per_minute == 30 diff --git a/tests/unit/target/test_huggingface_chat_target.py b/tests/unit/target/test_huggingface_chat_target.py index c100f8d8e..ec1593c79 100644 --- a/tests/unit/target/test_huggingface_chat_target.py +++ b/tests/unit/target/test_huggingface_chat_target.py @@ -329,7 +329,7 @@ async def test_is_json_response_supported(): @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio -async def test_hugging_face_chat_sets_endpoint_and_rate_limit(): +async def test_hugging_face_chat_sets_endpoint_and_rate_limit(patch_central_database): target = HuggingFaceChatTarget( model_id="test_model", use_cuda=False, @@ -338,6 +338,6 @@ async def test_hugging_face_chat_sets_endpoint_and_rate_limit(): # Await the background task to prevent warnings await target.load_model_and_tokenizer_task identifier = target.get_identifier() - # HuggingFaceChatTarget doesn't set an endpoint (it's local), so it shouldn't be in identifier - assert "endpoint" not in identifier + # HuggingFaceChatTarget doesn't set an endpoint (it's local), so it should be empty + assert not identifier.endpoint assert target._max_requests_per_minute == 30 diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 7e798c06c..21069883e 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -1069,16 +1069,18 @@ async def test_construct_message_from_response(target: OpenAIChatTarget, dummy_t def test_get_identifier_uses_model_name_when_no_underlying_model(patch_central_database): """Test that get_identifier uses model_name when underlying_model is not provided.""" - target = OpenAIChatTarget( - model_name="my-deployment", - endpoint="https://mock.azure.com/", - api_key="mock-api-key", - ) + # Clear the environment variable to ensure it doesn't interfere with the test + with patch.dict(os.environ, {"OPENAI_CHAT_UNDERLYING_MODEL": ""}, clear=False): + target = OpenAIChatTarget( + model_name="my-deployment", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) - identifier = target.get_identifier() + identifier = target.get_identifier() - assert identifier["model_name"] == "my-deployment" - assert identifier["__type__"] == "OpenAIChatTarget" + assert identifier.model_name == "my-deployment" + assert identifier.class_name == "OpenAIChatTarget" def test_get_identifier_uses_underlying_model_when_provided_as_param(patch_central_database): @@ -1092,8 +1094,8 @@ def test_get_identifier_uses_underlying_model_when_provided_as_param(patch_centr identifier = target.get_identifier() - assert identifier["model_name"] == "gpt-4o" - assert identifier["__type__"] == "OpenAIChatTarget" + assert identifier.model_name == "gpt-4o" + assert identifier.class_name == "OpenAIChatTarget" def test_get_identifier_uses_underlying_model_from_env_var(patch_central_database): @@ -1107,7 +1109,7 @@ def test_get_identifier_uses_underlying_model_from_env_var(patch_central_databas identifier = target.get_identifier() - assert identifier["model_name"] == "gpt-4o" + assert identifier.model_name == "gpt-4o" def test_underlying_model_param_takes_precedence_over_env_var(patch_central_database): @@ -1122,7 +1124,7 @@ def test_underlying_model_param_takes_precedence_over_env_var(patch_central_data identifier = target.get_identifier() - assert identifier["model_name"] == "gpt-4o-from-param" + assert identifier.model_name == "gpt-4o-from-param" def test_get_identifier_includes_endpoint(patch_central_database): @@ -1135,7 +1137,7 @@ def test_get_identifier_includes_endpoint(patch_central_database): identifier = target.get_identifier() - assert identifier["endpoint"] == "https://mock.azure.com/" + assert identifier.endpoint == "https://mock.azure.com/" def test_get_identifier_includes_temperature_when_set(patch_central_database): @@ -1149,7 +1151,7 @@ def test_get_identifier_includes_temperature_when_set(patch_central_database): identifier = target.get_identifier() - assert identifier["temperature"] == 0.7 + assert identifier.temperature == 0.7 def test_get_identifier_includes_top_p_when_set(patch_central_database): @@ -1163,7 +1165,7 @@ def test_get_identifier_includes_top_p_when_set(patch_central_database): identifier = target.get_identifier() - assert identifier["top_p"] == 0.9 + assert identifier.top_p == 0.9 # ============================================================================ From 4428df9e4f45ed3c4b82d935e5ca9f9991305c8c Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 14:09:12 -0800 Subject: [PATCH 02/10] first_commitv2 --- pyrit/memory/memory_models.py | 4 +++- pyrit/prompt_target/common/prompt_target.py | 2 +- tests/unit/converter/test_denylist_converter.py | 2 +- .../executor/attack/component/test_conversation_manager.py | 3 ++- tests/unit/executor/attack/single_turn/test_prompt_sending.py | 2 +- tests/unit/executor/attack/single_turn/test_role_play.py | 2 +- tests/unit/executor/attack/single_turn/test_skeleton_key.py | 2 +- tests/unit/executor/promptgen/fuzzer/test_fuzzer.py | 3 +-- .../memory_interface/test_interface_scenario_results.py | 2 +- tests/unit/score/test_gandalf_scorer.py | 2 +- tests/unit/score/test_general_float_scale_scorer.py | 2 +- tests/unit/score/test_general_true_false_scorer.py | 2 +- tests/unit/score/test_scorer.py | 2 +- tests/unit/score/test_self_ask_category.py | 2 +- tests/unit/score/test_self_ask_likert.py | 2 +- tests/unit/score/test_self_ask_refusal.py | 2 +- tests/unit/score/test_self_ask_scale.py | 2 +- tests/unit/score/test_self_ask_true_false.py | 2 +- tests/unit/score/test_video_scorer.py | 2 +- 19 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index b653ee20e..2783fa5ca 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -268,7 +268,9 @@ def __str__(self) -> str: """ if self.prompt_target_identifier: # prompt_target_identifier is stored as dict in the database - class_name = self.prompt_target_identifier.get("class_name") or self.prompt_target_identifier.get("__type__", "Unknown") + class_name = self.prompt_target_identifier.get("class_name") or self.prompt_target_identifier.get( + "__type__", "Unknown" + ) return f"{class_name}: {self.role}: {self.converted_value}" return f": {self.role}: {self.converted_value}" diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 7095e6c47..019c9cfa5 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from pyrit.identifiers import Identifiable, TargetIdentifier from pyrit.memory import CentralMemory, MemoryInterface diff --git a/tests/unit/converter/test_denylist_converter.py b/tests/unit/converter/test_denylist_converter.py index 8bbf06c58..c75d8f920 100644 --- a/tests/unit/converter/test_denylist_converter.py +++ b/tests/unit/converter/test_denylist_converter.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from unit.mocks import get_mock_target_identifier, MockPromptTarget +from unit.mocks import MockPromptTarget, get_mock_target_identifier from pyrit.models import Message, MessagePiece, SeedPrompt from pyrit.prompt_converter import DenylistConverter diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index 5c7f6c50c..b874169d4 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_scorer_identifier from pyrit.executor.attack import ConversationManager, ConversationState from pyrit.executor.attack.component import PrependedConversationConfig @@ -37,7 +38,6 @@ from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget, PromptTarget -from unit.mocks import get_mock_scorer_identifier def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: @@ -49,6 +49,7 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: identifier_type="instance", ) + # ============================================================================= # Test Context Class # ============================================================================= diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index b34e1c27d..b55da7cf3 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier from pyrit.executor.attack import ( AttackConverterConfig, @@ -27,7 +28,6 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import Scorer, TrueFalseScorer -from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier @pytest.fixture diff --git a/tests/unit/executor/attack/single_turn/test_role_play.py b/tests/unit/executor/attack/single_turn/test_role_play.py index c88037a57..98218e3c9 100644 --- a/tests/unit/executor/attack/single_turn/test_role_play.py +++ b/tests/unit/executor/attack/single_turn/test_role_play.py @@ -8,6 +8,7 @@ import pytest import yaml +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier from pyrit.executor.attack import ( AttackConverterConfig, @@ -25,7 +26,6 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration from pyrit.prompt_target import PromptChatTarget from pyrit.score import Scorer, TrueFalseScorer -from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier @pytest.fixture diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index 7560b6be7..41278e9bd 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier from pyrit.executor.attack import ( AttackConverterConfig, @@ -24,7 +25,6 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import TrueFalseScorer -from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier @pytest.fixture diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py index 9095a9af5..7b78eed2d 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from unit.mocks import MockPromptTarget +from unit.mocks import MockPromptTarget, get_mock_scorer_identifier from pyrit.common.path import DATASETS_PATH from pyrit.datasets import TextJailBreak @@ -29,7 +29,6 @@ Scorer, TrueFalseScorer, ) -from unit.mocks import get_mock_scorer_identifier @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 970980ea8..4fd86a413 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -5,6 +5,7 @@ from typing import Optional import pytest +from unit.mocks import get_mock_scorer_identifier from pyrit.identifiers import ScorerIdentifier from pyrit.memory import MemoryInterface @@ -14,7 +15,6 @@ ScenarioIdentifier, ScenarioResult, ) -from unit.mocks import get_mock_scorer_identifier @pytest.fixture diff --git a/tests/unit/score/test_gandalf_scorer.py b/tests/unit/score/test_gandalf_scorer.py index 728bec9c9..abad3a306 100644 --- a/tests/unit/score/test_gandalf_scorer.py +++ b/tests/unit/score/test_gandalf_scorer.py @@ -6,8 +6,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.exceptions.exception_classes import PyritException from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece diff --git a/tests/unit/score/test_general_float_scale_scorer.py b/tests/unit/score/test_general_float_scale_scorer.py index 86f750726..7ee85404d 100644 --- a/tests/unit/score/test_general_float_scale_scorer.py +++ b/tests/unit/score/test_general_float_scale_scorer.py @@ -5,8 +5,8 @@ from unittest.mock import AsyncMock, MagicMock import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.models import Message, MessagePiece from pyrit.score.float_scale.self_ask_general_float_scale_scorer import ( SelfAskGeneralFloatScaleScorer, diff --git a/tests/unit/score/test_general_true_false_scorer.py b/tests/unit/score/test_general_true_false_scorer.py index 8b963ad68..49e4b9839 100644 --- a/tests/unit/score/test_general_true_false_scorer.py +++ b/tests/unit/score/test_general_true_false_scorer.py @@ -5,8 +5,8 @@ from unittest.mock import AsyncMock, MagicMock import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.models import Message, MessagePiece from pyrit.score import SelfAskGeneralTrueFalseScorer diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 39dd60bfb..c2475d907 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -7,8 +7,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.exceptions import InvalidJsonException, remove_markdown_json from pyrit.memory import CentralMemory from pyrit.models import Message, MessagePiece, Score diff --git a/tests/unit/score/test_self_ask_category.py b/tests/unit/score/test_self_ask_category.py index 795e930fc..390788d1e 100644 --- a/tests/unit/score/test_self_ask_category.py +++ b/tests/unit/score/test_self_ask_category.py @@ -5,8 +5,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory from pyrit.memory.memory_interface import MemoryInterface diff --git a/tests/unit/score/test_self_ask_likert.py b/tests/unit/score/test_self_ask_likert.py index f907f7951..6e0db8188 100644 --- a/tests/unit/score/test_self_ask_likert.py +++ b/tests/unit/score/test_self_ask_likert.py @@ -5,8 +5,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import Message, MessagePiece diff --git a/tests/unit/score/test_self_ask_refusal.py b/tests/unit/score/test_self_ask_refusal.py index 8e0cbe5aa..75d3fbc65 100644 --- a/tests/unit/score/test_self_ask_refusal.py +++ b/tests/unit/score/test_self_ask_refusal.py @@ -6,8 +6,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory from pyrit.memory.memory_interface import MemoryInterface diff --git a/tests/unit/score/test_self_ask_scale.py b/tests/unit/score/test_self_ask_scale.py index e5b6d723f..2a437e09c 100644 --- a/tests/unit/score/test_self_ask_scale.py +++ b/tests/unit/score/test_self_ask_scale.py @@ -7,8 +7,8 @@ from unittest.mock import AsyncMock, MagicMock import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.identifiers import ScorerIdentifier from pyrit.models import Message, MessagePiece, UnvalidatedScore from pyrit.score import ContentClassifierPaths, SelfAskScaleScorer diff --git a/tests/unit/score/test_self_ask_true_false.py b/tests/unit/score/test_self_ask_true_false.py index 6b381d05d..80cc21ec2 100644 --- a/tests/unit/score/test_self_ask_true_false.py +++ b/tests/unit/score/test_self_ask_true_false.py @@ -5,8 +5,8 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - from unit.mocks import get_mock_target_identifier + from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory.central_memory import CentralMemory from pyrit.memory.memory_interface import MemoryInterface diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index 0429dbc4a..95138df82 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -8,6 +8,7 @@ import numpy as np import pytest +from unit.mocks import get_mock_scorer_identifier from pyrit.models import MessagePiece, Score from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer @@ -15,7 +16,6 @@ from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_scorer import TrueFalseScorer from pyrit.score.true_false.video_true_false_scorer import VideoTrueFalseScorer -from unit.mocks import get_mock_scorer_identifier def is_opencv_installed(): From 6e279a96a40b4ba2aca476b4796bf3ddfeade7e9 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 16:07:09 -0800 Subject: [PATCH 03/10] tests --- pyrit/identifiers/target_identifier.py | 3 + .../azure_blob_storage_target.py | 13 +- pyrit/prompt_target/azure_ml_chat_target.py | 13 +- pyrit/prompt_target/common/prompt_target.py | 38 +- pyrit/prompt_target/gandalf_target.py | 12 +- .../prompt_target/http_target/http_target.py | 13 +- .../hugging_face/hugging_face_chat_target.py | 17 +- .../hugging_face_endpoint_target.py | 12 +- .../openai/openai_chat_target.py | 12 +- .../openai/openai_completion_target.py | 12 +- .../openai/openai_image_target.py | 12 +- .../openai/openai_realtime_target.py | 12 +- .../openai/openai_response_target.py | 12 +- .../prompt_target/openai/openai_tts_target.py | 12 +- .../openai/openai_video_target.py | 12 +- .../playwright_copilot_target.py | 12 +- pyrit/prompt_target/prompt_shield_target.py | 12 +- .../prompt_target/websocket_copilot_target.py | 14 + .../identifiers/test_target_identifier.py | 550 ++++++++++++++++++ 19 files changed, 722 insertions(+), 71 deletions(-) create mode 100644 tests/unit/identifiers/test_target_identifier.py diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index 9d4bc9330..700110a39 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -39,6 +39,9 @@ class TargetIdentifier(Identifier): top_p: Optional[float] = None """The top_p parameter for generation.""" + max_requests_per_minute: Optional[int] = None + """Maximum number of requests per minute.""" + target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 505c463d5..91b1f21da 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -12,6 +12,7 @@ from pyrit.auth import AzureStorageAuth from pyrit.common import default_values +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -79,10 +80,16 @@ def __init__( super().__init__(endpoint=self._container_url, max_requests_per_minute=max_requests_per_minute) - def _build_identifier(self) -> None: - """Build the identifier with Azure Blob Storage-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Azure Blob Storage-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ + "container_url": self._container_url, "blob_content_type": self._blob_content_type, }, ) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 2aed246a3..bc0ba056d 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -13,6 +13,7 @@ handle_bad_request_exception, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer from pyrit.models import ( Message, @@ -103,14 +104,20 @@ def __init__( self._repetition_penalty = repetition_penalty self._extra_parameters = param_kwargs - def _build_identifier(self) -> None: - """Build the identifier with Azure ML-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Azure ML-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( temperature=self._temperature, top_p=self._top_p, target_specific_params={ "max_new_tokens": self._max_new_tokens, "repetition_penalty": self._repetition_penalty, + "message_normalizer": self.message_normalizer.__class__.__name__, }, ) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 019c9cfa5..f82d37733 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -93,13 +93,13 @@ def dispose_db_engine(self) -> None: """ self._memory.dispose_engine() - def _set_identifier( + def _create_identifier( self, *, temperature: Optional[float] = None, top_p: Optional[float] = None, target_specific_params: Optional[dict[str, Any]] = None, - ) -> None: + ) -> TargetIdentifier: """ Construct the target identifier. @@ -111,6 +111,9 @@ def _set_identifier( top_p (Optional[float]): The top_p parameter for generation. Defaults to None. target_specific_params (Optional[dict[str, Any]]): Additional target-specific parameters that should be included in the identifier. Defaults to None. + + Returns: + TargetIdentifier: The identifier for this prompt target. """ # Determine the model name to use model_name = "" @@ -119,7 +122,7 @@ def _set_identifier( elif self._model_name: model_name = self._model_name - self._identifier = TargetIdentifier( + return TargetIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", @@ -128,38 +131,21 @@ def _set_identifier( model_name=model_name, temperature=temperature, top_p=top_p, + max_requests_per_minute=self._max_requests_per_minute, target_specific_params=target_specific_params, ) - def _build_identifier(self) -> None: + def _build_identifier(self) -> TargetIdentifier: """ Build the identifier for this target. - Subclasses should override this method to call _set_identifier() with + Subclasses can override this method to call _create_identifier() with their specific parameters (temperature, top_p, target_specific_params). - The base implementation calls _set_identifier() with no parameters, + The base implementation calls _create_identifier() with no parameters, which works for targets that don't have model-specific settings. - """ - self._set_identifier() - - def get_identifier(self) -> TargetIdentifier: - """ - Get the target identifier. Built lazily on first access. Returns: - TargetIdentifier: The identifier containing all configuration parameters. - - Note: - If `self._underlying_model` is specified (via instantiation or environment - variable), it is used as the "model_name". Otherwise, `self._model_name` - (which is often the deployment name in Azure) is used. - - For storage in memory/database, call `.to_dict()` on the returned - identifier to get a dictionary suitable for JSON serialization. + TargetIdentifier: The identifier for this prompt target. """ - if self._identifier is None: - self._build_identifier() - if self._identifier is None: - raise RuntimeError("_build_identifier must set _identifier") - return self._identifier + return self._create_identifier() diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 2a18e3919..e823926f3 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -7,6 +7,7 @@ from typing import Optional from pyrit.common import net_utility +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -57,9 +58,14 @@ def __init__( self._defender = level.value - def _build_identifier(self) -> None: - """Build the identifier with Gandalf-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Gandalf-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "level": self._defender, }, diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index e7361ad61..17acfe881 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -9,6 +9,7 @@ import httpx +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -82,12 +83,18 @@ def __init__( if client and httpx_client_kwargs: raise ValueError("Cannot provide both a pre-configured client and additional httpx client kwargs.") - def _build_identifier(self) -> None: - """Build the identifier with HTTP target-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with HTTP target-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "use_tls": self.use_tls, "prompt_regex_string": self.prompt_regex_string, + "callback_function": getattr(self.callback_function, "__name__", None), }, ) diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 950a4def8..5fe1104a8 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -16,6 +16,7 @@ from pyrit.common import default_values from pyrit.common.download_hf_model import download_specific_files from pyrit.exceptions import EmptyResponseException, pyrit_target_retry +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -135,15 +136,25 @@ def __init__( self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer()) - def _build_identifier(self) -> None: - """Build the identifier with HuggingFace chat-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with HuggingFace chat-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( temperature=self._temperature, top_p=self._top_p, target_specific_params={ "max_new_tokens": self.max_new_tokens, "skip_special_tokens": self.skip_special_tokens, "use_cuda": self.use_cuda, + "tensor_format": self.tensor_format, + "trust_remote_code": self.trust_remote_code, + "device_map": self.device_map, + "torch_dtype": str(self.torch_dtype) if self.torch_dtype else None, + "attn_implementation": self.attn_implementation, }, ) diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index f4aa7cd55..2407d6918 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -5,6 +5,7 @@ from typing import Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -61,9 +62,14 @@ def __init__( self._temperature = temperature self._top_p = top_p - def _build_identifier(self) -> None: - """Build the identifier with HuggingFace endpoint-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with HuggingFace endpoint-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( temperature=self._temperature, top_p=self._top_p, target_specific_params={ diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 3e4a9c16b..87ffa26f4 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -12,6 +12,7 @@ PyritException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( ChatMessage, DataTypeSerializer, @@ -163,9 +164,14 @@ def __init__( self._extra_body_parameters = extra_body_parameters - def _build_identifier(self) -> None: - """Build the identifier with OpenAI chat-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with OpenAI chat-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( temperature=self._temperature, top_p=self._top_p, target_specific_params={ diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 376b82bb2..00800d72e 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -7,6 +7,7 @@ from pyrit.exceptions.exception_classes import ( pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.openai.openai_target import OpenAITarget @@ -72,9 +73,14 @@ def __init__( self._presence_penalty = presence_penalty self._n = n - def _build_identifier(self) -> None: - """Build the identifier with OpenAI completion-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with OpenAI completion-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( temperature=self._temperature, top_p=self._top_p, target_specific_params={ diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 774555790..9aca0a3e0 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -10,6 +10,7 @@ EmptyResponseException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, construct_response_from_request, @@ -85,9 +86,14 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } - def _build_identifier(self) -> None: - """Build the identifier with image generation-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with image generation-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "image_size": self.image_size, "quality": self.quality, diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 17c89f8c8..9e5ef778d 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -15,6 +15,7 @@ pyrit_target_retry, ) from pyrit.exceptions.exception_classes import ServerErrorException +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, construct_response_from_request, @@ -118,9 +119,14 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "wss://api.openai.com/v1", } - def _build_identifier(self) -> None: - """Build the identifier with Realtime API-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Realtime API-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "voice": self.voice, }, diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 6cefa23e3..bc709fa4b 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -21,6 +21,7 @@ PyritException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -156,9 +157,14 @@ def __init__( logger.debug("Detected grammar tool: %s", tool_name) self._grammar_name = tool_name - def _build_identifier(self) -> None: - """Build the identifier with OpenAI response-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with OpenAI response-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( temperature=self._temperature, top_p=self._top_p, target_specific_params={ diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 6d7f27d07..94194d78b 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -7,6 +7,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, construct_response_from_request, @@ -81,9 +82,14 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } - def _build_identifier(self) -> None: - """Build the identifier with TTS-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with TTS-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "voice": self._voice, "response_format": self._response_format, diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 8b3ba4eb5..f6915c027 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -7,6 +7,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -95,9 +96,14 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } - def _build_identifier(self) -> None: - """Build the identifier with video generation-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with video generation-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "resolution": self._size, "n_seconds": self._n_seconds, diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index aa6a7327c..c9a791601 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -8,6 +8,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, List, Tuple, Union +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -128,9 +129,14 @@ def __init__(self, *, page: "Page", copilot_type: CopilotType = CopilotType.CONS if page and self.M365_URL_IDENTIFIER not in page.url and copilot_type == CopilotType.M365: raise ValueError("The provided page URL does not indicate M365 Copilot, but the type is set to m365.") - def _build_identifier(self) -> None: - """Build the identifier with Copilot-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Copilot-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "copilot_type": self._type.value, }, diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index e64ae7576..b2da771ad 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, Optional, Sequence from pyrit.common import default_values, net_utility +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -93,9 +94,14 @@ def __init__( self._force_entry_field: PromptShieldEntryField = field - def _build_identifier(self) -> None: - """Build the identifier with Prompt Shield-specific parameters.""" - self._set_identifier( + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Prompt Shield-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( target_specific_params={ "api_version": self._api_version, "force_entry_field": self._force_entry_field if self._force_entry_field else None, diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 103c7efc1..da9f16432 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -16,6 +16,7 @@ EmptyResponseException, pyrit_target_retry, ) +from pyrit.identifiers.target_identifier import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute @@ -115,6 +116,19 @@ def __init__( model_name=model_name, ) + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with WebSocketCopilot-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "response_timeout_seconds": self._response_timeout_seconds, + }, + ) + @staticmethod def _dict_to_websocket(data: dict[str, Any]) -> str: """ diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py new file mode 100644 index 000000000..3621c1dcb --- /dev/null +++ b/tests/unit/identifiers/test_target_identifier.py @@ -0,0 +1,550 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for TargetIdentifier-specific functionality. + +Note: Base Identifier functionality (hash computation, to_dict/from_dict basics, +frozen/hashable properties) is tested via ScorerIdentifier in test_scorer_identifier.py. +These tests focus on target-specific fields and behaviors. +""" + +import pytest + +from pyrit.identifiers import TargetIdentifier + + +class TestTargetIdentifierBasic: + """Test basic TargetIdentifier functionality.""" + + def test_target_identifier_creation_minimal(self): + """Test creating a TargetIdentifier with only required fields.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + assert identifier.class_name == "TestTarget" + assert identifier.class_module == "pyrit.prompt_target.test_target" + assert identifier.endpoint == "" + assert identifier.model_name == "" + assert identifier.temperature is None + assert identifier.top_p is None + assert identifier.max_requests_per_minute is None + assert identifier.target_specific_params is None + assert identifier.hash is not None + assert len(identifier.hash) == 64 # SHA256 hex digest length + + def test_target_identifier_unique_name_auto_computed(self): + """Test that unique_name is auto-computed from class_name and hash.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + ) + + # unique_name format: {snake_case_class_name}::{hash[:8]} + assert identifier.unique_name.startswith("open_ai_chat_target::") + assert len(identifier.unique_name.split("::")[1]) == 8 + assert identifier.unique_name == f"open_ai_chat_target::{identifier.hash[:8]}" + + def test_target_identifier_creation_all_fields(self): + """Test creating a TargetIdentifier with all fields.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + temperature=0.7, + top_p=0.9, + max_requests_per_minute=100, + target_specific_params={"max_tokens": 1000, "headers": {}}, + ) + + assert identifier.endpoint == "https://api.openai.com/v1" + assert identifier.model_name == "gpt-4o" + assert identifier.temperature == 0.7 + assert identifier.top_p == 0.9 + assert identifier.max_requests_per_minute == 100 + assert identifier.target_specific_params["max_tokens"] == 1000 + + +class TestTargetIdentifierSpecificFields: + """Test TargetIdentifier-specific fields: endpoint, model_name, temperature, top_p, target_specific_params.""" + + def test_endpoint_stored_correctly(self): + """Test that endpoint is stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://example.com/api", + ) + + assert identifier.endpoint == "https://example.com/api" + + def test_model_name_stored_correctly(self): + """Test that model_name is stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + model_name="gpt-4o-mini", + ) + + assert identifier.model_name == "gpt-4o-mini" + + def test_temperature_and_top_p_stored_correctly(self): + """Test that temperature and top_p are stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + temperature=0.5, + top_p=0.95, + ) + + assert identifier.temperature == 0.5 + assert identifier.top_p == 0.95 + + def test_target_specific_params_stored_correctly(self): + """Test that target_specific_params are stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + target_specific_params={ + "max_tokens": 2000, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + }, + ) + + assert identifier.target_specific_params["max_tokens"] == 2000 + assert identifier.target_specific_params["frequency_penalty"] == 0.5 + assert identifier.target_specific_params["presence_penalty"] == 0.3 + + def test_max_requests_per_minute_stored_correctly(self): + """Test that max_requests_per_minute is stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + max_requests_per_minute=60, + ) + + assert identifier.max_requests_per_minute == 60 + + +class TestTargetIdentifierHash: + """Test hash computation for TargetIdentifier.""" + + def test_hash_deterministic(self): + """Test that hash is the same for identical configurations.""" + identifier1 = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + model_name="test-model", + ) + identifier2 = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + model_name="test-model", + ) + + assert identifier1.hash == identifier2.hash + assert len(identifier1.hash) == 64 # SHA256 hex digest length + + def test_hash_different_for_different_endpoints(self): + """Test that different endpoints produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(endpoint="https://api1.example.com", **base_args) + identifier2 = TargetIdentifier(endpoint="https://api2.example.com", **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_model_names(self): + """Test that different model names produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(model_name="gpt-4o", **base_args) + identifier2 = TargetIdentifier(model_name="gpt-4o-mini", **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_temperature(self): + """Test that different temperature values produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(temperature=0.7, **base_args) + identifier2 = TargetIdentifier(temperature=0.9, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_top_p(self): + """Test that different top_p values produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(top_p=0.9, **base_args) + identifier2 = TargetIdentifier(top_p=0.95, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_target_specific_params(self): + """Test that different target_specific_params produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(target_specific_params={"max_tokens": 100}, **base_args) + identifier2 = TargetIdentifier(target_specific_params={"max_tokens": 200}, **base_args) + + assert identifier1.hash != identifier2.hash + + +class TestTargetIdentifierToDict: + """Test to_dict method for TargetIdentifier.""" + + def test_to_dict_basic(self): + """Test basic to_dict output.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + result = identifier.to_dict() + + assert result["class_name"] == "TestTarget" + assert result["class_module"] == "pyrit.prompt_target.test_target" + assert result["hash"] == identifier.hash + assert result["unique_name"] == identifier.unique_name + # class_description and identifier_type should be excluded + assert "class_description" not in result + assert "identifier_type" not in result + + def test_to_dict_includes_endpoint_and_model_name(self): + """Test that endpoint and model_name are included in to_dict.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + ) + + result = identifier.to_dict() + + assert result["endpoint"] == "https://api.openai.com/v1" + assert result["model_name"] == "gpt-4o" + + def test_to_dict_includes_temperature_and_top_p_when_set(self): + """Test that temperature and top_p are included when set.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + temperature=0.7, + top_p=0.9, + ) + + result = identifier.to_dict() + + assert result["temperature"] == 0.7 + assert result["top_p"] == 0.9 + + def test_to_dict_excludes_none_values(self): + """Test that None values are excluded from to_dict.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + # temperature, top_p, max_requests_per_minute, target_specific_params are None + ) + + result = identifier.to_dict() + + assert "temperature" not in result + assert "top_p" not in result + assert "max_requests_per_minute" not in result + assert "target_specific_params" not in result + + def test_to_dict_includes_target_specific_params(self): + """Test that target_specific_params are included in to_dict.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + target_specific_params={"max_tokens": 1000, "seed": 42}, + ) + + result = identifier.to_dict() + + assert result["target_specific_params"] == {"max_tokens": 1000, "seed": 42} + + +class TestTargetIdentifierFromDict: + """Test from_dict method for TargetIdentifier.""" + + def test_from_dict_basic(self): + """Test creating TargetIdentifier from a basic dict.""" + data = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + # unique_name is auto-computed + assert identifier.unique_name.startswith("test_target::") + + def test_from_dict_with_all_target_fields(self): + """Test creating TargetIdentifier from dict with all target-specific fields.""" + data = { + "class_name": "OpenAIChatTarget", + "class_module": "pyrit.prompt_target.openai.openai_chat_target", + "class_description": "OpenAI chat target", + "identifier_type": "instance", + "endpoint": "https://api.openai.com/v1", + "model_name": "gpt-4o", + "temperature": 0.7, + "top_p": 0.9, + "max_requests_per_minute": 60, + "target_specific_params": {"max_tokens": 1000}, + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.endpoint == "https://api.openai.com/v1" + assert identifier.model_name == "gpt-4o" + assert identifier.temperature == 0.7 + assert identifier.top_p == 0.9 + assert identifier.max_requests_per_minute == 60 + assert identifier.target_specific_params["max_tokens"] == 1000 + + def test_from_dict_handles_legacy_type_key(self): + """Test that from_dict handles legacy '__type__' key.""" + data = { + "__type__": "TestTarget", # Legacy key + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + + def test_from_dict_handles_deprecated_type_key(self): + """Test that from_dict handles deprecated 'type' key with warning.""" + data = { + "type": "TestTarget", # Deprecated key + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + with pytest.warns(DeprecationWarning, match="'type' key in Identifier dict is deprecated"): + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + + def test_from_dict_ignores_unknown_fields(self): + """Test that from_dict ignores fields not in the dataclass.""" + data = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + "unknown_field": "should be ignored", + "hash": "abc123stored_hash_preserved", + "unique_name": "stored_name_ignored_because_recomputed", + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + # hash is preserved from dict (not recomputed) to handle truncated fields + assert identifier.hash == "abc123stored_hash_preserved" + # unique_name is recomputed from hash + assert identifier.unique_name == f"test_target::{identifier.hash[:8]}" + + def test_from_dict_roundtrip(self): + """Test that to_dict -> from_dict roundtrip works.""" + original = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + temperature=0.7, + target_specific_params={"max_tokens": 500}, + ) + + storage_dict = original.to_dict() + # Add back the excluded fields for reconstruction + storage_dict["class_description"] = "OpenAI chat target" + storage_dict["identifier_type"] = "instance" + + reconstructed = TargetIdentifier.from_dict(storage_dict) + + assert reconstructed.class_name == original.class_name + assert reconstructed.endpoint == original.endpoint + assert reconstructed.model_name == original.model_name + assert reconstructed.temperature == original.temperature + assert reconstructed.target_specific_params == original.target_specific_params + # Hash should match since config is the same + assert reconstructed.hash == original.hash + + def test_from_dict_provides_defaults_for_missing_fields(self): + """Test that from_dict provides defaults for missing optional fields.""" + data = { + "class_name": "LegacyTarget", + "class_module": "pyrit.prompt_target.legacy", + "class_description": "A legacy target", + "identifier_type": "instance", + # Missing endpoint, model_name, temperature, top_p, etc. + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.endpoint == "" + assert identifier.model_name == "" + assert identifier.temperature is None + assert identifier.top_p is None + assert identifier.max_requests_per_minute is None + assert identifier.target_specific_params is None + + +class TestTargetIdentifierFrozen: + """Test that TargetIdentifier is immutable (frozen).""" + + def test_cannot_modify_fields(self): + """Test that attempting to modify fields raises an error.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + ) + + with pytest.raises(AttributeError): + identifier.class_name = "ModifiedTarget" + + with pytest.raises(AttributeError): + identifier.endpoint = "https://modified.example.com" + + with pytest.raises(AttributeError): + identifier.temperature = 0.5 + + def test_can_use_as_dict_key(self): + """Test that frozen identifier can be used as dict key (hashable).""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + # Should be hashable and usable as dict key + d = {identifier: "value"} + assert d[identifier] == "value" + + +class TestTargetIdentifierNormalize: + """Test the normalize class method for TargetIdentifier.""" + + def test_normalize_returns_target_identifier_unchanged(self): + """Test that normalize returns a TargetIdentifier unchanged.""" + original = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + ) + + result = TargetIdentifier.normalize(original) + + assert result is original + assert result.endpoint == "https://api.example.com" + + def test_normalize_converts_dict_to_target_identifier(self): + """Test that normalize converts a dict to TargetIdentifier with deprecation warning.""" + data = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + "endpoint": "https://api.example.com", + "model_name": "gpt-4o", + } + + with pytest.warns(DeprecationWarning, match="dict for TargetIdentifier is deprecated"): + result = TargetIdentifier.normalize(data) + + assert isinstance(result, TargetIdentifier) + assert result.class_name == "TestTarget" + assert result.endpoint == "https://api.example.com" + assert result.model_name == "gpt-4o" + + def test_normalize_raises_for_invalid_type(self): + """Test that normalize raises TypeError for invalid input types.""" + with pytest.raises(TypeError, match="Expected TargetIdentifier or dict"): + TargetIdentifier.normalize("invalid") + + with pytest.raises(TypeError, match="Expected TargetIdentifier or dict"): + TargetIdentifier.normalize(123) + + with pytest.raises(TypeError, match="Expected TargetIdentifier or dict"): + TargetIdentifier.normalize(["list", "of", "values"]) From f12377046b2674f2068ed0d4426f6aea47d05716 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 16:12:37 -0800 Subject: [PATCH 04/10] deprecation 0.14.0 --- pyrit/identifiers/identifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/identifiers/identifier.py b/pyrit/identifiers/identifier.py index 3c90d2c89..3d698e717 100644 --- a/pyrit/identifiers/identifier.py +++ b/pyrit/identifiers/identifier.py @@ -208,7 +208,7 @@ def normalize(cls: Type[T], value: T | dict[str, Any]) -> T: print_deprecation_message( old_item=f"dict for {cls.__name__}", new_item=cls.__name__, - removed_in="0.13.0", + removed_in="0.14.0", ) return cls.from_dict(value) From 08719aef2d05fc739f489016144dad2e8c228331 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 16:20:40 -0800 Subject: [PATCH 05/10] use normalize method --- pyrit/memory/memory_models.py | 2 +- pyrit/models/message_piece.py | 17 ++++------------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 7bdec26a2..ad8c69e66 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -210,7 +210,7 @@ def __init__(self, *, entry: MessagePiece): self.converter_identifiers = [conv.to_dict() for conv in entry.converter_identifiers] # Normalize prompt_target_identifier and convert to dict for JSON serialization self.prompt_target_identifier = ( - TargetIdentifier.normalize(entry.prompt_target_identifier).to_dict() + entry.prompt_target_identifier.to_dict() if entry.prompt_target_identifier else {} ) diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 5daa325bc..cccca68d0 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -106,19 +106,10 @@ def __init__( self.labels = labels or {} self.prompt_metadata = prompt_metadata or {} - # Handle converter_identifiers: convert dicts to ConverterIdentifier with deprecation warning - self.converter_identifiers: List[ConverterIdentifier] = [] - if converter_identifiers: - for conv_id in converter_identifiers: - if isinstance(conv_id, dict): - print_deprecation_message( - old_item="dict for converter_identifiers", - new_item="ConverterIdentifier", - removed_in="0.14.0", - ) - self.converter_identifiers.append(ConverterIdentifier.from_dict(conv_id)) - else: - self.converter_identifiers.append(conv_id) + # Handle converter_identifiers: normalize to ConverterIdentifier (handles dict with deprecation warning) + self.converter_identifiers: List[ConverterIdentifier] = [ + ConverterIdentifier.normalize(conv_id) for conv_id in converter_identifiers + ] if converter_identifiers else [] # Handle prompt_target_identifier: normalize to TargetIdentifier (handles dict with deprecation warning) self.prompt_target_identifier: Optional[TargetIdentifier] = ( From 6bc02c6e7d6319116a30eff0094701f38c126ec4 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 16:26:29 -0800 Subject: [PATCH 06/10] scorer_registry change --- pyrit/registry/instance_registries/scorer_registry.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 7585c4378..d9fa4f6c5 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -65,10 +65,7 @@ def register_instance( (e.g., SelfAskRefusalScorer -> self_ask_refusal_abc123). """ if name is None: - base_name = class_name_to_snake_case(scorer.__class__.__name__, suffix="Scorer") - # Append identifier hash if available for uniqueness - identifier_hash = scorer.get_identifier().hash[:8] - name = f"{base_name}_{identifier_hash}" + name = scorer.get_identifier().unique_name self.register(scorer, name=name) logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") From ad464624539f5ce08bed8f1026b893a505c8138f Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 16:34:20 -0800 Subject: [PATCH 07/10] fix pre-commit and import --- pyrit/memory/memory_models.py | 4 +--- pyrit/models/message_piece.py | 8 +++++--- pyrit/prompt_target/common/prompt_target.py | 2 +- pyrit/prompt_target/websocket_copilot_target.py | 2 +- pyrit/registry/instance_registries/scorer_registry.py | 1 - 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index ad8c69e66..dcec433ed 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -210,9 +210,7 @@ def __init__(self, *, entry: MessagePiece): self.converter_identifiers = [conv.to_dict() for conv in entry.converter_identifiers] # Normalize prompt_target_identifier and convert to dict for JSON serialization self.prompt_target_identifier = ( - entry.prompt_target_identifier.to_dict() - if entry.prompt_target_identifier - else {} + entry.prompt_target_identifier.to_dict() if entry.prompt_target_identifier else {} ) self.attack_identifier = entry.attack_identifier diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index cccca68d0..460529475 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -107,9 +107,11 @@ def __init__( self.prompt_metadata = prompt_metadata or {} # Handle converter_identifiers: normalize to ConverterIdentifier (handles dict with deprecation warning) - self.converter_identifiers: List[ConverterIdentifier] = [ - ConverterIdentifier.normalize(conv_id) for conv_id in converter_identifiers - ] if converter_identifiers else [] + self.converter_identifiers: List[ConverterIdentifier] = ( + [ConverterIdentifier.normalize(conv_id) for conv_id in converter_identifiers] + if converter_identifiers + else [] + ) # Handle prompt_target_identifier: normalize to TargetIdentifier (handles dict with deprecation warning) self.prompt_target_identifier: Optional[TargetIdentifier] = ( diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index f82d37733..8cd80f47d 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -111,7 +111,7 @@ def _create_identifier( top_p (Optional[float]): The top_p parameter for generation. Defaults to None. target_specific_params (Optional[dict[str, Any]]): Additional target-specific parameters that should be included in the identifier. Defaults to None. - + Returns: TargetIdentifier: The identifier for this prompt target. """ diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index da9f16432..4cabc58c3 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -16,7 +16,7 @@ EmptyResponseException, pyrit_target_retry, ) -from pyrit.identifiers.target_identifier import TargetIdentifier +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index d9fa4f6c5..9b5e5f59f 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Optional from pyrit.identifiers import ScorerIdentifier -from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) From b8b39114835c62ffafe6e19b3e64f4e241ce5adc Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 28 Jan 2026 16:46:29 -0800 Subject: [PATCH 08/10] fix converter unit tests --- pyrit/identifiers/target_identifier.py | 8 -------- pyrit/prompt_converter/prompt_converter.py | 12 +++++++----- tests/unit/models/test_message_piece.py | 2 +- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index 700110a39..f08ad8709 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -17,14 +17,6 @@ class TargetIdentifier(Identifier): This frozen dataclass extends Identifier with target-specific fields. It provides a stable, hashable identifier for prompt targets that can be used for scorer evaluation, registry tracking, and memory storage. - - Attributes: - endpoint: The target endpoint URL, if applicable. - model_name: The model or deployment name. Uses underlying_model if specified, - otherwise falls back to the deployment name. - temperature: The temperature parameter for generation, if applicable. - top_p: The top_p parameter for generation, if applicable. - target_specific_params: Additional target-specific parameters. """ endpoint: str = "" diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index f7e0a6a2c..756d3972e 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -209,11 +209,13 @@ def _create_identifier( target_info: Optional[Dict[str, Any]] = None if converter_target: target_id = converter_target.get_identifier() - # Extract standard fields for converter identification - target_info = {} - for key in ["__type__", "model_name", "temperature", "top_p"]: - if key in target_id: - target_info[key] = target_id[key] + # Extract standard fields for converter identification using attribute access + target_info = { + "class_name": target_id.class_name, + "model_name": target_id.model_name, + "temperature": target_id.temperature, + "top_p": target_id.top_p, + } return ConverterIdentifier( class_name=self.__class__.__name__, diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index c4bed0817..78ecb6008 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -781,7 +781,7 @@ def test_message_piece_scorer_identifier_dict_backward_compatibility(): # Check that a deprecation warning was issued assert len(w) == 1 assert "deprecated" in str(w[0].message).lower() - assert "0.13.0" in str(w[0].message) + assert "0.14.0" in str(w[0].message) # Check that scorer_identifier is now a ScorerIdentifier assert isinstance(entry.scorer_identifier, ScorerIdentifier) From 30fc2a1d518b0f420eba12e0170776148352e0b7 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Thu, 29 Jan 2026 15:35:06 -0800 Subject: [PATCH 09/10] unique_name --- pyrit/identifiers/target_identifier.py | 32 ++++++++++++ pyrit/memory/memory_models.py | 8 ++- .../identifiers/test_target_identifier.py | 52 ++++++++++++++++--- 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index f08ad8709..753c8b51b 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Type, cast +from urllib.parse import urlparse from pyrit.identifiers.identifier import Identifier @@ -37,6 +38,37 @@ class TargetIdentifier(Identifier): target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" + def __post_init__(self) -> None: + """ + Compute derived fields with target-specific unique_name format. + + Overrides the base Identifier to include model_name and endpoint in unique_name. + Format: {snake_name}::{model_name}::{endpoint_host}::{hash[:8]} + Only includes model_name and endpoint if they have values. + """ + # Call parent to set up snake_class_name and hash + super().__post_init__() + + # Build unique_name with model_name and endpoint if available + parts = [self.snake_class_name] + + if self.model_name: + parts.append(self.model_name) + + if self.endpoint: + # Simplify endpoint to just the host for readability + try: + parsed = urlparse(self.endpoint) + host = parsed.netloc or self.endpoint + parts.append(host) + except Exception: + # Fallback: truncate if parsing fails + parts.append(self.endpoint[:20] if len(self.endpoint) > 20 else self.endpoint) + + parts.append(self.hash[:8]) + + object.__setattr__(self, "unique_name", "::".join(parts)) + @classmethod def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIdentifier": """ diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index a66c2b9e7..bc8b9f5f3 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -250,6 +250,12 @@ def get_message_piece(self) -> MessagePiece: ConverterIdentifier.from_dict({**c, "pyrit_version": stored_version}) for c in self.converter_identifiers ] + + # Reconstruct TargetIdentifier with the stored pyrit_version + target_id: Optional[TargetIdentifier] = None + if self.prompt_target_identifier: + target_id = TargetIdentifier.from_dict({**self.prompt_target_identifier, "pyrit_version": stored_version}) + message_piece = MessagePiece( role=self.role, original_value=self.original_value, @@ -263,7 +269,7 @@ def get_message_piece(self) -> MessagePiece: prompt_metadata=self.prompt_metadata, targeted_harm_categories=self.targeted_harm_categories, converter_identifiers=converter_ids, - prompt_target_identifier=self.prompt_target_identifier, + prompt_target_identifier=target_id, attack_identifier=self.attack_identifier, original_value_data_type=self.original_value_data_type, converted_value_data_type=self.converted_value_data_type, diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py index 3621c1dcb..93b6f9629 100644 --- a/tests/unit/identifiers/test_target_identifier.py +++ b/tests/unit/identifiers/test_target_identifier.py @@ -37,8 +37,8 @@ def test_target_identifier_creation_minimal(self): assert identifier.hash is not None assert len(identifier.hash) == 64 # SHA256 hex digest length - def test_target_identifier_unique_name_auto_computed(self): - """Test that unique_name is auto-computed from class_name and hash.""" + def test_target_identifier_unique_name_minimal(self): + """Test that unique_name is auto-computed with minimal fields (no model_name or endpoint).""" identifier = TargetIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target.openai.openai_chat_target", @@ -46,11 +46,49 @@ def test_target_identifier_unique_name_auto_computed(self): identifier_type="instance", ) - # unique_name format: {snake_case_class_name}::{hash[:8]} - assert identifier.unique_name.startswith("open_ai_chat_target::") - assert len(identifier.unique_name.split("::")[1]) == 8 + # unique_name format with no model/endpoint: {snake_class_name}::{hash[:8]} assert identifier.unique_name == f"open_ai_chat_target::{identifier.hash[:8]}" + def test_target_identifier_unique_name_with_model(self): + """Test that unique_name includes model_name when provided.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + model_name="gpt-4o", + ) + + # unique_name format: {snake_class_name}::{model_name}::{hash[:8]} + assert identifier.unique_name == f"open_ai_chat_target::gpt-4o::{identifier.hash[:8]}" + + def test_target_identifier_unique_name_with_endpoint(self): + """Test that unique_name includes endpoint host when provided.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1/chat/completions", + ) + + # unique_name format: {snake_class_name}::{endpoint_host}::{hash[:8]} + assert identifier.unique_name == f"open_ai_chat_target::api.openai.com::{identifier.hash[:8]}" + + def test_target_identifier_unique_name_with_model_and_endpoint(self): + """Test that unique_name includes both model_name and endpoint when provided.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + model_name="gpt-4o", + endpoint="https://api.openai.com/v1/chat/completions", + ) + + # unique_name format: {snake_class_name}::{model_name}::{endpoint_host}::{hash[:8]} + assert identifier.unique_name == f"open_ai_chat_target::gpt-4o::api.openai.com::{identifier.hash[:8]}" + def test_target_identifier_creation_all_fields(self): """Test creating a TargetIdentifier with all fields.""" identifier = TargetIdentifier( @@ -259,10 +297,10 @@ def test_to_dict_basic(self): assert result["class_name"] == "TestTarget" assert result["class_module"] == "pyrit.prompt_target.test_target" assert result["hash"] == identifier.hash - assert result["unique_name"] == identifier.unique_name - # class_description and identifier_type should be excluded + # class_description, identifier_type, and unique_name are excluded from storage assert "class_description" not in result assert "identifier_type" not in result + assert "unique_name" not in result def test_to_dict_includes_endpoint_and_model_name(self): """Test that endpoint and model_name are included in to_dict.""" From 3e5f5917a81cab78c216a93fb6964f673b5d6786 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Fri, 30 Jan 2026 10:21:33 -0800 Subject: [PATCH 10/10] revert unique name change --- pyrit/identifiers/target_identifier.py | 32 ------------- .../identifiers/test_target_identifier.py | 47 ++----------------- 2 files changed, 4 insertions(+), 75 deletions(-) diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index 753c8b51b..f08ad8709 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Type, cast -from urllib.parse import urlparse from pyrit.identifiers.identifier import Identifier @@ -38,37 +37,6 @@ class TargetIdentifier(Identifier): target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" - def __post_init__(self) -> None: - """ - Compute derived fields with target-specific unique_name format. - - Overrides the base Identifier to include model_name and endpoint in unique_name. - Format: {snake_name}::{model_name}::{endpoint_host}::{hash[:8]} - Only includes model_name and endpoint if they have values. - """ - # Call parent to set up snake_class_name and hash - super().__post_init__() - - # Build unique_name with model_name and endpoint if available - parts = [self.snake_class_name] - - if self.model_name: - parts.append(self.model_name) - - if self.endpoint: - # Simplify endpoint to just the host for readability - try: - parsed = urlparse(self.endpoint) - host = parsed.netloc or self.endpoint - parts.append(host) - except Exception: - # Fallback: truncate if parsing fails - parts.append(self.endpoint[:20] if len(self.endpoint) > 20 else self.endpoint) - - parts.append(self.hash[:8]) - - object.__setattr__(self, "unique_name", "::".join(parts)) - @classmethod def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIdentifier": """ diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py index 93b6f9629..0541b36be 100644 --- a/tests/unit/identifiers/test_target_identifier.py +++ b/tests/unit/identifiers/test_target_identifier.py @@ -38,7 +38,7 @@ def test_target_identifier_creation_minimal(self): assert len(identifier.hash) == 64 # SHA256 hex digest length def test_target_identifier_unique_name_minimal(self): - """Test that unique_name is auto-computed with minimal fields (no model_name or endpoint).""" + """Test that unique_name is auto-computed with minimal fields.""" identifier = TargetIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target.openai.openai_chat_target", @@ -46,48 +46,9 @@ def test_target_identifier_unique_name_minimal(self): identifier_type="instance", ) - # unique_name format with no model/endpoint: {snake_class_name}::{hash[:8]} - assert identifier.unique_name == f"open_ai_chat_target::{identifier.hash[:8]}" - - def test_target_identifier_unique_name_with_model(self): - """Test that unique_name includes model_name when provided.""" - identifier = TargetIdentifier( - class_name="OpenAIChatTarget", - class_module="pyrit.prompt_target.openai.openai_chat_target", - class_description="OpenAI chat target", - identifier_type="instance", - model_name="gpt-4o", - ) - - # unique_name format: {snake_class_name}::{model_name}::{hash[:8]} - assert identifier.unique_name == f"open_ai_chat_target::gpt-4o::{identifier.hash[:8]}" - - def test_target_identifier_unique_name_with_endpoint(self): - """Test that unique_name includes endpoint host when provided.""" - identifier = TargetIdentifier( - class_name="OpenAIChatTarget", - class_module="pyrit.prompt_target.openai.openai_chat_target", - class_description="OpenAI chat target", - identifier_type="instance", - endpoint="https://api.openai.com/v1/chat/completions", - ) - - # unique_name format: {snake_class_name}::{endpoint_host}::{hash[:8]} - assert identifier.unique_name == f"open_ai_chat_target::api.openai.com::{identifier.hash[:8]}" - - def test_target_identifier_unique_name_with_model_and_endpoint(self): - """Test that unique_name includes both model_name and endpoint when provided.""" - identifier = TargetIdentifier( - class_name="OpenAIChatTarget", - class_module="pyrit.prompt_target.openai.openai_chat_target", - class_description="OpenAI chat target", - identifier_type="instance", - model_name="gpt-4o", - endpoint="https://api.openai.com/v1/chat/completions", - ) - - # unique_name format: {snake_class_name}::{model_name}::{endpoint_host}::{hash[:8]} - assert identifier.unique_name == f"open_ai_chat_target::gpt-4o::api.openai.com::{identifier.hash[:8]}" + # unique_name format: {snake_case_class_name}::{hash[:8]} + assert identifier.unique_name.startswith("open_ai_chat_target::") + assert len(identifier.unique_name.split("::")[1]) == 8 def test_target_identifier_creation_all_fields(self): """Test creating a TargetIdentifier with all fields."""