diff --git a/.pyrit_conf_example b/.pyrit_conf_example new file mode 100644 index 000000000..70c60f74f --- /dev/null +++ b/.pyrit_conf_example @@ -0,0 +1,70 @@ +# PyRIT Configuration File Example +# ================================ +# Copy this file to ~/.pyrit/.pyrit_conf or specify a custom path when loading. +# +# For documentation on configuration options, see: +# https://github.com/Azure/PyRIT/blob/main/doc/setup/configuration.md + +# Memory Database Type +# -------------------- +# Specifies which database backend to use for storing prompts and results. +# Options: in_memory, sqlite, azure_sql (case-insensitive) +# - in_memory: Temporary in-memory database (data lost on exit) +# - sqlite: Persistent local SQLite database (default) +# - azure_sql: Azure SQL database (requires connection string in env vars) +memory_db_type: sqlite + +# Initializers +# ------------ +# List of built-in initializers to run during PyRIT initialization. +# Initializers configure default values for converters, scorers, and targets. +# Names are normalized to snake_case (e.g., "SimpleInitializer" -> "simple"). +# +# Available initializers: +# - simple: Basic OpenAI configuration (requires OPENAI_CHAT_* env vars) +# - airt: AI Red Team setup with Azure OpenAI (requires AZURE_OPENAI_* env vars) +# - load_default_datasets: Loads default datasets for all registered scenarios +# - objective_list: Sets default objectives for scenarios +# - openai_objective_target: Sets up OpenAI target for scenarios +# +# Each initializer can be specified as: +# - A simple string (name only) +# - A dictionary with 'name' and optional 'args' for constructor arguments +# +# Example: +# initializers: +# - simple +# - name: airt +# args: +# some_param: value +initializers: + - simple + +# Initialization Scripts +# ---------------------- +# List of paths to custom Python scripts containing PyRITInitializer subclasses. +# Paths can be absolute or relative to the current working directory. +# +# Example: +# initialization_scripts: +# - /path/to/my_custom_initializer.py +# - ./local_initializer.py +initialization_scripts: [] + +# Environment Files +# ----------------- +# List of .env file paths to load during initialization. +# Later files override values from earlier files. +# If not specified, PyRIT loads ~/.pyrit/.env and ~/.pyrit/.env.local by default. +# +# Example: +# env_files: +# - /path/to/.env +# - /path/to/.env.local +env_files: [] + +# Silent Mode +# ----------- +# If true, suppresses print statements during initialization. +# Useful for non-interactive environments or when embedding PyRIT in other tools. +silent: false diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 6cfc8b5a4..b8d4b9b95 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -46,6 +46,7 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i ScenarioMetadata, ScenarioRegistry, ) + from pyrit.setup import ConfigurationLoader logger = logging.getLogger(__name__) @@ -66,16 +67,23 @@ class FrontendCore: def __init__( self, *, - database: str = SQLITE, + config_file: Optional[Path] = None, + database: Optional[str] = None, initialization_scripts: Optional[list[Path]] = None, initializer_names: Optional[list[str]] = None, env_files: Optional[list[Path]] = None, - log_level: str = "WARNING", + log_level: Optional[str] = None, ): """ Initialize PyRIT context. + Configuration is loaded in the following order (later values override earlier): + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual CLI arguments (database, initializers, etc.) + Args: + config_file: Optional path to a YAML configuration file. database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. initializer_names: Optional list of built-in initializer names to run. @@ -83,14 +91,34 @@ def __init__( log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING. Raises: - ValueError: If database or log_level are invalid. + ValueError: If database or log_level are invalid, or if config file is invalid. + FileNotFoundError: If an explicitly specified config_file does not exist. """ - # Validate inputs - self._database = validate_database(database=database) - self._initialization_scripts = initialization_scripts - self._initializer_names = initializer_names - self._env_files = env_files - self._log_level = validate_log_level(log_level=log_level) + from pyrit.setup import ConfigurationLoader + + # Load configuration from files and merge with CLI arguments + config = self._load_and_merge_config( + config_file=config_file, + database=database, + initialization_scripts=initialization_scripts, + initializer_names=initializer_names, + env_files=env_files, + ) + + # Store the merged configuration + self._config = config + + # Extract values from config for internal use + # Map snake_case db type back to PascalCase for backward compatibility + db_type_map = {"in_memory": IN_MEMORY, "sqlite": SQLITE, "azure_sql": AZURE_SQL} + self._database = db_type_map[config.memory_db_type] + self._initialization_scripts = config._resolve_initialization_scripts() + self._initializer_names = [ic.name for ic in config._initializer_configs] if config._initializer_configs else None + self._env_files = config._resolve_env_files() + + # Log level comes from CLI arg (not in config file), default to WARNING + effective_log_level = log_level if log_level is not None else "WARNING" + self._log_level = validate_log_level(log_level=effective_log_level) # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -100,6 +128,93 @@ def __init__( # Configure logging logging.basicConfig(level=getattr(logging, self._log_level)) + def _load_and_merge_config( + self, + *, + config_file: Optional[Path], + database: Optional[str], + initialization_scripts: Optional[list[Path]], + initializer_names: Optional[list[str]], + env_files: Optional[list[Path]], + ) -> "ConfigurationLoader": + """ + Load configuration from files and merge with CLI arguments. + + Precedence (later overrides earlier): + 1. Default config file (~/.pyrit/.pyrit_conf) if it exists + 2. Explicit config_file argument if provided + 3. Individual CLI arguments + + Args: + config_file: Optional explicit config file path. + database: Optional database type from CLI. + initialization_scripts: Optional scripts from CLI. + initializer_names: Optional initializer names from CLI. + env_files: Optional env files from CLI. + + Returns: + Merged ConfigurationLoader instance. + """ + from pyrit.setup import ConfigurationLoader + + # Start with defaults + config_data: dict = { + "memory_db_type": "sqlite", + "initializers": [], + "initialization_scripts": [], + "env_files": [], + } + + # 1. Try loading default config file if it exists + default_config_path = ConfigurationLoader.get_default_config_path() + if default_config_path.exists(): + try: + default_config = ConfigurationLoader.from_yaml_file(default_config_path) + config_data["memory_db_type"] = default_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in default_config._initializer_configs + ] + config_data["initialization_scripts"] = default_config.initialization_scripts + config_data["env_files"] = default_config.env_files + except Exception as e: + logger.warning(f"Failed to load default config file {default_config_path}: {e}") + + # 2. Load explicit config file if provided (overrides default) + if config_file is not None: + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file}") + explicit_config = ConfigurationLoader.from_yaml_file(config_file) + config_data["memory_db_type"] = explicit_config.memory_db_type + config_data["initializers"] = [ + {"name": ic.name, "args": ic.args} if ic.args else ic.name + for ic in explicit_config._initializer_configs + ] + config_data["initialization_scripts"] = explicit_config.initialization_scripts + config_data["env_files"] = explicit_config.env_files + + # 3. Apply CLI overrides (non-None values take precedence) + if database is not None: + # Normalize to snake_case for ConfigurationLoader + normalized_db = database.lower().replace("-", "_") + # Handle PascalCase inputs + if normalized_db == "inmemory": + normalized_db = "in_memory" + elif normalized_db == "azuresql": + normalized_db = "azure_sql" + config_data["memory_db_type"] = normalized_db + + if initialization_scripts is not None: + config_data["initialization_scripts"] = [str(p) for p in initialization_scripts] + + if initializer_names is not None: + config_data["initializers"] = initializer_names + + if env_files is not None: + config_data["env_files"] = [str(p) for p in env_files] + + return ConfigurationLoader.from_dict(config_data) + async def initialize_async(self) -> None: """Initialize PyRIT and load registries (heavy operation).""" if self._initialized: @@ -734,6 +849,11 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path # Shared argument help text ARG_HELP = { + "config_file": ( + "Path to a YAML configuration file. Allows specifying database, initializers (with args), " + "initialization scripts, and env files. CLI arguments override config file values. " + "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." + ), "initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)", "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index df342ce0c..cf27df684 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -10,6 +10,7 @@ import asyncio import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from pathlib import Path from typing import Optional from pyrit.cli import frontend_core @@ -34,6 +35,9 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: # Run a scenario with built-in initializers pyrit_scan foundry --initializers openai_objective_target load_default_datasets + # Run with a configuration file (recommended for complex setups) + pyrit_scan foundry --config-file ./my_config.yaml + # Run with custom initialization scripts pyrit_scan garak.encoding --initialization-scripts ./my_config.py @@ -45,6 +49,12 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: formatter_class=RawDescriptionHelpFormatter, ) + parser.add_argument( + "--config-file", + type=Path, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--log-level", type=frontend_core.validate_log_level_argparse, @@ -176,12 +186,14 @@ def main(args: Optional[list[str]] = None) -> int: env_files = None if parsed_args.env_files: try: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + env_files = frontend_core.resolve_env_files( + env_file_paths=parsed_args.env_files) except ValueError as e: print(f"Error: {e}") return 1 context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, database=parsed_args.database, initialization_scripts=initialization_scripts, env_files=env_files, @@ -194,7 +206,10 @@ def main(args: Optional[list[str]] = None) -> int: # Discover from scenarios directory scenarios_path = frontend_core.get_default_initializer_discovery_path() - context = frontend_core.FrontendCore(log_level=parsed_args.log_level) + context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, + log_level=parsed_args.log_level, + ) return asyncio.run(frontend_core.print_initializers_list_async(context=context, discovery_path=scenarios_path)) # Verify scenario was provided @@ -214,10 +229,12 @@ def main(args: Optional[list[str]] = None) -> int: # Collect environment files env_files = None if parsed_args.env_files: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) + env_files = frontend_core.resolve_env_files( + env_file_paths=parsed_args.env_files) # Create context with initializers context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, database=parsed_args.database, initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, @@ -228,7 +245,8 @@ def main(args: Optional[list[str]] = None) -> int: # Parse memory labels if provided memory_labels = None if parsed_args.memory_labels: - memory_labels = frontend_core.parse_memory_labels(json_string=parsed_args.memory_labels) + memory_labels = frontend_core.parse_memory_labels( + json_string=parsed_args.memory_labels) # Run scenario asyncio.run( diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index bcb074342..a18b4ab14 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -105,7 +105,8 @@ def __init__( self._scenario_history: list[tuple[str, ScenarioResult]] = [] # Initialize PyRIT in background thread for faster startup - self._init_thread = threading.Thread(target=self._background_init, daemon=True) + self._init_thread = threading.Thread( + target=self._background_init, daemon=True) self._init_complete = threading.Event() self._init_thread.start() @@ -125,7 +126,8 @@ def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" self._ensure_initialized() try: - asyncio.run(frontend_core.print_scenarios_list_async(context=self.context)) + asyncio.run(frontend_core.print_scenarios_list_async( + context=self.context)) except Exception as e: print(f"Error listing scenarios: {e}") @@ -136,7 +138,8 @@ def do_list_initializers(self, arg: str) -> None: # Discover from scenarios directory by default (same as scan) discovery_path = frontend_core.get_default_initializer_discovery_path() asyncio.run( - frontend_core.print_initializers_list_async(context=self.context, discovery_path=discovery_path) + frontend_core.print_initializers_list_async( + context=self.context, discovery_path=discovery_path) ) except Exception as e: print(f"Error listing initializers: {e}") @@ -179,14 +182,19 @@ def do_run(self, line: str) -> None: print("\nUsage: run [options]") print("\nNote: Every scenario requires an initializer.") print("\nOptions:") - print(f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") + print( + f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") print( f" --initialization-scripts <...> {frontend_core.ARG_HELP['initialization_scripts']} (alternative to --initializers)" ) - print(f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") - print(f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") - print(f" --max-retries {frontend_core.ARG_HELP['max_retries']}") - print(f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") + print( + f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") + print( + f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") + print( + f" --max-retries {frontend_core.ARG_HELP['max_retries']}") + print( + f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") print( f" --database Override default database ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL})" ) @@ -194,7 +202,8 @@ def do_run(self, line: str) -> None: f" --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" ) print("\nExample:") - print(" run foundry --initializers openai_objective_target load_default_datasets") + print( + " run foundry --initializers openai_objective_target load_default_datasets") print("\nType 'help run' for more details and examples") return @@ -220,7 +229,8 @@ def do_run(self, line: str) -> None: resolved_env_files = None if args["env_files"]: try: - resolved_env_files = frontend_core.resolve_env_files(env_file_paths=args["env_files"]) + resolved_env_files = frontend_core.resolve_env_files( + env_file_paths=args["env_files"]) except ValueError as e: print(f"Error: {e}") return @@ -283,7 +293,8 @@ def do_scenario_history(self, arg: str) -> None: print(f"{idx}) {command}") print("=" * 80) print(f"\nTotal runs: {len(self._scenario_history)}") - print("\nUse 'print-scenario ' to view detailed results for a specific run.") + print( + "\nUse 'print-scenario ' to view detailed results for a specific run.") print("Use 'print-scenario' to view detailed results for all runs.") def do_print_scenario(self, arg: str) -> None: @@ -325,7 +336,8 @@ def do_print_scenario(self, arg: str) -> None: try: scenario_num = int(arg) if scenario_num < 1 or scenario_num > len(self._scenario_history): - print(f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") + print( + f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") return command, result = self._scenario_history[scenario_num - 1] @@ -338,7 +350,8 @@ def do_print_scenario(self, arg: str) -> None: printer = ConsoleScenarioResultPrinter() asyncio.run(printer.print_summary_async(result)) except ValueError: - print(f"Error: Invalid scenario number '{arg}'. Must be an integer.") + print( + f"Error: Invalid scenario number '{arg}'. Must be an integer.") def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" @@ -351,12 +364,14 @@ def do_help(self, arg: str) -> None: print(" --database ") print(" Default database type: InMemory, SQLite, or AzureSQL") print(" Default: SQLite") - print(" Can be overridden per-run with 'run --database '") + print( + " Can be overridden per-run with 'run --database '") print() print(" --log-level ") print(" Default logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL") print(" Default: WARNING") - print(" Can be overridden per-run with 'run --log-level '") + print( + " Can be overridden per-run with 'run --log-level '") print() print("=" * 70) print("Run Command Options (specified when running scenarios):") @@ -364,9 +379,11 @@ def do_help(self, arg: str) -> None: print(" --initializers [ ...] (REQUIRED)") print(f" {frontend_core.ARG_HELP['initializers']}") print(" Every scenario requires at least one initializer") - print(" Example: run foundry --initializers openai_objective_target load_default_datasets") + print( + " Example: run foundry --initializers openai_objective_target load_default_datasets") print() - print(" --initialization-scripts [ ...] (Alternative to --initializers)") + print( + " --initialization-scripts [ ...] (Alternative to --initializers)") print(f" {frontend_core.ARG_HELP['initialization_scripts']}") print(" Example: run foundry --initialization-scripts ./my_init.py") print() @@ -382,7 +399,8 @@ def do_help(self, arg: str) -> None: print() print(" --memory-labels ") print(f" {frontend_core.ARG_HELP['memory_labels']}") - print(' Example: run foundry --memory-labels \'{"env":"test"}\'') + print( + ' Example: run foundry --memory-labels \'{"env":"test"}\'') print() print("Start the shell like:") print(" pyrit_shell") @@ -453,9 +471,16 @@ def main() -> int: description="PyRIT Interactive Shell - Load modules once, run commands instantly", ) + parser.add_argument( + "--config-file", + type=frontend_core.Path, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--database", - choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], + choices=[frontend_core.IN_MEMORY, + frontend_core.SQLITE, frontend_core.AZURE_SQL], default=frontend_core.SQLITE, help=f"Default database type to use ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE}, can be overridden per-run)", ) @@ -481,13 +506,15 @@ def main() -> int: env_files = None if args.env_files: try: - env_files = frontend_core.resolve_env_files(env_file_paths=args.env_files) + env_files = frontend_core.resolve_env_files( + env_file_paths=args.env_files) except ValueError as e: print(f"Error: {e}") return 1 # Create context (initializers are specified per-run, not at startup) context = frontend_core.FrontendCore( + config_file=args.config_file, database=args.database, initialization_scripts=None, initializer_names=None, diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 4094ba8a4..b61eb09d9 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -33,6 +33,10 @@ def in_git_repo() -> bool: CONFIGURATION_DIRECTORY_PATH = pathlib.Path.home() / ".pyrit" +# Default configuration file name and path +DEFAULT_CONFIG_FILENAME = ".pyrit_conf" +DEFAULT_CONFIG_PATH = CONFIGURATION_DIRECTORY_PATH / DEFAULT_CONFIG_FILENAME + # Points to the root of the project HOME_PATH = pathlib.Path(PYRIT_PATH, "..").resolve() @@ -54,22 +58,30 @@ def in_git_repo() -> bool: DATASETS_PATH = pathlib.Path(PYRIT_PATH, "datasets").resolve() EXECUTOR_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "executors").resolve() -EXECUTOR_RED_TEAM_PATH = pathlib.Path(EXECUTOR_SEED_PROMPT_PATH, "red_teaming").resolve() -EXECUTOR_SIMULATED_TARGET_PATH = pathlib.Path(EXECUTOR_SEED_PROMPT_PATH, "simulated_target").resolve() -CONVERTER_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "prompt_converters").resolve() +EXECUTOR_RED_TEAM_PATH = pathlib.Path( + EXECUTOR_SEED_PROMPT_PATH, "red_teaming").resolve() +EXECUTOR_SIMULATED_TARGET_PATH = pathlib.Path( + EXECUTOR_SEED_PROMPT_PATH, "simulated_target").resolve() +CONVERTER_SEED_PROMPT_PATH = pathlib.Path( + DATASETS_PATH, "prompt_converters").resolve() SCORER_SEED_PROMPT_PATH = pathlib.Path(DATASETS_PATH, "score").resolve() -SCORER_CONTENT_CLASSIFIERS_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "content_classifiers").resolve() +SCORER_CONTENT_CLASSIFIERS_PATH = pathlib.Path( + SCORER_SEED_PROMPT_PATH, "content_classifiers").resolve() SCORER_LIKERT_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "likert").resolve() SCORER_SCALES_PATH = pathlib.Path(SCORER_SEED_PROMPT_PATH, "scales").resolve() HARM_DEFINITION_PATH = pathlib.Path(DATASETS_PATH, "harm_definition").resolve() -JAILBREAK_TEMPLATES_PATH = pathlib.Path(DATASETS_PATH, "jailbreak", "templates").resolve() +JAILBREAK_TEMPLATES_PATH = pathlib.Path( + DATASETS_PATH, "jailbreak", "templates").resolve() SCORER_EVALS_PATH = pathlib.Path(DATASETS_PATH, "scorer_evals").resolve() SCORER_EVALS_HARM_PATH = pathlib.Path(SCORER_EVALS_PATH, "harm").resolve() -SCORER_EVALS_OBJECTIVE_PATH = pathlib.Path(SCORER_EVALS_PATH, "objective").resolve() -SCORER_EVALS_REFUSAL_SCORER_PATH = pathlib.Path(SCORER_EVALS_PATH, "refusal_scorer").resolve() -SCORER_EVALS_TRUE_FALSE_PATH = pathlib.Path(SCORER_EVALS_PATH, "true_false").resolve() +SCORER_EVALS_OBJECTIVE_PATH = pathlib.Path( + SCORER_EVALS_PATH, "objective").resolve() +SCORER_EVALS_REFUSAL_SCORER_PATH = pathlib.Path( + SCORER_EVALS_PATH, "refusal_scorer").resolve() +SCORER_EVALS_TRUE_FALSE_PATH = pathlib.Path( + SCORER_EVALS_PATH, "true_false").resolve() SCORER_EVALS_LIKERT_PATH = pathlib.Path(SCORER_EVALS_PATH, "likert").resolve() diff --git a/pyrit/setup/__init__.py b/pyrit/setup/__init__.py index 4ecdbd9d4..2929a59ea 100644 --- a/pyrit/setup/__init__.py +++ b/pyrit/setup/__init__.py @@ -3,6 +3,7 @@ """Module containing initialization PyRIT.""" +from pyrit.setup.configuration_loader import ConfigurationLoader, initialize_from_config_async from pyrit.setup.initialization import AZURE_SQL, IN_MEMORY, SQLITE, MemoryDatabaseType, initialize_pyrit_async __all__ = [ @@ -10,5 +11,7 @@ "SQLITE", "IN_MEMORY", "initialize_pyrit_async", + "initialize_from_config_async", "MemoryDatabaseType", + "ConfigurationLoader", ] diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py new file mode 100644 index 000000000..2da6eda6b --- /dev/null +++ b/pyrit/setup/configuration_loader.py @@ -0,0 +1,319 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Configuration loader for PyRIT initialization. + +This module provides the ConfigurationLoader class that loads PyRIT configuration +from YAML files and initializes PyRIT accordingly. +""" + +import pathlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union + +from pyrit.common.path import DEFAULT_CONFIG_PATH +from pyrit.common.yaml_loadable import YamlLoadable +from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.setup.initialization import ( + AZURE_SQL, + IN_MEMORY, + SQLITE, + initialize_pyrit_async, +) + +if TYPE_CHECKING: + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + +# Type alias for YAML-serializable values that can be passed as initializer args +# This matches what YAML can represent: primitives, lists, and nested dicts +YamlPrimitive = Union[str, int, float, bool, None] +YamlValue = Union[YamlPrimitive, List["YamlValue"], Dict[str, "YamlValue"]] + +# Mapping from snake_case config values to internal constants +_MEMORY_DB_TYPE_MAP: Dict[str, str] = { + "in_memory": IN_MEMORY, + "sqlite": SQLITE, + "azure_sql": AZURE_SQL, +} + + +@dataclass +class InitializerConfig: + """ + Configuration for a single initializer. + + Attributes: + name: The name of the initializer (must be registered in InitializerRegistry). + args: Optional dictionary of YAML-serializable arguments to pass to the initializer constructor. + """ + + name: str + args: Optional[Dict[str, YamlValue]] = None + + +@dataclass +class ConfigurationLoader(YamlLoadable): + """ + Loader for PyRIT configuration from YAML files. + + This class loads configuration from a YAML file and provides methods to + initialize PyRIT with the loaded configuration. + + Attributes: + memory_db_type: The type of memory database (in_memory, sqlite, azure_sql). + initializers: List of initializer configurations (name + optional args). + initialization_scripts: List of paths to custom initialization scripts. + env_files: List of environment file paths to load. + silent: Whether to suppress initialization messages. + + Example YAML configuration: + memory_db_type: sqlite + + initializers: + - simple + - name: airt + args: + some_param: value + + initialization_scripts: + - /path/to/custom_initializer.py + + env_files: + - /path/to/.env + - /path/to/.env.local + + silent: false + """ + + memory_db_type: str = "sqlite" + initializers: List[Union[str, Dict[str, Any]] + ] = field(default_factory=list) + initialization_scripts: List[str] = field(default_factory=list) + env_files: List[str] = field(default_factory=list) + silent: bool = False + + def __post_init__(self) -> None: + """Validate and normalize the configuration after loading.""" + self._normalize_memory_db_type() + self._normalize_initializers() + + def _normalize_memory_db_type(self) -> None: + """ + Normalize and validate memory_db_type. + + Converts the input to lowercase snake_case and validates against known types. + Stores the normalized snake_case value for config consistency, but maps + to internal constants when initializing. + """ + # Normalize to lowercase + normalized = self.memory_db_type.lower().replace("-", "_") + + # Also handle PascalCase inputs (e.g., "InMemory" -> "in_memory") + if normalized not in _MEMORY_DB_TYPE_MAP: + # Try converting from PascalCase + normalized = class_name_to_snake_case(self.memory_db_type) + + if normalized not in _MEMORY_DB_TYPE_MAP: + valid_types = list(_MEMORY_DB_TYPE_MAP.keys()) + raise ValueError( + f"Invalid memory_db_type '{self.memory_db_type}'. " + f"Must be one of: {', '.join(valid_types)}" + ) + + # Store normalized snake_case value + self.memory_db_type = normalized + + def _normalize_initializers(self) -> None: + """ + Normalize initializer entries to InitializerConfig objects. + + Converts initializer names to snake_case for consistent registry lookup. + """ + normalized: List[InitializerConfig] = [] + for entry in self.initializers: + if isinstance(entry, str): + # Simple string entry: normalize name to snake_case + name = class_name_to_snake_case(entry) + normalized.append(InitializerConfig(name=name)) + elif isinstance(entry, dict): + # Dict entry: name and optional args + if "name" not in entry: + raise ValueError( + f"Initializer configuration must have a 'name' field. Got: {entry}" + ) + name = class_name_to_snake_case(entry["name"]) + normalized.append( + InitializerConfig( + name=name, + args=entry.get("args"), + ) + ) + else: + raise ValueError( + f"Initializer entry must be a string or dict, got: {type(entry).__name__}" + ) + self._initializer_configs = normalized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ConfigurationLoader": + """ + Create a ConfigurationLoader from a dictionary. + + Args: + data: Dictionary containing configuration values. + + Returns: + A new ConfigurationLoader instance. + """ + # Filter out None values and empty lists to use defaults + filtered_data = { + k: v for k, v in data.items() + if v is not None and v != [] + } + return cls(**filtered_data) + + @classmethod + def get_default_config_path(cls) -> pathlib.Path: + """ + Get the default configuration file path. + + Returns: + Path to the default config file in ~/.pyrit/.pyrit_conf + """ + return DEFAULT_CONFIG_PATH + + def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: + """ + Resolve initializer names to PyRITInitializer instances. + + Uses the InitializerRegistry to look up initializer classes by name + and instantiate them with optional arguments. + + Returns: + Sequence of PyRITInitializer instances. + + Raises: + ValueError: If an initializer name is not found in the registry. + """ + from pyrit.registry import InitializerRegistry + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + if not self._initializer_configs: + return [] + + registry = InitializerRegistry() + resolved: List[PyRITInitializer] = [] + + for config in self._initializer_configs: + initializer_class = registry.get_class(config.name) + if initializer_class is None: + available = ", ".join(sorted(registry.get_names())) + raise ValueError( + f"Initializer '{config.name}' not found in registry.\n" + f"Available initializers: {available}" + ) + + # Instantiate with args if provided + if config.args: + instance = initializer_class(**config.args) + else: + instance = initializer_class() + + resolved.append(instance) + + return resolved + + def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: + """ + Resolve initialization script paths. + + Returns: + Sequence of Path objects, or None if no scripts configured. + """ + if not self.initialization_scripts: + return None + + resolved: List[pathlib.Path] = [] + for script_str in self.initialization_scripts: + script_path = pathlib.Path(script_str) + if not script_path.is_absolute(): + script_path = pathlib.Path.cwd() / script_path + resolved.append(script_path) + + return resolved + + def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: + """ + Resolve environment file paths. + + Returns: + Sequence of Path objects, or None if no env files configured. + """ + if not self.env_files: + return None + + resolved: List[pathlib.Path] = [] + for env_str in self.env_files: + env_path = pathlib.Path(env_str) + if not env_path.is_absolute(): + env_path = pathlib.Path.cwd() / env_path + resolved.append(env_path) + + return resolved + + async def initialize_pyrit_async(self) -> None: + """ + Initialize PyRIT with the loaded configuration. + + This method resolves all initializer names to instances and calls + the core initialize_pyrit_async function. + + Raises: + ValueError: If configuration is invalid or initializers cannot be resolved. + """ + resolved_initializers = self._resolve_initializers() + resolved_scripts = self._resolve_initialization_scripts() + resolved_env_files = self._resolve_env_files() + + # Map snake_case memory_db_type to internal constant + internal_memory_db_type = _MEMORY_DB_TYPE_MAP[self.memory_db_type] + + await initialize_pyrit_async( + memory_db_type=internal_memory_db_type, + initialization_scripts=resolved_scripts, + initializers=resolved_initializers if resolved_initializers else None, + env_files=resolved_env_files, + silent=self.silent, + ) + + +async def initialize_from_config_async( + config_path: Optional[Union[str, pathlib.Path]] = None, +) -> ConfigurationLoader: + """ + Initialize PyRIT from a configuration file. + + This is a convenience function that loads a ConfigurationLoader from + a YAML file and initializes PyRIT. + + Args: + config_path: Path to the configuration file. If None, uses the default + path (~/.pyrit/.pyrit_conf). Can be a string or pathlib.Path. + + Returns: + The loaded ConfigurationLoader instance. + + Raises: + FileNotFoundError: If the configuration file does not exist. + ValueError: If the configuration is invalid. + """ + if config_path is None: + config_path = ConfigurationLoader.get_default_config_path() + elif isinstance(config_path, str): + config_path = pathlib.Path(config_path) + + config = ConfigurationLoader.from_yaml_file(config_path) + await config.initialize_pyrit_async() + return config diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py new file mode 100644 index 000000000..d652969e3 --- /dev/null +++ b/tests/unit/setup/test_configuration_loader.py @@ -0,0 +1,364 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pathlib +import tempfile +from unittest import mock + +import pytest + +from pyrit.setup.configuration_loader import ( + ConfigurationLoader, + InitializerConfig, + initialize_from_config_async, +) + + +class TestInitializerConfig: + """Tests for InitializerConfig dataclass.""" + + def test_initializer_config_with_name_only(self): + """Test creating InitializerConfig with just a name.""" + config = InitializerConfig(name="simple") + assert config.name == "simple" + assert config.args is None + + def test_initializer_config_with_args(self): + """Test creating InitializerConfig with name and args.""" + config = InitializerConfig(name="custom", args={"param1": "value1"}) + assert config.name == "custom" + assert config.args == {"param1": "value1"} + + +class TestConfigurationLoader: + """Tests for ConfigurationLoader class.""" + + def test_default_values(self): + """Test default configuration values.""" + config = ConfigurationLoader() + assert config.memory_db_type == "sqlite" + assert config.initializers == [] + assert config.initialization_scripts == [] + assert config.env_files == [] + assert config.silent is False + + def test_valid_memory_db_types_snake_case(self): + """Test all valid memory database types in snake_case.""" + for db_type in ["in_memory", "sqlite", "azure_sql"]: + config = ConfigurationLoader(memory_db_type=db_type) + assert config.memory_db_type == db_type + + def test_memory_db_type_normalization_from_pascal_case(self): + """Test that PascalCase memory_db_type is normalized to snake_case.""" + config = ConfigurationLoader(memory_db_type="InMemory") + assert config.memory_db_type == "in_memory" + + config = ConfigurationLoader(memory_db_type="SQLite") + assert config.memory_db_type == "sqlite" + + config = ConfigurationLoader(memory_db_type="AzureSQL") + assert config.memory_db_type == "azure_sql" + + def test_memory_db_type_normalization_case_insensitive(self): + """Test that memory_db_type normalization is case-insensitive.""" + config = ConfigurationLoader(memory_db_type="SQLITE") + assert config.memory_db_type == "sqlite" + + config = ConfigurationLoader(memory_db_type="In_Memory") + assert config.memory_db_type == "in_memory" + + def test_invalid_memory_db_type_raises_error(self): + """Test that invalid memory_db_type raises ValueError.""" + with pytest.raises(ValueError, match="Invalid memory_db_type"): + ConfigurationLoader(memory_db_type="InvalidType") + + def test_initializer_as_string(self): + """Test initializers specified as simple strings.""" + config = ConfigurationLoader(initializers=["simple", "airt"]) + assert len(config._initializer_configs) == 2 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[0].args is None + assert config._initializer_configs[1].name == "airt" + + def test_initializer_as_dict_with_name_only(self): + """Test initializers specified as dicts with only name.""" + config = ConfigurationLoader(initializers=[{"name": "simple"}]) + assert len(config._initializer_configs) == 1 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[0].args is None + + def test_initializer_as_dict_with_args(self): + """Test initializers specified as dicts with name and args.""" + config = ConfigurationLoader( + initializers=[{"name": "custom", "args": { + "param1": "value1", "param2": 42}}] + ) + assert len(config._initializer_configs) == 1 + assert config._initializer_configs[0].name == "custom" + assert config._initializer_configs[0].args == { + "param1": "value1", "param2": 42} + + def test_mixed_initializer_formats(self): + """Test initializers with mixed string and dict formats.""" + config = ConfigurationLoader( + initializers=[ + "simple", + {"name": "airt"}, + {"name": "custom", "args": {"key": "value"}}, + ] + ) + assert len(config._initializer_configs) == 3 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[1].name == "airt" + assert config._initializer_configs[2].name == "custom" + assert config._initializer_configs[2].args == {"key": "value"} + + def test_initializer_name_normalization_from_pascal_case(self): + """Test that PascalCase initializer names are normalized to snake_case.""" + config = ConfigurationLoader( + initializers=["SimpleInitializer", "AIRTInitializer"]) + assert config._initializer_configs[0].name == "simple_initializer" + assert config._initializer_configs[1].name == "airt_initializer" + + def test_initializer_name_normalization_preserves_snake_case(self): + """Test that snake_case names are preserved.""" + config = ConfigurationLoader( + initializers=["simple_initializer", "airt_init"]) + assert config._initializer_configs[0].name == "simple_initializer" + assert config._initializer_configs[1].name == "airt_init" + + def test_initializer_name_already_snake_case(self): + """Test that snake_case names remain unchanged.""" + config = ConfigurationLoader( + initializers=["load_default_datasets", "objective_list"]) + assert config._initializer_configs[0].name == "load_default_datasets" + assert config._initializer_configs[1].name == "objective_list" + + def test_initializer_dict_without_name_raises_error(self): + """Test that dict initializer without 'name' raises ValueError.""" + with pytest.raises(ValueError, match="must have a 'name' field"): + ConfigurationLoader(initializers=[{"args": {"key": "value"}}]) + + def test_initializer_invalid_type_raises_error(self): + """Test that invalid initializer type raises ValueError.""" + with pytest.raises(ValueError, match="must be a string or dict"): + ConfigurationLoader(initializers=[123]) # type: ignore + + def test_from_dict_with_all_fields(self): + """Test from_dict with all configuration fields.""" + data = { + "memory_db_type": "InMemory", + "initializers": ["simple"], + "initialization_scripts": ["/path/to/script.py"], + "env_files": ["/path/to/.env"], + "silent": True, + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "in_memory" # Normalized to snake_case + assert config.initializers == ["simple"] + assert config.initialization_scripts == ["/path/to/script.py"] + assert config.env_files == ["/path/to/.env"] + assert config.silent is True + + def test_from_dict_filters_none_values(self): + """Test that from_dict filters out None values.""" + data = { + "memory_db_type": "SQLite", + "initializers": None, + "env_files": [], + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "sqlite" # Normalized to snake_case + assert config.initializers == [] # Uses default, not None + + def test_from_yaml_file(self): + """Test loading configuration from a YAML file.""" + yaml_content = """ +memory_db_type: in_memory +initializers: + - simple + - name: airt + args: + key: value +silent: true +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + yaml_path = f.name + + try: + config = ConfigurationLoader.from_yaml_file(yaml_path) + assert config.memory_db_type == "in_memory" + assert len(config._initializer_configs) == 2 + assert config._initializer_configs[0].name == "simple" + assert config._initializer_configs[1].name == "airt" + assert config._initializer_configs[1].args == {"key": "value"} + assert config.silent is True + finally: + pathlib.Path(yaml_path).unlink() + + def test_get_default_config_path(self): + """Test get_default_config_path returns expected path.""" + default_path = ConfigurationLoader.get_default_config_path() + assert default_path.name == ".pyrit_conf" + assert ".pyrit" in str(default_path) + + +class TestConfigurationLoaderResolvers: + """Tests for ConfigurationLoader path resolution methods.""" + + def test_resolve_initialization_scripts_empty(self): + """Test that empty scripts returns None.""" + config = ConfigurationLoader() + assert config._resolve_initialization_scripts() is None + + def test_resolve_initialization_scripts_absolute_path(self): + """Test resolving absolute script paths.""" + config = ConfigurationLoader(initialization_scripts=[ + "/absolute/path/script.py"]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0] == pathlib.Path("/absolute/path/script.py") + + def test_resolve_initialization_scripts_relative_path(self): + """Test resolving relative script paths (converted to absolute).""" + config = ConfigurationLoader( + initialization_scripts=["relative/script.py"]) + resolved = config._resolve_initialization_scripts() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].is_absolute() + assert str(resolved[0]).endswith("relative/script.py") + + def test_resolve_env_files_empty(self): + """Test that empty env files returns None.""" + config = ConfigurationLoader() + assert config._resolve_env_files() is None + + def test_resolve_env_files_absolute_path(self): + """Test resolving absolute env file paths.""" + config = ConfigurationLoader(env_files=["/path/to/.env"]) + resolved = config._resolve_env_files() + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0] == pathlib.Path("/path/to/.env") + + +@pytest.mark.usefixtures("patch_central_database") +class TestConfigurationLoaderInitialization: + """Tests for ConfigurationLoader.initialize_pyrit_async method.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + async def test_initialize_pyrit_async_basic(self, mock_init): + """Test basic initialization with minimal configuration.""" + config = ConfigurationLoader(memory_db_type="in_memory") + await config.initialize_pyrit_async() + + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + # Should map snake_case to internal constant + assert call_kwargs["memory_db_type"] == "InMemory" + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["initializers"] is None + assert call_kwargs["env_files"] is None + assert call_kwargs["silent"] is False + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.initialize_pyrit_async") + @mock.patch("pyrit.registry.InitializerRegistry") + async def test_initialize_pyrit_async_with_initializers(self, mock_registry_cls, mock_init): + """Test initialization with initializers resolved from registry.""" + # Setup mock registry + mock_registry = mock.MagicMock() + mock_registry_cls.return_value = mock_registry + + # Mock an initializer class + mock_initializer_class = mock.MagicMock() + mock_initializer_instance = mock.MagicMock() + mock_initializer_class.return_value = mock_initializer_instance + mock_registry.get_class.return_value = mock_initializer_class + + config = ConfigurationLoader( + memory_db_type="in_memory", + initializers=["simple"], + ) + await config.initialize_pyrit_async() + + # Verify registry was used to resolve initializer + mock_registry.get_class.assert_called_once_with("simple") + mock_initializer_class.assert_called_once_with() + + # Verify initialize was called with resolved initializers + mock_init.assert_called_once() + call_kwargs = mock_init.call_args.kwargs + assert call_kwargs["initializers"] == [mock_initializer_instance] + + @pytest.mark.asyncio + @mock.patch("pyrit.registry.InitializerRegistry") + async def test_initialize_pyrit_async_unknown_initializer_raises_error(self, mock_registry_cls): + """Test that unknown initializer name raises ValueError.""" + mock_registry = mock.MagicMock() + mock_registry_cls.return_value = mock_registry + mock_registry.get_class.return_value = None + mock_registry.get_names.return_value = ["simple", "airt"] + + config = ConfigurationLoader( + memory_db_type="in_memory", + initializers=["unknown_initializer"], + ) + + with pytest.raises(ValueError, match="not found in registry"): + await config.initialize_pyrit_async() + + +@pytest.mark.usefixtures("patch_central_database") +class TestInitializeFromConfigAsync: + """Tests for initialize_from_config_async function.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_with_path(self, mock_init, mock_from_yaml): + """Test initialize_from_config_async with explicit path.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + + result = await initialize_from_config_async("/path/to/config.yaml") + + mock_from_yaml.assert_called_once_with( + pathlib.Path("/path/to/config.yaml")) + mock_init.assert_called_once() + assert result is mock_config + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_with_string_path(self, mock_init, mock_from_yaml): + """Test initialize_from_config_async with string path.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + + result = await initialize_from_config_async("/path/to/config.yaml") + + # Should convert string to Path + call_args = mock_from_yaml.call_args[0][0] + assert isinstance(call_args, pathlib.Path) + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.get_default_config_path") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.from_yaml_file") + @mock.patch("pyrit.setup.configuration_loader.ConfigurationLoader.initialize_pyrit_async") + async def test_initialize_from_config_default_path(self, mock_init, mock_from_yaml, mock_default_path): + """Test initialize_from_config_async uses default path when none specified.""" + mock_config = ConfigurationLoader() + mock_from_yaml.return_value = mock_config + mock_default_path.return_value = pathlib.Path( + "/default/path/.pyrit_conf") + + await initialize_from_config_async() + + mock_default_path.assert_called_once() + mock_from_yaml.assert_called_once_with( + pathlib.Path("/default/path/.pyrit_conf"))