diff --git a/doc/api.rst b/doc/api.rst index 8fbec2512..c774deca0 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -279,6 +279,7 @@ API Reference IdentifierType ScorerIdentifier snake_case_to_class_name + TargetIdentifier :py:mod:`pyrit.memory` ====================== diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 7a274053c..b005e2121 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -586,7 +586,7 @@ async def _send_prompt_to_objective_target_async( Raises: ValueError: If no response is received from the objective target. """ - objective_target_type = self._objective_target.get_identifier()["__type__"] + objective_target_type = self._objective_target.get_identifier().class_name # Send the generated prompt to the objective target prompt_preview = attack_message.get_value()[:100] if attack_message.get_value() else "" diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index a6e0d759b..30c501894 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -14,6 +14,7 @@ IdentifierType, ) from pyrit.identifiers.scorer_identifier import ScorerIdentifier +from pyrit.identifiers.target_identifier import TargetIdentifier __all__ = [ "class_name_to_snake_case", @@ -25,4 +26,5 @@ "LegacyIdentifiable", "ScorerIdentifier", "snake_case_to_class_name", + "TargetIdentifier", ] diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py new file mode 100644 index 000000000..f08ad8709 --- /dev/null +++ b/pyrit/identifiers/target_identifier.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Type, cast + +from pyrit.identifiers.identifier import Identifier + + +@dataclass(frozen=True) +class TargetIdentifier(Identifier): + """ + Identifier for PromptTarget instances. + + This frozen dataclass extends Identifier with target-specific fields. + It provides a stable, hashable identifier for prompt targets that can be + used for scorer evaluation, registry tracking, and memory storage. + """ + + endpoint: str = "" + """The target endpoint URL.""" + + model_name: str = "" + """The model or deployment name.""" + + temperature: Optional[float] = None + """The temperature parameter for generation.""" + + top_p: Optional[float] = None + """The top_p parameter for generation.""" + + max_requests_per_minute: Optional[int] = None + """Maximum number of requests per minute.""" + + target_specific_params: Optional[Dict[str, Any]] = None + """Additional target-specific parameters.""" + + @classmethod + def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIdentifier": + """ + Create a TargetIdentifier from a dictionary (e.g., retrieved from database). + + Extends the base Identifier.from_dict() to handle legacy key mappings. + + Args: + data: The dictionary representation. + + Returns: + TargetIdentifier: A new TargetIdentifier instance. + """ + # Delegate to parent class for standard processing + result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] + return cast(TargetIdentifier, result) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 3a1191af5..bc8b9f5f3 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -32,7 +32,7 @@ import pyrit from pyrit.common.utils import to_sha256 -from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier +from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -216,7 +216,10 @@ def __init__(self, *, entry: MessagePiece): self.prompt_metadata = entry.prompt_metadata self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = [conv.to_dict() for conv in entry.converter_identifiers] - self.prompt_target_identifier = entry.prompt_target_identifier + # Normalize prompt_target_identifier and convert to dict for JSON serialization + self.prompt_target_identifier = ( + entry.prompt_target_identifier.to_dict() if entry.prompt_target_identifier else {} + ) self.attack_identifier = entry.attack_identifier self.original_value = entry.original_value @@ -247,6 +250,12 @@ def get_message_piece(self) -> MessagePiece: ConverterIdentifier.from_dict({**c, "pyrit_version": stored_version}) for c in self.converter_identifiers ] + + # Reconstruct TargetIdentifier with the stored pyrit_version + target_id: Optional[TargetIdentifier] = None + if self.prompt_target_identifier: + target_id = TargetIdentifier.from_dict({**self.prompt_target_identifier, "pyrit_version": stored_version}) + message_piece = MessagePiece( role=self.role, original_value=self.original_value, @@ -260,7 +269,7 @@ def get_message_piece(self) -> MessagePiece: prompt_metadata=self.prompt_metadata, targeted_harm_categories=self.targeted_harm_categories, converter_identifiers=converter_ids, - prompt_target_identifier=self.prompt_target_identifier, + prompt_target_identifier=target_id, attack_identifier=self.attack_identifier, original_value_data_type=self.original_value_data_type, converted_value_data_type=self.converted_value_data_type, @@ -279,7 +288,11 @@ def __str__(self) -> str: str: Formatted string representation of the memory entry. """ if self.prompt_target_identifier: - return f"{self.prompt_target_identifier['__type__']}: {self.role}: {self.converted_value}" + # prompt_target_identifier is stored as dict in the database + class_name = self.prompt_target_identifier.get("class_name") or self.prompt_target_identifier.get( + "__type__", "Unknown" + ) + return f"{class_name}: {self.role}: {self.converted_value}" return f": {self.role}: {self.converted_value}" @@ -902,7 +915,8 @@ def __init__(self, *, entry: ScenarioResult): self.scenario_version = entry.scenario_identifier.version self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data - self.objective_target_identifier = entry.objective_target_identifier + # Convert TargetIdentifier to dict for JSON storage + self.objective_target_identifier = entry.objective_target_identifier.to_dict() # Convert ScorerIdentifier to dict for JSON storage self.objective_scorer_identifier = ( entry.objective_scorer_identifier.to_dict() if entry.objective_scorer_identifier else None @@ -952,10 +966,13 @@ def get_scenario_result(self) -> ScenarioResult: {**self.objective_scorer_identifier, "pyrit_version": stored_version} ) + # Convert dict back to TargetIdentifier for reconstruction + target_identifier = TargetIdentifier.from_dict(self.objective_target_identifier) + return ScenarioResult( id=self.id, scenario_identifier=scenario_identifier, - objective_target_identifier=self.objective_target_identifier, + objective_target_identifier=target_identifier, attack_results=attack_results, objective_scorer_identifier=scorer_identifier, scenario_run_state=self.scenario_run_state, diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index af16fe612..460529475 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,11 +5,10 @@ import uuid from datetime import datetime -from typing import Dict, List, Literal, Optional, Union, get_args +from typing import Any, Dict, List, Literal, Optional, Union, get_args from uuid import uuid4 -from pyrit.common.deprecation import print_deprecation_message -from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier +from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.score import Score @@ -39,7 +38,7 @@ def __init__( labels: Optional[Dict[str, str]] = None, prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, converter_identifiers: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = None, - prompt_target_identifier: Optional[Dict[str, str]] = None, + prompt_target_identifier: Optional[Union[TargetIdentifier, Dict[str, Any]]] = None, attack_identifier: Optional[Dict[str, str]] = None, scorer_identifier: Optional[Union[ScorerIdentifier, Dict[str, str]]] = None, original_value_data_type: PromptDataType = "text", @@ -107,35 +106,24 @@ def __init__( self.labels = labels or {} self.prompt_metadata = prompt_metadata or {} - # Handle converter_identifiers: convert dicts to ConverterIdentifier with deprecation warning - self.converter_identifiers: List[ConverterIdentifier] = [] - if converter_identifiers: - for conv_id in converter_identifiers: - if isinstance(conv_id, dict): - print_deprecation_message( - old_item="dict for converter_identifiers", - new_item="ConverterIdentifier", - removed_in="0.14.0", - ) - self.converter_identifiers.append(ConverterIdentifier.from_dict(conv_id)) - else: - self.converter_identifiers.append(conv_id) - - self.prompt_target_identifier = prompt_target_identifier or {} + # Handle converter_identifiers: normalize to ConverterIdentifier (handles dict with deprecation warning) + self.converter_identifiers: List[ConverterIdentifier] = ( + [ConverterIdentifier.normalize(conv_id) for conv_id in converter_identifiers] + if converter_identifiers + else [] + ) + + # Handle prompt_target_identifier: normalize to TargetIdentifier (handles dict with deprecation warning) + self.prompt_target_identifier: Optional[TargetIdentifier] = ( + TargetIdentifier.normalize(prompt_target_identifier) if prompt_target_identifier else None + ) + self.attack_identifier = attack_identifier or {} - # Handle scorer_identifier: convert dict to ScorerIdentifier with deprecation warning - if scorer_identifier is None: - self.scorer_identifier: Optional[ScorerIdentifier] = None - elif isinstance(scorer_identifier, dict): - print_deprecation_message( - old_item="dict for scorer_identifier", - new_item="ScorerIdentifier", - removed_in="0.13.0", - ) - self.scorer_identifier = ScorerIdentifier.from_dict(scorer_identifier) - else: - self.scorer_identifier = scorer_identifier + # Handle scorer_identifier: normalize to ScorerIdentifier (handles dict with deprecation warning) + self.scorer_identifier: Optional[ScorerIdentifier] = ( + ScorerIdentifier.normalize(scorer_identifier) if scorer_identifier else None + ) self.original_value = original_value @@ -292,7 +280,9 @@ def to_dict(self) -> dict[str, object]: "targeted_harm_categories": self.targeted_harm_categories if self.targeted_harm_categories else None, "prompt_metadata": self.prompt_metadata, "converter_identifiers": [conv.to_dict() for conv in self.converter_identifiers], - "prompt_target_identifier": self.prompt_target_identifier, + "prompt_target_identifier": ( + self.prompt_target_identifier.to_dict() if self.prompt_target_identifier else None + ), "attack_identifier": self.attack_identifier, "scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None, "original_value_data_type": self.original_value_data_type, @@ -308,7 +298,8 @@ def to_dict(self) -> dict[str, object]: } def __str__(self) -> str: - return f"{self.prompt_target_identifier}: {self._role}: {self.converted_value}" + target_str = self.prompt_target_identifier.class_name if self.prompt_target_identifier else "Unknown" + return f"{target_str}: {self._role}: {self.converted_value}" __repr__ = __str__ diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index a052ecaec..ab0ba0e5e 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -10,7 +10,7 @@ from pyrit.models import AttackOutcome, AttackResult if TYPE_CHECKING: - from pyrit.identifiers import ScorerIdentifier + from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.score import Scorer from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics @@ -59,7 +59,7 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: dict[str, str], + objective_target_identifier: Union[Dict[str, Any], "TargetIdentifier"], attack_results: dict[str, List[AttackResult]], objective_scorer_identifier: Union[Dict[str, Any], "ScorerIdentifier"], scenario_run_state: ScenarioRunState = "CREATED", @@ -71,11 +71,13 @@ def __init__( objective_scorer: Optional["Scorer"] = None, ) -> None: from pyrit.common import print_deprecation_message - from pyrit.identifiers import ScorerIdentifier + from pyrit.identifiers import ScorerIdentifier, TargetIdentifier self.id = id if id is not None else uuid.uuid4() self.scenario_identifier = scenario_identifier - self.objective_target_identifier = objective_target_identifier + + # Normalize objective_target_identifier to TargetIdentifier + self.objective_target_identifier = TargetIdentifier.normalize(objective_target_identifier) # Handle deprecated objective_scorer parameter if objective_scorer is not None: diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index f7e0a6a2c..756d3972e 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -209,11 +209,13 @@ def _create_identifier( target_info: Optional[Dict[str, Any]] = None if converter_target: target_id = converter_target.get_identifier() - # Extract standard fields for converter identification - target_info = {} - for key in ["__type__", "model_name", "temperature", "top_p"]: - if key in target_id: - target_info[key] = target_id[key] + # Extract standard fields for converter identification using attribute access + target_info = { + "class_name": target_id.class_name, + "model_name": target_id.model_name, + "temperature": target_id.temperature, + "top_p": target_id.top_p, + } return ConverterIdentifier( class_name=self.__class__.__name__, diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index edbb0419b..91b1f21da 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -12,6 +12,7 @@ from pyrit.auth import AzureStorageAuth from pyrit.common import default_values +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -79,6 +80,20 @@ def __init__( super().__init__(endpoint=self._container_url, max_requests_per_minute=max_requests_per_minute) + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Azure Blob Storage-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "container_url": self._container_url, + "blob_content_type": self._blob_content_type, + }, + ) + async def _create_container_client_async(self) -> None: """ Create an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index 7734acdd0..bc0ba056d 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -13,6 +13,7 @@ handle_bad_request_exception, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.message_normalizer import ChatMessageNormalizer, MessageListNormalizer from pyrit.models import ( Message, @@ -103,6 +104,23 @@ def __init__( self._repetition_penalty = repetition_penalty self._extra_parameters = param_kwargs + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Azure ML-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_new_tokens": self._max_new_tokens, + "repetition_penalty": self._repetition_penalty, + "message_normalizer": self.message_normalizer.__class__.__name__, + }, + ) + def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str] = None) -> None: """ Set the endpoint and key for accessing the Azure ML model. Use this function to manually diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index fbfa3f93f..5eac0209f 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -121,7 +121,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) if config.enabled and not self.is_json_response_supported(): - target_name = self.get_identifier()["__type__"] + target_name = self.get_identifier().class_name raise ValueError(f"This target {target_name} does not support JSON response format.") return config diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index cbc0cc079..8cd80f47d 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -3,16 +3,16 @@ import abc import logging -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional -from pyrit.identifiers import LegacyIdentifiable +from pyrit.identifiers import Identifiable, TargetIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import Message logger = logging.getLogger(__name__) -class PromptTarget(LegacyIdentifiable): +class PromptTarget(Identifiable[TargetIdentifier]): """ Abstract base class for prompt targets. @@ -26,6 +26,8 @@ class PromptTarget(LegacyIdentifiable): #: An empty list implies that the prompt target supports all converters. supported_converters: List[Any] + _identifier: Optional[TargetIdentifier] = None + def __init__( self, verbose: bool = False, @@ -91,36 +93,59 @@ def dispose_db_engine(self) -> None: """ self._memory.dispose_engine() - def get_identifier(self) -> Dict[str, Any]: + def _create_identifier( + self, + *, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + target_specific_params: Optional[dict[str, Any]] = None, + ) -> TargetIdentifier: """ - Get an identifier dictionary for this prompt target. + Construct the target identifier. - This includes essential attributes needed for scorer evaluation and registry tracking. - Subclasses should override this method to include additional relevant attributes - (e.g., temperature, top_p) when available. + Subclasses should call this method in their _build_identifier() implementation + to set the identifier with their specific parameters. - Returns: - Dict[str, Any]: A dictionary containing identification attributes. + Args: + temperature (Optional[float]): The temperature parameter for generation. Defaults to None. + top_p (Optional[float]): The top_p parameter for generation. Defaults to None. + target_specific_params (Optional[dict[str, Any]]): Additional target-specific parameters + that should be included in the identifier. Defaults to None. - Note: - If the `self._underlying_model` is specified, either passed in during instantiation - or via environment variable, it is used as the "model_name" for the identifier. - Otherwise, `self._model_name` (which is often the deployment name in Azure) is used. + Returns: + TargetIdentifier: The identifier for this prompt target. """ - public_attributes: Dict[str, Any] = {} - public_attributes["__type__"] = self.__class__.__name__ - public_attributes["__module__"] = self.__class__.__module__ - if self._endpoint: - public_attributes["endpoint"] = self._endpoint - # if the underlying model is specified, use it as the model name for identification - # otherwise, use self._model_name (which is often the deployment name in Azure) + # Determine the model name to use + model_name = "" if self._underlying_model: - public_attributes["model_name"] = self._underlying_model + model_name = self._underlying_model elif self._model_name: - public_attributes["model_name"] = self._model_name - # Include temperature and top_p if available (set by subclasses) - if hasattr(self, "_temperature") and self._temperature is not None: - public_attributes["temperature"] = self._temperature - if hasattr(self, "_top_p") and self._top_p is not None: - public_attributes["top_p"] = self._top_p - return public_attributes + model_name = self._model_name + + return TargetIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", + identifier_type="instance", + endpoint=self._endpoint, + model_name=model_name, + temperature=temperature, + top_p=top_p, + max_requests_per_minute=self._max_requests_per_minute, + target_specific_params=target_specific_params, + ) + + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier for this target. + + Subclasses can override this method to call _create_identifier() with + their specific parameters (temperature, top_p, target_specific_params). + + The base implementation calls _create_identifier() with no parameters, + which works for targets that don't have model-specific settings. + + Returns: + TargetIdentifier: The identifier for this prompt target. + """ + return self._create_identifier() diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 4819b6de9..e823926f3 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -7,6 +7,7 @@ from typing import Optional from pyrit.common import net_utility +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -57,6 +58,19 @@ def __init__( self._defender = level.value + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Gandalf-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "level": self._defender, + }, + ) + @limit_requests_per_minute async def send_prompt_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 419f22633..17acfe881 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -9,6 +9,7 @@ import httpx +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -82,6 +83,21 @@ def __init__( if client and httpx_client_kwargs: raise ValueError("Cannot provide both a pre-configured client and additional httpx client kwargs.") + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with HTTP target-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "use_tls": self.use_tls, + "prompt_regex_string": self.prompt_regex_string, + "callback_function": getattr(self.callback_function, "__name__", None), + }, + ) + @classmethod def with_client( cls, diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index c38f5858b..5fe1104a8 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -16,6 +16,7 @@ from pyrit.common import default_values from pyrit.common.download_hf_model import download_specific_files from pyrit.exceptions import EmptyResponseException, pyrit_target_retry +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -135,6 +136,28 @@ def __init__( self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer()) + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with HuggingFace chat-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_new_tokens": self.max_new_tokens, + "skip_special_tokens": self.skip_special_tokens, + "use_cuda": self.use_cuda, + "tensor_format": self.tensor_format, + "trust_remote_code": self.trust_remote_code, + "device_map": self.device_map, + "torch_dtype": str(self.torch_dtype) if self.torch_dtype else None, + "attn_implementation": self.attn_implementation, + }, + ) + def _load_from_path(self, path: str, **kwargs: Any) -> None: """ Load the model and tokenizer from a given path. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index ebaf10c3e..2407d6918 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -5,6 +5,7 @@ from typing import Optional from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p @@ -61,6 +62,21 @@ def __init__( self._temperature = temperature self._top_p = top_p + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with HuggingFace endpoint-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_tokens": self.max_tokens, + }, + ) + @limit_requests_per_minute async def send_prompt_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 544046b59..87ffa26f4 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -12,6 +12,7 @@ PyritException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( ChatMessage, DataTypeSerializer, @@ -163,6 +164,26 @@ def __init__( self._extra_body_parameters = extra_body_parameters + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with OpenAI chat-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_completion_tokens": self._max_completion_tokens, + "max_tokens": self._max_tokens, + "frequency_penalty": self._frequency_penalty, + "presence_penalty": self._presence_penalty, + "seed": self._seed, + "n": self._n, + }, + ) + def _set_openai_env_configuration_vars(self) -> None: """ Set deployment_environment_variable, endpoint_environment_variable, diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 9180466e1..00800d72e 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -7,6 +7,7 @@ from pyrit.exceptions.exception_classes import ( pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.openai.openai_target import OpenAITarget @@ -72,6 +73,24 @@ def __init__( self._presence_penalty = presence_penalty self._n = n + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with OpenAI completion-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_tokens": self._max_tokens, + "frequency_penalty": self._frequency_penalty, + "presence_penalty": self._presence_penalty, + "n": self._n, + }, + ) + def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "OPENAI_COMPLETION_MODEL" self.endpoint_environment_variable = "OPENAI_COMPLETION_ENDPOINT" diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 925dbb002..6249aa5ef 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -10,6 +10,7 @@ EmptyResponseException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, construct_response_from_request, @@ -101,6 +102,21 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with image generation-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "image_size": self.image_size, + "quality": self.quality, + "style": self.style, + }, + ) + @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async( diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index faea8c56b..9e5ef778d 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -15,6 +15,7 @@ pyrit_target_retry, ) from pyrit.exceptions.exception_classes import ServerErrorException +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, construct_response_from_request, @@ -118,6 +119,19 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "wss://api.openai.com/v1", } + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Realtime API-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "voice": self.voice, + }, + ) + def _validate_url_for_target(self, endpoint_url: str) -> None: """ Validate URL for Realtime API with websocket-specific checks. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 9f1868f5c..bc709fa4b 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -21,6 +21,7 @@ PyritException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -156,6 +157,21 @@ def __init__( logger.debug("Detected grammar tool: %s", tool_name) self._grammar_name = tool_name + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with OpenAI response-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + temperature=self._temperature, + top_p=self._top_p, + target_specific_params={ + "max_output_tokens": self._max_output_tokens, + }, + ) + def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "OPENAI_RESPONSES_MODEL" self.endpoint_environment_variable = "OPENAI_RESPONSES_ENDPOINT" diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index d745f3e26..94194d78b 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -7,6 +7,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, construct_response_from_request, @@ -81,6 +82,22 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with TTS-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "voice": self._voice, + "response_format": self._response_format, + "language": self._language, + "speed": self._speed, + }, + ) + @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> list[Message]: diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index def844cb5..f6915c027 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -7,6 +7,7 @@ from pyrit.exceptions import ( pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -95,6 +96,20 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with video generation-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "resolution": self._size, + "n_seconds": self._n_seconds, + }, + ) + def _validate_resolution(self, *, resolution_dimensions: str) -> str: """ Validate resolution dimensions. diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 2610e78d8..c9a791601 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -8,6 +8,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, List, Tuple, Union +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -128,6 +129,19 @@ def __init__(self, *, page: "Page", copilot_type: CopilotType = CopilotType.CONS if page and self.M365_URL_IDENTIFIER not in page.url and copilot_type == CopilotType.M365: raise ValueError("The provided page URL does not indicate M365 Copilot, but the type is set to m365.") + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Copilot-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "copilot_type": self._type.value, + }, + ) + def _get_selectors(self) -> CopilotSelectors: """ Get the appropriate selectors for the current Copilot type. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 35e30f358..b2da771ad 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, Optional, Sequence from pyrit.common import default_values, net_utility +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -93,6 +94,20 @@ def __init__( self._force_entry_field: PromptShieldEntryField = field + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with Prompt Shield-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "api_version": self._api_version, + "force_entry_field": self._force_entry_field if self._force_entry_field else None, + }, + ) + @limit_requests_per_minute async def send_prompt_async(self, *, message: Message) -> list[Message]: """ diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 103c7efc1..4cabc58c3 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -16,6 +16,7 @@ EmptyResponseException, pyrit_target_retry, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute @@ -115,6 +116,19 @@ def __init__( model_name=model_name, ) + def _build_identifier(self) -> TargetIdentifier: + """ + Build the identifier with WebSocketCopilot-specific parameters. + + Returns: + TargetIdentifier: The identifier for this target instance. + """ + return self._create_identifier( + target_specific_params={ + "response_timeout_seconds": self._response_timeout_seconds, + }, + ) + @staticmethod def _dict_to_websocket(data: dict[str, Any]) -> str: """ diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py index 7585c4378..9b5e5f59f 100644 --- a/pyrit/registry/instance_registries/scorer_registry.py +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Optional from pyrit.identifiers import ScorerIdentifier -from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.instance_registries.base_instance_registry import ( BaseInstanceRegistry, ) @@ -65,10 +64,7 @@ def register_instance( (e.g., SelfAskRefusalScorer -> self_ask_refusal_abc123). """ if name is None: - base_name = class_name_to_snake_case(scorer.__class__.__name__, suffix="Scorer") - # Append identifier hash if available for uniqueness - identifier_hash = scorer.get_identifier().hash[:8] - name = f"{base_name}_{identifier_hash}" + name = scorer.get_identifier().unique_name self.register(scorer, name=name) logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index fcd94baee..7d9b3c3a3 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -19,6 +19,7 @@ from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import TargetIdentifier from pyrit.memory import CentralMemory from pyrit.memory.memory_models import ScenarioResultEntry from pyrit.models import AttackResult @@ -94,7 +95,7 @@ def __init__( # These will be set in initialize_async self._objective_target: Optional[PromptTarget] = None - self._objective_target_identifier: Optional[Dict[str, str]] = None + self._objective_target_identifier: Optional[TargetIdentifier] = None self._memory_labels: Dict[str, str] = {} self._max_concurrency: int = 1 self._max_retries: int = 0 diff --git a/pyrit/scenario/printer/console_printer.py b/pyrit/scenario/printer/console_printer.py index 3a248e0e0..45787916e 100644 --- a/pyrit/scenario/printer/console_printer.py +++ b/pyrit/scenario/printer/console_printer.py @@ -115,9 +115,10 @@ async def print_summary_async(self, result: ScenarioResult) -> None: # Target information print() self._print_colored(f"{self._indent}🎯 Target Information", Style.BRIGHT) - target_type = result.objective_target_identifier.get("__type__", "Unknown") - target_model = result.objective_target_identifier.get("model_name", "Unknown") - target_endpoint = result.objective_target_identifier.get("endpoint", "Unknown") + target_id = result.objective_target_identifier + target_type = target_id.class_name if target_id else "Unknown" + target_model = target_id.model_name if target_id else "Unknown" + target_endpoint = target_id.endpoint if target_id else "Unknown" self._print_colored(f"{self._indent * 2}• Target Type: {target_type}", Fore.CYAN) self._print_colored(f"{self._indent * 2}• Target Model: {target_model}", Fore.CYAN) diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 61c1d466e..83cd795ec 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -132,16 +132,19 @@ def _create_identifier( target_info: Optional[Dict[str, Any]] = None if prompt_target: target_id = prompt_target.get_identifier() - # Extract standard fields for scorer evaluation - target_info = {} - for key in ["__type__", "model_name", "temperature", "top_p"]: - if key in target_id: - target_info[key] = target_id[key] + # Extract standard fields for scorer evaluation, excluding None values + target_info = {"class_name": target_id.class_name} + if target_id.model_name: + target_info["model_name"] = target_id.model_name + if target_id.temperature is not None: + target_info["temperature"] = target_id.temperature + if target_id.top_p is not None: + target_info["top_p"] = target_id.top_p return ScorerIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, - class_description=self.__class__.__doc__ or "", + class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", identifier_type="instance", scorer_type=self.scorer_type, system_prompt_template=system_prompt_template, diff --git a/tests/unit/converter/test_denylist_converter.py b/tests/unit/converter/test_denylist_converter.py index bc7afb896..c75d8f920 100644 --- a/tests/unit/converter/test_denylist_converter.py +++ b/tests/unit/converter/test_denylist_converter.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from unit.mocks import MockPromptTarget +from unit.mocks import MockPromptTarget, get_mock_target_identifier from pyrit.models import Message, MessagePiece, SeedPrompt from pyrit.prompt_converter import DenylistConverter @@ -28,6 +28,7 @@ def mock_target() -> MockPromptTarget: ] ) target.send_prompt_async = AsyncMock(return_value=[response]) + target.get_identifier.return_value = get_mock_target_identifier("MockDenylistTarget") return target diff --git a/tests/unit/converter/test_generic_llm_converter.py b/tests/unit/converter/test_generic_llm_converter.py index 8c2a912ab..403294cab 100644 --- a/tests/unit/converter/test_generic_llm_converter.py +++ b/tests/unit/converter/test_generic_llm_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import ( @@ -28,6 +29,7 @@ def mock_target() -> PromptTarget: ] ) target.send_prompt_async = AsyncMock(return_value=[response]) + target.get_identifier.return_value = get_mock_target_identifier("MockLLMTarget") return target diff --git a/tests/unit/converter/test_math_prompt_converter.py b/tests/unit/converter/test_math_prompt_converter.py index 9203f5950..703000a73 100644 --- a/tests/unit/converter/test_math_prompt_converter.py +++ b/tests/unit/converter/test_math_prompt_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece, SeedPrompt from pyrit.prompt_converter import ConverterResult @@ -15,6 +16,7 @@ async def test_math_prompt_converter_convert_async(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Solve the following problem: {{ prompt }}" dataset_name = "dataset_1" @@ -72,6 +74,7 @@ async def test_math_prompt_converter_handles_disallowed_content(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Encode this instruction: {{ prompt }}" dataset_name = "dataset_1" @@ -126,6 +129,7 @@ async def test_math_prompt_converter_invalid_input_type(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Encode this instruction: {{ prompt }}" dataset_name = "dataset_1" @@ -148,6 +152,7 @@ async def test_math_prompt_converter_error_handling(): # Mock the converter target - use MagicMock for synchronous methods mock_converter_target = MagicMock() mock_converter_target.send_prompt_async = AsyncMock() + mock_converter_target.get_identifier.return_value = get_mock_target_identifier("MockMathTarget") # Specify parameters=['prompt'] to match the placeholder in the template template_value = "Encode this instruction: {{ prompt }}" dataset_name = "dataset_1" diff --git a/tests/unit/converter/test_random_translation_converter.py b/tests/unit/converter/test_random_translation_converter.py index e9c9ed4e9..a71fb86fe 100644 --- a/tests/unit/converter/test_random_translation_converter.py +++ b/tests/unit/converter/test_random_translation_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import RandomTranslationConverter @@ -27,6 +28,7 @@ def mock_target() -> PromptTarget: ] ) target.send_prompt_async = AsyncMock(return_value=[response]) + target.get_identifier.return_value = get_mock_target_identifier("MockTranslationTarget") return target diff --git a/tests/unit/converter/test_toxic_sentence_generator_converter.py b/tests/unit/converter/test_toxic_sentence_generator_converter.py index ed0d1c867..1e75ef89b 100644 --- a/tests/unit/converter/test_toxic_sentence_generator_converter.py +++ b/tests/unit/converter/test_toxic_sentence_generator_converter.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import MessagePiece, SeedPrompt from pyrit.prompt_converter import ToxicSentenceGeneratorConverter @@ -20,7 +21,7 @@ def mock_target(): conversation_id="test-conversation", ).to_message() mock.send_prompt_async = AsyncMock(return_value=[response]) - mock.get_identifier = MagicMock(return_value="mock_target") + mock.get_identifier.return_value = get_mock_target_identifier("MockToxicTarget") return mock diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index a3bc0474f..b874169d4 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_scorer_identifier from pyrit.executor.attack import ConversationManager, ConversationState from pyrit.executor.attack.component import PrependedConversationConfig @@ -33,10 +34,21 @@ ) from pyrit.executor.attack.core import AttackContext from pyrit.executor.attack.core.attack_parameters import AttackParameters +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget, PromptTarget -from tests.unit.mocks import get_mock_scorer_identifier + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + # ============================================================================= # Test Context Class @@ -78,7 +90,7 @@ def mock_chat_target() -> MagicMock: """Create a mock chat target for testing.""" target = MagicMock(spec=PromptChatTarget) target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"id": "mock_chat_target_id"} + target.get_identifier.return_value = _mock_target_id("MockChatTarget") return target @@ -86,7 +98,7 @@ def mock_chat_target() -> MagicMock: def mock_prompt_target() -> MagicMock: """Create a mock prompt target (non-chat) for testing.""" target = MagicMock(spec=PromptTarget) - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/component/test_simulated_conversation.py b/tests/unit/executor/attack/component/test_simulated_conversation.py index 2e59fb386..201692f9f 100644 --- a/tests/unit/executor/attack/component/test_simulated_conversation.py +++ b/tests/unit/executor/attack/component/test_simulated_conversation.py @@ -11,6 +11,7 @@ from pyrit.executor.attack.multi_turn.simulated_conversation import ( generate_simulated_conversation_async, ) +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -25,13 +26,33 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_adversarial_chat() -> MagicMock: """Create a mock adversarial chat target for testing.""" chat = MagicMock(spec=PromptChatTarget) chat.send_prompt_async = AsyncMock() chat.set_system_prompt = MagicMock() - chat.get_identifier.return_value = {"__type__": "MockAdversarialChat", "__module__": "test_module"} + chat.get_identifier.return_value = _mock_target_id("MockAdversarialChat") return chat @@ -40,7 +61,7 @@ def mock_objective_scorer() -> MagicMock: """Create a mock objective scorer for testing.""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_async = AsyncMock() - scorer.get_identifier.return_value = {"__type__": "MockScorer", "__module__": "test_module"} + scorer.get_identifier.return_value = _mock_scorer_id("MockScorer") return scorer diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index dcb9db2d6..918b9c753 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -13,6 +13,7 @@ _DefaultAttackStrategyEventHandler, ) from pyrit.executor.core import StrategyEvent, StrategyEventData +from pyrit.identifiers import TargetIdentifier from pyrit.memory.central_memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -22,6 +23,16 @@ from pyrit.prompt_target import PromptTarget +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory(): """Mock CentralMemory instance""" @@ -34,7 +45,7 @@ def mock_memory(): def mock_objective_target(): """Mock PromptTarget instance""" target = MagicMock(spec=PromptTarget) - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/multi_turn/test_crescendo.py b/tests/unit/executor/attack/multi_turn/test_crescendo.py index 1ba71af60..73d4909b3 100644 --- a/tests/unit/executor/attack/multi_turn/test_crescendo.py +++ b/tests/unit/executor/attack/multi_turn/test_crescendo.py @@ -24,7 +24,7 @@ CrescendoAttackContext, CrescendoAttackResult, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, ChatMessageRole, @@ -50,6 +50,16 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + def create_mock_chat_target(*, name: str = "MockChatTarget") -> MagicMock: """Create a mock chat target with common setup. @@ -59,7 +69,7 @@ def create_mock_chat_target(*, name: str = "MockChatTarget") -> MagicMock: target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"__type__": name, "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id(name) return target diff --git a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py index b11a7a61e..bdf70a77d 100644 --- a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py +++ b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py @@ -15,7 +15,7 @@ MultiPromptSendingAttackParameters, MultiTurnAttackContext, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -39,12 +39,22 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 7e1b7b269..340dc10c6 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -19,7 +19,7 @@ RedTeamingAttack, RTASystemPromptPaths, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -45,11 +45,21 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target() -> MagicMock: target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -58,7 +68,7 @@ def mock_adversarial_chat() -> MagicMock: chat = MagicMock(spec=PromptChatTarget) chat.send_prompt_async = AsyncMock() chat.set_system_prompt = MagicMock() - chat.get_identifier.return_value = {"__type__": "MockChatTarget", "__module__": "test_module"} + chat.get_identifier.return_value = _mock_target_id("MockChatTarget") return chat @@ -534,10 +544,7 @@ async def test_max_turns_validation_with_prepended_conversation( mock_chat_objective_target = MagicMock(spec=PromptChatTarget) mock_chat_objective_target.send_prompt_async = AsyncMock() mock_chat_objective_target.set_system_prompt = MagicMock() - mock_chat_objective_target.get_identifier.return_value = { - "__type__": "MockChatTarget", - "__module__": "test_module", - } + mock_chat_objective_target.get_identifier.return_value = _mock_target_id("MockChatTarget") adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) scoring_config = AttackScoringConfig(objective_scorer=mock_objective_scorer) diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index f6f99e173..141598b8e 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -26,7 +26,7 @@ TAPAttackScoringConfig, _TreeOfAttacksNode, ) -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, ConversationReference, @@ -223,7 +223,12 @@ def build(self) -> TreeOfAttacksWithPruningAttack: def _create_mock_target() -> PromptTarget: target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock(return_value=None) - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test_module"} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module="test_module", + class_description="", + identifier_type="instance", + ) return cast(PromptTarget, target) @staticmethod @@ -231,7 +236,12 @@ def _create_mock_chat() -> PromptChatTarget: chat = MagicMock(spec=PromptChatTarget) chat.send_prompt_async = AsyncMock(return_value=None) chat.set_system_prompt = MagicMock() - chat.get_identifier.return_value = {"__type__": "MockChatTarget", "__module__": "test_module"} + chat.get_identifier.return_value = TargetIdentifier( + class_name="MockChatTarget", + class_module="test_module", + class_description="", + identifier_type="instance", + ) return cast(PromptChatTarget, chat) @staticmethod diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index 7871e4483..3908cb89f 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -15,6 +15,7 @@ ContextComplianceAttack, SingleTurnAttackContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -26,12 +27,22 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -40,7 +51,7 @@ def mock_adversarial_chat(): """Create a mock adversarial chat target for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_adversarial_id"} + target.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_flip_attack.py b/tests/unit/executor/attack/single_turn/test_flip_attack.py index 51982a86e..faffd7709 100644 --- a/tests/unit/executor/attack/single_turn/test_flip_attack.py +++ b/tests/unit/executor/attack/single_turn/test_flip_attack.py @@ -13,6 +13,7 @@ FlipAttack, SingleTurnAttackContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -23,12 +24,22 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py index 428918fa6..9f9877c32 100644 --- a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py +++ b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py @@ -13,6 +13,7 @@ ManyShotJailbreakAttack, SingleTurnAttackContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -24,12 +25,22 @@ from pyrit.score import TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptTarget for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index 6553fd8a4..b55da7cf3 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier from pyrit.executor.attack import ( AttackConverterConfig, @@ -27,7 +28,6 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import Scorer, TrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier @pytest.fixture @@ -35,7 +35,7 @@ def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockTarget") return target diff --git a/tests/unit/executor/attack/single_turn/test_role_play.py b/tests/unit/executor/attack/single_turn/test_role_play.py index 79b04853f..98218e3c9 100644 --- a/tests/unit/executor/attack/single_turn/test_role_play.py +++ b/tests/unit/executor/attack/single_turn/test_role_play.py @@ -8,6 +8,7 @@ import pytest import yaml +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier from pyrit.executor.attack import ( AttackConverterConfig, @@ -25,7 +26,6 @@ from pyrit.prompt_normalizer import PromptConverterConfiguration from pyrit.prompt_target import PromptChatTarget from pyrit.score import Scorer, TrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier @pytest.fixture @@ -33,7 +33,7 @@ def mock_objective_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockTarget") return target @@ -42,7 +42,7 @@ def mock_adversarial_chat_target(): """Create a mock adversarial chat target for testing""" target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_adversarial_chat_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockAdversarialChat") return target diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index 963645e43..41278e9bd 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_scorer_identifier, get_mock_target_identifier from pyrit.executor.attack import ( AttackConverterConfig, @@ -24,7 +25,6 @@ from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import TrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier @pytest.fixture @@ -32,7 +32,7 @@ def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = get_mock_target_identifier("MockTarget") return target diff --git a/tests/unit/executor/attack/test_attack_parameter_consistency.py b/tests/unit/executor/attack/test_attack_parameter_consistency.py index b0d545076..deaa91cb1 100644 --- a/tests/unit/executor/attack/test_attack_parameter_consistency.py +++ b/tests/unit/executor/attack/test_attack_parameter_consistency.py @@ -23,7 +23,7 @@ TreeOfAttacksWithPruningAttack, ) from pyrit.executor.attack.multi_turn.tree_of_attacks import TAPAttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import ( ChatMessageRole, @@ -47,6 +47,16 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + # ============================================================================= # Multi-Modal Message Fixtures # ============================================================================= @@ -137,7 +147,7 @@ def mock_chat_target() -> MagicMock: target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"__type__": "MockChatTarget", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockChatTarget") return target @@ -146,7 +156,7 @@ def mock_non_chat_target() -> MagicMock: """Create a mock PromptTarget (non-chat) with common setup.""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -156,7 +166,7 @@ def mock_adversarial_chat() -> MagicMock: target = MagicMock(spec=PromptChatTarget) target.send_prompt_async = AsyncMock() target.set_system_prompt = MagicMock() - target.get_identifier.return_value = {"__type__": "MockAdversarialChat", "__module__": "test_module"} + target.get_identifier.return_value = _mock_target_id("MockAdversarialChat") return target diff --git a/tests/unit/executor/attack/test_error_skip_scoring.py b/tests/unit/executor/attack/test_error_skip_scoring.py index 15825baeb..781406b46 100644 --- a/tests/unit/executor/attack/test_error_skip_scoring.py +++ b/tests/unit/executor/attack/test_error_skip_scoring.py @@ -20,18 +20,39 @@ ) from pyrit.executor.attack.core import AttackAdversarialConfig, AttackScoringConfig from pyrit.executor.attack.multi_turn.tree_of_attacks import TAPAttackScoringConfig +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece, SeedGroup, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.score import FloatScaleThresholdScorer, TrueFalseScorer +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_target(): """Create a mock prompt target for testing""" target = MagicMock(spec=PromptTarget) target.send_prompt_async = AsyncMock() - target.get_identifier.return_value = {"id": "mock_target_id"} + target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -40,7 +61,7 @@ def mock_scorer(): """Create a mock scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_async = AsyncMock() - scorer.get_identifier.return_value = {"id": "mock_scorer_id"} + scorer.get_identifier.return_value = _mock_scorer_id("MockScorer") return scorer @@ -150,7 +171,7 @@ async def test_attack_executor_skips_scoring_on_error( if attack_class == TreeOfAttacksWithPruningAttack: tap_scorer = MagicMock(spec=FloatScaleThresholdScorer) tap_scorer.score_async = AsyncMock() - tap_scorer.get_identifier.return_value = {"id": "mock_tap_scorer_id"} + tap_scorer.get_identifier.return_value = _mock_scorer_id("MockTapScorer") tap_scorer.threshold = 0.7 attack_scoring_config = TAPAttackScoringConfig( objective_scorer=tap_scorer, @@ -171,7 +192,7 @@ async def test_attack_executor_skips_scoring_on_error( adversarial_target = MagicMock(spec=PromptTarget) adversarial_target.send_prompt_async = AsyncMock() - adversarial_target.get_identifier.return_value = {"id": "adversarial_target_id"} + adversarial_target.get_identifier.return_value = _mock_target_id("AdversarialTarget") attack_adversarial_config = AttackAdversarialConfig( target=adversarial_target, @@ -182,7 +203,7 @@ async def test_attack_executor_skips_scoring_on_error( if attack_class == CrescendoAttack: refusal_scorer = MagicMock(spec=TrueFalseScorer) refusal_scorer.score_async = AsyncMock(return_value=[]) - refusal_scorer.get_identifier.return_value = {"id": "refusal_scorer_id"} + refusal_scorer.get_identifier.return_value = _mock_scorer_id("RefusalScorer") attack_scoring_config.refusal_scorer = refusal_scorer # Create attack with proper configuration diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py index 12028448d..7b78eed2d 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from unit.mocks import MockPromptTarget +from unit.mocks import MockPromptTarget, get_mock_scorer_identifier from pyrit.common.path import DATASETS_PATH from pyrit.datasets import TextJailBreak @@ -29,7 +29,6 @@ Scorer, TrueFalseScorer, ) -from tests.unit.mocks import get_mock_scorer_identifier @pytest.mark.usefixtures("patch_central_database") diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py new file mode 100644 index 000000000..0541b36be --- /dev/null +++ b/tests/unit/identifiers/test_target_identifier.py @@ -0,0 +1,549 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for TargetIdentifier-specific functionality. + +Note: Base Identifier functionality (hash computation, to_dict/from_dict basics, +frozen/hashable properties) is tested via ScorerIdentifier in test_scorer_identifier.py. +These tests focus on target-specific fields and behaviors. +""" + +import pytest + +from pyrit.identifiers import TargetIdentifier + + +class TestTargetIdentifierBasic: + """Test basic TargetIdentifier functionality.""" + + def test_target_identifier_creation_minimal(self): + """Test creating a TargetIdentifier with only required fields.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + assert identifier.class_name == "TestTarget" + assert identifier.class_module == "pyrit.prompt_target.test_target" + assert identifier.endpoint == "" + assert identifier.model_name == "" + assert identifier.temperature is None + assert identifier.top_p is None + assert identifier.max_requests_per_minute is None + assert identifier.target_specific_params is None + assert identifier.hash is not None + assert len(identifier.hash) == 64 # SHA256 hex digest length + + def test_target_identifier_unique_name_minimal(self): + """Test that unique_name is auto-computed with minimal fields.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + ) + + # unique_name format: {snake_case_class_name}::{hash[:8]} + assert identifier.unique_name.startswith("open_ai_chat_target::") + assert len(identifier.unique_name.split("::")[1]) == 8 + + def test_target_identifier_creation_all_fields(self): + """Test creating a TargetIdentifier with all fields.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + temperature=0.7, + top_p=0.9, + max_requests_per_minute=100, + target_specific_params={"max_tokens": 1000, "headers": {}}, + ) + + assert identifier.endpoint == "https://api.openai.com/v1" + assert identifier.model_name == "gpt-4o" + assert identifier.temperature == 0.7 + assert identifier.top_p == 0.9 + assert identifier.max_requests_per_minute == 100 + assert identifier.target_specific_params["max_tokens"] == 1000 + + +class TestTargetIdentifierSpecificFields: + """Test TargetIdentifier-specific fields: endpoint, model_name, temperature, top_p, target_specific_params.""" + + def test_endpoint_stored_correctly(self): + """Test that endpoint is stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://example.com/api", + ) + + assert identifier.endpoint == "https://example.com/api" + + def test_model_name_stored_correctly(self): + """Test that model_name is stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + model_name="gpt-4o-mini", + ) + + assert identifier.model_name == "gpt-4o-mini" + + def test_temperature_and_top_p_stored_correctly(self): + """Test that temperature and top_p are stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + temperature=0.5, + top_p=0.95, + ) + + assert identifier.temperature == 0.5 + assert identifier.top_p == 0.95 + + def test_target_specific_params_stored_correctly(self): + """Test that target_specific_params are stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + target_specific_params={ + "max_tokens": 2000, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + }, + ) + + assert identifier.target_specific_params["max_tokens"] == 2000 + assert identifier.target_specific_params["frequency_penalty"] == 0.5 + assert identifier.target_specific_params["presence_penalty"] == 0.3 + + def test_max_requests_per_minute_stored_correctly(self): + """Test that max_requests_per_minute is stored correctly.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + max_requests_per_minute=60, + ) + + assert identifier.max_requests_per_minute == 60 + + +class TestTargetIdentifierHash: + """Test hash computation for TargetIdentifier.""" + + def test_hash_deterministic(self): + """Test that hash is the same for identical configurations.""" + identifier1 = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + model_name="test-model", + ) + identifier2 = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + model_name="test-model", + ) + + assert identifier1.hash == identifier2.hash + assert len(identifier1.hash) == 64 # SHA256 hex digest length + + def test_hash_different_for_different_endpoints(self): + """Test that different endpoints produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(endpoint="https://api1.example.com", **base_args) + identifier2 = TargetIdentifier(endpoint="https://api2.example.com", **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_model_names(self): + """Test that different model names produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(model_name="gpt-4o", **base_args) + identifier2 = TargetIdentifier(model_name="gpt-4o-mini", **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_temperature(self): + """Test that different temperature values produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(temperature=0.7, **base_args) + identifier2 = TargetIdentifier(temperature=0.9, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_top_p(self): + """Test that different top_p values produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(top_p=0.9, **base_args) + identifier2 = TargetIdentifier(top_p=0.95, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_hash_different_for_different_target_specific_params(self): + """Test that different target_specific_params produce different hashes.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(target_specific_params={"max_tokens": 100}, **base_args) + identifier2 = TargetIdentifier(target_specific_params={"max_tokens": 200}, **base_args) + + assert identifier1.hash != identifier2.hash + + +class TestTargetIdentifierToDict: + """Test to_dict method for TargetIdentifier.""" + + def test_to_dict_basic(self): + """Test basic to_dict output.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + result = identifier.to_dict() + + assert result["class_name"] == "TestTarget" + assert result["class_module"] == "pyrit.prompt_target.test_target" + assert result["hash"] == identifier.hash + # class_description, identifier_type, and unique_name are excluded from storage + assert "class_description" not in result + assert "identifier_type" not in result + assert "unique_name" not in result + + def test_to_dict_includes_endpoint_and_model_name(self): + """Test that endpoint and model_name are included in to_dict.""" + identifier = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + ) + + result = identifier.to_dict() + + assert result["endpoint"] == "https://api.openai.com/v1" + assert result["model_name"] == "gpt-4o" + + def test_to_dict_includes_temperature_and_top_p_when_set(self): + """Test that temperature and top_p are included when set.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + temperature=0.7, + top_p=0.9, + ) + + result = identifier.to_dict() + + assert result["temperature"] == 0.7 + assert result["top_p"] == 0.9 + + def test_to_dict_excludes_none_values(self): + """Test that None values are excluded from to_dict.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + # temperature, top_p, max_requests_per_minute, target_specific_params are None + ) + + result = identifier.to_dict() + + assert "temperature" not in result + assert "top_p" not in result + assert "max_requests_per_minute" not in result + assert "target_specific_params" not in result + + def test_to_dict_includes_target_specific_params(self): + """Test that target_specific_params are included in to_dict.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + target_specific_params={"max_tokens": 1000, "seed": 42}, + ) + + result = identifier.to_dict() + + assert result["target_specific_params"] == {"max_tokens": 1000, "seed": 42} + + +class TestTargetIdentifierFromDict: + """Test from_dict method for TargetIdentifier.""" + + def test_from_dict_basic(self): + """Test creating TargetIdentifier from a basic dict.""" + data = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + # unique_name is auto-computed + assert identifier.unique_name.startswith("test_target::") + + def test_from_dict_with_all_target_fields(self): + """Test creating TargetIdentifier from dict with all target-specific fields.""" + data = { + "class_name": "OpenAIChatTarget", + "class_module": "pyrit.prompt_target.openai.openai_chat_target", + "class_description": "OpenAI chat target", + "identifier_type": "instance", + "endpoint": "https://api.openai.com/v1", + "model_name": "gpt-4o", + "temperature": 0.7, + "top_p": 0.9, + "max_requests_per_minute": 60, + "target_specific_params": {"max_tokens": 1000}, + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.endpoint == "https://api.openai.com/v1" + assert identifier.model_name == "gpt-4o" + assert identifier.temperature == 0.7 + assert identifier.top_p == 0.9 + assert identifier.max_requests_per_minute == 60 + assert identifier.target_specific_params["max_tokens"] == 1000 + + def test_from_dict_handles_legacy_type_key(self): + """Test that from_dict handles legacy '__type__' key.""" + data = { + "__type__": "TestTarget", # Legacy key + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + + def test_from_dict_handles_deprecated_type_key(self): + """Test that from_dict handles deprecated 'type' key with warning.""" + data = { + "type": "TestTarget", # Deprecated key + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + with pytest.warns(DeprecationWarning, match="'type' key in Identifier dict is deprecated"): + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + + def test_from_dict_ignores_unknown_fields(self): + """Test that from_dict ignores fields not in the dataclass.""" + data = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + "unknown_field": "should be ignored", + "hash": "abc123stored_hash_preserved", + "unique_name": "stored_name_ignored_because_recomputed", + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.class_name == "TestTarget" + # hash is preserved from dict (not recomputed) to handle truncated fields + assert identifier.hash == "abc123stored_hash_preserved" + # unique_name is recomputed from hash + assert identifier.unique_name == f"test_target::{identifier.hash[:8]}" + + def test_from_dict_roundtrip(self): + """Test that to_dict -> from_dict roundtrip works.""" + original = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + temperature=0.7, + target_specific_params={"max_tokens": 500}, + ) + + storage_dict = original.to_dict() + # Add back the excluded fields for reconstruction + storage_dict["class_description"] = "OpenAI chat target" + storage_dict["identifier_type"] = "instance" + + reconstructed = TargetIdentifier.from_dict(storage_dict) + + assert reconstructed.class_name == original.class_name + assert reconstructed.endpoint == original.endpoint + assert reconstructed.model_name == original.model_name + assert reconstructed.temperature == original.temperature + assert reconstructed.target_specific_params == original.target_specific_params + # Hash should match since config is the same + assert reconstructed.hash == original.hash + + def test_from_dict_provides_defaults_for_missing_fields(self): + """Test that from_dict provides defaults for missing optional fields.""" + data = { + "class_name": "LegacyTarget", + "class_module": "pyrit.prompt_target.legacy", + "class_description": "A legacy target", + "identifier_type": "instance", + # Missing endpoint, model_name, temperature, top_p, etc. + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.endpoint == "" + assert identifier.model_name == "" + assert identifier.temperature is None + assert identifier.top_p is None + assert identifier.max_requests_per_minute is None + assert identifier.target_specific_params is None + + +class TestTargetIdentifierFrozen: + """Test that TargetIdentifier is immutable (frozen).""" + + def test_cannot_modify_fields(self): + """Test that attempting to modify fields raises an error.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + ) + + with pytest.raises(AttributeError): + identifier.class_name = "ModifiedTarget" + + with pytest.raises(AttributeError): + identifier.endpoint = "https://modified.example.com" + + with pytest.raises(AttributeError): + identifier.temperature = 0.5 + + def test_can_use_as_dict_key(self): + """Test that frozen identifier can be used as dict key (hashable).""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + # Should be hashable and usable as dict key + d = {identifier: "value"} + assert d[identifier] == "value" + + +class TestTargetIdentifierNormalize: + """Test the normalize class method for TargetIdentifier.""" + + def test_normalize_returns_target_identifier_unchanged(self): + """Test that normalize returns a TargetIdentifier unchanged.""" + original = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + endpoint="https://api.example.com", + ) + + result = TargetIdentifier.normalize(original) + + assert result is original + assert result.endpoint == "https://api.example.com" + + def test_normalize_converts_dict_to_target_identifier(self): + """Test that normalize converts a dict to TargetIdentifier with deprecation warning.""" + data = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + "endpoint": "https://api.example.com", + "model_name": "gpt-4o", + } + + with pytest.warns(DeprecationWarning, match="dict for TargetIdentifier is deprecated"): + result = TargetIdentifier.normalize(data) + + assert isinstance(result, TargetIdentifier) + assert result.class_name == "TestTarget" + assert result.endpoint == "https://api.example.com" + assert result.model_name == "gpt-4o" + + def test_normalize_raises_for_invalid_type(self): + """Test that normalize raises TypeError for invalid input types.""" + with pytest.raises(TypeError, match="Expected TargetIdentifier or dict"): + TargetIdentifier.normalize("invalid") + + with pytest.raises(TypeError, match="Expected TargetIdentifier or dict"): + TargetIdentifier.normalize(123) + + with pytest.raises(TypeError, match="Expected TargetIdentifier or dict"): + TargetIdentifier.normalize(["list", "of", "values"]) diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 171c47eba..4fd86a413 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -5,6 +5,7 @@ from typing import Optional import pytest +from unit.mocks import get_mock_scorer_identifier from pyrit.identifiers import ScorerIdentifier from pyrit.memory import MemoryInterface @@ -14,7 +15,6 @@ ScenarioIdentifier, ScenarioResult, ) -from tests.unit.mocks import get_mock_scorer_identifier @pytest.fixture @@ -294,7 +294,7 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): assert retrieved.scenario_identifier.description == "A test scenario with metadata" assert retrieved.scenario_identifier.version == 3 assert retrieved.scenario_identifier.init_data == {"param1": "value1", "param2": 42} - assert retrieved.objective_target_identifier == {"target": "test_target", "endpoint": "https://example.com"} + assert retrieved.objective_target_identifier.endpoint == "https://example.com" # objective_scorer_identifier is now a ScorerIdentifier, check its properties assert retrieved.objective_scorer_identifier.class_name == "TestScorer" assert retrieved.objective_scorer_identifier.class_module == "test.module" @@ -630,4 +630,4 @@ def test_combined_filters(sqlite_instance: MemoryInterface): ) assert len(results) == 1 assert results[0].scenario_identifier.pyrit_version == "0.5.0" - assert "gpt-4" in results[0].objective_target_identifier["model_name"] + assert "gpt-4" in results[0].objective_target_identifier.model_name diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 0460febd6..a9e4db4f9 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -207,7 +207,7 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): assert converter_identifiers[0].class_name == "Base64Converter" prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target["__type__"] == "TextTarget" + assert prompt_target.class_name == "TextTarget" labels = retrieved_entry.labels assert labels["normalizer_id"] == "id1" diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index f52f04198..eb345bb2f 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -376,7 +376,7 @@ def test_get_memories_with_json_properties(sqlite_instance): assert converter_identifiers[0].class_name == "Base64Converter" prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target["__type__"] == "TextTarget" + assert prompt_target.class_name == "TextTarget" labels = retrieved_entry.labels assert labels["normalizer_id"] == "id1" diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 928bc0fba..ee1fdd6be 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,7 +9,7 @@ from typing import Generator, MutableSequence, Optional, Sequence from unittest.mock import MagicMock, patch -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -28,6 +28,26 @@ def get_mock_scorer_identifier() -> ScorerIdentifier: ) +def get_mock_target_identifier(name: str = "MockTarget", module: str = "tests.unit.mocks") -> TargetIdentifier: + """ + Returns a mock TargetIdentifier for use in tests where the specific + target identity doesn't matter. + + Args: + name: The class name for the mock target. Defaults to "MockTarget". + module: The module path for the mock target. Defaults to "tests.unit.mocks". + + Returns: + A TargetIdentifier configured with the provided name and module. + """ + return TargetIdentifier( + class_name=name, + class_module=module, + class_description="Mock target for testing", + identifier_type="instance", + ) + + class MockHttpPostAsync(AbstractAsyncContextManager): def __init__(self, url, headers=None, json=None, params=None, ssl=None): self.status = 200 diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index fced766ab..78ecb6008 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -78,8 +78,8 @@ def test_prompt_targets_serialize(patch_central_database): prompt_target_identifier=target.get_identifier(), ) assert patch_central_database.called - assert entry.prompt_target_identifier["__type__"] == "MockPromptTarget" - assert entry.prompt_target_identifier["__module__"] == "unit.mocks" + assert entry.prompt_target_identifier.class_name == "MockPromptTarget" + assert entry.prompt_target_identifier.class_module == "unit.mocks" def test_executors_serialize(): @@ -745,7 +745,7 @@ def test_message_piece_to_dict(): assert result["targeted_harm_categories"] == entry.targeted_harm_categories assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] - assert result["prompt_target_identifier"] == entry.prompt_target_identifier + assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() assert result["attack_identifier"] == entry.attack_identifier assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type @@ -781,7 +781,7 @@ def test_message_piece_scorer_identifier_dict_backward_compatibility(): # Check that a deprecation warning was issued assert len(w) == 1 assert "deprecated" in str(w[0].message).lower() - assert "0.13.0" in str(w[0].message) + assert "0.14.0" in str(w[0].message) # Check that scorer_identifier is now a ScorerIdentifier assert isinstance(entry.scorer_identifier, ScorerIdentifier) diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index e70aaf646..cc185592c 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -7,7 +7,7 @@ from uuid import uuid4 import pytest -from unit.mocks import MockPromptTarget, get_image_message_piece +from unit.mocks import MockPromptTarget, get_image_message_piece, get_mock_target_identifier from pyrit.exceptions import ( ComponentRole, @@ -136,6 +136,7 @@ async def test_send_prompt_async_empty_response_exception_handled(mock_memory_in # Use MagicMock with send_prompt_async as AsyncMock to avoid coroutine warnings on other methods prompt_target = MagicMock() prompt_target.send_prompt_async = AsyncMock(side_effect=EmptyResponseException(message="Empty response")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") @@ -215,8 +216,9 @@ async def test_send_prompt_async_exception(mock_memory_instance, seed_group): @pytest.mark.asyncio async def test_send_prompt_async_empty_exception(mock_memory_instance, seed_group): - prompt_target = AsyncMock() + prompt_target = MagicMock() prompt_target.send_prompt_async = AsyncMock(side_effect=Exception("")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") @@ -418,6 +420,7 @@ async def test_convert_response_values_type(mock_memory_instance, response: Mess async def test_send_prompt_async_exception_conv_id(mock_memory_instance, seed_group): prompt_target = MagicMock(PromptTarget) prompt_target.send_prompt_async = AsyncMock(side_effect=Exception("Test Exception")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") diff --git a/tests/unit/scenarios/test_content_harms.py b/tests/unit/scenarios/test_content_harms.py index 66598c204..bfd0fbf8c 100644 --- a/tests/unit/scenarios/test_content_harms.py +++ b/tests/unit/scenarios/test_content_harms.py @@ -9,7 +9,7 @@ import pytest from pyrit.common.path import DATASETS_PATH -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget @@ -34,11 +34,21 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -46,7 +56,7 @@ def mock_objective_target(): def mock_adversarial_target(): """Create a mock adversarial target for testing.""" mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock diff --git a/tests/unit/scenarios/test_cyber.py b/tests/unit/scenarios/test_cyber.py index d84785190..37eed916d 100644 --- a/tests/unit/scenarios/test_cyber.py +++ b/tests/unit/scenarios/test_cyber.py @@ -12,7 +12,7 @@ from pyrit.common.path import DATASETS_PATH from pyrit.executor.attack import PromptSendingAttack, RedTeamingAttack from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget, PromptTarget from pyrit.scenario import DatasetConfiguration @@ -30,6 +30,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seed_groups(): """Create mock seed groups that _get_default_seed_groups() would return.""" @@ -86,7 +96,7 @@ def mock_runtime_env(): def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -102,7 +112,7 @@ def mock_objective_scorer(): def mock_adversarial_target(): """Create a mock adversarial target for testing.""" mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock @@ -169,7 +179,7 @@ def test_init_default_adversarial_chat(self, mock_objective_scorer, mock_memory_ def test_init_with_adversarial_chat(self, mock_objective_scorer, mock_memory_seed_groups): """Test initialization with adversarial chat (for red teaming attack variation).""" adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") with patch.object(Cyber, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Cyber( diff --git a/tests/unit/scenarios/test_encoding.py b/tests/unit/scenarios/test_encoding.py index 70c4274fb..699304aa1 100644 --- a/tests/unit/scenarios/test_encoding.py +++ b/tests/unit/scenarios/test_encoding.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack import PromptSendingAttack -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_converter import Base64Converter from pyrit.prompt_target import PromptTarget @@ -28,6 +28,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seeds(): """Create mock seed prompts that memory.get_seeds() would return.""" @@ -67,7 +77,7 @@ def mock_dataset_config(mock_seed_attack_groups): def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock diff --git a/tests/unit/scenarios/test_foundry.py b/tests/unit/scenarios/test_foundry.py index aef8fde09..59bcca446 100644 --- a/tests/unit/scenarios/test_foundry.py +++ b/tests/unit/scenarios/test_foundry.py @@ -10,7 +10,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedObjective from pyrit.prompt_converter import Base64Converter from pyrit.prompt_target import PromptTarget @@ -30,6 +30,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seed_groups(): """Create mock seed groups that _get_default_seed_groups() would return.""" @@ -56,7 +66,7 @@ def mock_dataset_config(mock_memory_seed_groups): def mock_objective_target(): """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -64,7 +74,7 @@ def mock_objective_target(): def mock_adversarial_target(): """Create a mock adversarial target for testing.""" mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock diff --git a/tests/unit/scenarios/test_leakage_scenario.py b/tests/unit/scenarios/test_leakage_scenario.py index 1d795a7e2..29ec0b0bf 100644 --- a/tests/unit/scenarios/test_leakage_scenario.py +++ b/tests/unit/scenarios/test_leakage_scenario.py @@ -12,7 +12,7 @@ from pyrit.common.path import DATASETS_PATH from pyrit.executor.attack import CrescendoAttack, PromptSendingAttack, RolePlayAttack from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget, PromptTarget from pyrit.scenario import DatasetConfiguration @@ -30,6 +30,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seeds(): leakage_path = pathlib.Path(DATASETS_PATH) / "seed_datasets" / "local" / "airt" @@ -95,7 +105,7 @@ def mock_runtime_env(): @pytest.fixture def mock_objective_target(): mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -109,7 +119,7 @@ def mock_objective_scorer(): @pytest.fixture def mock_adversarial_target(): mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock @@ -182,7 +192,7 @@ def test_init_default_adversarial_chat(self, mock_objective_scorer, mock_memory_ def test_init_with_adversarial_chat(self, mock_objective_scorer, mock_memory_seeds): """Test initialization with adversarial chat (for multi-turn attack variations).""" adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") with patch.object( LeakageScenario, "_get_default_objectives", return_value=[seed.value for seed in mock_memory_seeds] diff --git a/tests/unit/scenarios/test_scam.py b/tests/unit/scenarios/test_scam.py index f8fcd7975..a4573f551 100644 --- a/tests/unit/scenarios/test_scam.py +++ b/tests/unit/scenarios/test_scam.py @@ -16,7 +16,7 @@ RolePlayAttack, ) from pyrit.executor.attack.core.attack_config import AttackScoringConfig -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedAttackGroup, SeedDataset, SeedGroup, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget, PromptTarget from pyrit.scenario import DatasetConfiguration @@ -37,6 +37,16 @@ def _mock_scorer_id(name: str = "MockObjectiveScorer") -> ScorerIdentifier: ) +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_memory_seed_groups() -> List[SeedGroup]: """Create mock seed groups that _get_default_seed_groups() would return.""" @@ -94,7 +104,7 @@ def mock_runtime_env(): @pytest.fixture def mock_objective_target() -> PromptTarget: mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockObjectiveTarget") return mock @@ -108,7 +118,7 @@ def mock_objective_scorer() -> TrueFalseCompositeScorer: @pytest.fixture def mock_adversarial_target() -> PromptChatTarget: mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = _mock_target_id("MockAdversarialTarget") return mock @@ -181,7 +191,7 @@ def test_init_with_adversarial_chat( self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: List[SeedGroup] ) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): scenario = Scam( diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index 9b13d8730..f7e0e6fe4 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.core import AttackExecutorResult -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioIdentifier, ScenarioResult @@ -71,10 +71,12 @@ def mock_atomic_attacks(): def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = { - "__type__": "MockTarget", - "__module__": "test", - } + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module="test", + class_description="", + identifier_type="instance", + ) return target @@ -230,10 +232,9 @@ async def test_initialize_async_sets_objective_target(self, mock_objective_targe await scenario.initialize_async(objective_target=mock_objective_target) assert scenario._objective_target == mock_objective_target - assert scenario._objective_target_identifier == { - "__type__": "MockTarget", - "__module__": "test", - } + # Verify it's a TargetIdentifier with the expected class_name + assert scenario._objective_target_identifier.class_name == "MockTarget" + assert scenario._objective_target_identifier.class_module == "test" @pytest.mark.asyncio async def test_initialize_async_requires_objective_target(self): diff --git a/tests/unit/scenarios/test_scenario_partial_results.py b/tests/unit/scenarios/test_scenario_partial_results.py index 5c0c67934..1886c3639 100644 --- a/tests/unit/scenarios/test_scenario_partial_results.py +++ b/tests/unit/scenarios/test_scenario_partial_results.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.core import AttackExecutorResult -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult @@ -29,7 +29,12 @@ def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test"} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module="test", + class_description="", + identifier_type="instance", + ) return target diff --git a/tests/unit/scenarios/test_scenario_retry.py b/tests/unit/scenarios/test_scenario_retry.py index 8ffc8c6de..69bcecef0 100644 --- a/tests/unit/scenarios/test_scenario_retry.py +++ b/tests/unit/scenarios/test_scenario_retry.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.core import AttackExecutorResult -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult from pyrit.scenario import DatasetConfiguration, ScenarioResult @@ -195,7 +195,12 @@ def mock_atomic_attacks(): def mock_objective_target(): """Create a mock objective target for testing.""" target = MagicMock() - target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": TEST_MODULE} + target.get_identifier.return_value = TargetIdentifier( + class_name="MockTarget", + class_module=TEST_MODULE, + class_description="", + identifier_type="instance", + ) return target diff --git a/tests/unit/score/test_gandalf_scorer.py b/tests/unit/score/test_gandalf_scorer.py index dd4e62115..abad3a306 100644 --- a/tests/unit/score/test_gandalf_scorer.py +++ b/tests/unit/score/test_gandalf_scorer.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import PyritException from pyrit.memory.memory_interface import MemoryInterface @@ -52,6 +53,7 @@ async def test_gandalf_scorer_score( mocked_post, sqlite_instance: MemoryInterface, level: GandalfLevel, password_correct: bool ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") conversation_id = str(uuid.uuid4()) sqlite_instance.add_message_to_memory(request=generate_request(conversation_id=conversation_id)) @@ -95,6 +97,7 @@ async def test_gandalf_scorer_set_system_prompt( sqlite_instance.add_message_to_memory(request=response) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[response]) scorer = GandalfScorer(chat_target=chat_target, level=level) @@ -119,6 +122,7 @@ async def test_gandalf_scorer_adds_to_memory(mocked_post, level: GandalfLevel, s sqlite_instance.add_message_to_memory(request=response) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[response]) # Mock the requests.post call to return a successful response @@ -139,6 +143,7 @@ async def test_gandalf_scorer_runtime_error_retries(level: GandalfLevel, sqlite_ sqlite_instance.add_message_to_memory(request=response) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(side_effect=[RuntimeError("Error"), response]) scorer = GandalfScorer(level=level, chat_target=chat_target) diff --git a/tests/unit/score/test_general_float_scale_scorer.py b/tests/unit/score/test_general_float_scale_scorer.py index 0933646d7..7ee85404d 100644 --- a/tests/unit/score/test_general_float_scale_scorer.py +++ b/tests/unit/score/test_general_float_scale_scorer.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.score.float_scale.self_ask_general_float_scale_scorer import ( @@ -31,6 +32,7 @@ def general_float_scorer_response() -> Message: @pytest.mark.asyncio async def test_general_float_scorer_score_async(patch_central_database, general_float_scorer_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_float_scorer_response]) scorer = SelfAskGeneralFloatScaleScorer( @@ -54,6 +56,7 @@ async def test_general_float_scorer_score_async_with_prompt_f_string( general_float_scorer_response: Message, patch_central_database ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_float_scorer_response]) scorer = SelfAskGeneralFloatScaleScorer( @@ -77,6 +80,7 @@ async def test_general_float_scorer_score_async_with_prompt_f_string( @pytest.mark.asyncio async def test_general_float_scorer_score_async_handles_custom_keys(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") assert chat_target json_response = ( @@ -114,6 +118,7 @@ async def test_general_float_scorer_score_async_handles_custom_keys(patch_centra @pytest.mark.asyncio async def test_general_float_scorer_score_async_min_max_scale(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( """ @@ -145,6 +150,7 @@ async def test_general_float_scorer_score_async_min_max_scale(patch_central_data def test_general_float_scorer_init_invalid_min_max(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") with pytest.raises(ValueError): SelfAskGeneralFloatScaleScorer( chat_target=chat_target, diff --git a/tests/unit/score/test_general_true_false_scorer.py b/tests/unit/score/test_general_true_false_scorer.py index 85deb92e5..49e4b9839 100644 --- a/tests/unit/score/test_general_true_false_scorer.py +++ b/tests/unit/score/test_general_true_false_scorer.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.models import Message, MessagePiece from pyrit.score import SelfAskGeneralTrueFalseScorer @@ -30,6 +31,7 @@ def general_scorer_response() -> Message: @pytest.mark.asyncio async def test_general_scorer_score_async(patch_central_database, general_scorer_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_scorer_response]) scorer = SelfAskGeneralTrueFalseScorer( @@ -54,6 +56,7 @@ async def test_general_scorer_score_async_with_prompt_f_string( general_scorer_response: Message, patch_central_database ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[general_scorer_response]) scorer = SelfAskGeneralTrueFalseScorer( @@ -79,6 +82,7 @@ async def test_general_scorer_score_async_with_prompt_f_string( @pytest.mark.asyncio async def test_general_scorer_score_async_handles_custom_keys(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") assert chat_target json_response = ( diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 5cbb53940..8fc2f5545 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions import InvalidJsonException, remove_markdown_json from pyrit.identifiers import ScorerIdentifier @@ -151,6 +152,7 @@ def get_scorer_metrics(self): @pytest.mark.parametrize("bad_json", [BAD_JSON, KEY_ERROR_JSON, KEY_ERROR2_JSON]) async def test_scorer_send_chat_target_async_bad_json_exception_retries(bad_json: str): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=bad_json, conversation_id="test-convo")] ) @@ -174,6 +176,7 @@ async def test_scorer_send_chat_target_async_bad_json_exception_retries(bad_json @pytest.mark.asyncio async def test_scorer_score_value_with_llm_exception_display_prompt_id(): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(side_effect=Exception("Test exception")) scorer = MockScorer() @@ -198,6 +201,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[message]) chat_target.set_system_prompt = MagicMock() @@ -233,6 +237,7 @@ async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empt message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[message]) chat_target.set_system_prompt = MagicMock() @@ -259,6 +264,7 @@ async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empt @pytest.mark.asyncio async def test_scorer_send_chat_target_async_good_response(good_json): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] @@ -283,6 +289,7 @@ async def test_scorer_send_chat_target_async_good_response(good_json): @pytest.mark.asyncio async def test_scorer_remove_markdown_json_called(good_json): chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) @@ -308,6 +315,7 @@ async def test_scorer_remove_markdown_json_called(good_json): async def test_score_value_with_llm_prepended_text_message_piece_creates_multipiece_message(good_json): """Test that prepended_text_message_piece creates a multi-piece message (text context + main content).""" chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) @@ -351,6 +359,7 @@ async def test_score_value_with_llm_prepended_text_message_piece_creates_multipi async def test_score_value_with_llm_no_prepended_text_creates_single_piece_message(good_json): """Test that without prepended_text_message_piece, only a single piece message is created.""" chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) @@ -386,6 +395,7 @@ async def test_score_value_with_llm_no_prepended_text_creates_single_piece_messa async def test_score_value_with_llm_prepended_text_works_with_audio(good_json): """Test that prepended_text_message_piece works with audio content (type-independent).""" chat_target = MagicMock(PromptChatTarget) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") good_json_resp = Message( message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] ) diff --git a/tests/unit/score/test_self_ask_category.py b/tests/unit/score/test_self_ask_category.py index 68038ffac..390788d1e 100644 --- a/tests/unit/score/test_self_ask_category.py +++ b/tests/unit/score/test_self_ask_category.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory @@ -50,6 +51,7 @@ def scorer_category_response_false() -> Message: def test_category_scorer_set_no_category_found(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskCategoryScorer( chat_target=chat_target, content_classifier_path=ContentClassifierPaths.HARMFUL_CONTENT_CLASSIFIER.value, @@ -63,6 +65,7 @@ def test_category_scorer_set_no_category_found(): @pytest.mark.asyncio async def test_category_scorer_set_system_prompt(scorer_category_response_bullying: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_bullying]) scorer = SelfAskCategoryScorer( @@ -78,6 +81,7 @@ async def test_category_scorer_set_system_prompt(scorer_category_response_bullyi @pytest.mark.asyncio async def test_category_scorer_score(scorer_category_response_bullying: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_bullying]) @@ -100,6 +104,7 @@ async def test_category_scorer_score(scorer_category_response_bullying: Message, @pytest.mark.asyncio async def test_category_scorer_score_false(scorer_category_response_false: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_false]) @@ -122,6 +127,7 @@ async def test_category_scorer_score_false(scorer_category_response_false: Messa async def test_category_scorer_adds_to_memory(scorer_category_response_false: Message, patch_central_database): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_category_response_false]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskCategoryScorer( @@ -137,6 +143,7 @@ async def test_category_scorer_adds_to_memory(scorer_category_response_false: Me @pytest.mark.asyncio async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -155,6 +162,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra @pytest.mark.asyncio async def test_self_ask_objective_scorer_json_missing_key_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( @@ -192,6 +200,7 @@ async def test_score_prompts_batch_async( patch_central_database, ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock() chat_target._max_requests_per_minute = max_requests_per_minute with patch.object(CentralMemory, "get_memory_instance", return_value=MagicMock()): diff --git a/tests/unit/score/test_self_ask_likert.py b/tests/unit/score/test_self_ask_likert.py index e194f5c2b..6e0db8188 100644 --- a/tests/unit/score/test_self_ask_likert.py +++ b/tests/unit/score/test_self_ask_likert.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory, MemoryInterface @@ -36,6 +37,7 @@ async def test_likert_scorer_set_system_prompt(scorer_likert_response: Message): memory = MagicMock(MemoryInterface) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_likert_response]) scorer = SelfAskLikertScorer(chat_target=chat_target, likert_scale=LikertScalePaths.CYBER_SCALE) @@ -58,6 +60,7 @@ async def test_likert_scorer_set_system_prompt(scorer_likert_response: Message): async def test_likert_scorer_adds_to_memory(scorer_likert_response: Message): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_likert_response]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskLikertScorer(chat_target=chat_target, likert_scale=LikertScalePaths.CYBER_SCALE) @@ -70,6 +73,7 @@ async def test_likert_scorer_adds_to_memory(scorer_likert_response: Message): @pytest.mark.asyncio async def test_likert_scorer_score(patch_central_database, scorer_likert_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_likert_response]) @@ -91,6 +95,7 @@ async def test_likert_scorer_score(patch_central_database, scorer_likert_respons @pytest.mark.asyncio async def test_self_ask_scorer_bad_json_exception_retries(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -105,6 +110,7 @@ async def test_self_ask_scorer_bad_json_exception_retries(): @pytest.mark.asyncio async def test_self_ask_likert_scorer_json_missing_key_exception_retries(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( diff --git a/tests/unit/score/test_self_ask_refusal.py b/tests/unit/score/test_self_ask_refusal.py index 4d99e1263..75d3fbc65 100644 --- a/tests/unit/score/test_self_ask_refusal.py +++ b/tests/unit/score/test_self_ask_refusal.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory import CentralMemory @@ -32,6 +33,7 @@ def scorer_true_false_response() -> Message: @pytest.mark.asyncio async def test_refusal_scorer_score(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -48,6 +50,7 @@ async def test_refusal_scorer_score(scorer_true_false_response: Message, patch_c @pytest.mark.asyncio async def test_refusal_scorer_set_system_prompt(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -59,6 +62,7 @@ async def test_refusal_scorer_set_system_prompt(scorer_true_false_response: Mess @pytest.mark.asyncio async def test_refusal_scorer_no_task(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -71,6 +75,7 @@ async def test_refusal_scorer_no_task(scorer_true_false_response: Message, patch @pytest.mark.asyncio async def test_refusal_scorer_with_task(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -85,6 +90,7 @@ async def test_refusal_scorer_with_task(scorer_true_false_response: Message, pat @pytest.mark.asyncio async def test_refusal_scorer_image_non_block(scorer_true_false_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -101,6 +107,7 @@ async def test_refusal_scorer_image_non_block(scorer_true_false_response: Messag async def test_refusal_scorer_adds_to_memory(scorer_true_false_response: Message): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskRefusalScorer(chat_target=chat_target) @@ -112,6 +119,7 @@ async def test_refusal_scorer_adds_to_memory(scorer_true_false_response: Message @pytest.mark.asyncio async def test_refusal_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -127,6 +135,7 @@ async def test_refusal_scorer_bad_json_exception_retries(patch_central_database) @pytest.mark.asyncio async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") json_response = ( dedent( @@ -155,6 +164,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra async def test_score_async_filtered_response(patch_central_database): memory = CentralMemory.get_memory_instance() chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskRefusalScorer(chat_target=chat_target) request = MessagePiece(role="assistant", original_value="blocked response", response_error="blocked").to_message() diff --git a/tests/unit/score/test_self_ask_scale.py b/tests/unit/score/test_self_ask_scale.py index 07c485056..2a437e09c 100644 --- a/tests/unit/score/test_self_ask_scale.py +++ b/tests/unit/score/test_self_ask_scale.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from unit.mocks import get_mock_target_identifier from pyrit.identifiers import ScorerIdentifier from pyrit.models import Message, MessagePiece, UnvalidatedScore @@ -39,8 +40,10 @@ def scorer_scale_response() -> Message: @pytest.fixture def scale_scorer(patch_central_database) -> SelfAskScaleScorer: + chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") return SelfAskScaleScorer( - chat_target=MagicMock(), + chat_target=chat_target, scale_arguments_path=SelfAskScaleScorer.ScalePaths.TREE_OF_ATTACKS_SCALE.value, system_prompt_path=SelfAskScaleScorer.SystemPaths.GENERAL_SYSTEM_PROMPT.value, ) @@ -65,6 +68,7 @@ async def test_scale_scorer_set_system_prompt( patch_central_database, ): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_scale_response]) scorer = SelfAskScaleScorer( @@ -86,6 +90,7 @@ async def test_scale_scorer_set_system_prompt( def test_scale_scorer_invalid_scale_file_contents(): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") # When using a YAML with wrong keys the Scale constructor will raise an exception. with pytest.raises(ValueError, match="Missing key in scale_args:"): SelfAskScaleScorer( @@ -135,6 +140,7 @@ def test_validate_scale_arguments_missing_args_raises_value_error(scale_args, sc @pytest.mark.asyncio async def test_scale_scorer_score(scorer_scale_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_scale_response]) @@ -161,6 +167,7 @@ async def test_scale_scorer_score(scorer_scale_response: Message, patch_central_ @pytest.mark.asyncio async def test_scale_scorer_score_custom_scale(scorer_scale_response: Message, patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") # set a higher score to test the scaling scorer_scale_response.message_pieces[0].original_value = scorer_scale_response.message_pieces[ @@ -197,6 +204,7 @@ async def test_scale_scorer_score_custom_scale(scorer_scale_response: Message, p @pytest.mark.asyncio async def test_scale_scorer_score_calls_send_chat(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskScaleScorer( chat_target=chat_target, diff --git a/tests/unit/score/test_self_ask_true_false.py b/tests/unit/score/test_self_ask_true_false.py index 70747285f..80cc21ec2 100644 --- a/tests/unit/score/test_self_ask_true_false.py +++ b/tests/unit/score/test_self_ask_true_false.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from unit.mocks import get_mock_target_identifier from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.memory.central_memory import CentralMemory @@ -31,6 +32,7 @@ def scorer_true_false_response() -> Message: @pytest.mark.asyncio async def test_true_false_scorer_score(patch_central_database, scorer_true_false_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskTrueFalseScorer( @@ -49,6 +51,7 @@ async def test_true_false_scorer_score(patch_central_database, scorer_true_false @pytest.mark.asyncio async def test_true_false_scorer_set_system_prompt(patch_central_database, scorer_true_false_response: Message): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) scorer = SelfAskTrueFalseScorer( @@ -68,6 +71,7 @@ async def test_true_false_scorer_set_system_prompt(patch_central_database, score async def test_true_false_scorer_adds_to_memory(scorer_true_false_response: Message): memory = MagicMock(MemoryInterface) chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[scorer_true_false_response]) with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = SelfAskTrueFalseScorer( @@ -82,6 +86,7 @@ async def test_true_false_scorer_adds_to_memory(scorer_true_false_response: Mess @pytest.mark.asyncio async def test_self_ask_scorer_bad_json_exception_retries(patch_central_database): chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value="this is not a json")]) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) @@ -111,6 +116,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra ) bad_json_resp = Message(message_pieces=[MessagePiece(role="assistant", original_value=json_response)]) + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) scorer = SelfAskTrueFalseScorer( @@ -127,6 +133,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra def test_self_ask_true_false_scorer_identifier_has_system_prompt_template(patch_central_database): """Test that identifier includes system_prompt_template.""" chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskTrueFalseScorer( chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value @@ -143,6 +150,7 @@ def test_self_ask_true_false_scorer_identifier_has_system_prompt_template(patch_ def test_self_ask_true_false_get_identifier_type(patch_central_database): """Test that get_identifier returns correct class_name.""" chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskTrueFalseScorer( chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value @@ -158,6 +166,7 @@ def test_self_ask_true_false_get_identifier_type(patch_central_database): def test_self_ask_true_false_get_identifier_long_prompt_hashed(patch_central_database): """Test that long system prompts are truncated when serialized via to_dict().""" chat_target = MagicMock() + chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") scorer = SelfAskTrueFalseScorer( chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index 27de6693a..c15913128 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -8,6 +8,7 @@ import numpy as np import pytest +from unit.mocks import get_mock_scorer_identifier from pyrit.identifiers import ScorerIdentifier from pyrit.models import MessagePiece, Score @@ -16,7 +17,6 @@ from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_scorer import TrueFalseScorer from pyrit.score.true_false.video_true_false_scorer import VideoTrueFalseScorer -from tests.unit.mocks import get_mock_scorer_identifier def is_opencv_installed(): diff --git a/tests/unit/target/test_crucible_target.py b/tests/unit/target/test_crucible_target.py index a2e71e82e..9469d8e7f 100644 --- a/tests/unit/target/test_crucible_target.py +++ b/tests/unit/target/test_crucible_target.py @@ -20,7 +20,7 @@ def test_crucible_initializes(crucible_target: CrucibleTarget): def test_crucible_sets_endpoint_and_rate_limit(): target = CrucibleTarget(endpoint="https://crucible", api_key="abc", max_requests_per_minute=10) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://crucible" + assert identifier.endpoint == "https://crucible" assert target._max_requests_per_minute == 10 diff --git a/tests/unit/target/test_gandalf_target.py b/tests/unit/target/test_gandalf_target.py index d8e3ffdb9..ffd7a8274 100644 --- a/tests/unit/target/test_gandalf_target.py +++ b/tests/unit/target/test_gandalf_target.py @@ -20,7 +20,7 @@ def test_gandalf_initializes(gandalf_target: GandalfTarget): def test_gandalf_sets_endpoint_and_rate_limit(): target = GandalfTarget(level=GandalfLevel.LEVEL_1, max_requests_per_minute=15) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://gandalf-api.lakera.ai/api/send-message" + assert identifier.endpoint == "https://gandalf-api.lakera.ai/api/send-message" assert target._max_requests_per_minute == 15 diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index 6f5b5842d..5d49702b0 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -59,7 +59,7 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite max_requests_per_minute=25, ) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://example.com/" + assert identifier.endpoint == "https://example.com/" assert target._max_requests_per_minute == 25 @@ -67,7 +67,7 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite @patch("httpx.AsyncClient.request") async def test_send_prompt_async(mock_request, mock_http_target, mock_http_response): message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] mock_request.return_value = mock_http_response response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -113,7 +113,7 @@ async def test_send_prompt_async_client_kwargs(): # Use **httpx_client_kwargs to pass them as keyword arguments http_target = HTTPTarget(http_request=sample_request, **httpx_client_kwargs) message = MagicMock() - message.message_pieces = [MagicMock(converted_value="")] + message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None)] mock_response = MagicMock() mock_response.content = b"Response content" mock_request.return_value = mock_response @@ -148,7 +148,7 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): mock_http_target.callback_function = callback_function message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] mock_response = MagicMock() mock_response.content = b"Match: 1234" @@ -175,7 +175,7 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send first prompt message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -193,7 +193,7 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send second prompt second_message = MagicMock() - second_message.message_pieces = [MagicMock(converted_value="second_test_prompt")] + second_message.message_pieces = [MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None)] await mock_http_target.send_prompt_async(message=second_message) # Assert that the original template is still the same @@ -241,7 +241,7 @@ async def test_http_target_with_injected_client(): mock_request.return_value = mock_response message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt")] + message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] response = await target.send_prompt_async(message=message) diff --git a/tests/unit/target/test_hugging_face_endpoint_target.py b/tests/unit/target/test_hugging_face_endpoint_target.py index 790654204..2bc4de352 100644 --- a/tests/unit/target/test_hugging_face_endpoint_target.py +++ b/tests/unit/target/test_hugging_face_endpoint_target.py @@ -29,7 +29,7 @@ def test_hugging_face_endpoint_sets_endpoint_and_rate_limit(): max_requests_per_minute=30, ) identifier = target.get_identifier() - assert identifier["endpoint"] == "https://api-inference.huggingface.co/models/test-model" + assert identifier.endpoint == "https://api-inference.huggingface.co/models/test-model" assert target._max_requests_per_minute == 30 diff --git a/tests/unit/target/test_huggingface_chat_target.py b/tests/unit/target/test_huggingface_chat_target.py index c100f8d8e..ec1593c79 100644 --- a/tests/unit/target/test_huggingface_chat_target.py +++ b/tests/unit/target/test_huggingface_chat_target.py @@ -329,7 +329,7 @@ async def test_is_json_response_supported(): @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio -async def test_hugging_face_chat_sets_endpoint_and_rate_limit(): +async def test_hugging_face_chat_sets_endpoint_and_rate_limit(patch_central_database): target = HuggingFaceChatTarget( model_id="test_model", use_cuda=False, @@ -338,6 +338,6 @@ async def test_hugging_face_chat_sets_endpoint_and_rate_limit(): # Await the background task to prevent warnings await target.load_model_and_tokenizer_task identifier = target.get_identifier() - # HuggingFaceChatTarget doesn't set an endpoint (it's local), so it shouldn't be in identifier - assert "endpoint" not in identifier + # HuggingFaceChatTarget doesn't set an endpoint (it's local), so it should be empty + assert not identifier.endpoint assert target._max_requests_per_minute == 30 diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 58a10cfe3..6ba29eb2e 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -1073,16 +1073,18 @@ async def test_construct_message_from_response(target: OpenAIChatTarget, dummy_t def test_get_identifier_uses_model_name_when_no_underlying_model(patch_central_database): """Test that get_identifier uses model_name when underlying_model is not provided.""" - target = OpenAIChatTarget( - model_name="my-deployment", - endpoint="https://mock.azure.com/", - api_key="mock-api-key", - ) + # Clear the environment variable to ensure it doesn't interfere with the test + with patch.dict(os.environ, {"OPENAI_CHAT_UNDERLYING_MODEL": ""}, clear=False): + target = OpenAIChatTarget( + model_name="my-deployment", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) - identifier = target.get_identifier() + identifier = target.get_identifier() - assert identifier["model_name"] == "my-deployment" - assert identifier["__type__"] == "OpenAIChatTarget" + assert identifier.model_name == "my-deployment" + assert identifier.class_name == "OpenAIChatTarget" def test_get_identifier_uses_underlying_model_when_provided_as_param(patch_central_database): @@ -1096,8 +1098,8 @@ def test_get_identifier_uses_underlying_model_when_provided_as_param(patch_centr identifier = target.get_identifier() - assert identifier["model_name"] == "gpt-4o" - assert identifier["__type__"] == "OpenAIChatTarget" + assert identifier.model_name == "gpt-4o" + assert identifier.class_name == "OpenAIChatTarget" def test_get_identifier_uses_underlying_model_from_env_var(patch_central_database): @@ -1111,7 +1113,7 @@ def test_get_identifier_uses_underlying_model_from_env_var(patch_central_databas identifier = target.get_identifier() - assert identifier["model_name"] == "gpt-4o" + assert identifier.model_name == "gpt-4o" def test_underlying_model_param_takes_precedence_over_env_var(patch_central_database): @@ -1126,7 +1128,7 @@ def test_underlying_model_param_takes_precedence_over_env_var(patch_central_data identifier = target.get_identifier() - assert identifier["model_name"] == "gpt-4o-from-param" + assert identifier.model_name == "gpt-4o-from-param" def test_get_identifier_includes_endpoint(patch_central_database): @@ -1139,7 +1141,7 @@ def test_get_identifier_includes_endpoint(patch_central_database): identifier = target.get_identifier() - assert identifier["endpoint"] == "https://mock.azure.com/" + assert identifier.endpoint == "https://mock.azure.com/" def test_get_identifier_includes_temperature_when_set(patch_central_database): @@ -1153,7 +1155,7 @@ def test_get_identifier_includes_temperature_when_set(patch_central_database): identifier = target.get_identifier() - assert identifier["temperature"] == 0.7 + assert identifier.temperature == 0.7 def test_get_identifier_includes_top_p_when_set(patch_central_database): @@ -1167,7 +1169,7 @@ def test_get_identifier_includes_top_p_when_set(patch_central_database): identifier = target.get_identifier() - assert identifier["top_p"] == 0.9 + assert identifier.top_p == 0.9 # ============================================================================