diff --git a/.env_example b/.env_example index 2d63d66913..281b3db223 100644 --- a/.env_example +++ b/.env_example @@ -19,27 +19,45 @@ PLATFORM_OPENAI_CHAT_GPT4O_MODEL="gpt-4o" AZURE_OPENAI_GPT4O_ENDPOINT="https://xxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4O_KEY="xxxxx" AZURE_OPENAI_GPT4O_MODEL="deployment-name" -# Since deployment name may be custom and differ from the actual underlying model, -# you can specify the underlying model for identifier purposes +# Since Azure deployment name may be custom and differ from the actual underlying model, +# you can specify the underlying model for identifier purposes. If not specified, +# identifiers will default to the value of the standard MODEL environment variable. AZURE_OPENAI_GPT4O_UNDERLYING_MODEL="gpt-4o" AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_INTEGRATION_TEST_KEY="xxxxx" AZURE_OPENAI_INTEGRATION_TEST_MODEL="deployment-name" +AZURE_OPENAI_INTEGRATION_TEST_UNDERLYING_MODEL="" AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT3_5_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT3_5_CHAT_MODEL="deployment-name" +AZURE_OPENAI_GPT3_5_CHAT_UNDERLYING_MODEL="" AZURE_OPENAI_GPT4_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_GPT4_CHAT_KEY="xxxxx" AZURE_OPENAI_GPT4_CHAT_MODEL="deployment-name" +AZURE_OPENAI_GPT4_CHAT_UNDERLYING_MODEL="" + +# Endpoints that host models with fewer safety mechanisms (e.g. via adversarial fine tuning +# or content filters turned off) can be defined below and used in adversarial attack testing scenarios. +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY="xxxxx" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL="deployment-name" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL="" + +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2="https://xxxxx.openai.azure.com/openai/v1" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2="xxxxx" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2="deployment-name" +AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2="" AZURE_FOUNDRY_DEEPSEEK_ENDPOINT="https://xxxxx.eastus2.models.ai.azure.com" AZURE_FOUNDRY_DEEPSEEK_KEY="xxxxx" +AZURE_FOUNDRY_DEEPSEEK_MODEL="" AZURE_FOUNDRY_PHI4_ENDPOINT="https://xxxxx.models.ai.azure.com" AZURE_CHAT_PHI4_KEY="xxxxx" +AZURE_FOUNDRY_PHI4_MODEL="" AZURE_FOUNDRY_MISTRAL_LARGE_ENDPOINT="https://xxxxx.services.ai.azure.com/openai/v1/" AZURE_FOUNDRY_MISTRAL_LARGE_KEY="xxxxx" @@ -75,6 +93,7 @@ AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT="https://xxxxxxxxx.azure.com/openai/v1" AZURE_OPENAI_GPT5_COMPLETION_ENDPOINT="https://xxxxxxxxx.azure.com/openai/v1" AZURE_OPENAI_GPT5_KEY="xxxxxxx" AZURE_OPENAI_GPT5_MODEL="gpt-5" +AZURE_OPENAI_GPT5_UNDERLYING_MODEL="gpt-5" PLATFORM_OPENAI_RESPONSES_ENDPOINT="https://api.openai.com/v1" PLATFORM_OPENAI_RESPONSES_KEY="sk-xxxxx" @@ -83,6 +102,7 @@ PLATFORM_OPENAI_RESPONSES_MODEL="o4-mini" AZURE_OPENAI_RESPONSES_ENDPOINT="https://xxxxx.openai.azure.com/openai/v1" AZURE_OPENAI_RESPONSES_KEY="xxxxx" AZURE_OPENAI_RESPONSES_MODEL="o4-mini" +AZURE_OPENAI_RESPONSES_UNDERLYING_MODEL="o4-mini" OPENAI_RESPONSES_ENDPOINT=${PLATFORM_OPENAI_RESPONSES_ENDPOINT} OPENAI_RESPONSES_KEY=${PLATFORM_OPENAI_RESPONSES_KEY} @@ -103,6 +123,7 @@ PLATFORM_OPENAI_REALTIME_MODEL="gpt-4o-realtime-preview" AZURE_OPENAI_REALTIME_ENDPOINT = "wss://xxxx.openai.azure.com/openai/v1" AZURE_OPENAI_REALTIME_API_KEY = "xxxxx" AZURE_OPENAI_REALTIME_MODEL = "gpt-4o-realtime-preview" +AZURE_OPENAI_REALTIME_UNDERLYING_MODEL = "gpt-4o-realtime-preview" OPENAI_REALTIME_ENDPOINT = ${PLATFORM_OPENAI_REALTIME_ENDPOINT} OPENAI_REALTIME_API_KEY = ${PLATFORM_OPENAI_REALTIME_API_KEY} @@ -119,10 +140,12 @@ OPENAI_REALTIME_UNDERLYING_MODEL = "" OPENAI_IMAGE_ENDPOINT1 = "https://xxxxx.openai.azure.com/openai/v1" OPENAI_IMAGE_API_KEY1 = "xxxxxx" OPENAI_IMAGE_MODEL1 = "deployment-name" +OPENAI_IMAGE_UNDERLYING_MODEL1 = "dall-e-3" OPENAI_IMAGE_ENDPOINT2 = "https://api.openai.com/v1" OPENAI_IMAGE_API_KEY2 = "sk-xxxxx" OPENAI_IMAGE_MODEL2 = "dall-e-3" +OPENAI_IMAGE_UNDERLYING_MODEL2 = "dall-e-3" OPENAI_IMAGE_ENDPOINT = ${OPENAI_IMAGE_ENDPOINT2} OPENAI_IMAGE_API_KEY = ${OPENAI_IMAGE_API_KEY2} @@ -140,10 +163,12 @@ OPENAI_IMAGE_UNDERLYING_MODEL = "" OPENAI_TTS_ENDPOINT1 = "https://xxxxx.openai.azure.com/openai/v1" OPENAI_TTS_KEY1 = "xxxxxxx" OPENAI_TTS_MODEL1 = "tts" +OPENAI_TTS_UNDERLYING_MODEL1 = "tts" OPENAI_TTS_ENDPOINT2 = "https://api.openai.com/v1" OPENAI_TTS_KEY2 = "xxxxxx" OPENAI_TTS_MODEL2 = "tts-1" +OPENAI_TTS_UNDERLYING_MODEL2 = "tts-1" OPENAI_TTS_ENDPOINT = ${OPENAI_TTS_ENDPOINT2} OPENAI_TTS_KEY = ${OPENAI_TTS_KEY2} @@ -161,6 +186,7 @@ OPENAI_TTS_UNDERLYING_MODEL = "" AZURE_OPENAI_VIDEO_ENDPOINT="https://xxxxx.cognitiveservices.azure.com/openai/v1" AZURE_OPENAI_VIDEO_KEY="xxxxxxx" AZURE_OPENAI_VIDEO_MODEL="sora-2" +AZURE_OPENAI_VIDEO_UNDERLYING_MODEL="sora-2" OPENAI_VIDEO_ENDPOINT = ${AZURE_OPENAI_VIDEO_ENDPOINT} OPENAI_VIDEO_KEY = ${AZURE_OPENAI_VIDEO_KEY} diff --git a/doc/api.rst b/doc/api.rst index c774deca08..99273143fb 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -703,6 +703,7 @@ API Reference PyRITInitializer AIRTInitializer + AIRTTargetInitializer SimpleInitializer LoadDefaultDatasets ScenarioObjectiveListInitializer diff --git a/doc/code/registry/2_instance_registry.ipynb b/doc/code/registry/2_instance_registry.ipynb index 24a8b1bb68..52ce374054 100644 --- a/doc/code/registry/2_instance_registry.ipynb +++ b/doc/code/registry/2_instance_registry.ipynb @@ -35,10 +35,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Found default environment files: ['C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env', 'C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env.local']\n", - "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env\n", - "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env.local\n", - "Registered scorers: ['self_ask_refusal_d9007ba2']\n" + "Found default environment files: ['C:\\\\Users\\\\songjustin\\\\.pyrit\\\\.env']\n", + "Loaded environment file: C:\\Users\\songjustin\\.pyrit\\.env\n", + "Registered scorers: ['self_ask_refusal_scorer::94a582f5']\n" ] } ], @@ -83,7 +82,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Retrieved scorer: \n", + "Retrieved scorer: \n", "Scorer type: SelfAskRefusalScorer\n" ] } @@ -118,7 +117,7 @@ "output_type": "stream", "text": [ "\n", - "self_ask_refusal_d9007ba2:\n", + "self_ask_refusal_scorer::94a582f5:\n", " Class: SelfAskRefusalScorer\n", " Type: true_false\n", " Description: A self-ask scorer that detects refusal in AI responses. This...\n", @@ -126,7 +125,7 @@ "\u001b[1m 📊 Scorer Information\u001b[0m\n", "\u001b[37m ▸ Scorer Identifier\u001b[0m\n", "\u001b[36m • Scorer Type: SelfAskRefusalScorer\u001b[0m\n", - "\u001b[36m • Target Model: gpt-40\u001b[0m\n", + "\u001b[36m • Target Model: gpt-4o\u001b[0m\n", "\u001b[36m • Temperature: None\u001b[0m\n", "\u001b[36m • Score Aggregator: OR_\u001b[0m\n", "\n", @@ -141,12 +140,12 @@ "# Get metadata for all registered scorers\n", "metadata = registry.list_metadata()\n", "for item in metadata:\n", - " print(f\"\\n{item.name}:\")\n", + " print(f\"\\n{item.unique_name}:\")\n", " print(f\" Class: {item.class_name}\")\n", " print(f\" Type: {item.scorer_type}\")\n", - " print(f\" Description: {item.description[:60]}...\")\n", + " print(f\" Description: {item.class_description[:60]}...\")\n", "\n", - " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier)" + " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item)" ] }, { @@ -169,26 +168,69 @@ "name": "stdout", "output_type": "stream", "text": [ - "True/False scorers: ['self_ask_refusal_d9007ba2']\n", - "Refusal scorers: ['self_ask_refusal_d9007ba2']\n", - "True/False refusal scorers: ['self_ask_refusal_d9007ba2']\n" + "True/False scorers: ['self_ask_refusal_scorer::94a582f5']\n", + "Refusal scorers: ['self_ask_refusal_scorer::94a582f5']\n", + "True/False refusal scorers: ['self_ask_refusal_scorer::94a582f5']\n" ] } ], "source": [ "# Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer)\n", "true_false_scorers = registry.list_metadata(include_filters={\"scorer_type\": \"true_false\"})\n", - "print(f\"True/False scorers: {[m.name for m in true_false_scorers]}\")\n", + "print(f\"True/False scorers: {[m.unique_name for m in true_false_scorers]}\")\n", "\n", "# Filter by class_name\n", "refusal_scorers = registry.list_metadata(include_filters={\"class_name\": \"SelfAskRefusalScorer\"})\n", - "print(f\"Refusal scorers: {[m.name for m in refusal_scorers]}\")\n", + "print(f\"Refusal scorers: {[m.unique_name for m in refusal_scorers]}\")\n", "\n", "# Combine multiple filters (AND logic)\n", "specific_scorers = registry.list_metadata(\n", " include_filters={\"scorer_type\": \"true_false\", \"class_name\": \"SelfAskRefusalScorer\"}\n", ")\n", - "print(f\"True/False refusal scorers: {[m.name for m in specific_scorers]}\")" + "print(f\"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Using Target Initializer\n", + "\n", + "You can optionally use the `AIRTTargetInitializer` to automatically configure and register targets that use commonly used environment variables (from `.env_example`). This initializer does not strictly require any environment variables - it simply registers whatever endpoints are available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['C:\\\\Users\\\\songjustin\\\\.pyrit\\\\.env']\n", + "Loaded environment file: C:\\Users\\songjustin\\.pyrit\\.env\n", + "Registered targets after AIRT initialization: ['azure_content_safety', 'azure_gpt4o_unsafe_chat', 'azure_gpt4o_unsafe_chat2', 'default_openai_frontend', 'openai_chat', 'openai_image', 'openai_realtime', 'openai_responses', 'openai_tts', 'openai_video']\n" + ] + } + ], + "source": [ + "from pyrit.registry import TargetRegistry\n", + "from pyrit.setup import initialize_pyrit_async\n", + "from pyrit.setup.initializers import AIRTTargetInitializer\n", + "\n", + "# Using built-in initializer\n", + "await initialize_pyrit_async( # type: ignore\n", + " memory_db_type=\"InMemory\", initializers=[AIRTTargetInitializer()]\n", + ")\n", + "\n", + "# Get the registry singleton\n", + "registry = TargetRegistry.get_registry_singleton()\n", + "# List registered targets\n", + "target_names = registry.get_names()\n", + "print(f\"Registered targets after AIRT initialization: {target_names}\")" ] } ], @@ -203,7 +245,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/doc/code/registry/2_instance_registry.py b/doc/code/registry/2_instance_registry.py index c20755730c..d645529f25 100644 --- a/doc/code/registry/2_instance_registry.py +++ b/doc/code/registry/2_instance_registry.py @@ -5,11 +5,15 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.18.1 +# jupytext_version: 1.17.2 +# kernelspec: +# display_name: pyrit-dev +# language: python +# name: python3 # --- # %% [markdown] -# ## Why Instance Registries? +# # Why Instance Registries? # # Some components need configuration that can't easily be passed at instantiation time. For example, scorers often need: # - A configured `chat_target` for LLM-based scoring @@ -19,7 +23,7 @@ # Instance registries let initializers register fully-configured instances that are ready to use. # %% [markdown] -# # Listing Available Instances +# ## Listing Available Instances # # Use `get_names()` to see registered instances, or `list_metadata()` for details. @@ -67,12 +71,12 @@ # Get metadata for all registered scorers metadata = registry.list_metadata() for item in metadata: - print(f"\n{item.name}:") + print(f"\n{item.unique_name}:") print(f" Class: {item.class_name}") print(f" Type: {item.scorer_type}") - print(f" Description: {item.description[:60]}...") + print(f" Description: {item.class_description[:60]}...") - ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier) + ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item) # %% [markdown] # ## Filtering @@ -82,14 +86,35 @@ # %% # Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer) true_false_scorers = registry.list_metadata(include_filters={"scorer_type": "true_false"}) -print(f"True/False scorers: {[m.name for m in true_false_scorers]}") +print(f"True/False scorers: {[m.unique_name for m in true_false_scorers]}") # Filter by class_name refusal_scorers = registry.list_metadata(include_filters={"class_name": "SelfAskRefusalScorer"}) -print(f"Refusal scorers: {[m.name for m in refusal_scorers]}") +print(f"Refusal scorers: {[m.unique_name for m in refusal_scorers]}") # Combine multiple filters (AND logic) specific_scorers = registry.list_metadata( include_filters={"scorer_type": "true_false", "class_name": "SelfAskRefusalScorer"} ) -print(f"True/False refusal scorers: {[m.name for m in specific_scorers]}") +print(f"True/False refusal scorers: {[m.unique_name for m in specific_scorers]}") + +# %% [markdown] +# ## Using Target Initializer +# +# You can optionally use the `AIRTTargetInitializer` to automatically configure and register targets that use commonly used environment variables (from `.env_example`). This initializer does not strictly require any environment variables - it simply registers whatever endpoints are available. + +# %% +from pyrit.registry import TargetRegistry +from pyrit.setup import initialize_pyrit_async +from pyrit.setup.initializers import AIRTTargetInitializer + +# Using built-in initializer +await initialize_pyrit_async( # type: ignore + memory_db_type="InMemory", initializers=[AIRTTargetInitializer()] +) + +# Get the registry singleton +registry = TargetRegistry.get_registry_singleton() +# List registered targets +target_names = registry.get_names() +print(f"Registered targets after AIRT initialization: {target_names}") diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index f08ad8709d..b8924fb0c0 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -34,6 +34,9 @@ class TargetIdentifier(Identifier): max_requests_per_minute: Optional[int] = None """Maximum number of requests per minute.""" + supports_conversation_history: bool = False + """Whether the target supports explicit setting of conversation history (is a PromptChatTarget).""" + target_specific_params: Optional[Dict[str, Any]] = None """Additional target-specific parameters.""" diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 8cd80f47d4..653d008e65 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -122,6 +122,9 @@ def _create_identifier( elif self._model_name: model_name = self._model_name + # Late import to avoid circular dependency + from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget + return TargetIdentifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, @@ -132,6 +135,7 @@ def _create_identifier( temperature=temperature, top_p=top_p, max_requests_per_minute=self._max_requests_per_minute, + supports_conversation_history=isinstance(self, PromptChatTarget), target_specific_params=target_specific_params, ) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 209ec6c146..5f2fe7536f 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -21,6 +21,7 @@ from pyrit.registry.instance_registries import ( BaseInstanceRegistry, ScorerRegistry, + TargetRegistry, ) __all__ = [ @@ -39,4 +40,5 @@ "ScenarioMetadata", "ScenarioRegistry", "ScorerRegistry", + "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py index eab870f0e1..2cf50693cf 100644 --- a/pyrit/registry/instance_registries/__init__.py +++ b/pyrit/registry/instance_registries/__init__.py @@ -17,10 +17,14 @@ from pyrit.registry.instance_registries.scorer_registry import ( ScorerRegistry, ) +from pyrit.registry.instance_registries.target_registry import ( + TargetRegistry, +) __all__ = [ # Base class "BaseInstanceRegistry", # Concrete registries "ScorerRegistry", + "TargetRegistry", ] diff --git a/pyrit/registry/instance_registries/target_registry.py b/pyrit/registry/instance_registries/target_registry.py new file mode 100644 index 0000000000..3fcdbb3160 --- /dev/null +++ b/pyrit/registry/instance_registries/target_registry.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Target registry for discovering and managing PyRIT prompt targets. + +Targets are registered explicitly via initializers as pre-configured instances. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional + +from pyrit.identifiers import TargetIdentifier +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget + +logger = logging.getLogger(__name__) + + +class TargetRegistry(BaseInstanceRegistry["PromptTarget", TargetIdentifier]): + """ + Registry for managing available prompt target instances. + + This registry stores pre-configured PromptTarget instances (not classes). + Targets are registered explicitly via initializers after being instantiated + with their required parameters (e.g., endpoint, API keys). + + Targets are identified by their snake_case name derived from the class name, + or a custom name provided during registration. + """ + + @classmethod + def get_registry_singleton(cls) -> "TargetRegistry": + """ + Get the singleton instance of the TargetRegistry. + + Returns: + The singleton TargetRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_instance( + self, + target: "PromptTarget", + *, + name: Optional[str] = None, + ) -> None: + """ + Register a target instance. + + Note: Unlike ScenarioRegistry and InitializerRegistry which register classes, + TargetRegistry registers pre-configured instances. + + Args: + target: The pre-configured target instance (not a class). + name: Optional custom registry name. If not provided, + derived from class name with identifier hash appended + (e.g., OpenAIChatTarget -> openai_chat_abc123). + """ + if name is None: + name = target.get_identifier().unique_name + + self.register(target, name=name) + logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") + + def get_instance_by_name(self, name: str) -> Optional["PromptTarget"]: + """ + Get a registered target instance by name. + + Note: This returns an already-instantiated target, not a class. + + Args: + name: The registry name of the target. + + Returns: + The target instance, or None if not found. + """ + return self.get(name) + + def _build_metadata(self, name: str, instance: "PromptTarget") -> TargetIdentifier: + """ + Build metadata for a target instance. + + Args: + name: The registry name of the target. + instance: The target instance. + + Returns: + TargetIdentifier describing the target. + """ + return instance.get_identifier() diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index 1c0cbd4683..6b1c63c484 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -4,6 +4,7 @@ """PyRIT initializers package.""" from pyrit.setup.initializers.airt import AIRTInitializer +from pyrit.setup.initializers.airt_targets import AIRTTargetInitializer from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets from pyrit.setup.initializers.scenarios.objective_list import ScenarioObjectiveListInitializer @@ -13,6 +14,7 @@ __all__ = [ "PyRITInitializer", "AIRTInitializer", + "AIRTTargetInitializer", "SimpleInitializer", "LoadDefaultDatasets", "ScenarioObjectiveListInitializer", diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py new file mode 100644 index 0000000000..f421c53c6e --- /dev/null +++ b/pyrit/setup/initializers/airt_targets.py @@ -0,0 +1,422 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +AIRT Target Initializer for registering pre-configured targets from environment variables. + +This module provides the AIRTTargetInitializer class that registers available +targets into the TargetRegistry based on environment variable configuration. + +Note: This module only includes PRIMARY endpoint configurations from .env_example. + Alias configurations (those using ${...} syntax) are excluded since they + reference other primary configurations. +""" + +import logging +import os +from dataclasses import dataclass +from typing import Any, List, Optional, Type + +from pyrit.prompt_target import ( + AzureMLChatTarget, + OpenAIChatTarget, + OpenAICompletionTarget, + OpenAIImageTarget, + OpenAIResponseTarget, + OpenAITTSTarget, + OpenAIVideoTarget, + PromptShieldTarget, + PromptTarget, + RealtimeTarget, +) +from pyrit.registry import TargetRegistry +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +logger = logging.getLogger(__name__) + + +@dataclass +class TargetConfig: + """Configuration for a target to be registered.""" + + registry_name: str + target_class: Type[PromptTarget] + endpoint_var: str + key_var: str = "" # Empty string means no auth required + model_var: Optional[str] = None + underlying_model_var: Optional[str] = None + + +# Define all supported target configurations. +# Only PRIMARY configurations are included here - alias configurations that use ${...} +# syntax in .env_example are excluded since they reference other primary configurations. +TARGET_CONFIGS: List[TargetConfig] = [ + # ============================================ + # OpenAI Chat Targets (OpenAIChatTarget) + # ============================================ + TargetConfig( + registry_name="platform_openai_chat", + target_class=OpenAIChatTarget, + endpoint_var="PLATFORM_OPENAI_CHAT_ENDPOINT", + key_var="PLATFORM_OPENAI_CHAT_API_KEY", + model_var="PLATFORM_OPENAI_CHAT_GPT4O_MODEL", + ), + TargetConfig( + registry_name="azure_openai_gpt4o", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_ENDPOINT", + key_var="AZURE_OPENAI_GPT4O_KEY", + model_var="AZURE_OPENAI_GPT4O_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4O_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_openai_integration_test", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_INTEGRATION_TEST_ENDPOINT", + key_var="AZURE_OPENAI_INTEGRATION_TEST_KEY", + model_var="AZURE_OPENAI_INTEGRATION_TEST_MODEL", + underlying_model_var="AZURE_OPENAI_INTEGRATION_TEST_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_openai_gpt35_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT3_5_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT3_5_CHAT_KEY", + model_var="AZURE_OPENAI_GPT3_5_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT3_5_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_openai_gpt4_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT4_CHAT_KEY", + model_var="AZURE_OPENAI_GPT4_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_gpt4o_unsafe_chat", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", + key_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY", + model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL", + underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="azure_gpt4o_unsafe_chat2", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", + key_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2", + model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", + underlying_model_var="AZURE_OPENAI_GPT4O_UNSAFE_CHAT_UNDERLYING_MODEL2", + ), + TargetConfig( + registry_name="azure_foundry_deepseek", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_FOUNDRY_DEEPSEEK_ENDPOINT", + key_var="AZURE_FOUNDRY_DEEPSEEK_KEY", + model_var="AZURE_FOUNDRY_DEEPSEEK_MODEL", + ), + TargetConfig( + registry_name="azure_foundry_phi4", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_FOUNDRY_PHI4_ENDPOINT", + key_var="AZURE_CHAT_PHI4_KEY", + model_var="AZURE_FOUNDRY_PHI4_MODEL", + ), + TargetConfig( + registry_name="azure_foundry_mistral_large", + target_class=OpenAIChatTarget, + endpoint_var="AZURE_FOUNDRY_MISTRAL_LARGE_ENDPOINT", + key_var="AZURE_FOUNDRY_MISTRAL_LARGE_KEY", + model_var="AZURE_FOUNDRY_MISTRAL_LARGE_MODEL", + ), + TargetConfig( + registry_name="groq", + target_class=OpenAIChatTarget, + endpoint_var="GROQ_ENDPOINT", + key_var="GROQ_KEY", + model_var="GROQ_LLAMA_MODEL", + ), + TargetConfig( + registry_name="open_router", + target_class=OpenAIChatTarget, + endpoint_var="OPEN_ROUTER_ENDPOINT", + key_var="OPEN_ROUTER_KEY", + model_var="OPEN_ROUTER_CLAUDE_MODEL", + ), + TargetConfig( + registry_name="ollama", + target_class=OpenAIChatTarget, + endpoint_var="OLLAMA_CHAT_ENDPOINT", + model_var="OLLAMA_MODEL", + ), + TargetConfig( + registry_name="google_gemini", + target_class=OpenAIChatTarget, + endpoint_var="GOOGLE_GEMINI_ENDPOINT", + key_var="GOOGLE_GEMINI_API_KEY", + model_var="GOOGLE_GEMINI_MODEL", + ), + # ============================================ + # OpenAI Responses Targets (OpenAIResponseTarget) + # ============================================ + TargetConfig( + registry_name="azure_openai_gpt5_responses", + target_class=OpenAIResponseTarget, + endpoint_var="AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT", + key_var="AZURE_OPENAI_GPT5_KEY", + model_var="AZURE_OPENAI_GPT5_MODEL", + underlying_model_var="AZURE_OPENAI_GPT5_UNDERLYING_MODEL", + ), + TargetConfig( + registry_name="platform_openai_responses", + target_class=OpenAIResponseTarget, + endpoint_var="PLATFORM_OPENAI_RESPONSES_ENDPOINT", + key_var="PLATFORM_OPENAI_RESPONSES_KEY", + model_var="PLATFORM_OPENAI_RESPONSES_MODEL", + ), + TargetConfig( + registry_name="azure_openai_responses", + target_class=OpenAIResponseTarget, + endpoint_var="AZURE_OPENAI_RESPONSES_ENDPOINT", + key_var="AZURE_OPENAI_RESPONSES_KEY", + model_var="AZURE_OPENAI_RESPONSES_MODEL", + underlying_model_var="AZURE_OPENAI_RESPONSES_UNDERLYING_MODEL", + ), + # ============================================ + # Realtime Targets (RealtimeTarget) + # ============================================ + TargetConfig( + registry_name="platform_openai_realtime", + target_class=RealtimeTarget, + endpoint_var="PLATFORM_OPENAI_REALTIME_ENDPOINT", + key_var="PLATFORM_OPENAI_REALTIME_API_KEY", + model_var="PLATFORM_OPENAI_REALTIME_MODEL", + ), + TargetConfig( + registry_name="azure_openai_realtime", + target_class=RealtimeTarget, + endpoint_var="AZURE_OPENAI_REALTIME_ENDPOINT", + key_var="AZURE_OPENAI_REALTIME_API_KEY", + model_var="AZURE_OPENAI_REALTIME_MODEL", + underlying_model_var="AZURE_OPENAI_REALTIME_UNDERLYING_MODEL", + ), + # ============================================ + # Image Targets (OpenAIImageTarget) + # ============================================ + TargetConfig( + registry_name="openai_image_azure", + target_class=OpenAIImageTarget, + endpoint_var="OPENAI_IMAGE_ENDPOINT1", + key_var="OPENAI_IMAGE_API_KEY1", + model_var="OPENAI_IMAGE_MODEL1", + underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL1", + ), + TargetConfig( + registry_name="openai_image_platform", + target_class=OpenAIImageTarget, + endpoint_var="OPENAI_IMAGE_ENDPOINT2", + key_var="OPENAI_IMAGE_API_KEY2", + model_var="OPENAI_IMAGE_MODEL2", + underlying_model_var="OPENAI_IMAGE_UNDERLYING_MODEL2", + ), + # ============================================ + # TTS Targets (OpenAITTSTarget) + # ============================================ + TargetConfig( + registry_name="openai_tts_azure", + target_class=OpenAITTSTarget, + endpoint_var="OPENAI_TTS_ENDPOINT1", + key_var="OPENAI_TTS_KEY1", + model_var="OPENAI_TTS_MODEL1", + underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL1", + ), + TargetConfig( + registry_name="openai_tts_platform", + target_class=OpenAITTSTarget, + endpoint_var="OPENAI_TTS_ENDPOINT2", + key_var="OPENAI_TTS_KEY2", + model_var="OPENAI_TTS_MODEL2", + underlying_model_var="OPENAI_TTS_UNDERLYING_MODEL2", + ), + # ============================================ + # Video Targets (OpenAIVideoTarget) + # ============================================ + TargetConfig( + registry_name="azure_openai_video", + target_class=OpenAIVideoTarget, + endpoint_var="AZURE_OPENAI_VIDEO_ENDPOINT", + key_var="AZURE_OPENAI_VIDEO_KEY", + model_var="AZURE_OPENAI_VIDEO_MODEL", + underlying_model_var="AZURE_OPENAI_VIDEO_UNDERLYING_MODEL", + ), + # ============================================ + # Completion Targets (OpenAICompletionTarget) + # ============================================ + TargetConfig( + registry_name="openai_completion", + target_class=OpenAICompletionTarget, + endpoint_var="OPENAI_COMPLETION_ENDPOINT", + key_var="OPENAI_COMPLETION_API_KEY", + model_var="OPENAI_COMPLETION_MODEL", + ), + # ============================================ + # Azure ML Targets (AzureMLChatTarget) + # ============================================ + TargetConfig( + registry_name="azure_ml_phi", + target_class=AzureMLChatTarget, + endpoint_var="AZURE_ML_PHI_ENDPOINT", + key_var="AZURE_ML_PHI_KEY", + ), + # ============================================ + # Safety Targets (PromptShieldTarget) + # ============================================ + TargetConfig( + registry_name="azure_content_safety", + target_class=PromptShieldTarget, + endpoint_var="AZURE_CONTENT_SAFETY_API_ENDPOINT", + key_var="AZURE_CONTENT_SAFETY_API_KEY", + ), +] + + +class AIRTTargetInitializer(PyRITInitializer): + """ + AIRT Target Initializer for registering pre-configured targets. + + This initializer scans for known endpoint environment variables and registers + the corresponding targets into the TargetRegistry. It only includes PRIMARY + endpoint configurations - alias configurations (those using ${...} syntax in + .env_example) are excluded since they reference other primary configurations. + + Supported Endpoints by Category: + + **OpenAI Chat Targets (OpenAIChatTarget):** + - PLATFORM_OPENAI_CHAT_* - Platform OpenAI Chat API + - AZURE_OPENAI_GPT4O_* - Azure OpenAI GPT-4o + - AZURE_OPENAI_INTEGRATION_TEST_* - Integration test endpoint + - AZURE_OPENAI_GPT3_5_CHAT_* - Azure OpenAI GPT-3.5 + - AZURE_OPENAI_GPT4_CHAT_* - Azure OpenAI GPT-4 + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_* - Azure OpenAI GPT-4o unsafe + - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_*2 - Azure OpenAI GPT-4o unsafe secondary + - AZURE_FOUNDRY_DEEPSEEK_* - Azure AI Foundry DeepSeek + - AZURE_FOUNDRY_PHI4_* - Azure AI Foundry Phi-4 + - AZURE_FOUNDRY_MISTRAL_LARGE_* - Azure AI Foundry Mistral Large + - GROQ_* - Groq API + - OPEN_ROUTER_* - OpenRouter API + - OLLAMA_* - Ollama local + - GOOGLE_GEMINI_* - Google Gemini (OpenAI-compatible) + + **OpenAI Responses Targets (OpenAIResponseTarget):** + - AZURE_OPENAI_GPT5_RESPONSES_* - Azure OpenAI GPT-5 Responses + - PLATFORM_OPENAI_RESPONSES_* - Platform OpenAI Responses + - AZURE_OPENAI_RESPONSES_* - Azure OpenAI Responses + + **Realtime Targets (RealtimeTarget):** + - PLATFORM_OPENAI_REALTIME_* - Platform OpenAI Realtime + - AZURE_OPENAI_REALTIME_* - Azure OpenAI Realtime + + **Image Targets (OpenAIImageTarget):** + - OPENAI_IMAGE_*1 - Azure OpenAI Image + - OPENAI_IMAGE_*2 - Platform OpenAI Image + + **TTS Targets (OpenAITTSTarget):** + - OPENAI_TTS_*1 - Azure OpenAI TTS + - OPENAI_TTS_*2 - Platform OpenAI TTS + + **Video Targets (OpenAIVideoTarget):** + - AZURE_OPENAI_VIDEO_* - Azure OpenAI Video + + **Completion Targets (OpenAICompletionTarget):** + - OPENAI_COMPLETION_* - OpenAI Completion + + **Azure ML Targets (AzureMLChatTarget):** + - AZURE_ML_PHI_* - Azure ML Phi + + **Safety Targets (PromptShieldTarget):** + - AZURE_CONTENT_SAFETY_* - Azure Content Safety + + Example: + initializer = AIRTTargetInitializer() + await initializer.initialize_async() + """ + + def __init__(self) -> None: + """Initialize the AIRT Target Initializer.""" + super().__init__() + + @property + def name(self) -> str: + """Get the name of this initializer.""" + return "AIRT Target Initializer" + + @property + def description(self) -> str: + """Get the description of this initializer.""" + return ( + "Instantiates a collection of (AI Red Team suggested) targets from " + "available environment variables and adds them to the TargetRegistry" + ) + + @property + def required_env_vars(self) -> List[str]: + """ + Get list of required environment variables. + + Returns empty list since this initializer is optional - it registers + whatever endpoints are available without requiring any. + """ + return [] + + async def initialize_async(self) -> None: + """ + Register available targets based on environment variables. + + Scans for known endpoint environment variables and registers the + corresponding targets into the TargetRegistry. + """ + for config in TARGET_CONFIGS: + self._register_target(config) + + def _register_target(self, config: TargetConfig) -> None: + """ + Register a target if its required environment variables are set. + + Args: + config: The target configuration specifying env vars and target class. + """ + endpoint = os.getenv(config.endpoint_var) + if not endpoint: + return + + # If key_var is empty, use placeholder (for targets like Ollama that don't require auth) + # If key_var is set, look up the env var and skip registration if not found + if config.key_var: + api_key = os.getenv(config.key_var) + if not api_key: + return + else: + api_key = "not-needed" + + model_name = os.getenv(config.model_var) if config.model_var else None + underlying_model = os.getenv(config.underlying_model_var) if config.underlying_model_var else None + + # Build kwargs for the target constructor + kwargs: dict[str, Any] = { + "endpoint": endpoint, + "api_key": api_key, + } + + # Only add model_name if the target supports it (PromptShieldTarget doesn't) + if model_name is not None: + kwargs["model_name"] = model_name + + # Add underlying_model if specified (for Azure deployments where name differs from model) + if underlying_model is not None: + kwargs["underlying_model"] = underlying_model + + target = config.target_class(**kwargs) + registry = TargetRegistry.get_registry_singleton() + registry.register_instance(target, name=config.registry_name) + logger.info(f"Registered target: {config.registry_name}") diff --git a/tests/unit/identifiers/test_target_identifier.py b/tests/unit/identifiers/test_target_identifier.py index 0541b36be5..148c60983d 100644 --- a/tests/unit/identifiers/test_target_identifier.py +++ b/tests/unit/identifiers/test_target_identifier.py @@ -500,6 +500,63 @@ def test_can_use_as_dict_key(self): assert d[identifier] == "value" +class TestTargetIdentifierSupportsConversationHistory: + """Test the supports_conversation_history field in TargetIdentifier.""" + + def test_supports_conversation_history_defaults_to_false(self): + """Test that supports_conversation_history defaults to False.""" + identifier = TargetIdentifier( + class_name="TestTarget", + class_module="pyrit.prompt_target.test_target", + class_description="A test target", + identifier_type="instance", + ) + + assert identifier.supports_conversation_history is False + + def test_supports_conversation_history_included_in_hash(self): + """Test that supports_conversation_history affects the hash.""" + base_args = { + "class_name": "TestTarget", + "class_module": "pyrit.prompt_target.test_target", + "class_description": "A test target", + "identifier_type": "instance", + } + + identifier1 = TargetIdentifier(supports_conversation_history=False, **base_args) + identifier2 = TargetIdentifier(supports_conversation_history=True, **base_args) + + assert identifier1.hash != identifier2.hash + + def test_supports_conversation_history_in_to_dict(self): + """Test that supports_conversation_history is included in to_dict.""" + identifier = TargetIdentifier( + class_name="TestChatTarget", + class_module="pyrit.prompt_target.test_chat_target", + class_description="A test chat target", + identifier_type="instance", + supports_conversation_history=True, + ) + + result = identifier.to_dict() + + assert result["supports_conversation_history"] is True + + def test_supports_conversation_history_from_dict(self): + """Test that supports_conversation_history is restored from dict.""" + data = { + "class_name": "TestChatTarget", + "class_module": "pyrit.prompt_target.test_chat_target", + "class_description": "A test chat target", + "identifier_type": "instance", + "supports_conversation_history": True, + } + + identifier = TargetIdentifier.from_dict(data) + + assert identifier.supports_conversation_history is True + + class TestTargetIdentifierNormalize: """Test the normalize class method for TargetIdentifier.""" diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py new file mode 100644 index 0000000000..8e32411b89 --- /dev/null +++ b/tests/unit/registry/test_target_registry.py @@ -0,0 +1,277 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import pytest + +from pyrit.identifiers import TargetIdentifier +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.registry.instance_registries.target_registry import TargetRegistry + + +class MockPromptTarget(PromptTarget): + """Mock PromptTarget for testing.""" + + def __init__(self, *, model_name: str = "mock_model") -> None: + super().__init__(model_name=model_name) + + async def send_prompt_async( + self, + *, + message: Message, + ) -> list[Message]: + return [ + MessagePiece( + role="assistant", + original_value="mock response", + ).to_message() + ] + + def _validate_request(self, *, message: Message) -> None: + pass + + +class MockPromptChatTarget(PromptChatTarget): + """Mock PromptChatTarget for testing conversation history support.""" + + def __init__(self, *, model_name: str = "mock_chat_model", endpoint: str = "http://chat-test") -> None: + super().__init__(model_name=model_name, endpoint=endpoint) + + async def send_prompt_async( + self, + *, + message: Message, + ) -> list[Message]: + return [ + MessagePiece( + role="assistant", + original_value="chat response", + ).to_message() + ] + + def _validate_request(self, *, message: Message) -> None: + pass + + def is_json_response_supported(self) -> bool: + return False + + +class TestTargetRegistrySingleton: + """Tests for the singleton pattern in TargetRegistry.""" + + def setup_method(self): + """Reset the singleton before each test.""" + TargetRegistry.reset_instance() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + """Test that get_registry_singleton returns the same singleton each time.""" + instance1 = TargetRegistry.get_registry_singleton() + instance2 = TargetRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_get_registry_singleton_returns_target_registry_type(self): + """Test that get_registry_singleton returns a TargetRegistry instance.""" + instance = TargetRegistry.get_registry_singleton() + assert isinstance(instance, TargetRegistry) + + def test_reset_instance_clears_singleton(self): + """Test that reset_instance clears the singleton.""" + instance1 = TargetRegistry.get_registry_singleton() + TargetRegistry.reset_instance() + instance2 = TargetRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryRegisterInstance: + """Tests for register_instance functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_register_instance_with_custom_name(self): + """Test registering a target with a custom name.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="custom_target") + + assert "custom_target" in self.registry + assert self.registry.get("custom_target") is target + + def test_register_instance_generates_name_from_class(self): + """Test that register_instance generates a name from class name when not provided.""" + target = MockPromptTarget() + self.registry.register_instance(target) + + # Name should be derived from class name with hash suffix + names = self.registry.get_names() + assert len(names) == 1 + assert names[0].startswith("mock_prompt_") + + def test_register_instance_multiple_targets_unique_names(self): + """Test registering multiple targets generates unique names.""" + target1 = MockPromptTarget() + target2 = MockPromptChatTarget() + + self.registry.register_instance(target1) + self.registry.register_instance(target2) + + assert len(self.registry) == 2 + + def test_register_instance_same_target_type_different_config(self): + """Test that same target class with different configs can be registered.""" + target1 = MockPromptTarget(model_name="model_a") + target2 = MockPromptTarget(model_name="model_b") + + # Register with explicit names + self.registry.register_instance(target1, name="target_1") + self.registry.register_instance(target2, name="target_2") + + assert len(self.registry) == 2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryGetInstanceByName: + """Tests for get_instance_by_name functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + self.target = MockPromptTarget() + self.registry.register_instance(self.target, name="test_target") + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_get_instance_by_name_returns_target(self): + """Test getting a registered target by name.""" + result = self.registry.get_instance_by_name("test_target") + assert result is self.target + + def test_get_instance_by_name_nonexistent_returns_none(self): + """Test that getting a non-existent target returns None.""" + result = self.registry.get_instance_by_name("nonexistent") + assert result is None + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryBuildMetadata: + """Tests for _build_metadata functionality in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_build_metadata_includes_class_name(self): + """Test that metadata (TargetIdentifier) includes the class name.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert isinstance(metadata[0], TargetIdentifier) + assert metadata[0].class_name == "MockPromptTarget" + + def test_build_metadata_includes_model_name(self): + """Test that metadata includes the model_name.""" + target = MockPromptTarget(model_name="test_model") + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + assert metadata[0].model_name == "test_model" + + def test_build_metadata_description_from_docstring(self): + """Test that class_description is derived from the target's docstring.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="mock_target") + + metadata = self.registry.list_metadata() + # MockPromptTarget has a docstring + assert "Mock PromptTarget for testing" in metadata[0].class_description + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistryListMetadata: + """Tests for list_metadata in TargetRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry with multiple targets.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + self.target1 = MockPromptTarget(model_name="model_a") + self.target2 = MockPromptTarget(model_name="model_b") + self.chat_target = MockPromptChatTarget() + + self.registry.register_instance(self.target1, name="target_1") + self.registry.register_instance(self.target2, name="target_2") + self.registry.register_instance(self.chat_target, name="chat_target") + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_list_metadata_returns_all_registered(self): + """Test that list_metadata returns metadata for all registered targets.""" + metadata = self.registry.list_metadata() + assert len(metadata) == 3 + + def test_list_metadata_filter_by_class_name(self): + """Test filtering metadata by class_name.""" + mock_metadata = self.registry.list_metadata(include_filters={"class_name": "MockPromptTarget"}) + + assert len(mock_metadata) == 2 + for m in mock_metadata: + assert m.class_name == "MockPromptTarget" + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetRegistrySupportsConversationHistory: + """Tests for supports_conversation_history field in TargetIdentifier.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + TargetRegistry.reset_instance() + self.registry = TargetRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + TargetRegistry.reset_instance() + + def test_registered_chat_target_has_supports_conversation_history_true(self): + """Test that registered chat targets have supports_conversation_history=True in metadata.""" + chat_target = MockPromptChatTarget() + self.registry.register_instance(chat_target, name="chat_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].supports_conversation_history is True + + def test_registered_non_chat_target_has_supports_conversation_history_false(self): + """Test that registered non-chat targets have supports_conversation_history=False in metadata.""" + target = MockPromptTarget() + self.registry.register_instance(target, name="prompt_target") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].supports_conversation_history is False diff --git a/tests/unit/setup/test_airt_targets_initializer.py b/tests/unit/setup/test_airt_targets_initializer.py new file mode 100644 index 0000000000..356a6388d5 --- /dev/null +++ b/tests/unit/setup/test_airt_targets_initializer.py @@ -0,0 +1,221 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +import pytest + +from pyrit.registry import TargetRegistry +from pyrit.setup.initializers import AIRTTargetInitializer +from pyrit.setup.initializers.airt_targets import TARGET_CONFIGS + + +class TestAIRTTargetInitializerBasic: + """Tests for AIRTTargetInitializer class - basic functionality.""" + + def test_can_be_created(self): + """Test that AIRTTargetInitializer can be instantiated.""" + init = AIRTTargetInitializer() + assert init is not None + assert init.name == "AIRT Target Initializer" + assert init.execution_order == 1 + + def test_required_env_vars_is_empty(self): + """Test that no env vars are required (initializer is optional).""" + init = AIRTTargetInitializer() + assert init.required_env_vars == [] + + +@pytest.mark.usefixtures("patch_central_database") +class TestAIRTTargetInitializerInitialize: + """Tests for AIRTTargetInitializer.initialize_async method.""" + + def setup_method(self) -> None: + """Reset registry before each test.""" + TargetRegistry.reset_instance() + # Clear all target-related env vars + self._clear_env_vars() + + def teardown_method(self) -> None: + """Clean up after each test.""" + TargetRegistry.reset_instance() + self._clear_env_vars() + + def _clear_env_vars(self) -> None: + """Clear all environment variables used by TARGET_CONFIGS.""" + for config in TARGET_CONFIGS: + for var in [config.endpoint_var, config.key_var, config.model_var, config.underlying_model_var]: + if var and var in os.environ: + del os.environ[var] + + @pytest.mark.asyncio + async def test_initialize_runs_without_error_no_env_vars(self): + """Test that initialize runs without errors when no env vars are set.""" + init = AIRTTargetInitializer() + await init.initialize_async() + + # No targets should be registered + registry = TargetRegistry.get_registry_singleton() + assert len(registry) == 0 + + @pytest.mark.asyncio + async def test_registers_target_when_env_vars_set(self): + """Test that a target is registered when its env vars are set.""" + os.environ["PLATFORM_OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["PLATFORM_OPENAI_CHAT_API_KEY"] = "test_key" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "platform_openai_chat" in registry + target = registry.get_instance_by_name("platform_openai_chat") + assert target is not None + assert target._model_name == "gpt-4o" + + @pytest.mark.asyncio + async def test_does_not_register_target_without_endpoint(self): + """Test that target is not registered if endpoint is missing.""" + # Only set key, not endpoint + os.environ["PLATFORM_OPENAI_CHAT_API_KEY"] = "test_key" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "platform_openai_chat" not in registry + + @pytest.mark.asyncio + async def test_does_not_register_target_without_api_key(self): + """Test that target is not registered if api_key env var is missing.""" + # Only set endpoint, not key + os.environ["PLATFORM_OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "platform_openai_chat" not in registry + + @pytest.mark.asyncio + async def test_registers_multiple_targets(self): + """Test that multiple targets are registered when their env vars are set.""" + # Set up platform_openai_chat + os.environ["PLATFORM_OPENAI_CHAT_ENDPOINT"] = "https://api.openai.com/v1" + os.environ["PLATFORM_OPENAI_CHAT_API_KEY"] = "test_key" + os.environ["PLATFORM_OPENAI_CHAT_GPT4O_MODEL"] = "gpt-4o" + + # Set up openai_image_platform (uses ENDPOINT2/KEY2/MODEL2) + os.environ["OPENAI_IMAGE_ENDPOINT2"] = "https://api.openai.com/v1" + os.environ["OPENAI_IMAGE_API_KEY2"] = "test_image_key" + os.environ["OPENAI_IMAGE_MODEL2"] = "dall-e-3" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert len(registry) == 2 + assert "platform_openai_chat" in registry + assert "openai_image_platform" in registry + + @pytest.mark.asyncio + async def test_registers_azure_content_safety_without_model(self): + """Test that PromptShieldTarget is registered without model_name (it doesn't use one).""" + os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test.cognitiveservices.azure.com" + os.environ["AZURE_CONTENT_SAFETY_API_KEY"] = "test_safety_key" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "azure_content_safety" in registry + + @pytest.mark.asyncio + async def test_underlying_model_passed_when_set(self): + """Test that underlying_model is passed to target when env var is set.""" + os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] = "https://my-deployment.openai.azure.com" + os.environ["AZURE_OPENAI_GPT4O_KEY"] = "test_key" + os.environ["AZURE_OPENAI_GPT4O_MODEL"] = "my-deployment-name" + os.environ["AZURE_OPENAI_GPT4O_UNDERLYING_MODEL"] = "gpt-4o" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + target = registry.get_instance_by_name("azure_openai_gpt4o") + assert target is not None + assert target._model_name == "my-deployment-name" + assert target._underlying_model == "gpt-4o" + + @pytest.mark.asyncio + async def test_registers_ollama_without_api_key(self): + """Test that Ollama target is registered without requiring an API key.""" + os.environ["OLLAMA_CHAT_ENDPOINT"] = "http://127.0.0.1:11434/v1" + os.environ["OLLAMA_MODEL"] = "llama2" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "ollama" in registry + target = registry.get_instance_by_name("ollama") + assert target is not None + assert target._model_name == "llama2" + + +@pytest.mark.usefixtures("patch_central_database") +class TestAIRTTargetInitializerTargetConfigs: + """Tests verifying TARGET_CONFIGS covers expected targets.""" + + def test_target_configs_not_empty(self): + """Test that TARGET_CONFIGS has configurations defined.""" + assert len(TARGET_CONFIGS) > 0 + + def test_all_configs_have_required_fields(self): + """Test that all TARGET_CONFIGS have required fields (key_var is optional for some).""" + for config in TARGET_CONFIGS: + assert config.registry_name, f"Config missing registry_name" + assert config.target_class, f"Config {config.registry_name} missing target_class" + assert config.endpoint_var, f"Config {config.registry_name} missing endpoint_var" + # key_var is optional for targets like Ollama that don't require auth + + def test_expected_targets_in_configs(self): + """Test that expected target names are in TARGET_CONFIGS.""" + registry_names = [config.registry_name for config in TARGET_CONFIGS] + + # Verify key targets are configured (using new primary config names) + assert "platform_openai_chat" in registry_names + assert "azure_openai_gpt4o" in registry_names + assert "openai_image_platform" in registry_names + assert "openai_tts_platform" in registry_names + assert "azure_content_safety" in registry_names + assert "ollama" in registry_names + assert "groq" in registry_names + assert "google_gemini" in registry_names + + +class TestAIRTTargetInitializerGetInfo: + """Tests for AIRTTargetInitializer.get_info_async method.""" + + @pytest.mark.asyncio + async def test_get_info_returns_expected_structure(self): + """Test that get_info_async returns expected structure.""" + info = await AIRTTargetInitializer.get_info_async() + + assert isinstance(info, dict) + assert info["name"] == "AIRT Target Initializer" + assert info["class"] == "AIRTTargetInitializer" + assert "description" in info + assert isinstance(info["description"], str) + + @pytest.mark.asyncio + async def test_get_info_required_env_vars_empty_or_not_present(self): + """Test that get_info has empty or no required_env_vars (since none are required).""" + info = await AIRTTargetInitializer.get_info_async() + + # required_env_vars may be omitted or empty since this initializer has no requirements + if "required_env_vars" in info: + assert info["required_env_vars"] == []